前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >pytorch中的线性回归

pytorch中的线性回归

作者头像
GeekLiHua
发布2025-01-21 14:23:15
发布2025-01-21 14:23:15
4100
代码可运行
举报
文章被收录于专栏:JavaJava
运行总次数:0
代码可运行

pytorch中的线性回归

简介:

线性回归是一种基本的机器学习模型,用于建立输入特征与连续输出之间的关系。它假设输入特征与输出之间的关系是线性的,并且尝试找到最佳的线性拟合,以最小化预测值与真实值之间的差距。

线性回归原理

在线性回归中,我们假设输入特征

X

与输出

Y

之间的关系可以表示为:

Y = WX + b

其中,

W

是特征的权重(系数),

b

是偏置项,用于调整输出值。我们的目标是找到最佳的

W

b

,使得预测值

\hat{Y}

与真实值

Y

之间的误差最小化。通常使用最小化均方误差(Mean Squared Error,MSE)来衡量预测值与真实值之间的差距。

实现线性回归

在 PyTorch 中,我们可以利用自动求导功能和优化器来实现线性回归模型。下面是一个简单的线性回归示例代码:

我们的目的是:预测输入特征X与对应的真实标签Y之间的关系。

代码语言:javascript
代码运行次数:0
复制
import torch
import matplotlib.pyplot as plt

# 输入数据
x_data = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0], [8.0]])

# 定义线性回归模型
class LinearRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)  # 输入维度为1,输出维度为1

    def forward(self, x):
        return self.linear(x)

model = LinearRegressionModel()

# 定义损失函数和优化器
criterion = torch.nn.MSELoss()  # 均方误差损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器

# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):
    # 前向传播
    y_pred = model(x_data)
    
    # 计算损失
    loss = criterion(y_pred, y_data)
    
    # 反向传播与优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 输出训练后的参数
print("训练后的参数:")
print("W =", model.linear.weight.item())
print("b =", model.linear.bias.item())

# 绘制数据点
plt.scatter(x_data, y_data)
# 绘制回归线
plt.plot(x_data, model(x_data).detach().numpy(), 'r-', label='Regression Line')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Linear Regression')
plt.legend()
plt.show()
  • 运行结果

根据训练得到的参数,线性回归模型的方程为:

Y = 1.9862X + 0.0405

其中:

Y

是预测的因变量值, -

X

是自变量的值。

这意味着自变量

X

的变化每增加 1 单位,因变量

Y

的变化大约为 1.9862单位。此外,即使自变量

X

为 0 时,因变量

Y

也会接近 0.0405。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-03-12,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • pytorch中的线性回归
    • 线性回归原理
    • 实现线性回归
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档