Keras使用sklearn中的交叉验证和网格搜索

news/2024/7/9 8:34:28 标签: keras, sklearn, 人工智能

Keras是Python在深度学习领域非常受欢迎的第三方库,但Keras的侧重点是深度学习,而不是所以的机器学习。事实上,Keras力求极简主义,只专注于快速、简单地定义和构建深度学习模型所需要的内容。Python中的scikit-learn是非常受欢迎的机器学习库,它基于Scipy,用于高效的数值计算。scikit-learn是一个功能齐全的通用机器学习库,并提供了许多在开发深度学习过程中非常有帮助的方法。例如scikit-learn提供了很多用于选择模型和对模型调参的方法,这些方法同样适用于深度学习。

Keras提供了一个Wrapper,将Keras的深度学习模型包装成scikit-learn中的分类模型或回归模型,以便于使用scikit-learn中的方法和函数。对于深度学习模型的包装是通过KerasClassifier(分类模型)和KerasRegressor(回归模型)来实现的。KerasClassifier和KerasRegressor类使用参数build_fn,指定用来创建模型的函数的名称。

Keras的一般构建流程:

model = Sequential() # 定义模型
model.add(Dense(units=64, activation='relu', input_dim=100)) # 定义网络结构
#第一层网络:输出尺寸64,输入尺寸100,activation激活函数relu
model.add(Dense(units=10, activation='softmax')) # 定义网络结构
#第二层网络:输出尺寸10,输入是上一层的输出尺寸64,activation激活函数softmax
model.compile(loss='categorical_crossentropy', # 定义loss函数、优化方法、评估标准
              optimizer='sgd',
              metrics=['accuracy'])
#输入训练样本和标签,迭代5次,每次迭代32个数据
model.fit(x_train, y_train, epochs=5, batch_size=32) # 训练模型
loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128) # 评估模型
classes = model.predict(x_test, batch_size=128) # 使用训练好的数据进行预测

参数意义:

keras.layers.Dense(units, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None)

units: 正整数,输出空间维度。
activation: 激活函数。 若不指定,则不使用激活函数 (即,「线性」激活: a(x) = x)。
use_bias: 布尔值,该层是否使用偏置向量。
kernel_initializer: kernel 权值矩阵的初始化器。
bias_initializer: 偏置向量的初始化器。
kernel_regularizer: 运用到 kernel 权值矩阵的正则化函数 。
bias_regularizer: 运用到偏置向的的正则化函数 。
activity_regularizer: 运用到层的输出的正则化函数 。
kernel_constraint: 运用到 kernel 权值矩阵的约束函数 。
bias_constraint: 运用到偏置向量的约束函数。

Keras调用scikit-learn实现交叉验证:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import cross_val_score, KFold
from keras.wrappers.scikit_learn import KerasClassifier


def creat_model():
    # 构建模型
    model = Sequential()
    model.add(Dense(units=12, input_dim=11, activation='relu'))
    model.add(Dense(units=8, activation='relu'))
    model.add(Dense(units=1, activation='sigmoid'))
    # 模型编译
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')

# 删除id列
data.drop('客户编号',axis=1,inplace=True)

X, Y = data.values[:,:-1], data.values[:,-1] 

# Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, epochs=150, batch_size=10, verbose=0)

# 10折交叉验证
kfold = KFold(n_splits=10, shuffle=True, random_state=10)
result = cross_val_score(model, X, Y, cv=kfold)

 Keras调用scikit-learn实现模型调参

在构建深度学习模型时,如何配置一个最优模型一直是进行一个项目的重点。在机器学习中,可以通过算法自动调优这些配置参数,在这里将通过Keras的包装类,借助scikit-learn的网格搜索算法评估神经网络模型的不同配置,并找到最佳评估性能的参数组合。creat_model()函数被定义为具有两个默认值的参数(optimizer和init)的函数,创建模型后,定义要搜索的参数的数值数组,包括优化器(optimizer)、权重初始化方案(init)、epochs和batch_size。

在scikit-learn中的GridSearchCV需要一个字典类型的字段作为需要调整的参数,默认采用3折交叉验证来评估算法,由于4个参数需要进行调参,因此将会产生4✖️3个模型。
Keras调用scikit-learn实现GridSearchCV网格搜索:

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasClassifier


def creat_model(optimizer='adam,init='glorot_uniform'):
    # 构建模型
    model = Sequential()
    model.add(Dense(units=12, input_dim=11,kernel_initializer=init, activation='relu'))
    model.add(Dense(units=8, kernel_initializer=init, activation='relu'))
    model.add(Dense(units=1, kernel_initializer=init, activation='sigmoid'))
    # 模型编译
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return model

# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')

# 删除id列
data.drop('客户编号',axis=1,inplace=True)

X, Y = data.values[:,:-1], data.values[:,-1] 

# Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, verbose=0)

# 构建需要调整的参数
param_gird = {}
param_grid['optimizer'] = ['rmsprop','adam']
param_grid['init'] = ['glorot_uniform', 'normal', 'uniform']
param_gird['epochs'] = [50, 100, 150, 200]
param_gird['batch_size'] = [5, 10, 20]

# 调参
grid = GridSearchCV(estimator=model, param_gird=param_grid)
result = grid.fit(X, Y)

# 输出结果
print('Best: %f using %s' % (result.best_score_, result.best_params_))

关于Epochs和batch_size的解释?

Epochs是神经网络训练过程中的一个重要超参数,定义为向前和向后传播中所有批次的单次训练迭代。简单说,一个Epoch是将所有的数据输入网络完成一次向前计算及反向传播。在训练过程中,数据会被“轮”多少次,即应当完整遍历数据集多少次(一次为一个Epoch)。如果Epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果Epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。所以,选择适当的Epoch数量需要在充分训练和避免过拟合之间找到平衡。
 

假设我们有1000个数据样本,每次我们送入10个数据进行训练(也就是batch_size为10)。那么完成一个Epoch,我们需要进行100次迭代(也就是100次前向传播和100次反向传播)。具体来说,我们需要将所有的数据都送入神经网络进行一次前向传播和反向传播,所以一次Epoch相当于所有数据集/batch size=N次迭代。
 


http://www.niftyadmin.cn/n/5276651.html

相关文章

详解数据科学自动化与机器学习自动化

过去十年里,人工智能(AI)构建自动化发展迅速并取得了多项成就。在关于AI未来的讨论中,您可能会经常听到人们交替使用数据科学自动化与机器学习自动化这两个术语。事实上,这些术语有着不同的定义:如今的自动…

npm run dev 与npm run serve的区别

npm run serve 和 npm run dev 是在开发阶段使用 npm 运行脚本的两种常见命令,它们的区别主要在于脚本的配置和执行方式。 npm run serve:通常与 Vue.js 相关的项目中使用。这个命令是在 package.json 文件中定义的一个脚本命令,用来启动开发…

嵌入式串口输入详细实例

学习目标 掌握串口初始化流程掌握串口输出单个字符掌握串口输出字符串掌握通过串口printf熟练掌握串口开发流程学习内容 需求 串口循环输出内容到PC机。 串口数据发送 添加Usart功能。 首先,选中Firmware,鼠标右键,点击Manage Project Items 接着,将gd32f4xx_usart.c添…

多门店自助点餐+外卖二合一小程序系统源码:自助点餐+外卖配送 带完整搭建教程

互联网的普及和移动支付的便捷,餐饮行业也在经历着数字化转型。小编来给大家介绍一款多门店自助点餐外卖二合一小程序,带完整的搭建教程。 以下是部分代码示例: 系统特色功能一览: 1.多门店管理:支持一个平台管理多个…

UE5 runtime模式下自定义视口大小和位置并跟随分辨率自适应缩放

本文旨在解决因UI问题导致屏幕中心位置不对的问题 处理前的现象:如果四周UI透明度都为1,那么方块的位置就不太对,没在中心 处理后的现象: 解决办法:自定义大小和视口偏移 创建一个基于子系统的类或者蓝图函数库(什么类…

分布式全局ID之雪花算法

系列文章目录 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 雪花算法 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 系列文章目录前言一、什么是雪花算法&#xff1f…

【JavaWeb】Session JSP(学习笔记)

Session 一、Session概述 Session:在一次会话的多次请求间共享数据,将数据保存在服务器端的对象HttpSession中 原理:Session的实现是依赖于Cookie的 二、Session使用 1、获取HttpSession对象 HttpSession session request.getSession(…

WSL移动ubuntu到其他盘的几个问题记录

ubuntu移动到其他盘可以参考WSL Ubuntu子系统迁移到非系统盘 下面问题是我安装时遇到的,不代表所有人都是这个问题。 执行ubuntu.exe config --default-user username时,会创建一个新的ubuntu系统出来,而不是直接链接到我们import进来的ubu…