首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

利用PyTorch的三元组损失Hard Triplet Loss进行嵌入模型微调

本文介绍如何使用 PyTorch 和三元组边缘损失 (Triplet Margin Loss) 微调嵌入模型,并重点阐述实现细节和代码示例。三元组损失是一种对比损失函数,通过缩小锚点与正例间的距离,同时扩大锚点与负例间的距离来优化模型。

数据集准备与处理

一般的嵌入模型都会使用Sentence Transformer ,其中的 encode() 方法可以直接处理文本输入。但是为了进行微调,我们需要采用 Transformer 库,所以就要将文本转换为模型可接受的 token IDs 和 attention masks。Token IDs 代表模型词汇表中的词或字符,attention masks 用于防止模型关注填充 tokens。

本文使用 thenlper/gte-base 模型,需要对应的 tokenizer 对文本进行预处理。该模型基于 BertModel 架构:

BertModel(

(embeddings): BertEmbeddings(

  (word_embeddings): Embedding(30522, 768, padding_idx=0)

  (position_embeddings): Embedding(512, 768)

  (token_type_embeddings): Embedding(2, 768)

  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)

  (dropout): Dropout(p=0.1, inplace=False)

)

(encoder): BertEncoder(

  (layer): ModuleList(

    (0-11): 12 x BertLayer(

      (attention): BertAttention(

        (self): BertSdpaSelfAttention(

          (query): Linear(in_features=768, out_features=768, bias=True)

          (key): Linear(in_features=768, out_features=768, bias=True)

          (value): Linear(in_features=768, out_features=768, bias=True)

          (dropout): Dropout(p=0.1, inplace=False)

        )

        (output): BertSelfOutput(

          (dense): Linear(in_features=768, out_features=768, bias=True)

          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)

          (dropout): Dropout(p=0.1, inplace=False)

        )

      )

      (intermediate): BertIntermediate(

        (dense): Linear(in_features=768, out_features=3072, bias=True)

        (intermediate_act_fn): GELUActivation()

      )

      (output): BertOutput(

        (dense): Linear(in_features=3072, out_features=768, bias=True)

        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)

        (dropout): Dropout(p=0.1, inplace=False)

      )

    )

  )

)

(pooler): BertPooler(

  (dense): Linear(in_features=768, out_features=768, bias=True)

  (activation): Tanh()

)

)

利用 Transformers 库的 AutoTokenizer 和 AutoModel 可以简化模型加载过程,无需手动处理底层架构和配置细节。

from transformers import AutoTokenizer, AutoModel

from tqdm import tqdm

tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base")

# 获取文本并进行标记

train_texts = [df_train.loc[i]['content'] for i in range(df_train.shape[0])]

dev_texts = [df_dev.loc[i]['content'] for i in range(df_dev.shape[0])]

test_texts = [df_test.loc[i]['content'] for i in range(df_test.shape[0])]

train_tokens = []

train_attention_masks = []

dev_tokens = []

dev_attention_masks = []

test_tokens = []

test_attention_masks = []

for sent in tqdm(train_texts):

 encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt')

 train_tokens.append(encoding['input_ids'].squeeze(0))

 train_attention_masks.append(encoding['attention_mask'].squeeze(0))

for sent in tqdm(dev_texts):

 encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt')

 dev_tokens.append(encoding['input_ids'].squeeze(0))

 dev_attention_masks.append(encoding['attention_mask'].squeeze(0))

for sent in tqdm(test_texts):

 encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt')

 test_tokens.append(encoding['input_ids'].squeeze(0))

 test_attention_masks.append(encoding['attention_mask'].squeeze(0))

获取 token IDs 和 attention masks 后,需要将其存储并创建一个自定义的 PyTorch 数据集。

import random

from collections import defaultdict

import torch

from torch.utils.data import Dataset, DataLoader, Sampler, SequentialSampler

class CustomTripletDataset(Dataset):

   def __init__(self, tokens, attention_masks, labels):

       self.tokens = tokens

       self.attention_masks = attention_masks

       self.labels = torch.Tensor(labels)

       self.label_dict = defaultdict(list)

       for i in range(len(tokens)):

           self.label_dict[int(self.labels[i])].append(i)

       self.unique_classes = list(self.label_dict.keys())

   def __len__(self):

       return len(self.tokens)

   def __getitem__(self, index):

       ids = self.tokens[index].to(device)

       ams = self.attention_masks[index].to(device)

       y = self.labels[index].to(device)

       return ids, ams, y

由于采用三元组损失,需要从数据集中采样正例和负例。label_dict 字典用于存储每个类别及其对应的数据索引,方便随机采样。DataLoader 用于加载数据集:

train_loader = DataLoader(train_dataset, batch_sampler=train_batch_sampler)

其中 train_batch_sampler 是自定义的批次采样器:

class CustomBatchSampler(SequentialSampler):

   def __init__(self, dataset, batch_size):

       self.dataset = dataset

       self.batch_size = batch_size

       self.unique_classes = sorted(dataset.unique_classes)

       self.label_dict = dataset.label_dict

       self.num_batches = len(self.dataset) // self.batch_size

       self.class_size = self.batch_size // 4

   def __iter__(self):

       total_samples_used = 0

       weights = np.repeat(1, len(self.unique_classes))

       while total_samples_used < len(self.dataset):

           batch = []

           classes = []

           for _ in range(4):

               next_selected_class = self._select_class(weights)

               while next_selected_class in classes:

                 next_selected_class = self._select_class(weights)

               weights[next_selected_class] += 1

               classes.append(next_selected_class)

               new_choices = self.label_dict[next_selected_class]

               remaining_samples = list(np.random.choice(new_choices, min(self.class_size, len(new_choices)), replace=False))

               batch.extend(remaining_samples)

           total_samples_used += len(batch)

           yield batch

   def _select_class(self, weights):

       dist = 1/weights

       dist = dist/np.sum(dist)

       selected = int(np.random.choice(self.unique_classes, p=dist))

       return selected

   def __len__(self):

       return self.num_batches

自定义批次采样器控制训练批次的构成,本文的实现确保每个批次包含 4 个类别,每个类别包含 8 个数据点。验证采样器则确保验证集批次在不同 epoch 间保持一致。

模型构建

嵌入模型通常基于 Transformer 架构,输出每个 token 的嵌入。为了获得句子嵌入,需要对 token 嵌入进行汇总。常用的方法包括 CLS 池化和平均池化。本文使用的 gte-base 模型采用平均池化,需要从模型输出中提取 token 嵌入并计算平均值。

import torch.nn.functional as F

import torch.nn as nn

class EmbeddingModel(nn.Module):

   def __init__(self, base_model):

       super().__init__()

       self.base_model = base_model

   def average_pool(self, last_hidden_states, attention_mask):

       # 平均 token 嵌入

       last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)

       return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

   def forward(self, input_ids, attention_mask):

       outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)

       last_hidden_state = outputs.last_hidden_state

       pooled_output = self.average_pool(last_hidden_state, attention_mask)

       normalized_output = F.normalize(pooled_output, p=2, dim=1)

       return normalized_output

base_model = AutoModel.from_pretrained("thenlper/gte-base")

model = EmbeddingModel(base_model)

EmbeddingModel 类封装了 Hugging Face 模型,并实现了平均池化和嵌入归一化。

模型训练

训练循环中需要动态计算每个锚点的最难正例和最难负例。

import numpy as np

def train(model, train_loader, criterion, optimizer, scheduler):

   model.train()

   epoch_train_losses = []

   for idx, (ids, attention_masks, labels) in enumerate(train_loader):

       optimizer.zero_grad()

       embeddings = model(ids, attention_masks)

       distance_matrix = torch.cdist(embeddings, embeddings, p=2) # 创建方形距离矩阵

       anchors = []

       positives = []

       negatives = []

       for i in range(len(labels)):

           anchor_label = labels[i].item()

           anchor_distance = distance_matrix[i] # 锚点与所有其他点之间的距离

           # 最难的正例(同一类别中最远的)

           hardest_positive_idx = (labels == anchor_label).nonzero(as_tuple=True)[0] # 所有同类索引

           hardest_positive_idx = hardest_positive_idx[hardest_positive_idx != i] # 排除自己的标签

           hardest_positive = hardest_positive_idx[anchor_distance[hardest_positive_idx].argmax()] # 最远同类的标签

           # 最难的负例(不同类别中最近的)

           hardest_negative_idx = (labels != anchor_label).nonzero(as_tuple=True)[0] # 所有不同类索引

           hardest_negative = hardest_negative_idx[anchor_distance[hardest_negative_idx].argmin()] # 最近不同类的标签

           # 加载选择的

           anchors.append(embeddings[i])

           positives.append(embeddings[hardest_positive])

           negatives.append(embeddings[hardest_negative])

       # 将列表转换为张量

       anchors = torch.stack(anchors)

       positives = torch.stack(positives)

       negatives = torch.stack(negatives)

       # 计算损失

       loss = criterion(anchors, positives, negatives)

       epoch_train_losses.append(loss.item())

       # 反向传播和优化

       loss.backward()

       optimizer.step()

       # 更新学习率

       scheduler.step()

   return np.mean(epoch_train_losses)

训练过程中使用 torch.cdist() 计算嵌入间的距离矩阵,并根据距离选择最难正例和最难负例。PyTorch 的 TripletMarginLoss 用于计算损失。

结论与讨论

实践表明,Batch Hard Triplet Loss 在某些情况下并非最优选择。例如,当正例样本内部差异较大时,强制其嵌入相似可能适得其反。

本文的重点在于 PyTorch 中自定义批次采样和动态距离计算的实现。

对于某些任务,直接在分类任务上微调嵌入模型可能比使用三元组损失更有效。

喜欢就关注一下吧:

点个在看你最好看!

  • 发表于:
  • 原文链接https://page.om.qq.com/page/OLi-AsaWeYKV6J4pyC_zaAeg0
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券