前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >梯度下降法Python实现

梯度下降法Python实现

原创
作者头像
_咯噔_
修改2020-04-13 11:09:39
9630
修改2020-04-13 11:09:39
举报
文章被收录于专栏:CS学习笔记
梯度下降算法
梯度下降算法

几点说明

给定数据集即样本点

求出拟合的直线,给定模型f(x)=kx+b,k,b为要求的参数

定义损失函数(Loss function),回归问题里常用的是平方损失函数

初始化模型f(x)=x+1,即k,b都为1

步长即学习率alpha

代码如下:

代码语言:txt
复制
import numpy as np
import matplotlib.pyplot as plt

# Size of the points dataset.
m = 20

# Points x-coordinate and dummy value (x0, x1).
X0 = np.ones((m, 1))
X1 = np.arange(1, m+1).reshape(m, 1)
X = np.hstack((X0, X1))

# Points y-coordinate
y = np.array([
    3, 4, 5, 5, 2, 4, 7, 8, 11, 8, 12,
    11, 13, 13, 16, 17, 18, 17, 19, 21
]).reshape(m, 1)

# The Learning Rate alpha.
alpha = 0.01

def plot_graph(theta):
    x = np.linspace(1, 20, 100)
    fx = theta[1, 0] * x + theta[0, 0]
    plt.plot(x, fx)

def error_function(theta, X, y):
    '''Error function'''
    diff = np.dot(X, theta) - y
    return (1./(2*m)) * np.dot(np.transpose(diff), diff)

def gradient_function(theta, X, y):
    '''Gradient function'''
    diff = np.dot(X, theta) - y
    return (1.0/m)* np.dot(np.transpose(X), diff)

def gradient_descent(X, y, alpha):
    '''Perform gradient descent.'''
    theta = np.array([1, 1]).reshape(2, 1)
    last_error = error_function(theta, X, y)[0, 0]
    while True:
        #plot_graph(theta)
        gradient = gradient_function(theta, X, y)
        theta = theta - alpha * gradient
        new_error = error_function(theta, X, y)[0, 0]
        if(np.absolute(last_error-new_error) <= 1e-5):
            break
        last_error = new_error
        #print(gradient)
    return theta

optimal = gradient_descent(X, y, alpha)
print('optimal:', optimal)
print('error function:', error_function(optimal, X, y)[0,0])

x=np.linspace(1,20,100)
fx=optimal[1,0]*x+optimal[0,0]
plt.plot(x,fx)

plt.scatter(np.transpose(X1),np.transpose(y))
plt.xlabel('x')
plt.ylabel('y')
plt.title('Graph')

plt.show()

拟合效果:

myplot.png
myplot.png

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 代码如下:
  • 拟合效果:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档