库名称: transformers
许可证: mit
数据集:
- google-research-datasets/go_emotions
基础模型: answerdotai/ModernBERT-large
语言:
- en
标签:
- 文本分类
- 情感分类
- 情感识别
- 情感检测
- 情感
- 多标签
评估指标:
- f1
- 精确率
- 召回率
概述
这是基于go_emotions数据集微调的ModernBERT-large模型,用于多标签分类任务。该模型可用于从英文文本中提取所有情感或检测特定情感。阈值是通过在验证集上最大化所有标签的宏观f1分数来选择的。
您可以使用Flash Attention 2来加速推理。
模型在不同类别上的表现差异很大(参见下方的评估指标表)。有些类别如钦佩、娱乐、乐观、恐惧、悔恨等表现出较高的识别质量,而有些类别如失望、领悟等由于训练数据中样本较少,模型识别较为困难。
使用方法
使用Huggingface Transformers可以轻松调用该模型。
ModernBERT架构在transformers 4.48.0及更高版本中支持,因此需要安装:
pip install "transformers>=4.48.0"
以下是如何提取文本中包含的情感:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained('fyaronskiy/ModernBERT-large-go-emotions')
model = AutoModelForSequenceClassification.from_pretrained('fyaronskiy/ModernBERT-large-go-emotions')
best_thresholds = [0.5510204081632653, 0.26530612244897955, 0.14285714285714285, 0.12244897959183673, 0.44897959183673464, 0.22448979591836732, 0.2040816326530612, 0.4081632653061224, 0.5306122448979591, 0.22448979591836732, 0.2857142857142857, 0.3061224489795918, 0.2040816326530612, 0.14285714285714285, 0.1020408163265306, 0.4693877551020408, 0.24489795918367346, 0.3061224489795918, 0.2040816326530612, 0.36734693877551017, 0.2857142857142857, 0.04081632653061224, 0.3061224489795918, 0.16326530612244897, 0.26530612244897955, 0.32653061224489793, 0.12244897959183673, 0.2040816326530612]
LABELS = ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
ID2LABEL = dict(enumerate(LABELS))
def detect_emotions(text):
inputs = tokenizer(text, truncation=True, add_special_tokens=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = model(**inputs).logits
probas = torch.sigmoid(logits).squeeze(dim=0)
class_binary_labels = (probas > torch.tensor(best_thresholds)).int()
return [ID2LABEL[label_id] for label_id, value in enumerate(class_binary_labels) if value == 1]
print(detect_emotions('You have excellent service and the best coffee in the city, I love your coffee shop!'))
以下方法可以获取所有情感及其得分:
def predict(text):
inputs = tokenizer(text, truncation=True, add_special_tokens=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = model(**inputs).logits
probas = torch.sigmoid(logits).squeeze(dim=0).tolist()
probas = [round(proba, 3) for proba in probas]
labels2probas = dict(zip(LABELS, probas))
probas_dict_sorted = dict(sorted(labels2probas.items(), key=lambda x: x[1], reverse=True))
return probas_dict_sorted
print(predict('You have excellent service and the best coffee in the city, I love your coffee shop!'))
在go-emotions测试集上的评估结果
|
精确率 |
召回率 |
f1分数 |
支持数 |
阈值 |
admiration |
0.68 |
0.72 |
0.7 |
504 |
0.55 |
amusement |
0.76 |
0.91 |
0.83 |
264 |
0.27 |
anger |
0.44 |
0.53 |
0.48 |
198 |
0.14 |
annoyance |
0.27 |
0.46 |
0.34 |
320 |
0.12 |
approval |
0.41 |
0.38 |
0.4 |
351 |
0.45 |
caring |
0.37 |
0.46 |
0.41 |
135 |
0.22 |
confusion |
0.36 |
0.51 |
0.42 |
153 |
0.2 |
curiosity |
0.45 |
0.77 |
0.57 |
284 |
0.41 |
desire |
0.66 |
0.46 |
0.54 |
83 |
0.53 |
disappointment |
0.41 |
0.26 |
0.32 |
151 |
0.22 |
disapproval |
0.39 |
0.54 |
0.45 |
267 |
0.29 |
disgust |
0.52 |
0.41 |
0.46 |
123 |
0.31 |
embarrassment |
0.52 |
0.41 |
0.45 |
37 |
0.2 |
excitement |
0.29 |
0.59 |
0.39 |
103 |
0.14 |
fear |
0.55 |
0.78 |
0.65 |
78 |
0.1 |
gratitude |
0.96 |
0.88 |
0.92 |
352 |
0.47 |
grief |
0.29 |
0.67 |
0.4 |
6 |
0.24 |
joy |
0.57 |
0.66 |
0.61 |
161 |
0.31 |
love |
0.74 |
0.87 |
0.8 |
238 |
0.2 |
nervousness |
0.37 |
0.43 |
0.4 |
23 |
0.37 |
optimism |
0.6 |
0.58 |
0.59 |
186 |
0.29 |
pride |
0.28 |
0.44 |
0.34 |
16 |
0.04 |
realization |
0.36 |
0.19 |
0.24 |
145 |
0.31 |
relief |
0.62 |
0.45 |
0.53 |
11 |
0.16 |
remorse |
0.51 |
0.84 |
0.63 |
56 |
0.27 |
sadness |
0.54 |
0.56 |
0.55 |
156 |
0.33 |
surprise |
0.47 |
0.63 |
0.54 |
141 |
0.12 |
neutral |
0.58 |
0.82 |
0.68 |
1787 |
0.2 |
微观平均 |
0.54 |
0.67 |
0.6 |
6329 |
|
宏观平均 |
0.5 |
0.58 |
0.52 |
6329 |
|
加权平均 |
0.55 |
0.67 |
0.6 |
6329 |
|
样本平均 |
0.59 |
0.69 |
0.61 |
6329 |
|