sklearn中一些简单机器学习算法的使用

news/2024/7/9 9:11:00 标签: 机器学习, sklearn, 算法

目录

前言

KNN算法

决策树算法

朴素贝叶斯算法

岭回归算法

线性优化算法


前言

本篇文章会介绍一些sklearn库中简单的机器学习算法如何使用,一些注释已经写在代码中,帮助一些小伙伴入门sklearn库的使用。

注意:本篇文章只涉及到如何使用,并不会讲解原理,如果想了解原理的小伙伴请自行搜索其他技术博客或者查看官方文档。

KNN算法

from sklearn.datasets import load_iris  # 导入莺尾花数据集的模块
from sklearn.model_selection import train_test_split # 导入划分数据集的模块
from sklearn.preprocessing import StandardScaler  # 导入标准化的模块
from sklearn.neighbors import KNeighborsClassifier  # 导入KNN算法的模块
from sklearn.model_selection import GridSearchCV  # 导入网格搜索和交叉验证的模块(判断k取几的时候KNN算法的准确率最高)

iris = load_iris()  # 引入数据集

x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target)  # 进行训练集和测试集的划分

transfer = StandardScaler()  # 标准化操作
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

estimator = KNeighborsClassifier()  # KNN算法
param_dict = {'n_neighbors': [1, 3, 5, 7, 9, 11]}  # 以字典的形式传入
estimator = GridSearchCV(estimator, param_grid=param_dict,cv=10)  # 网格搜索
estimator.fit(x_train, y_train)
y_predict = estimator.predict(x_test)
print(y_predict)
print(y_predict == y_test)
r = estimator.score(x_test, y_test)
print('准确率:', r)
print('最佳参数:', estimator.best_params_)
print('最佳结果:', estimator.best_score_)
print('最佳估计器:', estimator.best_estimator_)
print('交叉验证结果:', estimator.cv_results_)

决策树算法

from sklearn.datasets import load_iris  # 导入莺尾花数据集的模块
from sklearn.model_selection import train_test_split  # 导入划分数据集的模块
from sklearn.tree import DecisionTreeClassifier  # 导入决策树算法的模块
from sklearn import tree  # 导入决策树可视化的模块
import matplotlib.pyplot as plt

iris = load_iris()  # 引入数据集

x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target)  # 进行训练集和测试集的划分

estimator = DecisionTreeClassifier(criterion='entropy')  # 按照信息增益决定特征分别位于树的那层
estimator.fit(x_train, y_train)
y_predict = estimator.predict(x_test)
print(y_predict)
print(y_predict == y_test)
r = estimator.score(x_test, y_test)
print('准确率:', r)

plt.figure(figsize=(10, 10))
tree.plot_tree(estimator, feature_names=iris.feature_names)  # 决策树可视化
plt.show()

 

 朴素贝叶斯算法

# 计算概率,那种的概率大就把它划分为那种
from sklearn.datasets import load_iris  # 导入莺尾花数据集的模块
from sklearn.model_selection import train_test_split  # 导入划分数据集的模块
from sklearn.naive_bayes import MultinomialNB  # 导入朴素贝叶斯算法的模块

iris = load_iris()  # 引入数据集

x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target)  # 进行训练集和测试集的划分

estimator = MultinomialNB()  # 朴素贝叶斯算法
estimator.fit(x_train, y_train)
y_predict = estimator.predict(x_test)
print(y_predict)
print(y_predict == y_test)
r = estimator.score(x_test, y_test)
print('准确率:', r)

 

岭回归算法

# 用岭回归对波士顿房价进行预测
from sklearn.datasets import load_boston  # 导入波士顿房价的模块
from sklearn.model_selection import train_test_split  # 导入数据集划分的模块
from sklearn.preprocessing import StandardScaler  # 导入标准化的模块
from sklearn.linear_model import Ridge  # 导入岭回归算法的模块
from sklearn.metrics import mean_squared_error  # 导入均方误差的模块

boston = load_boston()
print('特征数量:', boston.data.shape)

x_train, x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=22)  # 进行数据集划分,最后一个参数是设定随机数种子

transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

estimator = Ridge()
estimator.fit(x_train, y_train)
y_predict = estimator.predict(x_test)
error = mean_squared_error(y_test, y_predict)
print('岭回归-权重系数(k)为:', estimator.coef_)
print('岭回归-偏置(b)为:', estimator.intercept_)
print('岭回归-均方误差为:', error)

线性优化算法

# 几个特征对应几个权重系数:y=k1x1+k2x2+k3x3+k4x4+.....+knxn+b
# 对波士顿房价进行预测
# 正规方程优化算法和梯度下降优化算法
from sklearn.datasets import load_boston  # 导入波士顿房价的模块
from sklearn.model_selection import train_test_split  # 导入数据集划分的模块
from sklearn.preprocessing import StandardScaler # 导入标准化的模块
from sklearn.linear_model import LinearRegression, SGDRegressor  # 导入正规方程算法和梯度下降算法的模块
from sklearn.metrics import mean_squared_error  # 导入均方误差的模块(判断两个算法那个更优,均方误差越小的算法越优)

boston = load_boston()
print('特征数量:', boston.data.shape)

x_train, x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=22)  # 进行数据集划分,最后一个参数是设定随机数种子

transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)


estimator = LinearRegression()
estimator.fit(x_train, y_train)
y_predict = estimator.predict(x_test)
error=mean_squared_error(y_test, y_predict)
print('正规方程-权重系数(k)为:', estimator.coef_)
print('正规方程-偏置(b)为:', estimator.intercept_)
print('正规方程-均方误差为:', error)


estimator = SGDRegressor()
estimator.fit(x_train, y_train)
y_predict = estimator.predict(x_test)
error = mean_squared_error(y_test, y_predict)
print('梯度下降-权重系数(k)为:', estimator.coef_)
print('梯度下降-偏置(b)为:', estimator.intercept_)
print('梯度下降-均方误差为:', error)


http://www.niftyadmin.cn/n/5372273.html

相关文章

javaEE - 22( 5000 字 Tomcat 和 HTTP 协议入门 -3)

一:Tomcat 1.1 Tomcat 是什么 谈到 “汤姆猫”, 大家可能更多想到的是大名鼎鼎的这个: 事实上, Java 世界中的 “汤姆猫” 完全不是一回事, 但是同样大名鼎鼎. Tomcat 是一个 HTTP 服务器. 前面我们已经学习了 HTTP 协议, 知道了 HTTP 协议就是 HTTP 客户端和…

Backtrader 文档学习- Observers - Reference

Backtrader 文档学习- Observers - Reference 1.Benchmark class backtrader.observers.Benchmark() 观察器存储策略的回报和参考资产的回报,参考资产是传递给系统的数据之一。 参数: timeframe (default: None) ,如果None,则将…

已解决org.springframework.web.HttpMediaTypeNotAcceptableException异常的正确解决方法,亲测有效!!!

已解决org.springframework.web.HttpMediaTypeNotAcceptableException异常的正确解决方法,亲测有效!!! 文章目录 问题分析 报错原因 解决思路 解决方法 总结 问题分析 在Spring MVC应用中处理HTTP请求时,我们有…

Blazor Wasm Google 登录

目录: OpenID 与 OAuth2 基础知识Blazor wasm Google 登录Blazor wasm Gitee 码云登录Blazor SSR/WASM IDS/OIDC 单点登录授权实例1-建立和配置IDS身份验证服务Blazor SSR/WASM IDS/OIDC 单点登录授权实例2-登录信息组件wasmBlazor SSR/WASM IDS/OIDC 单点登录授权实例3-服务端…

itextpdf使用:使用PdfReader添加图片水印

gitee参考代码地址:https://gitee.com/wangtianwen1996/cento-practice/tree/master/src/test/java/com/xiaobai/itextpdf 参考文章:https://www.cnblogs.com/wuxu/p/17371780.html 1、生成带有文字的图片 使用java.awt包的相关类生成带文字的图片&…

数据结构第十一天(栈)

目录 前言 概述 源码: 主函数: 运行结果: ​编辑 前言 今天简单的实现了栈,主要还是指针操作,soeasy! 友友们如果想存储其他内容,只需修改结构体中的内容即可。 哈哈,要是感觉不错&…

Amazon Dynamo学习总结

目录 一、Amazon Dynamo的问世 二、Amazon Dynamo主要技术概要 三、数据划分算法 四、数据复制 五、版本控制 六、故障处理 七、成员和故障检测 一、Amazon Dynamo的问世 Amazon Dynamo是由亚马逊在2007年开发的一种高度可扩展和分布式的键值存储系统,旨在解…

Vue3中路由配置Catch all routes (“*“) must .....问题

Vue3中路由配置Catch all routes (“*”) must …问题 文章目录 Vue3中路由配置Catch all routes ("*") must .....问题1. 业务场景描述1. 加载并添加异步路由场景2. vue2中加载并添加异步路由(OK)3. 转vue3后不好使(Error)1. 代码2. 错误 2. 处理方式1. 修改前2. 修…