许可证: cc-by-nc-4.0
数据集:
- issai/中亚食物数据集
语言:
- 英语
评估指标:
- 准确率
- F1值
基础模型:
- microsoft/resnet-50
任务类型: 图像分类
标签:
- 分类
- 图像
- pytorch
- 安全张量
- ResNet
库名称: transformers
中亚图像分类ResNet-50模型
模型描述
这是一个基于ResNet-50架构的预训练模型,经过中亚食物数据集的微调。该模型用于多类别图像分类任务。数据被划分为训练集、验证集和测试集。模型训练采用随机梯度下降优化器(SGD)和交叉熵损失函数。
训练参数
- 训练轮次: 25
- 批次大小: 32
- 学习率: 0.001
- 优化器: 带动量(0.9)的SGD
- 损失函数: 交叉熵损失
结果
训练与验证指标(F1值)
阶段 |
训练损失 |
训练准确率 |
验证损失 |
验证准确率 |
第1轮 |
2.1171 |
47.00% |
0.8727 |
75.00% |
第2轮 |
1.0462 |
69.00% |
0.6721 |
78.00% |
... |
... |
... |
... |
... |
第25轮 |
0.4286 |
86.00% |
0.4349 |
86.00% |
模型在Kaggle笔记本上使用两块T4 GPU训练,耗时36分7秒
最佳验证准确率: 86.54%
精确率 召回率 F1值 支持数
阿奇楚克 0.91 0.98 0.94 41
艾兰-卡提克 0.84 0.93 0.89 46
阿西普 0.78 0.57 0.66 37
包吾尔萨克 0.90 0.90 0.90 62
带马肉的别什巴尔马克 0.71 0.84 0.77 44
不带马肉的别什巴尔马克 0.86 0.69 0.76 61
恰克恰克 0.94 0.94 0.94 93
切布列克 0.92 0.88 0.90 94
多纳尔卷饼 0.77 1.00 0.87 20
馕包肉 0.86 0.82 0.84 22
霍罗斯特 0.98 0.86 0.91 141
伊林希克 0.96 0.94 0.95 175
卡塔玛馕 0.84 0.88 0.86 66
卡兹灌肠 0.72 0.78 0.75 46
库尔特 0.86 0.97 0.91 61
库尔达克 0.92 0.93 0.92 58
库梅兹发酵乳 0.93 0.82 0.87 49
炒拉条子 0.86 0.95 0.90 38
汤拉条子 0.90 0.80 0.85 75
干拌拉条子 0.58 0.86 0.69 22
馒头 0.91 0.95 0.93 63
纳仁 0.97 0.99 0.98 84
诺鲁孜粥 0.88 0.96 0.92 52
奥拉玛 0.68 0.84 0.75 38
抓饭 0.95 0.98 0.97 101
烤包子 0.91 0.93 0.92 106
鸡肉串 0.68 0.65 0.66 62
蔬菜鸡肉串 0.74 0.76 0.75 33
大块肉串 0.75 0.75 0.75 71
蔬菜大块肉串 0.53 0.79 0.64 29
碎肉串 0.74 0.69 0.72 42
羊头肉 0.75 0.94 0.83 16
谢丽派克 0.77 0.86 0.81 64
肖尔帕 0.95 0.88 0.91 80
清汤 0.96 0.94 0.95 71
苏什基 0.83 1.00 0.91 43
苏兹贝 0.89 0.82 0.86 62
塔巴馕 0.92 0.80 0.86 136
塔尔坎面粉 0.86 0.80 0.83 90
炒图什帕拉 0.79 0.74 0.76 46
汤图什帕拉 0.94 0.94 0.94 67
干图什帕拉 0.92 0.87 0.89 91
整体准确率 0.87 2698
宏平均指标 0.84 0.86 0.85 2698
加权平均指标 0.88 0.87 0.87 2698

测试结果
模型在测试集上的表现:
仓库结构
main.py
— 模型训练与测试代码
model/
— 以SafeTensors格式保存的模型
使用说明
from transformers import AutoModelForImageClassification
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
repo_id = "Eraly-ml/centraasia-ResNet-50"
filename = "model.safetensors"
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
model = AutoModelForImageClassification.from_pretrained(repo_id)
model.load_state_dict(load_file(model_path))
我的Telegram @eralyf