许可协议: apache-2.0
标签:
- 训练生成
数据集:
- 图像文件夹
评估指标:
- 准确率
基础模型: google/vit-base-patch16-224-in21k
模型索引:
- 名称: vit-artworkclassifier
结果:
- 任务:
类型: 图像分类
名称: 图像分类
数据集:
名称: 图像文件夹
类型: 图像文件夹
配置: artbench10-vit
拆分: 测试集
参数: artbench10-vit
评估指标:
- 类型: 准确率
值: 0.5947786606129398
名称: 准确率
vit-artworkclassifier
该模型可返回任意输入图像的艺术风格类别。
本模型是基于google/vit-base-patch16-224-in21k在图像文件夹数据集上微调的版本。该数据集是artbench-10数据集(https://www.kaggle.com/datasets/alexanderliao/artbench10)的子集,包含每类别1000张训练图像和100张验证图像。
在评估集上取得如下结果:
模型描述
关于本模型训练项目的详细说明可参阅: https://medium.com/@oliverpj.schamp/training-and-evaluating-stable-diffusion-for-artwork-generation-b099d1f5b7a6
使用范围与限制
本模型仅包含artbench-10中的9个艺术类别——未包含浮世绘风格(ukiyo_e),这是由于数据可用性和格式问题所致。
训练与评估数据
训练集: 每类别从artbench-10随机选取1000张图像
验证集: 每类别从artbench-10随机选取100张图像
训练流程
训练超参数
训练过程中使用以下超参数:
- 学习率: 0.0001
- 训练批大小: 32
- 评估批大小: 8
- 随机种子: 42
- 优化器: Adam (β1=0.9, β2=0.999, ε=1e-08)
- 学习率调度器类型: 线性
- 训练轮次: 4
- 混合精度训练: 原生AMP
训练结果
训练损失 |
训练轮次 |
步数 |
验证损失 |
准确率 |
1.5906 |
0.36 |
100 |
1.4709 |
0.4847 |
1.3395 |
0.72 |
200 |
1.3208 |
0.5074 |
1.1461 |
1.08 |
300 |
1.3363 |
0.5165 |
0.9593 |
1.44 |
400 |
1.1790 |
0.5846 |
0.8761 |
1.8 |
500 |
1.1252 |
0.5902 |
0.5922 |
2.16 |
600 |
1.1392 |
0.5948 |
0.4803 |
2.52 |
700 |
1.1560 |
0.5936 |
0.4454 |
2.88 |
800 |
1.1545 |
0.6118 |
0.2271 |
3.24 |
900 |
1.2284 |
0.6039 |
0.207 |
3.6 |
1000 |
1.2625 |
0.5959 |
0.1958 |
3.96 |
1100 |
1.2621 |
0.6005 |
框架版本
- Transformers 4.26.1
- Pytorch 1.13.1+cu117
- Datasets 2.9.0
- Tokenizers 0.13.2
运行代码
def vit_classify(image):
vit = ViTForImageClassification.from_pretrained("oschamp/vit-artworkclassifier")
vit.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vit.to(device)
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
encoding = feature_extractor(images=image, return_tensors="pt")
encoding.keys()
pixel_values = encoding['pixel_values'].to(device)
outputs = vit(pixel_values)
logits = outputs.logits
prediction = logits.argmax(-1)
return prediction.item()