1. 引言
最小二乘法是用来做函数拟合或者求函数极值的方法。在机器学习,尤其是回归模型中,经常可以看到最小二乘法的身影。
最小平方法是十九世纪统计学的主题曲。
从许多方面来看, 它之于统计学就相当于十八世纪的微积分之于数学。
---- 乔治·斯蒂格勒的《The History of Statistics》
2. 线性回归
如下图,要对样本点进行线性拟合,求得使预测尽可能准确的函数,这个过程就是线性回归。
上升到两个变量即为如下形式:
多个变量即可以写成:
我们可以使用误差平方和(几何中可视为欧式距离的平方)来度量回归任务的性能。
那么我们如何找到一条直线,使所有样本到直线地距离之和最小呢?这就是最小二乘法(least square method)——基于误差平方和最小化来进行模型求解.
3. 最小二乘法的原理
最小二乘法则是一种统计学习优化技术,它的目标是最小化误差平方之和来作为目标,从而找到最优模型,这个模型可以拟合(fit)训练数据。
回归学习最常用的损失函数是平方损失函数,在此情况下,回归问题可以用著名的最小二乘法来解决。
最小二乘法就是曲线拟合的一种解决方法。
如果误差的分布是正态分布,那么最小二乘法得到的就是最有可能的值。
最小二乘法是由勒让德在19世纪发现的,形式如下:
最小二乘法的目标就是最小化以上公式,f 则是模型(取自假设空间),y 则是观察值。(观测值就是样本值)
通俗来讲,就是样本值和预测值(模型给出)之间的距离平方和最小化作为目标来优化。
4. 求解方法
4.1 矩阵求导方法
思想就是把目标函数划归为矩阵运算问题,然后求导后等于0,从而得到极值。矩阵法比代数法要简洁,下面主要讲解下矩阵法解法,这里用多元线性回归例子来描述:
假设函数:
矩阵表达方式:
损失函数定义为:
1/2要是为了求导后系数为1,方便计算。
根据最小二乘法的原理,我们要对这个损失函数对向量求导取0。结果如下式:
对上述求导等式整理后可得:
4.2 迭代法随机梯度下降
梯度下降是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以)
思路:对参数向量求导,使其梯度为0,然后得到参数变量的迭代更新公式。
具体请参考:
5.局限性
从上面可以看出,最小二乘法的求解方法--矩阵求导简洁高效,但是这里我们就聊聊矩阵求导的局限性。
第一,最小二乘法需要计算逆矩阵,有可能它的逆矩阵不存在,这样就不能用矩阵求导了,此时梯度下降法仍然可以使用;
第二,当样本特征 n 非常的大的时候,计算的逆矩阵是一个非常耗时的工作( n * n 的矩阵求逆),此时以梯度下降为代表的迭代法仍然可以使用。
建议超过10000个特征就用迭代法,或者通过主成分分析降低特征的维度后再用。
第三,如果拟合函数不是线性的,通常用迭代法求解,比如梯度下降。
所以如果把最小二乘法看做是优化问题的话,那么梯度下降是最小二乘求解方法的一种。
6.案例python实现
举例:我们用目标函数
, 加上一个正太分布的噪音干扰,用多项式去拟合:
importnumpyasnp
importscipyassp
fromscipy.optimizeimportleastsq
importmatplotlib.pyplotasplt
%matplotlib inline
# 目标函数,需要拟合的函数func
defreal_func(x):
returnnp.sin(2*np.pi*x)
# 多项式 模型函数
# ps: numpy.poly1d([1,2,3]) 生成 $1x^2+2x^1+3x^0$*
deffit_func(p, x):
f = np.poly1d(p)
returnf(x)
# 自己定义的一个计算误差的函数
defresiduals_func(p, x, y):
ret = fit_func(p, x) - y
returnret
x = np.linspace(,1,10)# 十个点
x_points = np.linspace(,1,1000)# 1000个数 从0-1
y_ = real_func(x)
# 加上正态分布噪音的目标函数的值
y = [np.random.normal(,0.1)+y1fory1iny_]
deffitting(M=):
"""
M 为 多项式的次数
"""
# 随机初始化多项式参数
p_init = np.random.rand(M+1)
# 最小二乘法 python的科学计算包scipy的里面提供了一个函数,可以求出任意的想要拟合的函数的参数
p_lsq = leastsq(residuals_func, p_init, args=(x, y))
print('Fitting Parameters:', p_lsq[])
# 可视化
plt.plot(x_points, real_func(x_points), label='real')
plt.plot(x_points, fit_func(p_lsq[], x_points), label='fitted curve')
plt.plot(x, y,'bo', label='noise')
plt.legend()
returnp_lsq
# M=0
p_lsq_0 = fitting(M=)
# M=1
p_lsq_0 = fitting(M=1)
# M=3
p_lsq_0 = fitting(M=3)
# M=9
p_lsq_0 = fitting(M=9)
当M=9时,多项式曲线通过了每个数据点,但是造成了过拟合。
从可视化的图形可以看出结果显示过拟合, 可以引入正则化项(regularizer),来降低过拟合,具体步骤后续讲详细介绍。
6.参考文献
1. http://www.cnblogs.com/pinard/p/5976811.html
2. 李航. 统计学习方法[M]. 北京:清华大学出版社,2012
领取专属 10元无门槛券
私享最新 技术干货