🚀 胸部X光肺炎分类视觉变换器模型
本模型用于胸部X光图像的肺炎分类,基于预训练模型微调而来,在评估集上表现出色,能高效准确地识别肺炎与正常的胸部X光影像,为医疗诊断提供有力支持。
🚀 快速开始
本模型是 google/vit-base-patch16-224-in21k 在胸部X光分类数据集上的微调版本。
它在评估集上取得了以下结果:
💻 使用示例
基础用法
from transformers import pipeline
classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification")
classifier("https://d2jx2rerrg6sh3.cloudfront.net/image-handler/ts/20200618040600/ri/650/picture/2020/6/shutterstock_786937069.jpg")
>>>
[{'score': 0.990334689617157, 'label': 'PNEUMONIA'},
{'score': 0.009665317833423615, 'label': 'NORMAL'}]
📚 详细文档
训练过程
笔记本链接:点击查看
训练超参数
训练过程中使用了以下超参数:
- 学习率:5e-05
- 训练批次大小:16
- 评估批次大小:16
- 随机种子:42
- 梯度累积步数:4
- 总训练批次大小:64
- 优化器:Adam(β1 = 0.9,β2 = 0.999,ε = 1e-08)
- 学习率调度器类型:线性
- 学习率调度器预热比例:0.1
- 训练轮数:15
from transformers import EarlyStoppingCallback
training_args = TrainingArguments(
output_dir="vit-xray-pneumonia-classification",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
per_device_eval_batch_size=16,
num_train_epochs=15,
save_total_limit=2,
warmup_ratio=0.1,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=True,
push_to_hub=True,
report_to="tensorboard"
)
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=processor,
compute_metrics=compute_metrics,
callbacks=[early_stopping],
)
训练结果
训练损失 |
轮数 |
步数 |
验证损失 |
准确率 |
0.5152 |
0.99 |
63 |
0.2507 |
0.9245 |
0.2334 |
1.99 |
127 |
0.1766 |
0.9382 |
0.1647 |
3.0 |
191 |
0.1218 |
0.9588 |
0.144 |
4.0 |
255 |
0.1222 |
0.9502 |
0.1348 |
4.99 |
318 |
0.1293 |
0.9571 |
0.1276 |
5.99 |
382 |
0.1000 |
0.9665 |
0.1175 |
7.0 |
446 |
0.1177 |
0.9502 |
0.109 |
8.0 |
510 |
0.1079 |
0.9665 |
0.0914 |
8.99 |
573 |
0.0804 |
0.9717 |
0.0872 |
9.99 |
637 |
0.0800 |
0.9717 |
0.0804 |
11.0 |
701 |
0.0862 |
0.9682 |
0.0935 |
12.0 |
765 |
0.0883 |
0.9657 |
0.0686 |
12.99 |
828 |
0.0868 |
0.9742 |
框架版本
- Transformers 4.30.2
- Pytorch 1.9.0+cu102
- Datasets 2.12.0
- Tokenizers 0.13.3
📄 许可证
本模型采用 Apache-2.0 许可证。
属性 |
详情 |
模型类型 |
图像分类 |
训练数据 |
chest-xray-classification、keremberke/chest-xray-classification |
评估指标 |
准确率 |
基础模型 |
google/vit-base-patch16-224-in21k |