前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >PyTorch 实现数据并行的 BERT

PyTorch 实现数据并行的 BERT

原创
作者头像
繁依Fanyi
发布2025-03-28 03:28:03
发布2025-03-28 03:28:03
9100
代码可运行
举报
运行总次数:0
代码可运行

在这篇文章里,我们要把 BERT(Bidirectional Encoder Representations from Transformers) 和 PyTorch 的数据并行(DataParallel) 这两位重量级选手拉到一起,手把手教你如何高效地在多张显卡上训练 BERT 模型。不管你是 PyTorch 小白,还是刚接触深度学习,这篇文章都能让你轻松理解它们的底层逻辑,并且学会如何在实际项目中使用它们。

1. 数据并行(DataParallel)到底是啥?

在深度学习的世界里,模型越大,训练所需的计算资源就越夸张。如果你曾在单张显卡上跑 BERT,估计你已经被显存爆炸劝退了。幸运的是,我们可以用 PyTorch 自带的 DataParallel 轻松让 BERT 训练时“并行作战”,充分利用多张 GPU,提高计算效率。

1.1 传统的单 GPU 训练方式

通常,我们的代码默认只会使用一张 GPU,比如这样:

代码语言:python
代码运行次数:0
运行
复制
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

这样训练虽然简单,但当你的 BERT 变得越来越大,或者 batch size 增加时,单张显卡很快就会吃不消。

1.2 多 GPU 训练的“升级版”

PyTorch 让我们可以用 torch.nn.DataParallel 轻松把训练任务分配到多个 GPU 上:

代码语言:python
代码运行次数:0
运行
复制
model = nn.DataParallel(model)
model.to(device)

它的原理也不复杂:

  1. 把 batch 数据拆分:假设你有 4 张 GPU,每个 batch 有 64 个样本,那 DataParallel 会把数据拆成 4 份,每个 GPU 处理 16 个样本。
  2. 并行计算:每张 GPU 用自己的参数计算前向传播和反向传播。
  3. 收集结果:主 GPU(默认是 cuda:0)会收集所有 GPU 的计算结果,并进行梯度更新。

这样,每个 GPU 负责一部分工作,整体训练速度就大大提高了!


2. PyTorch 里的 BERT 模型

2.1 BERT 的基本结构

BERT 由多个 Transformer 层堆叠而成,每层包括:

  • 自注意力(Self-Attention):让模型能看到输入序列中的所有单词,而不是只看前面的词(区别于 RNN)。
  • 前馈神经网络(Feed Forward Network, FFN):负责对注意力层输出的数据进行非线性变换,提升模型的表达能力。
  • 层归一化(Layer Normalization):让训练更稳定,防止梯度爆炸或消失。

BERT 有 BaseLarge 两个版本,其中 bert-base 由 12 层 Transformer 组成,而 bert-large 则有 24 层,训练成本更高。

2.2 加载预训练模型

直接用 transformers 库里的 BertModel 加载预训练的 BERT:

代码语言:python
代码运行次数:0
运行
复制
from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased")

这个方法会自动下载 BERT 的预训练权重,并且帮你配置好模型结构,省去了自己搭建模型的麻烦。


3. 让 BERT 在多 GPU 上飞速训练

我们来看看如何利用 DataParallel 让 BERT 进行 多 GPU 训练

3.1 准备数据

首先,我们需要一些文本数据,通常 BERT 需要的是 tokenized 的数据,所以我们要用 BertTokenizer

代码语言:python
代码运行次数:0
运行
复制
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

text = ["Hello, how are you?", "I am learning PyTorch!"]
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

这样我们就得到了 PyTorch 格式的张量输入。

3.2 定义训练函数

训练过程中,我们要把 BERT 和数据都搬到 GPU 上,并使用 DataParallel 让多个 GPU 共同工作:

代码语言:python
代码运行次数:0
运行
复制
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertForSequenceClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载 BERT
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# 使用 DataParallel
model = nn.DataParallel(model)
model.to(device)

# 定义优化器
optimizer = optim.AdamW(model.parameters(), lr=2e-5)

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 训练循环
def train(model, data, labels, optimizer, criterion):
    model.train()
    optimizer.zero_grad()

    # 把数据送到 GPU
    data, labels = data.to(device), labels.to(device)

    # 前向传播
    outputs = model(**data)
    loss = criterion(outputs.logits, labels)

    # 反向传播
    loss.backward()
    optimizer.step()

    return loss.item()

3.3 训练过程

代码语言:python
代码运行次数:0
运行
复制
inputs = {key: val.to(device) for key, val in inputs.items()}
labels = torch.tensor([0, 1]).to(device)

for epoch in range(3):
    loss = train(model, inputs, labels, optimizer, criterion)
    print(f"Epoch {epoch + 1}, Loss: {loss:.4f}")

4. 深入解析 DataParallel 的工作原理

DataParallel 在内部做了这些事:

  1. 自动复制模型:PyTorch 会把 BERT 复制到所有 GPU 上,每张 GPU 都有一个完整的 BERT 副本。
  2. 切分 batch 数据:每个 GPU 只处理 batch 的一部分,减少单卡负载。
  3. 收集结果:所有 GPU 计算完梯度后,主 GPU 会收集它们并更新参数。

不过,DataParallel 有个小问题:主 GPU 负担较重,因为它不仅要训练自己的数据,还要管理多个 GPU 之间的通讯。所以在更大规模的训练任务中,通常会选择 torch.nn.parallel.DistributedDataParallel(简称 DDP),它能进一步提升效率。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 数据并行(DataParallel)到底是啥?
    • 1.1 传统的单 GPU 训练方式
    • 1.2 多 GPU 训练的“升级版”
  • 2. PyTorch 里的 BERT 模型
    • 2.1 BERT 的基本结构
    • 2.2 加载预训练模型
  • 3. 让 BERT 在多 GPU 上飞速训练
    • 3.1 准备数据
    • 3.2 定义训练函数
    • 3.3 训练过程
  • 4. 深入解析 DataParallel 的工作原理
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档