一文带你搞懂sklearn.metrics混淆矩阵

news/2024/7/9 9:08:02 标签: sklearn, python, 二分类, 混淆矩阵

一般的二分类任务需要的评价指标有4个

  • accuracy
  • precision
  • recall
  • f1-score

四个指标的计算公式如下

accuracy = \frac{TP+TN}{TP+TN+FP+FN}

precision=\frac{TP}{TP+FP}

recall=\frac{TP}{TP+FN}

 F_1-score=\frac{2*precision*recall}{precision+recall}

计算这些指标要涉及到下面这四个概念,而它们又构成了混淆矩阵

  • TP (True Positive)
  • FP (False Positive)
  • TN (True Negative)
  • FN (False Negative)
混淆矩阵实际值
01
预测值0TNFP
1FNTP

这里我给出的混淆矩阵是按照sklearn-metrics-confusion_matrix的形式绘制的。

Negative中文译作阴性,一般指标签0;Positive中文译作阳性,一般指标签1。

True中文译作预测正确;False中文译作预测错误。

TN 预测正确(True)并且实际为阴性(Negative)即实际值和预测值均为Negative

TP 预测正确(True)并且实际为阳性(Positive)即实际值和预测值均为Positive

FN 预测错误(False)并且实际为阴性(Negative)即实际值为Negative,预测值为Positive

FP 预测错误(False)并且实际为阳性(Positive)即实际值为Positive,预测值为Negative

下面以实际代码为例进行介绍

python">from sklearn import metrics
print(metrics.confusion_matrix(y_true=[0, 0, 0, 1, 1, 1],
    y_pred=[1, 1, 1, 0, 1, 0]))

这里的y_true是实际值,y_pred是预测值,可以观察到

TN=0,没有样本实际值和预测值同时为0

TP=1,只有第5个样本实际值和预测值均为1

FN=3,第1,2,3个样本实际值为0且预测值为1

FP=2,第4,6个样本实际值为1且预测值为0

输出结果也和我们观察的一致

[[0 3]
 [2 1]]

编写函数根据混淆矩阵计算 accuracy, precision, recall, f1-score

python">def cal(array):
    tp = array[1][1]
    tn = array[0][0]
    fp = array[0][1]
    fn = array[1][0]
    a = (tp+tn)/(tp+tn+fp+fn)
    p = tp/(tp+fp)
    r = tp/(tp+fn)
    f = 2*p*r/(p+r)
    print(a,p,r,f)

使用编写的函数cal计算该混淆矩阵的四项指标,并与metric自带的分类报告(classification_report)函数的结果进行比较,这里第三个参数digits=4表示保留4位小数

python">cal([[0, 3],[2, 1]])
print(metrics.classification_report(y_true=[0, 0, 0, 1, 1, 1], y_pred=[1, 1, 1, 0, 1, 0], digits=4))

运行结果如下,可以发现两者的计算结果一致。

0.16666666666666666 0.25 0.3333333333333333 0.28571428571428575
              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         3
           1     0.2500    0.3333    0.2857         3

    accuracy                         0.1667         6
   macro avg     0.1250    0.1667    0.1429         6
weighted avg     0.1250    0.1667    0.1429         6

这里需要补充说明一下,为什么0那一行和1那一行都有precision, recall, f1-score。

一般来说,我们通常计算的这三项指标均是把1视为阳性,把0视为阴性,以1作为研究对象。所以1那一行的三项指标的值和cal函数计算的结果一致。而0那一行表示把0作为研究对象。


http://www.niftyadmin.cn/n/1371064.html

相关文章

完整mes代码(含客户端和server端_体系解读罗克韦尔MES平台FTPC-跟我入门MES/MOM系列特别篇...

写在面前前面我们介绍了西门子、罗克韦尔、施耐德、达索等巨头的MES/MOM平台:最全解读西门子MES/MOM平台Opcenter,100多亿美金的数字化之路Wonderware MES—施耐德MES/MOM平台解读关于罗克韦尔MES FTPC这个系列,今天是第三次了,先…

MongoDb 查看用户名列表 , 修改用户密码

修改用户密码:db.addUser("java","java");查看一下所有的用户 , 查看当前Db的用户名db.system.users.find();

python卡方CHI特征检验提取关键文本特征

理论 类别非类别包含单词的文档数AB不包含单词的文档数CD 卡方特征提取主要度量类别 和 单词之间的依赖关系。计算公式如下 其中N是文档总数,A是包含单词且属于的文档数,B是包含单词但不属的文档数,C是不包含单词但属于的文档数,…

joy to key 下载_索尼SIE公布2019年9月PS4北美欧洲游戏下载榜

索尼SIE公布PlayStation Store 北美/ 欧洲地区9月游戏下载排行榜:其中,NBA 2K20北美位居第一,而欧洲则是EA的FIFA 20高居首位。另外从排行榜中我们还看到了Fortnite堡垒之夜的身影,据报道称Fortnite堡垒之夜的付费收入规模相当可观…

MySQL 使用 MySQLDump 复制数据库

1.导出整个数据库mysqldump -u 用户名 -p 数据库名 > 导出的文件名 mysqldump -u wcnc -p smgp_apps_wcnc > wcnc.sql2.导出一个表mysqldump -u 用户名 -p 数据库名 表名> 导出的文件名mysqldump -u wcnc -p smgp_apps_wcnc users> wcnc_users.sql3.导出一个数据…

python将一个列表平均分为N份

输入列表list和切分后每个子列表的大小sub_list_size def split_list_to_nlist(list, sub_list_size):num 0tmp []nlist []for i in range(len(list)):if num sub_list_size:nlist.append(tmp)tmp []num 0tmp.append(list[i])num 1if tmp ! []:nlist.append(tmp)return…

蓝宝石rx470d原版bios_狼神矿卡烤机89°C!强刷蓝宝石RX570超白金显卡BIOS降温75°教程...

矿卡烤机89C劝退警告!笔者今天又收到一堆显卡,其中竟然有我最喜欢的狼神A1服务器标配的“3风扇信仰灯”——“大狼”Radeon RX570显卡!这张卡我在今年2月份就买过一款(不过当时是RX470 4GB版的)。笔者非常喜欢这张显卡,因为温度很…

[二分类模板]python对若干数据集重复10次实验取平均结果

这里以xgboost为例 from time import time import xgboost as xgb import utils def main():t time()projects [xxx1, xxx2, xxx3, xxx4,xxx5, xxx6, xxx7, xxx8, xxx9, xxx10]AA, PP, RR, FF 0, 0, 0, 0res {}repeat_times 10for project in projects:train_y, train_x,…