首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

深度学习入门笔记(2)-线性回归 Linear Regression with autograd

一同前行!

假设我们有一个曲线(或者平面)y=wx+b

我们给定它一个特定的w,和b

w = [2,51]

b = 21.2

即y=2x1+51x2+21.2

目标是通过数据训练使得w和b靠近w =[2,51],b = 21.2,换句话说就是通过训练得到一个平面能够跟实际的平面(y=2x1+51x2+21.2)一致。

-代码实现-

回顾深度学习的套路

准备数据集dataset

构建网络(激活函数activation function)

初始化

训练(epochs,更新权重)

预测

所用的深度学习框架为Mxnet

需要用到的库为 mxnet 中的nd,gluon,gutograd

还有用于图形化显示的matplotlib

from mxnet import nd,gluon,autograd

import matplotlib.pyplot as plt

1.准备数据集

正太分布随机得到300个样本数

每个样本由X1,X2及结果y组成。

形成300个数据集,给Y加了一点噪声(0.01倍的正太分布数据)来模拟真实数据。

2.构建网络(激活函数activation function)

激活函数为参数的输入到输出的关系。

这里是y=wx+b的关系

def net(x): #激活函数

return nd.dot(x,w)+b

3.初始化

初始化真实的参数为

true_w=nd.array([2,51])

true_b=nd.array([21.2])

变化的参数,需要迭代所求的参数初始化如下:

w=nd.random_normal(shape=(2))

b=nd.zeros((1))

params=[w,b]

w初值为随机的数[w1,w2]

b初值为[0].

初值本可以任意由自己定义,对结果都是没有影响的,可能会影响迭代收敛的速度。

4.训练(epochs,更新权重)

训练时,还是利用的梯度下降法

def SGD(params,eta): #梯度下降法

for param in params:

param[:]=param-eta*param.grad

param 为训练中的【w,b】

eta 为训练步长

损失函数定义为均方误差:

训练:

迭代epochs=10次

步长为eta=0.01

其中用到了autograd自动求导函数。

attach_grad()给参数附上梯度,

向系统申请空间

with autograd.record():

记录需要求导的函数

backward()

回传求导

举例如下图z=y*x,y=2*x:

继续

epochs=10

eta=0.01

for param in params:

param.attach_grad() #要求系统申请对应的空间

for e in range(epochs):

for x,y in data_iter:

with autograd.record():

yhat=net(x)

loss=square_loss(yhat,y)

loss.backward()

SGD(params,eta)

#break

plot(xs)

5.预测:

选取部分数据(50个点),以x2为横坐标,Y为纵坐标。

由预估曲线和真实曲线进行可视化对比。

def plot(xs,sample_size=50):

_,fig=plt.subplots()

plotxs=xs[:sample_size,:]

plotxn=xs[:sample_size,1].asnumpy()

#以x2为横坐标

yhatn=net(plotxs).asnumpy()

fig.plot(plotxn,yhatn,'or') #估计曲线

ys=nd.dot(plotxs,true_w)+true_b

fig.plot(plotxn,ys.asnumpy(),'*g') #实际数据曲线

plt.show()

为了更好的说明参数迭代接近目标,分步截图了如下过程。

结果

迭代后的参数如下:

还记得我们设定的真实参数吗,

true_w=nd.array([2,51])

true_b=nd.array([21.2])

迭代后的参数已经趋近了!

至此,一个线性回归算法就算是完成了!

关注公众号有更多惊喜等着你!

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20171217G0EKJH00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券