自制机器学习工具库源码解释(KNN线性回归)

news/2024/7/9 8:53:50 标签: 机器学习, 线性回归, sklearn, python

欢迎大家提出宝贵意见

    • 源码:
    • 安装方法以及更新命令
      • 安装
      • 更新
    • 测试代码以及截图(部分)

源码:

由于刚开始写,我只写了一部分。

简易KNN

python">def GetKNNSoreByN(X, y, n_neighbors):
    """
    :param X: data 特征值
    :param y: aim 目标值
    :param n_neighbors K值
    :return: score 预测结果
    """
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    transform = StandardScaler()
    X_train = transform.fit_transform(X_train)
    X_test = transform.fit_transform(X_test)
    clf = KNeighborsClassifier(n_neighbors=n_neighbors)
    clf.fit(X_train, y_train)
    y_pre = clf.predict(X_test)
    return sum(y_pre == y_test) / y_pre.shape[0]

解释:这里数据标准化使用的是StandardScaler、对传进来的数据集进行了切割,实际上数据足够多的话没必要进行切割,但这里考虑到数据较少的情况进行了切割。

网格搜索版KNN

python">def GetKNNScoreByGridSearchCV(X, y, param_grid: dict={'n_neighbors': [i for i in range(1,10,1)]}):
    """
    :param X:data 特征值
    :param y:aim 目标值
    :param param_grid: GridSearchCV param 传递给GridSearchCV的参数
    :return: best params and best score for KNN最好的参数表以及最佳准确率
    """
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    transform = StandardScaler()
    X_train = transform.fit_transform(X_train)
    X_test = transform.fit_transform(X_test)
    estimator = KNeighborsClassifier()
    estimator = GridSearchCV(estimator, param_grid=param_grid, cv=5, verbose=0)
    estimator.fit(X_train, y_train)

    return estimator.best_params_, estimator.best_score_

说明:默认的CV为5,verbose=0,这个是经验判断,若有好的建议,可以一起讨论
补充:一般情况K值不超过10,所以默认值给的1到10

正则版线性回归

python">def linear_model_regular(data):
    """
    regular
    :param data: data 数据
    :return: coef 系数列表, intercept 截距, error 均方误差
    """
    x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, random_state=0)
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)
    estimator = LinearRegression()
    estimator.fit(x_train, y_train)
    y_predict = estimator.predict(x_test)
    error = mean_squared_error(y_test, y_predict)  # 均方误差

    return estimator.coef_, estimator.intercept_, error

梯度下降版线性回归

python">def linear_model_gradient(data):
	"""
	gradient descent
    :param data: data 数据
    :return: coef 系数列表, intercept 截距, error 均方误差
    """
    x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, random_state=0)
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)
    estimator = SGDRegressor(max_iter=1000)
    estimator.fit(x_train, y_train)
    y_predict = estimator.predict(x_test)
    error = mean_squared_error(y_test, y_predict)  # 均方误差

    return estimator.coef_, estimator.intercept_, error

防止重复库

python">def setup_module(module):
    """
	Prevent multiple uses of the same library
	"""
    # Check if a random seed exists in the environment, if not create one.
    _random_seed = os.environ.get('SKLEARN_SEED', None)
    if _random_seed is None:
        _random_seed = np.random.uniform() * np.iinfo(np.int32).max
    _random_seed = int(_random_seed)
    print("I: Seeding RNGs with %r" % _random_seed)
    np.random.seed(_random_seed)
    random.seed(_random_seed)

安装方法以及更新命令

安装

python">pip install techlearn

更新

python">pip3 install --upgrade techlearn

测试代码以及截图(部分)

KNN&鸢尾花

python">X, y = datasets.load_iris(return_X_y=True)
param_grid = {'n_neighbors': [1, 3, 5, 7]}
print(GetKNNSoreByN(X, y, 3))
print(GetKNNScoreByGridSearchCV(X, y, param_grid=param_grid))
print(GetKNNScoreByGridSearchCV(X, y))

在这里插入图片描述

线性回归&波士顿房价预测

python">    data = load_boston()
    coef_, intercept_, error = linear_model_gradient(data=data)
    print("模型中的系数为:\n", coef_)
    print("模型中的偏置为:\n", intercept_)
    print("误差为:\n", error)

在这里插入图片描述


http://www.niftyadmin.cn/n/838546.html

相关文章

快速构建Windows 8风格应用14-ShareContract概述及原理

本篇博文主要介绍Share Contract概述、Share Contract实现原理、实现Share Contract意义。 Share Contract概述 我们都知道Windows 8中包含3类不同的Contract:Search Contract、Share Contract、Setting Contract。这三种Application Contract为整合Windows 8体验提…

数据训练的时候出错:ValueError: Unknown label type: ‘continuous‘

解决方法: 使用.astype(‘int’) 将label转换为int型 建议:了解一下onehot

ORACLE 深入解析10053事件

新年新说:新年伊始,2012年过去了,我们又踏上了2013年的,回顾2012我们付出了很多,辛勤和汗水换来了知识和友谊,当我们技术成长的时候我才发现长路漫漫,唯心可敬。一份耕耘一份收获,走…

win mysql5.7安装_Win下Mysql5.7安装详解

安装mysql-5.7.10-win32.msi(解压)至安装目录(本人为C:\Program Files\MySQL\MySQL Server 5.7),新建my.ini,复制下面内容到my.ini文件中(注意:修改basedir,datadir为你的安装目录)# For advice on how to change settings please…

MariaDB 10审计日志去除记录select操作

默认情况下,记录select操作完全是蛋疼的功能,记录一些没有必要的操作。 去除以后,只会记录增删改、DDL操作。 安装使用: 把附件里的server_audit.so放入到/usr/local/mysql/lib/plugin/ 并chown mysql.mysql server_audit.so 123&…

自制库更新

欢迎大家提出宝贵意见更新代码Bunch转化Dataframe逻辑回归更新代码 Bunch转化Dataframe def TransformBunchToDataFrame(data):""":param data: sklearn‘s Bunch SKlearn的数据集格式:return: DataFrame pandas常用数据格式"""data, data[tar…

kettle mysql 时间_kettle中通过时间戳(timestamp)方式来实现数据库的增量同步_MySQL...

bitsCN.com这个实验主要思想是在创建数据库表的时候,通过增加一个额外的字段,也就是时间戳字段,例如在同步表 tt1 和表 tt2 的时候,通过检查那个表是最新更新的,那个表就作为新表,而另外的表最为旧表被新表…

GlusterFS 存储结构原理介绍

一、分布式文件系统 分布式文件系统(Distributed File System)是指文件系统管理的物理存储资源并不直接与本地节点相连,而是分布于计算网络中的一个或者多个节点的计算机上。目前意义上的分布式文件系统大多都是由多个节点计算机构成&#xf…