sklearn-6算法链与管道

news/2024/7/9 9:11:00 标签: sklearn, 人工智能, python

思想类似于pipeline,将多个处理步骤连接起来。

看个例子,如果用MinMaxScaler和训练模型,需要反复执行fit和tranform方法,很繁琐,然后还要网格搜索,交叉验证

1 预处理进行参数选择

对于放缩的数据,一定要最先开始放缩,否则会导致数据放缩不一致,比如SVM+网格交叉,网格需要放缩数据,数据放缩需要带上测试集,否则性能下降,准确率打折扣

2 构造管道

注意 管道每次会调用scaler的fit方法,注意,可以对同一个scaler调多次fit,但不可以用两个或多个scaler单独放缩数据!!!

只需要几行代码,很简便

python">    def test_chain_basic(self):
        xtr, xte, ytr, yte = train_test_split(self.cancer.data, self.cancer.target, random_state=0)
        pipe = Pipeline([('scaler', MinMaxScaler()), ('svm', SVC())]).fit(xtr, ytr)
        print(f'predict score: {pipe.score(xte, yte)}')

注意,给pipeline传参是个列表,列表项是长度为2的元组,元组第一个是字串,自定义,类似于一个名字,元组第二个参数是模型对象

3 网格搜索中使用管道

用法

1类似于上一节的scaler+监督模型的用法

2有个注意点是网格搜索需要给训练的模型传参,需要改下给grid对象传参字典的键名

3注意网格搜索是pipe作为参数传给GridSearchCV,二补数把pipe作为参数传给pipe

例子

python">    def test_chain_scale_train_grid(self):
        xtr, xte, ytr, yte = train_test_split(self.cancer.data, self.cancer.target, random_state=0)
        pipe = Pipeline([('scaler', MinMaxScaler()), ('svm', SVC())]).fit(xtr, ytr)
        print(f'predict score: {pipe.score(xte, yte)}')
        params_grid = {'svm__C': [0.001, 0.01, 0.1, 1, 10, 100],
                       'svm__gamma': [0.001, 0.01, 0.1, 1, 10, 100]}
        grid = GridSearchCV(pipe, params_grid, cv=5).fit(xtr, ytr).fit(xtr, ytr)
        print(f'best cross-validation accuracy: {grid.best_score_}')
        print(f'test score: {grid.score(xte, yte)}')
        print(f'best params: {grid.best_params_}')

4 通用管道接口

pipeline还支持特征提取和特征选择,pipeline可以和估计器连接在一起,还可以和缩放和分类器连接在一起

对估计器的要求是需要有transform方法

调pipeline.fit的过程中,会依次调用每个对象的fit和transform方法,对于pipeline最后一个对象只调fit方法不调tranform方法

调pipeline.predict流程是先调每个估计器的transform方法最后调分类器的predict方法

4.1 用make_pipeline创建管道

sklearn.pipeline.Pipeline初始化创建pipe对象比较繁琐,因为输入了每个步骤自定义的名称,有一种更简洁的方法,即调用make_pipeline方法创建pipe。这两种方法创建的pipe功能完全相同,但make_pipeline创建的对象的每个步骤命名是自动的

python">    def test_make_pipeline(self):
        pipe_long = Pipeline([('scaler', MinMaxScaler()), ('svm', SVC(C=100))])
        pipe_short = make_pipeline(MinMaxScaler(), SVC(C=100))
        print(f'show pipe step name via make_pipeline: {pipe_short.steps}')

其实有时如果需要自定义每个名称,用Pipeline初始化方法也可以

4.2 访问步骤属性

pipe还可以访问串联对象中某个对象的属性:

python">    def test_show_pipe_step_attrs(self):
        pipe = make_pipeline(StandardScaler(), PCA(n_components=2), StandardScaler()).fit(self.cancer.data, self.cancer.target)
        print(f'show pipe PCA main component shape: {pipe.named_steps["pca"].components_.shape}')

4.3 访问网格搜索管道中属性

任务 访问网格搜索的pipe的某个对象或属性

python">    def test_show_pipe_step_attrs(self):
        pipe = make_pipeline(StandardScaler(), PCA(n_components=2), StandardScaler()).fit(self.cancer.data, self.cancer.target)
        print(f'show pipe PCA main component shape: {pipe.named_steps["pca"].components_.shape}')
        xtr, xte, ytr, yte = train_test_split(self.cancer.data, self.cancer.target, random_state=0)
        pipe = make_pipeline(StandardScaler(), LogisticRegression())
        params_pipe = {'logisticregression__C': [0.01, 0.1, 1, 10, 100]}
        grid = GridSearchCV(pipe, params_pipe, cv=5).fit(xtr, ytr)
        print(f'best estimators: {grid.best_estimator_}')
        print(f'logistic regression best estimator: {grid.best_estimator_.named_steps["logisticregression"]}')
        print(f'best model coef: {grid.best_estimator_.named_steps["logisticregression"].coef_}')

5 网格搜索预处理于模型参数(综合应用)

python">    def test_chain_comprehensive(self):
        xtr, xte, ytr, yte = train_test_split(*self.boston, random_state=0)
        pipe = make_pipeline(StandardScaler(), PolynomialFeatures(), Ridge())
        params_grid = {"polynomialfeatures__degree": [1, 2, 3], "ridge__alpha": [0.001, 0.01, 0.1, 1, 10, 100]}
        grid = GridSearchCV(pipe, params_grid, cv=5, n_jobs=-1).fit(xtr, ytr)
        print(f'grid best params: {grid.best_params_}')
        print(f'grid best scores: {grid.score(xte, yte)}')

        # normal grid
        params_grida = {'ridge__alpha': [0.001, 0.01, 0.1, 1, 10, 100]}
        pipea = make_pipeline(StandardScaler(), Ridge())
        grida = GridSearchCV(pipea, params_grida, cv=5).fit(xtr, ytr)
        print(f'normal ridge without polynomial features scores: {grida.score(xte, yte)}')

        plot.matshow(grid.cv_results_['mean_test_score'].reshape(3, -1), vmin=0, cmap='viridis')
        plot.xlabel('ridge__alpha')
        plot.ylabel('polynomialfeatures__degree')
        plot.xticks(range(len(params_grid['ridge__alpha'])), params_grid['ridge__alpha'])
        plot.yticks(range(len(params_grid['polynomialfeatures__degree'])), params_grid['polynomialfeatures__degree'])
        plot.colorbar()
        plot.show()

6 网格搜索使用哪个模型

参考非网格空间,给不同模型设置网格参数,然后一次grid可以测很多模型,最后给出最高分


http://www.niftyadmin.cn/n/5116687.html

相关文章

英语——分享篇——每日200词——3201-3400

3201——air-conditioning——[eərkəndɪʃnɪŋ]——n.空调设备;vt.给…装上空调——air-conditioning——air-condition空调(熟词)ing鹰(谐音)——空调设备的噪音让鹰不得安宁——The trains dont even have proper air-conditioning, grumbles Mr So. ——地铁…

点云从入门到精通技术详解100篇-基于点云数据的机器人动态分拣(续)

目录 点云数据处理 3.1 点云数据采集 3.2 点云数据实时聚类 3.2.1 帧内聚类 3.2.2 帧间聚类

Django: 自动清理 PostgreSQL 数据

1. 写在前面 在实际项目开发过程中,有时需要考虑数据库或表大小,以避免如:日志记录等数据大量填充,导致数据库臃肿。本文以 PostgreSQL 数据库为例,简单演示在 Django 中如何监控数据库大小及自动清理数据&#xff1b…

计算机网络——理论知识总结(上)

开新番,因为博主备考的学校计网只考察1/6的分值,而且定位偏向于送分题,因此在备考时并没有很高强度的复习。本帖基于王道考研的教辅总结归纳,虽然是408的教材,但忽略其中有难度的部分,如计算题、画图题等&a…

FreeRTOS基础(如何学好FreeRTOS?)

目录 基础知识 进阶内容 后期“摆烂” 基础知识 实时操作系统 (RTOS):FreeRTOS是一个实时操作系统,它提供了任务管理、调度和同步等功能,在嵌入式系统中有效地管理多个任务。 任务(Task):任务是在RTOS…

html web前端,登录,post请求提交 json带参

html web前端&#xff0c;登录&#xff0c;post请求提交 json带参 3ca9855b3fd279fa17d46f01dc652030.jpg <!DOCTYPE html> <html><head><meta http-equiv"Content-Type" content"text/html; charsetutf-8" /><title></t…

YOLOv5独家最新改进《新颖高效AsDDet检测头》VisDrone数据集mAP涨点1.4%,即插即用|检测头新颖改进,性能高效涨点

💡本篇内容:YOLOv5独家最新改进《新颖高效AsDDet检测头》VisDrone数据集mAP涨点1.4%,即插即用|检测头新颖改进,性能高效涨点 💡🚀🚀🚀本博客 YOLO系列 + 全新新颖原创高效AsDDet检测头 改进创新点改进源代码改进 适用于 YOLOv5 按步骤操作运行改进后的代码即可…

注解方式对常见参数进行校验 java

概述 在进行接口请求时,需要对入参进行校验,如下 import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.RequestBody;public void test(@RequestBody @Validated Param param){// ... }这时候便需要使用下面的这些…