语言: 英语
许可证: MIT
任务标签: 文本分类
标签:
- 文本分类
- 变形金刚模型
- PyTorch
- 多标签分类
- 多类别分类
- 情感
- BERT
- go_emotions
- 情感分类
数据集:
- google-research-datasets/go_emotions
评估指标:
- F1分数
- 精确率
- 召回率
示例:
- 文本: 我今天就是放松一下。
示例标题: 中性示例
- 文本: 谢谢你救了我的命!
示例标题: 感激示例
- 文本: 我对明天的考试感到紧张。
示例标题: 紧张示例
基础模型:
- google-bert/bert-base-uncased
GoEmotions BERT分类器
基于bert-base-uncased微调,用于go_emotions的多标签分类(28种情感)。
模型详情
- 架构: BERT-base-uncased(1.1亿参数)
- 训练数据: GoEmotions(5.8万条Reddit评论,28种情感)
- 损失函数: 焦点损失(gamma=2)
- 优化器: AdamW(学习率=2e-5,权重衰减=0.01)
- 训练轮数: 5
- 硬件: Kaggle T4 x2 GPU
试用
为了获得使用优化阈值的准确预测,请访问Gradio演示。
性能
- 微平均F1: 0.6025(优化阈值)
- 宏平均F1: 0.5266
- 精确率: 0.5425
- 召回率: 0.6775
- 汉明损失: 0.0372
- 平均正预测数: 1.4564
分类性能
下表展示了使用优化阈值(见thresholds.json
)在测试集上的各类别指标:
情感 |
F1分数 |
精确率 |
召回率 |
支持数 |
钦佩 |
0.7022 |
0.6980 |
0.7063 |
504 |
娱乐 |
0.8171 |
0.7692 |
0.8712 |
264 |
愤怒 |
0.5123 |
0.5000 |
0.5253 |
198 |
恼怒 |
0.3820 |
0.2908 |
0.5563 |
320 |
赞同 |
0.4112 |
0.3485 |
0.5014 |
351 |
关心 |
0.4601 |
0.4045 |
0.5333 |
135 |
困惑 |
0.4488 |
0.4533 |
0.4444 |
153 |
好奇 |
0.5721 |
0.4402 |
0.8169 |
284 |
渴望 |
0.4068 |
0.6857 |
0.2892 |
83 |
失望 |
0.3476 |
0.3220 |
0.3775 |
151 |
不赞同 |
0.4126 |
0.3433 |
0.5169 |
267 |
厌恶 |
0.4950 |
0.6329 |
0.4065 |
123 |
尴尬 |
0.5000 |
0.7368 |
0.3784 |
37 |
兴奋 |
0.4084 |
0.4432 |
0.3786 |
103 |
恐惧 |
0.6311 |
0.5078 |
0.8333 |
78 |
感激 |
0.9173 |
0.9744 |
0.8665 |
352 |
悲伤 |
0.2500 |
0.5000 |
0.1667 |
6 |
快乐 |
0.6246 |
0.5798 |
0.6770 |
161 |
爱 |
0.8110 |
0.7630 |
0.8655 |
238 |
紧张 |
0.3830 |
0.3750 |
0.3913 |
23 |
乐观 |
0.5777 |
0.5856 |
0.5699 |
186 |
自豪 |
0.4138 |
0.4615 |
0.3750 |
16 |
领悟 |
0.2421 |
0.5111 |
0.1586 |
145 |
解脱 |
0.5385 |
0.4667 |
0.6364 |
11 |
悔恨 |
0.6797 |
0.5361 |
0.9286 |
56 |
悲伤 |
0.5391 |
0.6900 |
0.4423 |
156 |
惊讶 |
0.5724 |
0.5570 |
0.5887 |
141 |
中性 |
0.6895 |
0.5826 |
0.8444 |
1787 |
使用方法
该模型使用存储在thresholds.json
中的优化阈值进行预测。Python示例:
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import json
import requests
repo_id = "logasanjeev/goemotions-bert"
model = BertForSequenceClassification.from_pretrained(repo_id)
tokenizer = BertTokenizer.from_pretrained(repo_id)
thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
thresholds_data = json.loads(requests.get(thresholds_url).text)
emotion_labels = thresholds_data["emotion_labels"]
thresholds = thresholds_data["thresholds"]
text = "我今天就是放松一下。"
encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
print(sorted(predictions, key=lambda x: x[1], reverse=True))