数据集:
- snli
- anli
- multi_nli
- multi_nli_mismatch
- fever
许可证:mit
这是一个强大的预训练RoBERTa-Large自然语言推理模型。
训练数据融合了多个知名NLI数据集:SNLI
、MNLI
、FEVER-NLI
、ANLI (R1, R2, R3)
。
还提供其他预训练NLI模型,包括RoBERTa
、ALBert
、BART
、ELECTRA
、XLNet
。
由Yixin Nie训练,原始来源。
尝试以下代码片段:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
if __name__ == '__main__':
max_length = 256
premise = "两名女子正相拥而立,手里拿着外卖餐盒。"
hypothesis = "男人们正在熟食店外打架。"
hg_model_hub_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
# hg_model_hub_name = "ynie/albert-xxlarge-v2-snli_mnli_fever_anli_R1_R2_R3-nli"
# hg_model_hub_name = "ynie/bart-large-snli_mnli_fever_anli_R1_R2_R3-nli"
# hg_model_hub_name = "ynie/electra-large-discriminator-snli_mnli_fever_anli_R1_R2_R3-nli"
# hg_model_hub_name = "ynie/xlnet-large-cased-snli_mnli_fever_anli_R1_R2_R3-nli"
tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name)
model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name)
tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis,
max_length=max_length,
return_token_type_ids=True, truncation=True)
input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0)
# 注意:bart模型没有'token_type_ids',若使用bart请删除下面这行
token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0)
attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0)
outputs = model(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=None)
# 注意:
# "id2label": {
# "0": "蕴含",
# "1": "中立",
# "2": "矛盾"
# },
predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # 批处理大小仅为1
print("前提:", premise)
print("假设:", hypothesis)
print("蕴含概率:", predicted_probability[0])
print("中立概率:", predicted_probability[1])
print("矛盾概率:", predicted_probability[2])
更多内容见此处。
引用:
@inproceedings{nie-etal-2020-adversarial,
title = "对抗性自然语言推理:自然语言理解的新基准",
author = "聂艺昕 和
Williams, Adina 和
Dinan, Emily 和
Bansal, Mohit 和
Weston, Jason 和
Kiela, Douwe",
booktitle = "第58届计算语言学协会年会论文集",
year = "2020",
publisher = "计算语言学协会",
}