许可证:apache-2.0
基础模型:google/efficientnet-b2
评估指标:
- 准确率
任务标签:图像分类
标签:
- 生物学
- efficientnet-b2
- 图像分类
- 视觉
鸟类分类器 EfficientNet-B2
模型描述
你是否曾看到一只鸟,心想:“唉,要是我知道那是什么鸟就好了。”
除非你是狂热的观鸟爱好者(或者单纯喜欢鸟类),否则很难区分某些鸟类的品种。
不过你很幸运,现在可以使用图像分类器来识别鸟类品种了!
该模型是基于 google/efficientnet-b2 在 gpiosenka/100-bird-species 数据集(来自 Kaggle)上微调的版本。训练模型所用的数据集采集于 2023 年 9 月 24 日。
原始模型本身是在 ImageNet-1K 上训练的,因此可能仍保留了一些识别鸟类等生物的有用特征。
理论上,在该数据集上随机猜测的准确率为 0.0019047619(即 1/525)。模型在三个数据集上的表现均非常出色,结果如下:
- 训练集:0.999480
- 验证集:0.985904
- 测试集:0.991238
用途
你可以直接使用该模型进行图像分类。以下是使用鸟类图片运行模型的示例:
import torch
import urllib.request
from PIL import Image
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification
url = '某 URL'
img = Image.open(urllib.request.urlretrieve(url)[0])
preprocessor = EfficientNetImageProcessor.from_pretrained("dennisjooo/Birds-Classifier-EfficientNetB2")
model = EfficientNetForImageClassification.from_pretrained("dennisjooo/Birds-Classifier-EfficientNetB2")
inputs = preprocessor(img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
或者,你也可以使用 Huggingface 的 Pipeline 简化流程:
import torch
import urllib.request
from PIL import Image
from transformers import pipeline
url = '某 URL'
img = Image.open(urllib.request.urlretrieve(url)[0])
pipe = pipeline("image-classification", model="dennisjooo/Birds-Classifier-EfficientNetB2")
result = pipe(img)[0]
print(result['label'])
训练与评估
数据
数据集来自 Kaggle 的 gpiosenka/100-bird-species。
它包含 525 种鸟类,训练集有 84,635 张图像,验证集和测试集各有 2,625 张图像。
数据集中的每张图像均为 224×224 的 RGB 图像。
训练过程使用了作者提供的相同划分方式。
更多细节请参考 作者的 Kaggle 页面。
训练过程
训练使用 PyTorch 在 Kaggle 的免费 P100 GPU 上完成,过程中还使用了 Lightning 和 Torchmetrics 库。
预处理
每张图像均按照原始作者的 配置 进行预处理。
训练集还通过以下方式进行了数据增强:
- 随机旋转 10 度,概率为 50%
- 随机水平翻转,概率为 50%
训练超参数
训练使用的超参数如下:
- 训练模式:fp32
- 损失函数:交叉熵
- 优化器:Adam(默认 beta 值为 0.99 和 0.999)
- 学习率:1e-3
- 学习率调度器:监控验证损失的 Reduce on Plateau,耐心值为 2,衰减率为 0.1
- 批量大小:64
- 早停机制:监控验证准确率,耐心值为 10
结果
下图展示了训练集和验证集上的训练结果:
