一、GridSearchCV概述
GridSearchCV是scikit-learn中一个重要的调参工具,用于系统地遍历多个参数组合,通过交叉验证确定最佳参数。在机器学习算法中,各个算法有很多超参数,超参数的优化对算法的性能至关重要。而GridSearchCV,正是通过遍历所有参数组合,找到最优参数从而提高模型在给定数据集上的性能。
二、逐步介绍GridSearchCV
1、如何使用GridSearchCV?
GridSearchCV需要提供一个估计器(estimator)和一组参数(dictionary),用于将参数遍历作为一个估计器的参数组合进行交叉验证。首先需要加载相关库,如下:
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
接下来,加载数据并建立模型:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,random_state=0)
svc = SVC()
构建参数字典:
param_grid = {'C': [0.1, 1, 10, 100, 1000],
'gamma': ['scale', 'auto', 0.001, 0.0001],
'kernel': ['rbf', 'linear', 'poly', 'sigmoid']}
GridSearchCV的使用:
clf = GridSearchCV(svc, param_grid, cv=5)
clf.fit(X_train, y_train)
其中,SVC的参数指定为clf,param_grid指定了要搜索的参数组合,cv表示使用的交叉验证的策略。5表示使用5折交叉验证法。
2、参数的选择
超参数的选择是机器学习中非常重要的一步。参数的组合可能非常多,这使得参数的选择非常困难。但是使用GridSearchCV可以一个一个地测试每种超参数的组合,从而找到最优的参数组合。在上述例子中,我们使用了3个参数:C、gamma和kernel。C是SVM中的参数,表示惩罚误分类点的权重,gamma是核函数中的参数,表示向样本中添加一个样本的影响程度。kernel是SVM中的核函数类型。在param_grid中,C的取值范围是[0.1, 1, 10, 100, 1000],gamma的取值是['scale', 'auto', 0.001, 0.0001],kernel的取值是['rbf', 'linear', 'poly', 'sigmoid']。使用GridSearchCV会尝试每种参数组合,最终得出最佳的参数组合。
3、结果分析与输出
GridSearchCV有很多输出信息,其中最重要的是best_params_和best_score_。best_params_是最佳参数的集合,best_score_是最佳参数组合的得分。另外,可以使用cv_results_将所有的参数集合及其得分统计出来。代码如下:
print("The best parameters are %s with a score of %0.2f" % (clf.best_params_, clf.best_score_))
print("The best parameters sets found on development set:")
print(clf.best_params_)
print("Grid scores on development set:")
means = clf.cv_results_['mean_test_score']
stds = clf.cv_results_['std_test_score']
for mean, std, params in zip(means, stds, clf.cv_results_['params']):
print("%0.3f (+/-%0.03f) for %r" % (mean, std * 2, params))
三、注意事项
1、使用更易搜索的参数
当需要使用GridSearchCV时,保证参数越少越好。如果搜索的空间很大,搜索过程就会非常耗时。使用太多参数会使搜索过程变得非常缓慢。确保每个参数慎重再慎重的选择。
2、使用并行计算
sklearn提供了并行计算功能,这使得搜索过程更快。默认情况下,GridSearchCV使用的计算是单进程的,使用n_jobs参数将其更改为多进程的。
例如:
clf = GridSearchCV(svc, param_grid, cv=5, n_jobs=-1)
3、不要期望太高
网格搜索是一项非常强大的技术,但不要期望它能在所有数据集上表现良好。在某些情况下,它的表现可能不如其他优化技术。在实际应用时,最好记录不同模型及其参数的得分。
四、总结
GridSearchCV是机器学习中一个应该掌握的重要调参工具。在这篇文章中,我们了解了GridSearchCV的基本原理,如何使用它找到模型的最佳参数,如何分析最佳参数的结果。