iis服务器助手广告广告
返回顶部
首页 > 资讯 > 精选 >pytorch分类模型绘制混淆矩阵及可视化的方法
  • 915
分享到

pytorch分类模型绘制混淆矩阵及可视化的方法

2023-06-29 22:06:33 915人浏览 泡泡鱼
摘要

本文小编为大家详细介绍“PyTorch分类模型绘制混淆矩阵及可视化的方法”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch分类模型绘制混淆矩阵及可视化的方法”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧

本文小编为大家详细介绍“PyTorch分类模型绘制混淆矩阵及可视化的方法”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch分类模型绘制混淆矩阵及可视化的方法”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds) # 使用torch.no_grad()可以显著降低测试用例的GPU占用    with torch.no_grad():        for step, (imgs, targets) in enumerate(test_loader):            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉            targets = targets.squeeze()  # [50,1] ----->  [50]            # 将变量转为gpu            targets = targets.cuda()            imgs = imgs.cuda()            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())                        out = model(imgs)            #记录混淆矩阵参数            conf_matrix = confusion_matrix(out, targets, conf_matrix)            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):    preds = torch.argmax(preds, 1)    for p, t in zip(preds, labels):        conf_matrix[p, t] += 1    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到npcorrects=conf_matrix.diaGonal(offset=0)#抽取对角线的每种分类的识别正确个数per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".fORMat(int(np.sum(conf_matrix)),test_num)) print(conf_matrix) # 获取每种Emotion的识别准确率 print("每种情感总个数:",per_kinds) print("每种情感预测正确的个数:",corrects) print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵及可视化的方法

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵Emotion=8#这个数值是具体的分类数,大家可以自行修改labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签# 显示数据plt.imshow(conf_matrix, cmap=plt.cm.Blues)# 在图中标注数量/概率信息thresh = conf_matrix.max() / 2#数值颜色阈值,如果数值超过这个,就颜色加深。for x in range(Emotion_kinds):    for y in range(Emotion_kinds):        # 注意这里的matrix[y, x]不是matrix[x, y]        info = int(conf_matrix[y, x])        plt.text(x, y, info,                 verticalalignment='center',                 horizontalalignment='center',                 color="white" if info > thresh else "black")                 plt.tight_layout()#保证图不重叠plt.yticks(range(Emotion_kinds), labels)plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°plt.show()plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵及可视化的方法

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵及可视化的方法

读到这里,这篇“pytorch分类模型绘制混淆矩阵及可视化的方法”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注编程网精选频道。

--结束END--

本文标题: pytorch分类模型绘制混淆矩阵及可视化的方法

本文链接: https://www.lsjlt.com/news/326503.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • pytorch分类模型绘制混淆矩阵及可视化的方法
    本文小编为大家详细介绍“pytorch分类模型绘制混淆矩阵及可视化的方法”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch分类模型绘制混淆矩阵及可视化的方法”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧...
    99+
    2023-06-29
  • pytorch分类模型绘制混淆矩阵以及可视化详解
    目录Step 1. 获取混淆矩阵Step 2. 混淆矩阵可视化其它分类指标的获取总结Step 1. 获取混淆矩阵 #首先定义一个 分类数*分类数 的空混淆矩阵 conf_matri...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作