我试图训练一个HuggingFace Trainer
,但发现了以下错误:
RuntimeError: "mse_cuda" not implemented for 'Long' when training a transformer.Trainer
我在多个云环境(CPU和GPU)中尝试过这一点,但没有成功。dataset (tok_dds
)的形状和类型如下,我已经确保没有空值。
Dataset({
features: ['label', 'title', 'text', 'input', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 5000
})
{'label': int,
'title': str,
'text': str,
'input': str,
'input_ids': list,
'token_type_ids': list,
'attention_mask': list}
我已将我的损失职能定义如下:
def corr(x,y): return np.corrcoef(x,y)[0][1]
def corr_d(eval_pred): return {'pearson': corr(*eval_pred)}
但是,当尝试在我的数据集的训练/测试拆分上训练model_nm = 'microsoft/deberta-v3-small'
时。我看到以下错误:
dds = tok_ds.train_test_split(0.25, seed=42)
tokz = AutoTokenizer.from_pretrained(model_nm)
model = AutoModelForSequenceClassification.from_pretrained(model_nm, num_labels=1)
trainer = Trainer(model, args, train_dataset=dds['train'], eval_dataset=dds['test'],
tokenizer=tokz, compute_metrics=corr_d)
...
...
File /shared-libs/python3.9/py/lib/python3.9/site-packages/torch/nn/functional.py:3280, in mse_loss(input, target, size_average, reduce, reduction)
3277 reduction = _Reduction.legacy_get_string(size_average, reduce)
3279 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
-> 3280 return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
RuntimeError: "mse_cuda" not implemented for 'Long' when training a transformer.Trainer
下面是传入Trainer
的args (如果相关的话):
args = TrainingArguments('outputs', learning_rate=lr, warmup_ratio=0.1, lr_scheduler_type='cosine', fp16=True,
evaluation_strategy="epoch", per_device_train_batch_size=bs, per_device_eval_batch_size=bs*2,
num_train_epochs=epochs, weight_decay=0.01, report_to='none')
以下是我认为可能相关的环境信息
!python --version
Python 3.9.13
!pip list
Package Version
----------------------------- ------------
...
transformers 4.21.1
huggingface-hub 0.8.1
pandas 1.2.5
protobuf 3.19.4
scikit-learn 1.1.1
tensorflow 2.9.1
torch 1.12.0
有人能为我指出解决这个问题的正确方向吗?
发布于 2022-08-23 13:43:13
将标签列的数据类型从int
更改为float
解决了这个问题。如果数据集来自熊猫DataFrame,则可以在将数据格式传递到Dataset之前更改列的数据类型。
https://stackoverflow.com/questions/73428120
复制