🚀 Distilbert-base-uncased-emotion
本模型基于Distilbert,在情感数据集上进行微调,可用于文本情感分类任务。它在保留较高语言理解能力的同时,模型体积更小、运行速度更快。
🚀 快速开始
若你想使用该模型进行文本情感分类,可参考以下代码示例:
from transformers import pipeline
classifier = pipeline("text-classification",model='bhadresh-savani/distilbert-base-uncased-emotion', return_all_scores=True)
prediction = classifier("I love using transformers. The best part is wide range of support and its easy to use", )
print(prediction)
"""
Output:
[[
{'label': 'sadness', 'score': 0.0006792712374590337},
{'label': 'joy', 'score': 0.9959300756454468},
{'label': 'love', 'score': 0.0009452480007894337},
{'label': 'anger', 'score': 0.0018055217806249857},
{'label': 'fear', 'score': 0.00041110432357527316},
{'label': 'surprise', 'score': 0.0002288572577526793}
]]
"""
✨ 主要特性
- 轻量级:Distilbert 在预训练阶段采用知识蒸馏技术,将BERT模型的大小缩小了40%,同时保留了97%的语言理解能力。
- 高性能:在情感数据集上进行微调,在准确率和F1分数等指标上表现出色。
- 速度快:相较于其他基于BERT的模型,Distilbert-base-uncased-emotion运行速度更快。
📦 安装指南
使用该模型前,你需要安装transformers
库,可使用以下命令进行安装:
pip install transformers
💻 使用示例
基础用法
from transformers import pipeline
classifier = pipeline("text-classification",model='bhadresh-savani/distilbert-base-uncased-emotion', return_all_scores=True)
prediction = classifier("I love using transformers. The best part is wide range of support and its easy to use", )
print(prediction)
"""
Output:
[[
{'label': 'sadness', 'score': 0.0006792712374590337},
{'label': 'joy', 'score': 0.9959300756454468},
{'label': 'love', 'score': 0.0009452480007894337},
{'label': 'anger', 'score': 0.0018055217806249857},
{'label': 'fear', 'score': 0.00041110432357527316},
{'label': 'surprise', 'score': 0.0002288572577526793}
]]
"""
📚 详细文档
模型描述
Distilbert 在预训练阶段使用知识蒸馏技术,将BERT模型的大小缩小了40%,同时保留了97%的语言理解能力。它比BERT和其他基于BERT的模型更小、更快。
Distilbert-base-uncased 在情感数据集上进行了微调,使用HuggingFace Trainer和以下超参数:
learning rate 2e-5,
batch size 64,
num_train_epochs=8,
模型性能比较
以下是在Twitter情感数据集上的模型性能比较:
数据集
本模型使用的数据集为 Twitter-Sentiment-Analysis。
训练过程
训练过程可参考 Colab Notebook。
评估结果
{
"test_accuracy": 0.938,
"test_f1": 0.937932884041714,
"test_loss": 0.1472451239824295,
"test_mem_cpu_alloc_delta": 0,
"test_mem_cpu_peaked_delta": 0,
"test_mem_gpu_alloc_delta": 0,
"test_mem_gpu_peaked_delta": 163454464,
"test_runtime": 5.0164,
"test_samples_per_second": 398.69
}
参考资料
📄 许可证
本模型采用Apache-2.0许可证。