Sklearn线性回归

news/2024/7/9 8:44:15 标签: sklearn, 线性回归, 机器学习

Scikit-learn 中的线性回归是一个用于监督学习的算法,它用于拟合数据集中的特征和目标变量之间的线性关系。以下是使用 Scikit-learn 实现线性回归的基本步骤:

1. 导入所需库

首先,你需要导入所需的库和模块。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

2. 准备数据

接下来,你需要准备数据集,通常包括特征和目标变量。

# 假设 x 是特征集,y 是目标变量
x = np.array([[1], [2], [3], [4], [5]])
y = np.array([1, 2, 3, 4, 5])

3. 划分训练集和测试集

为了评估模型的性能,通常需要将数据集划分为训练集和测试集。

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)

4. 创建线性回归模型

然后,你需要创建一个线性回归模型实例。

linear_regression = LinearRegression()

5. 训练模型

使用训练集数据训练模型。

linear_regression.fit(x_train, y_train)

6. 预测

使用训练好的模型对测试集进行预测。

y_pred = linear_regression.predict(x_test)

7. 评估模型

评估模型的性能,通常使用均方误差(MSE)作为评估指标。

mse = mean_squared_error(y_test, y_pred)
print(f'Mean Squared Error: {mse}')

8. 可视化

可选步骤,使用散点图可视化实际值和预测值。

plt.scatter(x_test, y_test, color='blue')
plt.plot(x_test, y_pred, color='red')
plt.title('Linear Regression')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

9. 模型持久化(可选)

如果你需要保存训练好的模型,可以使用 joblib 库将其保存到文件,以后可以重新加载。

import joblib
# 保存模型
joblib.dump(linear_regression, 'linear_regression_model.joblib')
# 加载模型
loaded_model = joblib.load('linear_regression_model.joblib')

以上就是使用 Scikit-learn 进行线性回归分析的基本步骤。需要注意的是,线性回归假设特征和目标变量之间存在线性关系,实际应用中需要根据数据特点进行适当的预处理和特征选择。


http://www.niftyadmin.cn/n/5424155.html

相关文章

P1216 [USACO1.5] [IOI1994]数字三角形 Number Triangles

题目链接:P1216 [USACO1.5] [IOI1994]数字三角形 Number Triangles - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 解题思路: 最优路径问题,首先想到dfs深度优先搜索,一直往下走再回溯上一格换个方向走 下面是c代码&#xff1a…

深度强化学习(三)(DQN)

深度强化学习(三)DQN与Q学习 一.DQN 通过神经网络来近似最优动作价值函数 Q ∗ ( a t , s t ) Q_*(a_t,s_t) Q∗​(at​,st​),在实践中, 近似学习“先知” Q ⋆ Q_{\star} Q⋆​ 最有效的办法是深度 Q \mathrm{Q} Q网络 (deep Q network, 缩写 DQN)…

前缀和----指定区间求和

思想:前n项和 一维数据 方法一: 暴力求解,时间复杂度:O(n * m) 代码: 输入一个长度为n的整数序列。接下来再输入m个询问,每个询问输入一对l, r。对于每个询问,输出原序列中从第l个数到第r个…

逆向案例七——中国天气质量参数搜不到加密,以及应对禁止打开开发者工具和反debuger技巧

进入相关城市数据页面,发现不能调试 应对方法,再另一个页面,打开开发者工具,选择取消停靠到单独页面 接着,复制链接在该页面打开。接着会遇到debugger 再debugger处打上断点,一律不在此处暂停。 然后点击继…

蓝桥杯(3.11)

1233. 全球变暖 import java.util.Deque; import java.util.LinkedList; import java.util.Scanner;public class Main{static int n;static final int N 1010;static char[][] g new char[N][N];static boolean[][] st new boolean[N][N];public static boolean bfs(int s…

Tictoc3例子

在tictoc3中,实现了让 tic 和 toc 这两个简单模块之间传递消息,传递十次后结束仿真。 首先来介绍一下程序中用到的两个函数: 1.omnetpp中获取模块名称的函数 virtual const char *getName() const override {return name ? name : "&q…

wait 和 notify方法

目录 1.1 wait()方法 wait 做的事情: wait 结束等待的条件: 1.2 notify()方法 1.3notifyAll方法 1.4wait()和sleep()对比 由于线程之间是抢占式执行的, 因此线程之间执行的先后顺序难以预知. 但是实际开发中有时候我们希望合理的协调多个线程之间的执行先后顺序. 完成这个协调…

【Mysql】事务与索引

目录 MySQL事务 事务的特性 并发事务的问题? 事务隔离级别? MySQL索引 数据结构 索引类型 聚簇索引与非聚簇索引 聚集索引的优点 聚集索引的缺点 非聚集索引的优点 非聚集索引的缺点 非聚集索引一定回表查询吗(覆盖索引)? 覆盖索引 联合索…