🚀 犬种多分类图像识别模型
本项目基于视觉变换器(Vision Transformer)模型,对犬类图像进行分类,可识别 120 种不同犬种。该模型使用了预训练的 Google Vision Transformer 模型,并在斯坦福犬类数据集上进行微调,具有较高的准确性和良好的泛化能力。
🚀 快速开始
模型背景
最近,有人问我是否可以将犬类图像分类为不同的犬种,而不是像我之前的 笔记本 那样仅仅区分猫和狗。答案是肯定的!
由于问题的复杂性,我们将使用 2020 年 Google 论文 中发布的最先进的计算机视觉架构——视觉变换器(Vision Transformer)。
模型原理
视觉变换器(Vision Transformer) 与传统的 卷积神经网络(CNN) 的区别在于对图像的处理方式。在 视觉变换器 中,我们将输入视为原始图像的一个补丁(例如 16 x 16),并将其作为带有位置嵌入和自注意力的序列输入到变换器中;而在 卷积神经网络(CNN) 中,我们使用相同的原始图像补丁作为输入,但使用卷积和池化层作为归纳偏置。这意味着 视觉变换器 可以使用其自注意力机制以“全局”方式关注图像的任何特定补丁,而无需像 CNN 那样通过“局部”居中/裁剪/边界框来引导神经网络进行卷积操作。
这使得 视觉变换器 架构在本质上更加灵活和可扩展,使我们能够在计算机视觉中创建 基础模型,类似于自然语言处理中的基础模型,如 BERT 和 GPT,通过在大量图像数据上进行预训练(自监督/监督),可以推广到不同的计算机视觉任务,如图像分类、识别、分割等。这种交叉融合有助于我们更接近通用人工智能的目标。
需要注意的是,与 卷积神经网络 相比,视觉变换器 的归纳偏置较弱,这使得它具有可扩展性和灵活性。但这一特点(或缺点,取决于你的看法)要求大多数表现良好的预训练模型需要更多的数据,尽管与 CNN 相比,它的参数更少。
幸运的是,在这个模型中,我们将使用 Google 托管在 HuggingFace 上的 视觉变换器,该模型在 ImageNet-21k 数据集(1400 万张图像,21000 个类别)上进行了预训练,补丁大小为 16x16,分辨率为 224x224,以绕过数据限制。我们将在来自 斯坦福犬类数据集 的约 20000 张图像的“小”犬种数据集上对该模型进行微调,以将犬类图像分类为 120 种不同的犬种!
✨ 主要特性
- 基于先进架构:采用视觉变换器(Vision Transformer)架构,具有更好的灵活性和可扩展性。
- 预训练模型微调:使用在 ImageNet-21k 数据集上预训练的 Google Vision Transformer 模型,在斯坦福犬类数据集上进行微调,提高模型性能。
- 多指标评估:使用 Top-1 准确率、Top-3 准确率、Top-5 准确率和 Macro F1 等多个指标对模型进行评估,确保模型的准确性和泛化能力。
📦 安装指南
本模型使用 Python 编写,依赖于 transformers
、PIL
和 requests
等库。可以使用以下命令安装所需的库:
pip install transformers pillow requests
💻 使用示例
基础用法
from transformers import AutoImageProcessor, AutoModelForImageClassification
import PIL
import requests
url = "https://upload.wikimedia.org/wikipedia/commons/5/55/Beagle_600.jpg"
image = PIL.Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
model = AutoModelForImageClassification.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
📚 详细文档
模型描述
本模型使用 Google 视觉变换器(vit-base-patch16-224-in21k) 在 Kaggle 上的斯坦福犬类数据集 上进行微调,以将犬类图像分类为 120 种不同的犬种。
预期用途和限制
你可以使用这个微调后的模型仅对数据集中包含的犬类图像和犬种进行分类。
模型训练指标
轮数 |
Top-1 准确率 |
Top-3 准确率 |
Top-5 准确率 |
Macro F1 |
1 |
79.8% |
95.1% |
97.5% |
77.2% |
2 |
83.8% |
96.7% |
98.2% |
81.9% |
3 |
84.8% |
96.7% |
98.3% |
83.4% |
模型评估指标
Top-1 准确率 |
Top-3 准确率 |
Top-5 准确率 |
Macro F1 |
84.0% |
97.1% |
98.7% |
83.0% |
📄 许可证
本项目采用 MIT 许可证。