前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Paddle 2.1 拟合线性函数

Paddle 2.1 拟合线性函数

原创
作者头像
8菠萝
修改2021-10-29 09:52:22
5550
修改2021-10-29 09:52:22
举报
文章被收录于专栏:菠萝上市没有

背景

最近在用百度的飞桨paddlepaddle

做了一些nlp的研究性的小项目,写点总结。

概念

个人理念里的人工智能,最终是对某种“函数”的拟合,这种函数可能是一维的,二维的,多维的。但这个“函数”不是推导出来的公式,而是一个黑盒子,有点类是图灵机的感觉。通过一系列的输入,训练这个盒子,不断调参修正得到正确的拟合函数。

实践

最简单的线性函数开始, 拟合 y= k*x + b

代码语言:txt
复制
import numpy as np
import paddle
import paddle.nn as  nn
import matplotlib
import matplotlib.pyplot as plt

from paddle.io import DataLoader, Dataset

k = 3.2
b = 10


# 生成测试数据
x = [np.random.randn() for i in range(5000)]
y = [k * i + b for i in x]


# 定义数据集
class LineDataSet(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        data =  paddle.to_tensor(self.x[idx], dtype='float32')
        label = paddle.to_tensor(self.y[idx], dtype='float32')
        return data, label 

    def __len__(self):
        return len(self.x)
    

# 定义网架
class LineNet(nn.Layer):
    def __init__(self):
        super(LineNet, self).__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, image, label=None):
        return self.fc(image)

    def show(self):
        print("w:%.2f, b:%.2f" % (self.fc.weight, self.fc.bias))


net = LineNet()
# 优化器

opt = paddle.optimizer.SGD(learning_rate=1e-3,
                            parameters=net.parameters())


dataset = LineDataSet(x, y)
loader = DataLoader(dataset,
                    shuffle=True,
                    batch_size= 20)


# 训练
for e in range(10):
    for i, (data, label) in enumerate(loader()):
        out = net(data)
        loss = nn.functional.mse_loss(out, label)
        loss.backward()
        opt.step()
        opt.clear_grad()
        print("Epoch {} batch {}: loss = {}".format(
            e, i, np.mean(loss.numpy())))
     

# 评估展示
net.eval()
x = np.array([np.random.randn() * 10 for i in range(5000)])
y = np.array([k * i + b for i in x])
z = np.array([net(paddle.to_tensor(i)).numpy()[0] for i in x])


plt.plot(x, y, color='blue', label="act")
plt.plot(x, z, color='red', label="eval")
plt.legend()
plt.show()

net.show()

结果

Figure_1.png
Figure_1.png

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 背景
  • 概念
  • 实践
  • 结果
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档