语言:
- 英语
标签:
- 文本分类
- 零样本分类
评估指标:
- 准确率
流水线标签: 零样本分类
该模型基于MoritzLaurer/DeBERTa-v3-base-mnli仓库构建,并添加了handler.py文件,以便在推理端点中更轻松地使用该模型,对前提和假设进行零样本分类比较,判断其关系为蕴含、中性或矛盾。
DeBERTa-v3-base-mnli-fever-anli
模型描述
该模型在MultiNLI数据集上训练,包含392,702个NLI假设-前提对。基础模型是微软的DeBERTa-v3-base。DeBERTa的v3版本通过引入不同的预训练目标显著优于之前的模型版本,详见原始DeBERTa论文的附录11。如需更强大的模型,请查看DeBERTa-v3-base-mnli-fever-anli,该模型在更多数据上进行了训练。
预期用途与限制
如何使用模型
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
premise = "我最初以为我喜欢这部电影,但仔细想想其实很失望。"
hypothesis = "这部电影很好。"
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["蕴含", "中性", "矛盾"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)
示例cURL:
curl 你的推理端点URL \ -X POST \ -d '{"inputs": {"premise": "一个人在公园遛狗。", "hypothesis": "一个人和一只动物在外面。"}}' \ -H "Authorization: Bearer hf_你的令牌" \ -H "Content-Type: application/json
训练数据
该模型在MultiNLI数据集上训练,包含392,702个NLI假设-前提对。
训练过程
DeBERTa-v3-base-mnli使用Hugging Face训练器训练,超参数如下:
training_args = TrainingArguments(
num_train_epochs=5, # 训练总轮数
learning_rate=2e-05,
per_device_train_batch_size=32, # 训练时每设备批大小
per_device_eval_batch_size=32, # 评估批大小
warmup_ratio=0.1, # 学习率调度器的预热步数比例
weight_decay=0.06, # 权重衰减强度
fp16=True # 混合精度训练
)
评估结果
模型在匹配测试集上评估,准确率达到0.90。
限制与偏差
请参考原始DeBERTa论文及不同NLI数据集的相关文献以了解潜在偏差。
BibTeX引用信息
如需引用此模型,请引用原始DeBERTa论文、相关NLI数据集,并包含Hugging Face hub上该模型的链接。
合作意向或问题?
如有问题或合作意向,请联系m{dot}laurer{at}vu{dot}nl或LinkedIn
调试与问题
请注意,DeBERTa-v3近期发布,旧版HF Transformers可能在运行模型时存在问题(如分词器问题)。使用Transformers==4.13可能解决部分问题。
模型回收
在36个数据集上的评估显示,以MoritzLaurer/DeBERTa-v3-base-mnli为基础模型的平均得分为80.01,而microsoft/deberta-v3-base为79.04。
截至2023年9月1日,该模型在microsoft/deberta-v3-base架构的所有测试模型中排名第一。
结果:
20_newsgroup |
ag_news |
amazon_reviews_multi |
anli |
boolq |
cb |
cola |
copa |
dbpedia |
esnli |
financial_phrasebank |
imdb |
isear |
mnli |
mrpc |
multirc |
poem_sentiment |
qnli |
qqp |
rotten_tomatoes |
rte |
sst2 |
sst_5bins |
stsb |
trec_coarse |
trec_fine |
tweet_ev_emoji |
tweet_ev_emotion |
tweet_ev_hate |
tweet_ev_irony |
tweet_ev_offensive |
tweet_ev_sentiment |
wic |
wnli |
wsc |
yahoo_answers |
86.0196 |
90.6333 |
66.96 |
60.0938 |
83.792 |
83.9286 |
86.5772 |
72 |
79.2 |
91.419 |
85.1 |
94.232 |
71.5124 |
89.4426 |
90.4412 |
63.7583 |
86.5385 |
93.8129 |
91.9144 |
89.8687 |
85.9206 |
95.4128 |
57.3756 |
91.377 |
97.4 |
91 |
47.302 |
83.6031 |
57.6431 |
77.1684 |
83.3721 |
70.2947 |
71.7868 |
67.6056 |
74.0385 |
71.7 |
更多信息请参阅: 模型回收