一、GridSearchCV介绍
GridSearchCV是sklearn提供的一种自动化调参工具,能够遍历给定的参数组合,使用交叉验证的方式找出最优的参数组合。此外,GridSearchCV还可以并行处理多组参数,加快搜索速度。GridSearchCV包含以下几个重要参数:
class sklearn.model_selection.GridSearchCV
param_grid:要优化的参数组合,一个字典或列表,其中字典的键为想要调整的模型参数名(字符串),值为对应的想要调整的参数值的列表。
scoring:模型评估标准。
cv:cross-validation,交叉验证生成器或可迭代的迭代器,例如KFold。
n_jobs:并行运行的作业数。
二、使用GridSearchCV优化模型参数选择的示例
在下面的示例中,我们将介绍如何使用GridSearchCV优化逻辑回归模型的参数选择。
首先,我们加载需要的库和数据集:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
接下来,我们将定义一个逻辑回归模型,并将其参数放在param_grid字典中。
lr = LogisticRegression()
param_grid = {"penalty": ['l1', 'l2'], "C": [0.01, 0.1, 1, 10]}
这里需要注意的是,penalty参数表示正则化的方式,‘l1’表示L1正则化,‘l2’表示L2正则化,而C参数表示正则化的强度。
接下来,我们将使用GridSearchCV来寻找最佳参数组合,并对模型进行拟合和预测:
grid_search = GridSearchCV(lr, param_grid=param_grid, cv=5, scoring='accuracy', n_jobs=-1)
grid_search.fit(X, y)
print("Best parameters: ", grid_search.best_params_)
print("Best cross-validation score: {:.2f}".format(grid_search.best_score_))
print("Best estimator:\n{}".format(grid_search.best_estimator_))
最终,我们将得到最佳参数组合、在最佳参数下的交叉验证得分和最佳模型的信息。
三、如何使用GridSearchCV针对多个模型进行参数选择
在实际应用中,我们可能需要针对多个模型进行参数选择。在此情况下,我们可以使用for循环来遍历多个模型并进行参数选择。
下面的示例介绍了如何遍历多个模型并使用GridSearchCV进行参数选择。我们将使用逻辑回归和支持向量机两种模型,对于每个模型,我们将定义不同的参数组合。
from sklearn.svm import SVC
# 定义逻辑回归和支持向量机的参数组合
lr_param_grid = {"penalty": ['l1', 'l2'], "C": [0.01, 0.1, 1, 10]}
svm_param_grid = {"kernel": ['linear', 'rbf', 'poly'], "C": [0.1, 1, 10], "gamma": [0.1, 1, 10]}
# 定义模型列表和对应的参数组合字典列表
models = [
{
'name': 'LogisticRegression',
'model': LogisticRegression(),
'params': lr_param_grid
},
{
'name': 'SVC',
'model': SVC(),
'params': svm_param_grid
}
]
# 遍历每个模型并进行参数选择
for model in models:
print(model['name'])
grid_search = GridSearchCV(model['model'], model['params'], cv=5, scoring="accuracy", n_jobs=-1)
grid_search.fit(X, y)
print("Best parameters: ", grid_search.best_params_)
print("Best cross-validation score: {:.2f}".format(grid_search.best_score_))
print("Best estimator:\n{}".format(grid_search.best_estimator_))
print("\n")
最终,我们将对每个模型得到最佳参数组合、在最佳参数下的交叉验证得分和最佳模型的信息。