language:
- zh
license: mit
tags:
- 文本分类
- 零样本分类
datasets:
- multi_nli
- facebook/anli
- fever
- lingnli
- alisawuffles/WANLI
metrics:
- 准确率
pipeline_tag: 零样本分类
model-index:
- name: DeBERTa-v3-large-mnli-fever-anli-ling-wanli
results:
- task:
type: 文本分类
name: 自然语言推理
dataset:
name: MultiNLI-matched
type: multi_nli
split: validation_matched
metrics:
- type: 准确率
value: 0,912
verified: false
- task:
type: 文本分类
name: 自然语言推理
dataset:
name: MultiNLI-mismatched
type: multi_nli
split: validation_mismatched
metrics:
- type: 准确率
value: 0,908
verified: false
- task:
type: 文本分类
name: 自然语言推理
dataset:
name: ANLI-all
type: anli
split: test_r1+test_r2+test_r3
metrics:
- type: 准确率
value: 0,702
verified: false
- task:
type: 文本分类
name: 自然语言推理
dataset:
name: ANLI-r3
type: anli
split: test_r3
metrics:
- type: 准确率
value: 0,64
verified: false
- task:
type: 文本分类
name: 自然语言推理
dataset:
name: WANLI
type: alisawuffles/WANLI
split: test
metrics:
- type: 准确率
value: 0,77
verified: false
- task:
type: 文本分类
name: 自然语言推理
dataset:
name: LingNLI
type: lingnli
split: test
metrics:
- type: 准确率
value: 0,87
verified: false
DeBERTa-v3-large-mnli-fever-anli-ling-wanli
模型描述
该模型在MultiNLI、Fever-NLI、Adversarial-NLI (ANLI)、LingNLI和WANLI数据集上进行了微调,共包含885,242个NLI假设-前提对。截至2022年6月6日,该模型是Hugging Face Hub上性能最好的NLI模型,可用于零样本分类。它在ANLI基准测试上显著优于所有其他大型模型。
基础模型是微软的DeBERTa-v3-large。与BERT、RoBERTa等经典掩码语言模型相比,DeBERTa-v3结合了几项创新技术,详见论文。
如何使用该模型
简单的零样本分类流程
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli")
sequence_to_classify = "安格拉·默克尔是德国政治家,也是基民盟的领导人"
candidate_labels = ["政治", "经济", "娱乐", "环境"]
output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(output)
NLI用例
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
premise = "我最初以为我喜欢这部电影,但后来觉得它其实很令人失望。"
hypothesis = "这部电影不好。"
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["蕴含", "中立", "矛盾"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)
训练数据
DeBERTa-v3-large-mnli-fever-anli-ling-wanli在MultiNLI、Fever-NLI、Adversarial-NLI (ANLI)、LingNLI和WANLI数据集上进行了训练,共包含885,242个NLI假设-前提对。请注意,由于数据集质量问题,明确排除了SNLI。更多的数据并不一定意味着更好的NLI模型。
训练过程
DeBERTa-v3-large-mnli-fever-anli-ling-wanli使用Hugging Face训练器进行训练,超参数如下。请注意,在测试中,更长时间的训练和更多的轮次会损害性能(过拟合)。
training_args = TrainingArguments(
num_train_epochs=4, # 训练总轮次
learning_rate=5e-06,
per_device_train_batch_size=16, # 训练时每个设备的批量大小
gradient_accumulation_steps=2, # 将有效批量大小翻倍至32,同时降低内存需求
per_device_eval_batch_size=64, # 评估时的批量大小
warmup_ratio=0.06, # 学习率调度器的预热步数比例
weight_decay=0.01, # 权重衰减强度
fp16=True # 混合精度训练
)
评估结果
该模型使用MultiNLI、ANLI、LingNLI、WANLI的测试集和Fever-NLI的开发集进行评估。使用的指标是准确率。
该模型在每个数据集上都达到了最先进的性能。令人惊讶的是,它在ANLI基准测试上比之前的ALBERT-XXL模型高出8.3%。我认为这是因为ANLI是为了欺骗像RoBERTa(或ALBERT)这样的掩码语言模型而创建的,而DeBERTa-v3使用了更好的预训练目标(RTD)、解耦注意力机制,并且我在更高质量的NLI数据上对其进行了微调。
数据集 |
mnli_test_m |
mnli_test_mm |
anli_test |
anli_test_r3 |
ling_test |
wanli_test |
准确率 |
0.912 |
0.908 |
0.702 |
0.64 |
0.87 |
0.77 |
速度(文本/秒,A100 GPU) |
696.0 |
697.0 |
488.0 |
425.0 |
828.0 |
980.0 |
局限性和偏差
请参考原始DeBERTa-v3论文和有关不同NLI数据集的文献,以获取有关训练数据和潜在偏差的更多信息。该模型将重现训练数据中的统计模式。
引用
如果您使用该模型,请引用:Laurer, Moritz, Wouter van Atteveldt, Andreu Salleras Casas, and Kasper Welbers. 2022. ‘Less Annotating, More Classifying – Addressing the Data Scarcity Issue of Supervised Machine Learning with Deep Transfer Learning and BERT - NLI’. 预印本,6月。Open Science Framework。https://osf.io/74b8k。
合作想法或问题?
如果您有合作想法或问题,请通过m{dot}laurer{at}vu{dot}nl或LinkedIn联系我。
调试和问题
请注意,DeBERTa-v3于2021年12月6日发布,旧版本的HF Transformers似乎无法正常运行该模型(例如导致分词器问题)。使用Transformers>=4.13可能会解决一些问题。