您的位置:

如何使用GridSearchCV优化模型参数选择

一、GridSearchCV介绍

GridSearchCV是sklearn提供的一种自动化调参工具,能够遍历给定的参数组合,使用交叉验证的方式找出最优的参数组合。此外,GridSearchCV还可以并行处理多组参数,加快搜索速度。GridSearchCV包含以下几个重要参数:

class sklearn.model_selection.GridSearchCV

    param_grid:要优化的参数组合,一个字典或列表,其中字典的键为想要调整的模型参数名(字符串),值为对应的想要调整的参数值的列表。

    scoring:模型评估标准。

    cv:cross-validation,交叉验证生成器或可迭代的迭代器,例如KFold。

    n_jobs:并行运行的作业数。

二、使用GridSearchCV优化模型参数选择的示例

在下面的示例中,我们将介绍如何使用GridSearchCV优化逻辑回归模型的参数选择。

首先,我们加载需要的库和数据集:

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target

接下来,我们将定义一个逻辑回归模型,并将其参数放在param_grid字典中。

lr = LogisticRegression()
param_grid = {"penalty": ['l1', 'l2'], "C": [0.01, 0.1, 1, 10]}

这里需要注意的是,penalty参数表示正则化的方式,‘l1’表示L1正则化,‘l2’表示L2正则化,而C参数表示正则化的强度。

接下来,我们将使用GridSearchCV来寻找最佳参数组合,并对模型进行拟合和预测:

grid_search = GridSearchCV(lr, param_grid=param_grid, cv=5, scoring='accuracy', n_jobs=-1)
grid_search.fit(X, y)
print("Best parameters: ", grid_search.best_params_)
print("Best cross-validation score: {:.2f}".format(grid_search.best_score_))
print("Best estimator:\n{}".format(grid_search.best_estimator_))

最终,我们将得到最佳参数组合、在最佳参数下的交叉验证得分和最佳模型的信息。

三、如何使用GridSearchCV针对多个模型进行参数选择

在实际应用中,我们可能需要针对多个模型进行参数选择。在此情况下,我们可以使用for循环来遍历多个模型并进行参数选择。

下面的示例介绍了如何遍历多个模型并使用GridSearchCV进行参数选择。我们将使用逻辑回归和支持向量机两种模型,对于每个模型,我们将定义不同的参数组合。

from sklearn.svm import SVC

# 定义逻辑回归和支持向量机的参数组合
lr_param_grid = {"penalty": ['l1', 'l2'], "C": [0.01, 0.1, 1, 10]}
svm_param_grid = {"kernel": ['linear', 'rbf', 'poly'], "C": [0.1, 1, 10], "gamma": [0.1, 1, 10]}

# 定义模型列表和对应的参数组合字典列表
models = [
    {
        'name': 'LogisticRegression',
        'model': LogisticRegression(),
        'params': lr_param_grid
    },
    {
        'name': 'SVC',
        'model': SVC(),
        'params': svm_param_grid
    }
]

# 遍历每个模型并进行参数选择
for model in models:
    print(model['name'])
    grid_search = GridSearchCV(model['model'], model['params'], cv=5, scoring="accuracy", n_jobs=-1)
    grid_search.fit(X, y)
    print("Best parameters: ", grid_search.best_params_)
    print("Best cross-validation score: {:.2f}".format(grid_search.best_score_))
    print("Best estimator:\n{}".format(grid_search.best_estimator_))
    print("\n")

最终,我们将对每个模型得到最佳参数组合、在最佳参数下的交叉验证得分和最佳模型的信息。