前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >基于 BERT 的抽取式摘要

基于 BERT 的抽取式摘要

作者头像
小言从不摸鱼
发布2025-03-05 08:31:32
发布2025-03-05 08:31:32
5700
代码可运行
举报
文章被收录于专栏:机器学习入门机器学习入门
运行总次数:0
代码可运行

🍔 环境准备:

Python: 3.7 或更高版本(推荐 3.8 或 3.9,与 PyTorch 和 Transformers 兼容性更好)。

Anaconda/Miniconda (强烈推荐): 用于创建和管理虚拟环境,避免包版本冲突。

  • 创建环境: conda create -n textsum python=3.8
  • 激活环境: conda activate textsum

安装必要的库: Bash

代码语言:javascript
代码运行次数:0
复制
pip install torch  # 或 tensorflow,取决于你选择哪个深度学习框架
pip install transformers  # Hugging Face 的 Transformers 库
pip install jieba  # 中文分词
pip install scikit-learn  # 机器学习工具库
pip install rouge-score # ROUGE 评估 (或使用 py-rouge)
# 如果使用 py-rouge (更完整的 ROUGE 实现):
# pip install py-rouge  # 可能需要先安装 Perl

🍔 数据准备:

  • 数据集选择: 假设你选择 LCSTS 数据集(一个中文短文本摘要数据集)。你需要下载数据集并将其解压到你的项目目录中。LCSTS 数据集通常包含三个文件:
    • train.txt: 训练集
    • dev.txt: 验证集
    • test.txt: 测试集
    • 每个文件包含多行, 每行是一个json, 包含summarytext两个字段
  • 数据加载和预处理 (data_utils.py)
代码语言:javascript
代码运行次数:0
复制
import json
import jieba
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer

class SummarizationDataset(Dataset):
    def __init__(self, data_file, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(data_file)

    def load_data(self, data_file):
        data = []
        with open(data_file, 'r', encoding='utf-8') as f:
            for line in f:
                item = json.loads(line.strip())
                # text = item['text']
                text = item['title']  #根据具体情况选择title 还是 content
                # summary = item['summary']
                summary = item['content']
                data.append((text, summary))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, summary = self.data[idx]
        # 分词并添加特殊标记
        text_tokens = ['[CLS]'] + list(jieba.cut(text)) + ['[SEP]']
        summary_tokens = ['[CLS]'] + list(jieba.cut(summary)) + ['[SEP]']

        # 构建标签(抽取式摘要:0 或 1)
        labels = [0] * len(text_tokens)
        summary_token_ids = set(self.tokenizer.convert_tokens_to_ids(summary_tokens))
        for i, token in enumerate(text_tokens):
            if self.tokenizer.convert_tokens_to_ids(token) in summary_token_ids:
                labels[i] = 1 #如果需要更精细的label, 可以计算每个句子和summary的rouge值作为label

        # 截断或填充到最大长度
        text_tokens = text_tokens[:self.max_length]
        labels = labels[:self.max_length]

        text_ids = self.tokenizer.convert_tokens_to_ids(text_tokens)
        attention_mask = [1] * len(text_ids)

        padding_length = self.max_length - len(text_ids)
        text_ids += [self.tokenizer.pad_token_id] * padding_length
        attention_mask += [0] * padding_length
        labels += [0] * padding_length  # 标签也需要填充

        return {
            'input_ids': torch.tensor(text_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long)
        }

# 使用示例
if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    dataset = SummarizationDataset('train.txt', tokenizer)  # 替换为你的训练数据文件
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    for batch in dataloader:
        print(batch['input_ids'].shape)
        print(batch['attention_mask'].shape)
        print(batch['labels'].shape)
        break

🍔 模型定义 (model.py):

代码语言:javascript
代码运行次数:0
复制
import torch
import torch.nn as nn
from transformers import BertModel

class BertForExtractiveSummarization(nn.Module):
    def __init__(self, bert_model_name='bert-base-chinese'):
        super(BertForExtractiveSummarization, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # 使用 [CLS] 标记的输出,或者所有 token 输出的平均池化
        # pooled_output = outputs.pooler_output  # [CLS] 标记
        last_hidden_state = outputs.last_hidden_state # 取最后一个hidden_state
        logits = self.classifier(last_hidden_state) #
        probs = self.sigmoid(logits).squeeze(-1)  # 转换为概率
        return probs

🍔 训练脚本 (train.py):

代码语言:javascript
代码运行次数:0
复制
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from model import BertForExtractiveSummarization  # 从 model.py 导入模型
from data_utils import SummarizationDataset  # 从 data_utils.py 导入数据集
from tqdm import tqdm  # 导入 tqdm

# 超参数
BATCH_SIZE = 8
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3
MAX_LENGTH = 512
BERT_MODEL_NAME = 'bert-base-chinese'

# 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载 tokenizer 和数据集
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
train_dataset = SummarizationDataset('train.txt', tokenizer, MAX_LENGTH)  # 替换为你的训练数据文件
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

dev_dataset = SummarizationDataset('dev.txt', tokenizer, MAX_LENGTH)  # 替换为你的验证数据文件
dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False)


# 初始化模型、优化器和损失函数
model = BertForExtractiveSummarization(BERT_MODEL_NAME).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)  # 使用 AdamW 优化器
criterion = nn.BCELoss()  # 二元交叉熵损失

# 训练循环
for epoch in range(NUM_EPOCHS):
    model.train()  # 设置为训练模式
    total_loss = 0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()  # 清空梯度

        probs = model(input_ids, attention_mask)
        loss = criterion(probs, labels.float())

        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {total_loss / len(train_dataloader)}")

    # 验证 (可选,但强烈建议)
    model.eval()  # 设置为评估模式
    with torch.no_grad():  # 关闭梯度计算
        val_loss = 0
        for batch in dev_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            probs = model(input_ids, attention_mask)
            loss = criterion(probs, labels.float())

            val_loss += loss.item()

        print(f"Validation Loss: {val_loss / len(dev_dataloader)}")

# 保存模型
torch.save(model.state_dict(), 'extractive_summarizer.pth')

🍔 预测/推理脚本 (predict.py):

代码语言:javascript
代码运行次数:0
复制
import torch
from transformers import BertTokenizer
from model import BertForExtractiveSummarization  # 导入模型
import jieba

def summarize(text, model, tokenizer, max_length=512, threshold=0.5):
    """
    使用训练好的模型进行抽取式摘要。

    Args:
        text: 要摘要的文本。
        model: 训练好的模型。
        tokenizer: 分词器。
        max_length: 最大序列长度。
        threshold: 概率阈值,用于决定是否选择句子。

    Returns:
        str: 抽取式摘要。
    """
    model.eval()  # 设置为评估模式
    tokens = ['[CLS]'] + list(jieba.cut(text)) + ['[SEP]']
    tokens = tokens[:max_length]
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    attention_mask = [1] * len(input_ids)
    padding_length = max_length - len(input_ids)
    input_ids += [tokenizer.pad_token_id] * padding_length
    attention_mask += [0] * padding_length

    input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)  # 添加 batch 维度
    attention_mask = torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        probs = model(input_ids, attention_mask)

    probs = probs.squeeze(0).cpu().tolist()  # 移除 batch 维度,并移到 CPU
    selected_sentences = []
    current_sentence = ""
    for i, token in enumerate(tokens):

        if token == '[CLS]' or token == '[SEP]':
          continue
        if token in [',','。','?','!',';',';','!','?']:
          current_sentence += token
          if probs[i] > threshold:
            selected_sentences.append(current_sentence)
          current_sentence = ""
        else:
          current_sentence += token

    return "。".join(selected_sentences) + "。"

if __name__ == '__main__':
    # 加载模型和 tokenizer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    model = BertForExtractiveSummarization()
    model.load_state_dict(torch.load('extractive_summarizer.pth', map_location=device))  # 加载训练好的模型
    model.to(device)

    # 示例文本
    text = """
        你的长文本。
    """
    summary = summarize(text, model, tokenizer)
    print(f"Original Text:\n{text}\n")
    print(f"Summary:\n{summary}")

🍔 运行和评估:

训练: 运行 train.py 进行模型训练。

预测/摘要: 运行 predict.py,将你要摘要的文本放入 predict.py 文件中。

评估: 使用 rouge-score 库(或 py-rouge)计算 ROUGE 分数。 * 准备一个测试集,包含原文和人工编写的参考摘要。 * 使用你的模型生成摘要。 * 计算生成的摘要和参考摘要之间的 ROUGE 分数。

代码文件整理:

建议将代码组织成以下结构:

代码语言:javascript
代码运行次数:0
复制
text_summarization_project/
├── data/
│   ├── train.txt     # 训练数据
│   ├── dev.txt       # 验证数据
│   └── test.txt      # 测试数据 (可选,用于最终评估)
├── data_utils.py    # 数据加载和预处理
├── model.py         # 模型定义
├── train.py         # 训练脚本
├── predict.py       # 预测/摘要脚本
└── extractive_summarizer.pth  # 保存的模型 (训练完成后)

关键改进和技巧:

  • data_utils.py:
    • 使用 torch.utils.data.DatasetDataLoader 来高效地加载和批处理数据。
    • 实现了分词、添加特殊标记、标签生成、截断/填充等预处理步骤。
    • 使用了 jieba 进行中文分词。
    • 使用了 transformers 库中的 BertTokenizer
  • model.py:
    • 定义了 BertForExtractiveSummarization 模型,它使用预训练的 BERT 模型,并在其上添加了一个线性分类层。
    • forward 方法中,可以使用 outputs.pooler_output[CLS] 标记的输出)或 outputs.last_hidden_state 的平均池化作为句子表示。
  • train.py:
    • 使用了 AdamW 优化器(通常比 Adam 效果更好)。
    • 包含了训练循环和验证循环(可选,但强烈建议)。
    • 使用了 tqdm 库来显示训练进度条。
    • 将模型和数据移动到 GPU(如果可用)。
    • 保存训练好的模型。
  • predict.py:
    • 加载训练好的模型和分词器。
    • 实现了 summarize 函数,用于对单个文本进行摘要。
    • 使用了阈值来选择句子。
  • 更精细的label: 可以不仅仅使用0,1作为label, 可以尝试使用rouge值作为label.

后期仍可以做的改进:

  • 尝试不同的预训练模型: 尝试其他中文预训练模型,如 RoBERTa-wwm-ext-Chinese、ERNIE 等。
  • 调整超参数: 使用不同的学习率、批大小、最大序列长度等。
  • 尝试不同的损失函数: 除了二元交叉熵损失,你还可以尝试 focal loss 等。
  • 添加 early stopping: 根据验证集上的性能,提前停止训练,以防止过拟合。
  • 使用 beam search 或其他解码策略: 在生成摘要时,可以使用 beam search 等解码策略来提高摘要的质量。
  • 实现 Web 界面: 使用 Flask 或 FastAPI 构建一个 Web 应用程序,让用户可以输入文本并查看摘要结果。
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2025-03-04,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 🍔 环境准备:
  • 🍔 数据准备:
  • 🍔 模型定义 (model.py):
  • 🍔 训练脚本 (train.py):
  • 🍔 预测/推理脚本 (predict.py):
  • 🍔 运行和评估:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档