使用sklearn.linear_model.SGDClassifier增量模型进行学习的记录

news/2024/7/9 11:03:53 标签: sklearn, 机器学习

数据集下载链接是Human Activity Recognition Using Smartphones

 

train、test文件夹中分别包含训练和测试的文件,这里使用train中的数据进行增量学习模型,test中的数据用来测试
首先读取数据:

import numpy as np

from sklearn.linear_model import SGDClassifier

import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体

plt.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题

x_train=[]

f=open(r'./X_train.txt','r')

for i in f.readlines():

    i=i.strip()

    temp=i.split(' ')

    while '' in temp:

        temp.remove('')

    x_train.append(temp)

f.close()

x_train=np.array(x_train)

x_test=[]

f=open(r'./X_test.txt','r')

for i in f.readlines():

    i=i.strip()

    temp=i.split(' ')

    while '' in temp:

        temp.remove('')

    x_test.append(temp)

f.close()

x_test=np.array(x_test)

y_train=[]

f = open(r'./y_train.txt', 'r')

for i in f.readlines():

    i = i.strip()

    y_train.append(i)

f.close()

y_test=[]

f = open(r'./y_test.txt', 'r')

for i in f.readlines():

    i = i.strip()

    y_test.append(i)

f.close()

print(x_train.shape,end=' ')

print(x_test.shape,end=' ')

print(len(y_train),len(y_test),set(y_train+y_test))

输出结果为:

(7352, 561) (2947, 561) 7352 2947 {'4', '5', '3', '1', '2', '6'}

开始增量训练:

x_train=x_train.astype(np.float32)

x_test=x_test.astype(np.float32)

classes=np.unique(y_train+y_test)

interval=100

start=0

sgd_clf = SGDClassifier()

x_axis=[]

y_axis=[]

for i in np.arange(1,(x_train.shape[0]//interval+1),1):

    end=min([i*interval,x_train.shape[0]])

    X=x_train[start:end]

    Y=y_train[start:end]

    sgd_clf.partial_fit(X,Y,classes=classes) #

    start=end

    # print("{} time".format(i))  # 当前次数

    score=sgd_clf.score(x_test, y_test)

    # print("{} score".format(score))  # 在测试集上看效果

    x_axis.append(i)

    y_axis.append(score)

0.3505259586019681 score
0.49338310145911096 score
0.4832032575500509 score
0.497794367153037 score
0.46623685103495077 score
0.4696301323379708 score
...

绘制迭代次数-score图:

plt.figure()

plt.plot(x_axis,y_axis)

plt.xlabel('迭代的次数')

plt.ylabel('score')

plt.tight_layout()

plt.savefig('./score.png',bbox_inches='tight')

plt.show()


参考内容:
使用sklearn进行增量学习
sklearn.linear_model.SGDClassifier — scikit-learn 1.0 documentation


http://www.niftyadmin.cn/n/1325330.html

相关文章

第12届嵌入式蓝桥杯真题-停车场管理系统的设计与实现

目录 实验要求: 实验思路: 核心代码: (1)主函数 (2)lcd显示 (3)按键函数 (4)LED显示函数 (5)业务处理函数 &…

YOLO论文阅读记录

YOLO:You Only Look Once: Unified, Real-Time Object Detection。you only look once,仅仅看一眼就能检测出来结果,说明速度很快而且是单阶段的。 论文的链接:https://arxiv.org/abs/1506.02640 YOLO将目标检测看做是回归问题&am…

YOLOv3论文阅读记录

YOLOv3:An Incremental Improvement 论文的链接:https://arxiv.org/pdf/1804.02767.pdf 这个版本的更新有很多改动尤其是网络方面的,网络变得更大也更准确。320 * 320的图片作为YOLOv3的输入,每张图片的平均运行时间是22ms&#x…

YOLOv2、YOLO9000论文阅读记录

YOLO9000:Better, Faster, Stronger 论文的链接:https://arxiv.org/abs/1612.08242 YOLO9000是基于YOLOv2架构的,文中先介绍YOLOv2,然后再介绍的YOLO9000 YOLO9000是实时的检测系统,能检测约9000类物体。YOLOv2是基于Y…

YOLOv2、v3使用K-means聚类计算anchor boxes的具体方法

k-means需要有数据,中心点个数是需要人为指定的,位置可以随机初始化,但是还需要度量到聚类中心的距离。这里怎么度量这个距离是很关键的。 距离度量如果使用标准的欧氏距离,大盒子会比小盒子产生更多的错误。例 ​。因此这里使用其…

模式识别-期末复习简答题(87个知识点、问题集锦|已完结)

单选题、判断题、简答题、计算题、综合题 ① 课前测的题目 ② 87个知识点 1.什么是模式?监督模式识别和非监督模式识别的典型过程分别是什么? 模式:指需要识别且可测量的对象的描述 监督模式识别:分析问题→原始特征提取→特征提取与选择→分类器…

Windows10安装Linux子系统-不使用虚拟机安装ubuntu16.04

使用Win键打开程序面板,然后进入Windows系统,打开控制面板——点击卸载程序——点击左侧的启用或关闭Windows功能,勾选适用于Linux的Windows子系统,然后点击确定,等待下载完成,然后需要重启电脑。 ​ ​ …

物联网通信技术原理第5章 移动通信技术

目录 5.1 移动通信的基本概念及发展历史 5.1.1 移动通信的基本概念 5.1.2 移动通信的发展历史(理解) 1.第一代移动通信系统(1G) 2.第二代移动通信系统(2G) 3.第三代移动通信系统(3G) 5.1.3 移动通信的发展趋势与展望 5.2 无线传播与移动信道 5.2…