推理: false
许可证: mit
标签:
基于自监督调优的零样本文本分类模型(以albert-xxlarge-v2为基础)
该零样本文本分类模型通过自监督调优(SSTuning)训练而成。
其由论文《通过自监督调优实现零样本文本分类》(Zero-Shot Text Classification via Self-Supervised Tuning)首次提出,作者为刘超群、张文轩、陈桂珍、吴晓宝、Luu Anh Tuan、Chang Chip Hong及Bing Lidong,并发布于此代码库。
模型主干网络采用albert-xxlarge-v2。
模型描述
该模型使用名为首句预测(FSP)的学习目标对无标注数据进行调优。
FSP任务的设计综合考虑了无标注语料的特性以及分类任务的输入/输出格式。
训练集和验证集均通过FSP从无标注语料中构建。
调优过程中,采用类似BERT的预训练掩码语言模型(如RoBERTa和ALBERT)作为主干网络,并添加分类输出层。
FSP的学习目标是预测正确标签的索引,使用交叉熵损失函数进行模型调优。
模型变体
共发布三个版本模型,详情如下:
需注意,zero-shot-classify-SSTuning-base的训练数据量(20.48M)多于论文所述,此举旨在提升准确率。
使用场景与限制
该模型可直接用于情感分析、主题分类等零样本文本分类任务,无需额外微调。
标签数量建议控制在2~20个之间。
使用方法
可通过Colab笔记本体验模型。
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch, string, random
tokenizer = AutoTokenizer.from_pretrained("albert-xxlarge-v2")
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT")
text = "I love this place! The food is always so fresh and delicious."
list_label = ["negative", "positive"]
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
list_ABC = [x for x in string.ascii_uppercase]
def check_text(model, text, list_label, shuffle=False):
list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
if shuffle:
random.shuffle(list_label_new)
s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
text = f'{s_option} {tokenizer.sep_token} {text}'
model.to(device).eval()
encoding = tokenizer([text],truncation=True, max_length=512,return_tensors='pt')
item = {key: val.to(device) for key, val in encoding.items()}
logits = model(**item).logits
logits = logits if shuffle else logits[:,0:len(list_label)]
probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
predictions = torch.argmax(logits, dim=-1).item()
probabilities = [round(x,5) for x in probs[0]]
print(f'预测结果: {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}')
print(f'置信度: {round(probabilities[predictions]*100,2)}%')
check_text(model, text, list_label)
BibTeX引用信息
@inproceedings{acl23/SSTuning,
author = {刘超群 and 张文轩 and 陈桂珍 and 吴晓宝 and Luu Anh Tuan and Chang Chip Hong and Bing Lidong},
title = {通过自监督调优实现零样本文本分类},
booktitle = {ACL 2023会议论文集},
year = {2023},
url = {https://arxiv.org/abs/2305.11442},
}