许可协议:apache-2.0
标签:
- 目标检测
- 视觉
数据集:
- sku110k
示例:
- 图片链接:https://github.com/Isalia20/DETR-finetune/blob/main/IMG_3507.jpg?raw=true
示例标题:商店示例(非SKU110K数据集)
基于ResNet-101-DC5骨干网络并在SKU110K数据集上训练的DETR(端到端目标检测)模型(查询数为400)
该DEtection TRansformer(DETR)模型在SKU110K目标检测数据集(包含8千张标注图像)上进行了端到端训练。与原始模型的主要区别在于其查询数设置为400,并且是在SKU110K数据集上预训练的。
使用方法
以下是使用该模型的方法:
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image, ImageOps
import requests
url = "https://github.com/Isalia20/DETR-finetune/blob/main/IMG_3507.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
image = ImageOps.exif_transpose(image)
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101-dc5")
model = DetrForObjectDetection.from_pretrained("isalia99/detr-resnet-101-dc5-sku110k")
model = model.eval()
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.8)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"检测到 {model.config.id2label[label.item()]},置信度为 "
f"{round(score.item(), 3)},位置为 {box}"
)
预期输出:
检测到 LABEL_1,置信度为 0.983,位置为 [665.49, 480.05, 708.15, 650.11]
检测到 LABEL_1,置信度为 0.938,位置为 [204.99, 1405.9, 239.9, 1546.5]
...
检测到 LABEL_1,置信度为 0.998,位置为 [772.85, 169.49, 829.67, 372.18]
检测到 LABEL_1,置信度为 0.999,位置为 [828.28, 1475.16, 874.37, 1593.43]
目前,特征提取器和模型均支持PyTorch。
训练数据
DETR模型在SKU110K数据集上训练,该数据集包含8,219/588/2,936张分别用于训练/验证/测试的标注图像。
训练过程
训练
模型在1块RTX 4060 Ti GPU上进行了60轮训练(仅微调解码器),批次大小为1,梯度累积步数为8;随后又进行了60轮训练(微调整个网络),批次大小为1,梯度累积步数为8。
评估结果
该模型在SKU110K验证集上的mAP达到59.8。结果通过torchmetrics的MeanAveragePrecision类计算得出。
训练代码
代码已发布于此仓库。虽然尚未完全定稿或充分测试,但主要功能已实现。