模型简介
模型特点
模型能力
使用案例
许可证:apache-2.0
基础模型:distilbert-base-cased
标签:
- generated_from_trainer
- news_classification
- multi_label
数据集: - reuters21578
评估指标: - f1
- accuracy
模型索引:
- 名称:distilbert-finetuned-reuters21578-multilabel
结果:- 任务:
名称:文本分类
类型:text-classification
数据集:
名称:reuters21578
类型:reuters21578
配置:ModApte
分割:test
参数:ModApte
评估指标:- 名称:F1
类型:f1
值:0.8628858578607322 - 名称:准确率
类型:accuracy
值:0.8195625759416768
- 名称:F1
- 任务:
语言:
- en
流水线标签:text-classification
小部件示例:
- 文本:"JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWARDS..."
示例标题:"示例-1"
模型来源
此模型派生自 https://huggingface.co/lxyuan/distilbert-finetuned-reuters21578-multilabel —— 我只是在 /onnx 中生成了 ONNX 版本。
动机
在 Reuters-21578 多标签数据集上进行微调是一项有价值的练习,尤其是在面试的家庭测试中经常使用。该数据集的复杂性非常适合在有限时间内测试多标签分类技能,而其现实相关性有助于模拟实际挑战。通过此数据集进行实验不仅帮助候选人准备面试,还磨练了包括预处理、特征提取和模型评估在内的各种技能。
此模型是在 reuters21578 数据集上对 distilbert-base-cased 进行微调的版本。
推理示例
from transformers import pipeline
pipe = pipeline("text-classification", model="lxyuan/distilbert-finetuned-reuters21578-multilabel", return_all_scores=True)
# dataset["test"]["text"][2]
news_article = (
"JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWARDS The Ministry of International Trade and "
"Industry (MITI) will revise its long-term energy supply/demand "
"outlook by August to meet a forecast downtrend in Japanese "
"energy demand, ministry officials said. "
"MITI is expected to lower the projection for primary energy "
"supplies in the year 2000 to 550 mln kilolitres (kl) from 600 "
"mln, they said. "
"The decision follows the emergence of structural changes in "
"Japanese industry following the rise in the value of the yen "
"and a decline in domestic electric power demand. "
"MITI is planning to work out a revised energy supply/demand "
"outlook through deliberations of committee meetings of the "
"Agency of Natural Resources and Energy, the officials said. "
"They said MITI will also review the breakdown of energy "
"supply sources, including oil, nuclear, coal and natural gas. "
"Nuclear energy provided the bulk of Japan's electric power "
"in the fiscal year ended March 31, supplying an estimated 27 "
"pct on a kilowatt/hour basis, followed by oil (23 pct) and "
"liquefied natural gas (21 pct), they noted. "
"REUTER"
)
# dataset["test"]["topics"][2]
target_topics = ['crude', 'nat-gas']
fn_kwargs={"padding": "max_length", "truncation": True, "max_length": 512}
output = pipe(example, function_to_apply="sigmoid", **fn_kwargs)
for item in output[0]:
if item["score"]>=0.5:
print(item["label"], item["score"])
>>> crude 0.7355073690414429
nat-gas 0.8600426316261292
总体总结与对比表
指标 | 基线(Scikit-learn) | Transformer 模型 |
---|---|---|
微平均 F1 | 0.77 | 0.86 |
宏平均 F1 | 0.29 | 0.33 |
加权平均 F1 | 0.70 | 0.84 |
样本平均 F1 | 0.75 | 0.80 |
精确率 vs 召回率:两个模型都优先考虑高精确率而非召回率。在我们的面向客户的新闻分类模型中,精确率优先于召回率。这是因为假阳性的后果比假阴性更严重,且更难向客户解释。当模型错误地为新闻项标记主题时,解释这一错误具有挑战性。另一方面,如果模型遗漏了主题,可以通过说明新闻文章中未充分强调该主题来更容易地辩护。
类别不平衡处理:两个模型都存在对少数类别表现不佳的普遍问题,这在低宏平均 F1 分数中有所反映。然而,Transformer 模型在宏平均 F1 分数上略有改善(0.33 vs 0.29)。
零支持标签问题:两个模型都存在多个标签的零支持问题,意味着这些标签未出现在测试集中。这种“支持”的缺乏可能会显著扭曲性能指标,并可能表明模型未针对这些少数类别进行良好调整,或者数据集本身缺乏这些类别的足够示例。鉴于两个模型在低宏平均 F1 分数上都表现不佳,这一问题进一步强调了改进模型对少数类别处理的必要性。
总体性能:Transformer 模型在加权和样本平均 F1 分数上超过了 scikit-learn 基线,表明整体性能更好,对标签不平衡的处理也更优。
结论:虽然两个模型都表现出高精确率(这是业务需求),但 Transformer 模型在所有考虑的指标上略优于 scikit-learn 基线模型。它在精确率和召回率之间提供了更好的权衡,并且在处理少数类别方面也有一些小的改进。因此,尽管与基线模型存在相似的弱点,Transformer 模型展示了在生产环境中可能具有重要意义的增量改进。
训练与评估数据
我们使用以下代码从训练集和测试集中移除仅出现一次的标签:
# 查找仅出现一次的标签
def find_single_appearance_labels(y):
"""查找数据集中仅出现一次的标签。"""
all_labels = list(chain.from_iterable(y))
label_count = Counter(all_labels)
single_appearance_labels = [label for label, count in label_count.items() if count == 1]
return single_appearance_labels
# 从数据集中移除仅出现一次的标签
def remove_single_appearance_labels(dataset, single_appearance_labels):
"""从训练集和测试集中移除带有仅出现一次标签的样本。"""
for split in ['train', 'test']:
dataset[split] = dataset[split].filter(lambda x: all(label not in single_appearance_labels for label in x['topics']))
return dataset
dataset = load_dataset("reuters21578", "ModApte")
# 查找并移除仅出现一次的标签
y_train = [item['topics'] for item in dataset['train']]
single_appearance_labels = find_single_appearance_labels(y_train)
print(f"仅出现一次的标签:{single_appearance_labels}")
>>> 仅出现一次的标签:['lin-oil', 'rye', 'red-bean', 'groundnut-oil', 'citruspulp', 'rape-meal', 'corn-oil', 'peseta', 'cotton-oil', 'ringgit', 'castorseed', 'castor-oil', 'lit', 'rupiah', 'skr', 'nkr', 'dkr', 'sun-meal', 'lin-meal', 'cruzado']
print("正在移除带有仅出现一次标签的样本...")
dataset = remove_single_appearance_labels(dataset, single_appearance_labels)
unique_labels = set(chain.from_iterable(dataset['train']["topics"]))
print(f"我们有 {len(unique_labels)} 个唯一标签:\n{unique_labels}")
>>> 我们有 95 个唯一标签:
{'veg-oil', 'gold', 'platinum', 'ipi', 'acq', 'carcass', 'wool', 'coconut-oil', 'linseed', 'copper', 'soy-meal', 'jet', 'dlr', 'copra-cake', 'hog', 'rand', 'strategic-metal', 'can', 'tea', 'sorghum', 'livestock', 'barley', 'lumber', 'earn', 'wheat', 'trade', 'soy-oil', 'cocoa', 'inventories', 'income', 'rubber', 'tin', 'iron-steel', 'ship', 'rapeseed', 'wpi', 'sun-oil', 'pet-chem', 'palmkernel', 'nat-gas', 'gnp', 'l-cattle', 'propane', 'rice', 'lead', 'alum', 'instal-debt', 'saudriyal', 'cpu', 'jobs', 'meal-feed', 'oilseed', 'dmk', 'plywood', 'zinc', 'retail', 'dfl', 'cpi', 'crude', 'pork-belly', 'gas', 'money-fx', 'corn', 'tapioca', 'palladium', 'lei', 'cornglutenfeed', 'sunseed', 'potato', 'silver', 'sugar', 'grain', 'groundnut', 'naphtha', 'orange', 'soybean', 'coconut', 'stg', 'cotton', 'yen', 'rape-oil', 'palm-oil', 'oat', 'reserves', 'housing', 'interest', 'coffee', 'fuel', 'austdlr', 'money-supply', 'heat', 'fishmeal', 'bop', 'nickel', 'nzdlr'}
训练过程
Reuters-21578 数据集的探索性数据分析(EDA):
此笔记本提供了 Reuters-21578 数据集的探索性数据分析(EDA)。它包括可视化和统计摘要,提供了对数据集结构、标签分布和文本特征的深入理解。
Reuters 基线 Scikit-Learn 模型:
此笔记本使用 scikit-learn 为 Reuters-21578 数据集建立了文本分类的基线模型。它指导您完成数据预处理、特征提取、模型训练和评估。
Reuters Transformer 模型:
此笔记本深入探讨了使用 Transformer 模型对 Reuters-21578 数据集进行高级文本分类。它涵盖了针对此特定任务使用基于 Transformer 的模型的实现细节、训练过程和性能指标。
Reuters 数据集上的多标签分层抽样和超参数搜索:
在此笔记本中,我们通过 Hugging Face Trainer API 的视角探索了高级机器学习技术,特别是针对多标签迭代分层分割和超参数搜索。前者旨在在 k 折交叉验证中公平分配不平衡数据集中的多个标签,保持与完整数据集相似的分布。后者引导用户通过结构化的超参数搜索来微调模型性能以获得最佳结果。
评估结果
Transformer 模型评估结果
分类报告:
precision recall f1-score support
acq 0.97 0.93 0.95 719
alum 1.00 0.70 0.82 23
austdlr 0.00 0.00 0.00 0
barley 1.00 0.50 0.67 12
bop 0.79 0.50 0.61 30
can 0.00 0.00 0.00 0
carcass 0.67 0.67 0.67 18
cocoa 1.00 1.00 1.00 18
coconut 0.00 0.00 0.00 2
coconut-oil 0.00 0.00 0.00 2
coffee 0.86 0.89 0.87 27
copper 1.00 0.78 0.88 18
copra-cake 0.00 0.00 0.00 1
corn 0.84 0.87 0.86 55
cornglutenfeed 0.00 0.00 0.00 0
cotton 0.92 0.67 0.77 18
cpi 0.86 0.43 0.57 28
cpu 0.00 0.00 0.00 1
crude 0.87 0.93 0.90 189
dfl 0.00 0.00 0.00 1
dlr 0.72 0.64 0.67 44
dmk 0.00 0.00 0.00 4
earn 0.98 0.99 0.98 1087
fishmeal 0.00 0.00 0.00 0
fuel 0.00 0.00 0.00 10
gas 0.80 0.71 0.75 17
gnp 0.79 0.66 0.72 35
gold 0.95 0.67 0.78 30
grain 0.94 0.92 0.93 146
groundnut 0.00 0.00 0.00 4
heat 0.00 0.00 0.00 5
hog 1.00 0.33 0.50 6
housing 0.00 0.00 0.00 4
income 0.00 0.00 0.00 7
instal-debt 0.00 0.00 0.00 1
interest 0.89 0.67 0.77 131
inventories 0.00 0.00 0.00 0
ipi 1.00 0.58 0.74 12
iron-steel 0.90 0.64 0.75 14
jet 0.00 0.00 0.00 1
jobs 0.92 0.57 0.71 21
l-cattle 0.00 0.00 0.00 2
lead 0.00 0.00 0.00 14
lei 0.00 0.00 0.00 3
linseed 0.00 0.00 0.00 0
livestock 0.63 0.79 0.70 24
lumber 0.00 0.00 0.00 6
meal-feed 0.00 0.00 0.00 17
money-fx 0.78 0.81 0.80 177
money-supply 0.80 0.71 0.75 34
naphtha 0.00 0.00 0.00 4
nat-gas 0.82 0.60 0.69 30
nickel 0.00 0.00 0.00 1
nzdlr 0.00 0.00 0.00 2
oat 0.00 0.00 0.00 4
oilseed 0.64 0.61 0.63 44
orange 1.00 0.36 0.53 11
palladium 0.00 0.00 0.00 1
palm-oil 1.00 0.56 0.71 9
palmkernel 0.00 0.00 0.00 1
pet-chem 0.00 0.00 0.00 12
platinum 0.00 0.00 0.00 7
plywood 0.00 0.00 0.00 0
pork-belly 0.00 0.00 0.00 0
potato 0.00 0.00 0.00 3
propane 0.00 0.00 0.00 3
rand 0.00 0.00 0.00 1
rape-oil 0.00 0.00 0.00 1
rapeseed 0.00 0.00 0.00 8
reserves 0.83 0.56 0.67 18
retail 0.00 0.00 0.00 2
rice 1.00 0.57 0.72 23
rubber 0.82 0.75 0.78 12
saudriyal 0.00 0.00 0.00 0
ship 0.95 0.81 0.87 89
silver 1.00 0.12 0.22 8
sorghum 1.00 0.12 0.22 8
soy-meal 0.00 0.00 0.00 12
soy-oil 0.00 0.00 0.00 8
soybean 0.72 0.56 0.63 32
stg 0.00 0.00 0.00 0
strategic-metal 0.00 0.00 0.00 11
sugar 1.00 0.80 0.89 35
sun-oil 0.00 0.00 0.00 0
sunseed 0.00 0.00 0.00 5
tapioca 0.00 0.00 0.00 0
tea 0.00 0.00 0.00 3
tin 1.00 0.42 0.59 12
trade 0.78 0.79 0.79 116
veg-oil 0.91 0.59 0.71 34
wheat 0.83 0.83 0.83 69
wool 0.00 0.00 0.00 0
wpi 0.00 0.00 0.00 10
yen 0.57 0.29 0.38 14
zinc 1.00 0.69 0.82 13
微平均 0.92 0.81 0.86 3694
宏平均 0.41 0.30 0.33 3694
加权平均 0.87 0.81 0.84 3694
样本平均 0.81 0.80 0.80 3694
Scikit-learn 基线模型评估结果
分类报告: precision recall f1-score support acq 0.98 0.87 0.92 719
alum 1.00 0.00 0.00 23
austdlr 1.00 1.00 1.00 0
barley 1.00 0.00 0.00 12
bop 1.00 0.30 0.46 30
can 1.00 1.00 1.00 0
carcass 1.00 0.06 0.11 18
cocoa 1.00 0.61 0.76 18
coconut 1.00 0.00 0.00 2
coconut-oil 1.00 0.00 0.00 2
coffee 0.94 0.59 0.73 27
copper 1.00 0.22 0.36 18
copra-cake 1.00 0.00 0.00 1
corn 0.97 0.51 0.67 55
cornglutenfeed 1.00 1.00 1.00 0
cotton 1.00 0.06 0.11 18
cpi 1.00 0.14 0.25 28
cpu 1.00 0.00 0.00 1
crude 0.94 0.69 0.80 189
dfl 1.00 0.00 0.00 1
dlr 0.86 0.43 0.58 44
dmk 1.00 0.00 0.00 4
earn 0.99 0.97 0.98 1087
fishmeal 1.00 1.00 1.00 0
fuel 1.00 0.00 0.00 10
gas 1.00 0.00 0.00 17
gnp 1.00 0.31 0.48 35
gold 0.83 0.17 0.28 30
grain 1.00 0.65 0.79 146
groundnut 1.00 0.00 0.00 4
heat 1.00 0.00 0.00 5
hog 1.00 0.00 0.00 6
housing 1.00 0.00 0.00 4
income 1.00 0.00 0.00 7
instal-debt 1.00 0.00 0.00 1
interest 0.88 0.40 0.55 131
inventories 1.00 1.00 1.00 0
ipi 1.00 0.00 0.00 12
iron-steel 1.00 0.00 0.00 14
jet 1.00 0.00 0.00 1
jobs 1.00 0.14 0.25 21
l-cattle 1.00 0.00 0.00 2
lead 1.00 0.00 0.00 14
lei 1.00 0.00 0.00 3
linseed 1.00 1.00 1.00 0
livestock 0.67 0.08 0.15 24
lumber 1.00 0.00 0.00 6
meal-feed 1.00 0.00 0.00 17
money-fx 0.80 0.50 0.62 177
money-supply 0.88 0.41 0.56 34
naphtha 1.00 0.00 0.00 4
nat-gas 1.00 0.27 0.42 30
nickel 1.00 0.00 0.00 1
nzdlr 1.00 0.00 0.00 2
oat 1.00 0.00 0.00 4
oilseed 0.62 0.11 0.19 44
orange 1.00 0.00 0.00 11
palladium 1.00 0.00 0.00 1
palm-oil 1.00 0.22 0.36 9
palmkernel 1.00 0.00 0.00 1
pet-chem 1.00 0.00 0.00 12
platinum 1.00 0.00 0.00 7
plywood 1.00 1.00 1.00 0
pork-belly 1.00 1.00 1.00 0
potato 1.00 0.00 0.00 3
propane 1.00 0.00 0.00 3
rand 1.00 0.00 0.00 1
rape-oil 1.00 0.00 0.00 1
rapeseed 1.00 0.00 0.00 8
reserves 1.00 0.00 0.00 18
retail 1.00 0.00 0.00 2
rice 1.00 0.00 0.00 23
rubber 1.00 0.17 0.29 12
saudriyal 1.00 1.00 1.00 0
ship 0.92 0.26 0.40 89
silver 1.00 0.00 0.00 8
sorghum 1.00 0.00 0.00 8
soy-meal 1.00 0.00 0.00 12
soy-oil 1.00 0.00 0.00 8
soybean 1.00 0.16 0.27 32
stg 1.00 1.00 1.00 0
strategic-metal 1.00 0.00 0.00 11
sugar 1.00 0.60 0.75 35
sun-oil 1.00 1.00 1.00 0
sunseed 1.00 0.00 0.00 5
tapioca 1.00 1.00 1.00 0
tea 1.00 0.00 0.00 3
tin 1.00 0.00 0.00 12
trade 0.92 0.61 0.74 116
veg-oil 1.00 0.12 0.21 34
wheat 0.97 0.55 0.70 69
wool 1.00 1.00 1.00 0
wpi 1.00 0.00 0.00 10
yen 1.00 0.00 0.00 14
zinc 1.00 0.00 0.00 13
微平均 0.97 0.64 0.77 3694
宏平均 0.98 0.25 0.29 3694
加权平均 0.96 0.64 0.70 3694
样本平均 0.98 0.74 0.75 3694
训练超参数
训练期间使用了以下超参数:
- 学习率:2e-05
- 训练批次大小:32
- 评估批次大小:32
- 随机种子:42
- 优化器:Adam,betas=(0.9,0.999),epsilon=1e-08
- 学习率调度器类型:linear
- 训练轮数:20
训练结果
训练损失 | 轮数 | 步数 | 验证损失 | F1 | Roc Auc | 准确率 |
---|---|---|---|---|---|---|
0.1801 | 1.0 | 300 | 0.0439 | 0.3896 | 0.6210 | 0.3566 |
0.0345 | 2.0 | 600 | 0.0287 | 0.6289 | 0.7318 | 0.5954 |
0.0243 | 3.0 | 900 | 0.0219 | 0.6721 | 0.7579 | 0.6084 |
0.0178 | 4.0 | 1200 | 0.0177 | 0.7505 | 0.8128 | 0.6908 |
0.014 | 5.0 | 1500 | 0.0151 | 0.7905 | 0.8376 | 0.7278 |
0.0115 | 6.0 | 1800 | 0.0135 | 0.8132 | 0.8589 | 0.7555 |
0.0096 | 7.0 | 2100 | 0.0124 | 0.8291 | 0.8727 | 0.7725 |
0.0082 | 8.0 | 2400 | 0.0124 | 0.8335 | 0.8757 | 0.7822 |
0.0071 | 9.0 | 2700 | 0.0119 | 0.8392 | 0.8847 | 0.7883 |
0.0064 | 10.0 | 3000 | 0.0123 | 0.8339 | 0.8810 | 0.7828 |
0.0058 | 11.0 | 3300 | 0.0114 | 0.8538 | 0.8999 | 0.8047 |
0.0053 | 12.0 | 3600 | 0.0113 | 0.8525 | 0.8967 | 0.8044 |
0.0048 | 13.0 | 3900 | 0.0115 | 0.8520 | 0.8982 | 0.8029 |
0.0045 | 14.0 | 4200 | 0.0111 | 0.8566 | 0.8962 | 0.8104 |
0.0042 | 15.0 | 4500 | 0.0110 | 0.8610 | 0.9060 | 0.8165 |
0.0039 | 16.0 | 4800 | 0.0112 | 0.8583 | 0.9021 | 0.8138 |
0.0037 | 17.0 | 5100 | 0.0110 | 0.8620 | 0.9055 | 0.8196 |
0.0035 | 18.0 | 5400 | 0.0110 | 0.8629 | 0.9063 | 0.8196 |
0.0035 | 19.0 | 5700 | 0.0111 | 0.8624 | 0.9062 | 0.8180 |
0.0034 | 20.0 | 6000 | 0.0111 | 0.8626 | 0.9055 | 0.8177 |
框架版本
- Transformers 4.33.0.dev0
- Pytorch 2.0.1+cu117
- Datasets 2.14.3
- Tokenizers 0.13.3








