您的位置:

GridSearchCV详解

一、GridSearchCV概述

GridSearchCV是scikit-learn中一个重要的调参工具,用于系统地遍历多个参数组合,通过交叉验证确定最佳参数。在机器学习算法中,各个算法有很多超参数,超参数的优化对算法的性能至关重要。而GridSearchCV,正是通过遍历所有参数组合,找到最优参数从而提高模型在给定数据集上的性能。

二、逐步介绍GridSearchCV

1、如何使用GridSearchCV?

GridSearchCV需要提供一个估计器(estimator)和一组参数(dictionary),用于将参数遍历作为一个估计器的参数组合进行交叉验证。首先需要加载相关库,如下:

  
    from sklearn.model_selection import GridSearchCV
    from sklearn.svm import SVC
  

接下来,加载数据并建立模型:

  
    import numpy as np  
    from sklearn.datasets import load_iris  
    from sklearn.model_selection import train_test_split
    iris = load_iris()  
    X, y = iris.data, iris.target  
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,random_state=0)
    svc = SVC()  
  

构建参数字典:

  
    param_grid = {'C': [0.1, 1, 10, 100, 1000],  
              'gamma': ['scale', 'auto', 0.001, 0.0001], 
              'kernel': ['rbf', 'linear', 'poly', 'sigmoid']} 
  

GridSearchCV的使用:

  
    clf = GridSearchCV(svc, param_grid, cv=5)  
    clf.fit(X_train, y_train)
  

其中,SVC的参数指定为clf,param_grid指定了要搜索的参数组合,cv表示使用的交叉验证的策略。5表示使用5折交叉验证法。

2、参数的选择

超参数的选择是机器学习中非常重要的一步。参数的组合可能非常多,这使得参数的选择非常困难。但是使用GridSearchCV可以一个一个地测试每种超参数的组合,从而找到最优的参数组合。在上述例子中,我们使用了3个参数:C、gamma和kernel。C是SVM中的参数,表示惩罚误分类点的权重,gamma是核函数中的参数,表示向样本中添加一个样本的影响程度。kernel是SVM中的核函数类型。在param_grid中,C的取值范围是[0.1, 1, 10, 100, 1000],gamma的取值是['scale', 'auto', 0.001, 0.0001],kernel的取值是['rbf', 'linear', 'poly', 'sigmoid']。使用GridSearchCV会尝试每种参数组合,最终得出最佳的参数组合。

3、结果分析与输出

GridSearchCV有很多输出信息,其中最重要的是best_params_和best_score_。best_params_是最佳参数的集合,best_score_是最佳参数组合的得分。另外,可以使用cv_results_将所有的参数集合及其得分统计出来。代码如下:

  
      print("The best parameters are %s with a score of %0.2f" % (clf.best_params_, clf.best_score_)) 
      print("The best parameters sets found on development set:")  
      print(clf.best_params_)  
      print("Grid scores on development set:")  
      means = clf.cv_results_['mean_test_score']  
      stds = clf.cv_results_['std_test_score']  
      for mean, std, params in zip(means, stds, clf.cv_results_['params']):  
          print("%0.3f (+/-%0.03f) for %r" % (mean, std * 2, params))  
  

三、注意事项

1、使用更易搜索的参数

当需要使用GridSearchCV时,保证参数越少越好。如果搜索的空间很大,搜索过程就会非常耗时。使用太多参数会使搜索过程变得非常缓慢。确保每个参数慎重再慎重的选择。

2、使用并行计算

sklearn提供了并行计算功能,这使得搜索过程更快。默认情况下,GridSearchCV使用的计算是单进程的,使用n_jobs参数将其更改为多进程的。

例如:

  
    clf = GridSearchCV(svc, param_grid, cv=5, n_jobs=-1)  
  

3、不要期望太高

网格搜索是一项非常强大的技术,但不要期望它能在所有数据集上表现良好。在某些情况下,它的表现可能不如其他优化技术。在实际应用时,最好记录不同模型及其参数的得分。

四、总结

GridSearchCV是机器学习中一个应该掌握的重要调参工具。在这篇文章中,我们了解了GridSearchCV的基本原理,如何使用它找到模型的最佳参数,如何分析最佳参数的结果。