发布
社区首页 >问答首页 >使用数据集、标记器和torch数据集和数据采集器进行动态令牌化

使用数据集、标记器和torch数据集和数据采集器进行动态令牌化
EN

Stack Overflow用户
提问于 2020-12-05 17:15:55
回答 1查看 1.7K关注 0票数 4

我有一个关于“飞行中”标记的问题。这个问题是通过阅读“如何从零开始使用变形金刚和托卡器来训练一个新的语言模型”这里提出的。最后有一句话:“如果您的数据集非常大,可以选择动态加载和标记示例,而不是作为预处理步骤”。我尝试了一种将datasetstokenizers结合起来的解决方案,但是没有找到一个好的模式。

我想这个解决方案需要将一个数据集包装到Pytorch数据集中。

作为文档的一个具体例子

代码语言:javascript
代码运行次数:0
复制
import torch

class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        # instead of doing this beforehand, I'd like to do tokenization on the fly
        self.encodings = encodings 

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = SquadDataset(train_encodings)

如何利用令牌程序的矢量化功能来实现“动态”令牌化?

EN

回答 1

Stack Overflow用户

发布于 2020-12-08 21:55:52

更新2021年2月

从v1.3.0开始,数据集支持通过set_transform方法对函数进行延迟评估。因此,您可以像显示的这里那样直接应用动态标记化。

旧答案

最后,我接受了这个解决方案。我不喜欢batch_size现在被控制在dataset级别。然而,它做好了自己的工作。

通过这种方式,我们利用了两种美好的东西:

  • 快速索引HuggingFace数据集
  • HuggingFace令牌程序的矢量化能力
代码语言:javascript
代码运行次数:0
复制
class CustomPytorchDataset(Dataset):
    """
    This class wraps the HuggingFace dataset and allows for 
    batch indexing into the dataset. This allows exploiting
    the capabilities of the tokenizer to work on batches.

    NOTE: now we control batch_size at the Dataset level, not
    in the DataLoader therefore the DataLoader should always be
    used with `batch_size=1`.
    """

    def __init__(self, batch_size: int):
        self.batch_size = batch_size
        self.dataset = train_ds          # HuggingFace dataset
        self.tokenizer = bert_tokenizer  # HuggingFace tokenizer

    def __getitem__(self, batch_idx: List[int]):
        instance = self.dataset[batch_idx]

        # tokenize on-the-fly
        tokenized_instance = self.tokenizer(
            instance[text_col], 
            truncation=True, 
            padding=True
        )
        
        return tokenized_instance

    def __len__(self):
        return len(self.dataset)

    def sampler(self):
        # shuffling can be controlled by the sampler, 
        # without touching the dataset
        return BatchSampler(
            SequentialSampler(self), 
            batch_size=self.batch_size, 
            drop_last=True
        )

    @staticmethod
    def collate_fn(batches: List[Dict[str, int]]):
        return {
            k: torch.tensor(v, dtype=torch.int64) 
            for k, v in batches[0].items()
        }
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65159768

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档