博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
回归算法比较(线性回归,Ridge回归,Lasso回归)
阅读量:6348 次
发布时间:2019-06-22

本文共 3797 字,大约阅读时间需要 12 分钟。

代码:

1 # -*- coding: utf-8 -*-  2 """  3 Created on Mon Jul 16 09:08:09 2018  4   5 @author: zhen  6 """  7   8 from sklearn.linear_model import LinearRegression, Ridge, Lasso  9 import mglearn 10 from sklearn.model_selection import train_test_split 11 import matplotlib.pyplot as plt 12 import numpy as np 13 # 线性回归 14 x, y = mglearn.datasets.load_extended_boston() 15 x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=0) 16  17 linear_reg = LinearRegression() 18 lr = linear_reg.fit(x_train, y_train) 19  20 print("lr.coef_:{}".format(lr.coef_))  # 斜率 21 print("lr.intercept_:{}".format(lr.intercept_))  # 截距 22  23 print("="*25+"线性回归"+"="*25) 24 print("Training set score:{:.2f}".format(lr.score(x_train, y_train))) 25 print("Rest set score:{:.2f}".format(lr.score(x_test, y_test))) 26  27 """ 28     总结: 29         训练集和测试集上的分数非常接近,这说明可能存在欠耦合。 30         训练集和测试集之间的显著性能差异是过拟合的明显标志。解决方式是使用岭回归! 31 """ 32 print("="*25+"岭回归(默认值1.0)"+"="*25) 33 # 岭回归 34 ridge = Ridge().fit(x_train, y_train) 35  36 print("Training set score:{:.2f}".format(ridge.score(x_train, y_train))) 37 print("Test set score:{:.2f}".format(ridge.score(x_test, y_test))) 38  39 print("="*25+"岭回归(alpha=10)"+"="*25) 40 # 岭回归 41 ridge_10 = Ridge(alpha=10).fit(x_train, y_train) 42  43 print("Training set score:{:.2f}".format(ridge_10.score(x_train, y_train))) 44 print("Test set score:{:.2f}".format(ridge_10.score(x_test, y_test))) 45  46 print("="*25+"岭回归(alpha=0.1)"+"="*25) 47 # 岭回归 48 ridge_01 = Ridge(alpha=0.1).fit(x_train, y_train) 49  50 print("Training set score:{:.2f}".format(ridge_01.score(x_train, y_train))) 51 print("Test set score:{:.2f}".format(ridge_01.score(x_test, y_test))) 52  53  54 # 可视化 55 fig = plt.figure(10) 56 plt.subplots_adjust(wspace =0, hspace =0.6)#调整子图间距 57 ax1 = plt.subplot(2, 1, 1) 58  59 ax2 = plt.subplot(2, 1, 2) 60  61 ax1.plot(ridge_01.coef_, 'v', label="Ridge alpha=0.1") 62 ax1.plot(ridge.coef_, 's', label="Ridge alpha=1") 63 ax1.plot(ridge_10.coef_, '^', label="Ridge alpha=10") 64  65 ax1.plot(lr.coef_, 'o', label="LinearRegression") 66  67  68 ax1.set_ylabel("Cofficient magnitude") 69 ax1.set_ylim(-25,25) 70 ax1.hlines(0, 0, len(lr.coef_)) 71 ax1.legend(ncol=2, loc=(0.1, 1.05)) 72  73 print("="*25+"Lasso回归(默认配置)"+"="*25) 74 lasso = Lasso().fit(x_train, y_train) 75  76 print("Training set score:{:.2f}".format(lasso.score(x_train, y_train))) 77 print("Test set score:{:.2f}".format(lasso.score(x_test, y_test))) 78 print("Number of features used:{}".format(np.sum(lasso.coef_ != 0))) 79  80 print("="*25+"Lasso回归(aplpha=0.01)"+"="*25) 81 lasso_001 = Lasso(alpha=0.01, max_iter=1000).fit(x_train, y_train) 82  83 print("Training set score:{:.2f}".format(lasso_001.score(x_train, y_train))) 84 print("Test set score:{:.2f}".format(lasso_001.score(x_test, y_test))) 85 print("Number of features used:{}".format(np.sum(lasso_001.coef_ != 0))) 86  87  88 print("="*15+"Lasso回归(aplpha=0.0001)太小可能会过拟合"+"="*15) 89 lasso_00001 = Lasso(alpha=0.0001, max_iter=1000).fit(x_train, y_train) 90  91 print("Training set score:{:.2f}".format(lasso_00001.score(x_train, y_train))) 92 print("Test set score:{:.2f}".format(lasso_00001.score(x_test, y_test))) 93 print("Number of features used:{}".format(np.sum(lasso_00001.coef_ != 0))) 94  95  96 # 可视化 97 ax2.plot(ridge_01.coef_, 'o', label="Ridge alpha=0.1") 98 ax2.plot(lasso.coef_, 's', label="lasso alpha=1") 99 ax2.plot(lasso_001.coef_, '^', label="lasso alpha=0.001")100 ax2.plot(lasso_00001.coef_, 'v', label="lasso alpha=0.00001")101 102 ax2.set_ylabel("Cofficient magnitude")103 ax2.set_xlabel("Coefficient index")104 ax2.set_ylim(-25,25)105 ax2.legend(ncol=2, loc=(0.1, 1))

结果:

总结:各回归算法在相同的测试数据中表现差距很多,且算法内的配置参数调整对自身算法的效果影响也是巨大的,

  因此合理挑选合适的算法和配置合适的配置参数是使用算法的关键!

 

转载于:https://www.cnblogs.com/yszd/p/9317720.html

你可能感兴趣的文章
Exchange 2010之接受域
查看>>
ceph环境下 测试磁盘在写入时cache盘的占用情况
查看>>
找出数组中两数之和为指定值的所有整数对
查看>>
基本概念学习(2002)---指令周期
查看>>
Pentaho CDE详细开发使用手册
查看>>
Pylint的安装
查看>>
面向对象的三个基本特征 和 五种设计原则
查看>>
取消详细设计
查看>>
PHP 函数之 array_merge_recursive,相同键合并问题
查看>>
Android 获取不同大小字体的字符串的高宽
查看>>
Spring Boot学习记录(一)--环境搭建
查看>>
Tomcat集群session管理解决方案
查看>>
spring对hibernate的支持详解
查看>>
学习网站
查看>>
Linux下TCP延迟确认(Delayed Ack)机制导致的时延问题分析
查看>>
php命令加入全局配置,让php命令在任何地方可用
查看>>
我的友情链接
查看>>
红帽资深副总裁范吕文谈企业与开源创新
查看>>
用python 发送一个smtp邮件
查看>>
设计模式(1)单一职责原则
查看>>