语言:
- 英语
许可证: mit
标签:
- 文本分类
- 零样本分类
数据集:
- multi_nli
- facebook/anli
- fever
- lingnli
评估指标:
- 准确率
流水线标签: 零样本分类
DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary
模型描述
该模型基于4个自然语言推理(NLI)数据集中的782,357个假设-前提对进行训练,这些数据集包括:MultiNLI、Fever-NLI、LingNLI和ANLI。
需要注意的是,该模型针对二元NLI任务训练,旨在预测“蕴含”或“不蕴含”。这一设计特别适用于零样本分类任务,其中“中立”与“矛盾”之间的差异无关紧要。
基础模型为微软的DeBERTa-v3-xsmall。DeBERTa的v3版本通过引入不同的预训练目标,显著超越了之前版本的性能,详见DeBERTa-V3论文。
若追求最高性能(而非速度),推荐使用MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli。
预期用途与限制
如何使用该模型
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "MoritzLaurer/DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
premise = "我最初以为我喜欢这部电影,但转念一想其实很令人失望。"
hypothesis = "这部电影不错。"
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["蕴含", "不蕴含"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)
训练数据
该模型基于4个NLI数据集中的782,357个假设-前提对进行训练:MultiNLI、Fever-NLI、LingNLI和ANLI。
训练过程
DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary使用Hugging Face训练器,并采用以下超参数进行训练:
training_args = TrainingArguments(
num_train_epochs=5, # 训练总轮数
learning_rate=2e-05,
per_device_train_batch_size=32, # 训练时每设备批次大小
per_device_eval_batch_size=32, # 评估时批次大小
warmup_ratio=0.1, # 学习率调度器的预热步数比例
weight_decay=0.06, # 权重衰减强度
fp16=True # 混合精度训练
)
评估结果
模型在MultiNLI、ANLI、LingNLI的二元测试集以及Fever-NLI的二元开发集(两类而非三类)上进行了评估。所用指标为准确率。
数据集 |
mnli-m-2c |
mnli-mm-2c |
fever-nli-2c |
anli-all-2c |
anli-r3-2c |
lingnli-2c |
准确率 |
0.925 |
0.922 |
0.892 |
0.676 |
0.665 |
0.888 |
速度(文本/秒,CPU,128批次) |
6.0 |
6.3 |
3.0 |
5.8 |
5.0 |
7.6 |
速度(文本/秒,GPU Tesla P100,128批次) |
473 |
487 |
230 |
390 |
340 |
586 |
限制与偏见
请参考原始DeBERTa论文及关于不同NLI数据集的文献以了解潜在的偏见。
引用
若使用此模型,请引用:Laurer, Moritz, Wouter van Atteveldt, Andreu Salleras Casas, 和 Kasper Welbers。2022年。“更少标注,更多分类——通过深度迁移学习和BERT-NLI解决监督机器学习的数据稀缺问题”。预印本,6月。Open Science Framework。https://osf.io/74b8k。
合作意向或问题?
如有合作意向或问题,请通过m{dot}laurer{at}vu{dot}nl或LinkedIn联系我。
调试与问题
请注意,DeBERTa-v3发布于2021年12月6日,旧版HF Transformers可能在运行模型时存在问题(例如导致分词器问题)。使用Transformers>=4.13版本可能解决部分问题。