您的位置:

Python实现混淆矩阵热力图

混淆矩阵是分类模型预测结果的可视化工具,通过混淆矩阵可以更好的评估预测模型的性能。混淆矩阵通常用于衡量二分类模型的预测效果,但是也可以扩展到多分类问题。

在本篇文章中,我们会介绍如何使用Python来实现混淆矩阵的可视化工具——热力图。首先,我们会介绍热力图的基本知识和使用场景,然后详细讲解如何使用Python绘制混淆矩阵热力图。

一、热力图的基本知识和使用场景

热力图是一种用不同颜色来表示数值大小的二维图表。在混淆矩阵中,我们可以使用热力图来表示模型分类的预测结果,其中每个方格的颜色代表该分类模型在对应真实标签与预测结果下的数量。

热力图可以让我们更加直观地观察混淆矩阵中每种预测情况的比例,同时也可以发现模型预测结果的不足之处,进而对模型进行调整和改进。

二、如何使用Python绘制混淆矩阵热力图

1. 混淆矩阵数据的准备

首先,我们需要从分类模型中获取混淆矩阵的各种类别的预测结果和真实标签,然后将其整理成一个二维矩阵。

假设我们的混淆矩阵如下:

[[32, 21, 8],
 [12, 44, 6],
 [2, 3, 32]]

每一行代表真实标签,每一列代表模型的预测结果,在这个矩阵中,行列数都为3。

2. 热力图的绘制

接下来,我们将使用Python库matplotlib和seaborn来绘制混淆矩阵热力图。

首先,我们需要导入需要使用的Python库。

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

然后,我们可以使用seaborn的heatmap函数来绘制热力图。

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

上述代码中,cm为混淆矩阵数据,classes为分类模型的标签类别,normalize为是否进行标准化,title为热力图标题,cmap为热力图的颜色映射。

我们可以通过使用如下代码来生成热力图:

plot_confusion_matrix(cm, classes=['1', '2', '3'])

最终的热力图如下所示:

三、总结

本篇文章我们介绍了混淆矩阵在分类模型中的重要性和使用场景,并详细介绍了使用Python来绘制混淆矩阵热力图的方法。通过热力图,我们可以更加直观地观察分类模型的分类情况,并进一步优化和改进模型。