许可证:apache-2.0
语言:
任务标签:零样本目标检测
库名称:transformers
基础模型:
- omlab/omdet-turbo-swin-tiny-hf
标签:
此仓库为🤗推理端点实现了零样本目标检测
的自定义
任务。自定义处理器的代码位于handler.py。
要将此模型部署为推理端点,需选择Custom
作为任务以使用handler.py
文件。
仓库包含requirements.txt以下载timm库。
预期请求载荷
{
"inputs": {
"image": "/9j/4AAQSkZJRgABAQEBLAEsAAD/2wBDAAMCAgICAgMC....",
"candidates": ["破损路缘", "破损道路", "破损路标", "破损人行道"]
}
}
以下是使用Python和requests
发起请求的示例。
运行请求
import json
from typing import List
import requests as r
import base64
ENDPOINT_URL = ""
HF_TOKEN = ""
def predict(path_to_image: str = None, candidates: List[str] = None):
with open(path_to_image, "rb") as i:
b64 = base64.b64encode(i.read())
payload = {"inputs": {"image": b64.decode("utf-8"), "candidates": candidates}}
response = r.post(
ENDPOINT_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, json=payload
)
return response.json()
prediction = predict(
path_to_image="image/brokencurb.jpg", candidates=["破损路缘", "破损道路", "破损路标", "破损人行道"]
)
print(json.dumps(prediction, indent=2))
预期输出
{
"boxes": [
[
1.919342041015625,
231.1556396484375,
1011.4019775390625,
680.3773193359375
],
[
610.9949951171875,
397.6180419921875,
1019.9259033203125,
510.8144226074219
],
[
1.919342041015625,
231.1556396484375,
1011.4019775390625,
680.3773193359375
],
[
786.1240234375,
68.618896484375,
916.1265869140625,
225.0513458251953
]
],
"scores": [
0.4329715967178345,
0.4215811491012573,
0.3389397859573364,
0.3133399784564972
],
"candidates": [
"破损人行道",
"破损路标",
"破损道路",
"破损路标"
]
}
边界框格式为{x_min, y_min, x_max, y_max}
可视化结果
输入图像
可通过以下代码可视化请求结果:
prediction = predict(
path_to_image="image/cat_and_remote.jpg", candidates=["猫", "遥控器", "坑洞"]
)
import matplotlib.pyplot as plt
import matplotlib.patches as patches
with open("image/cat_and_remote.jpg", "rb") as i:
image = plt.imread(i)
fig, ax = plt.subplots(1)
ax.imshow(image)
for score, class_name, box in zip(
prediction["scores"], prediction["candidates"], prediction["boxes"]
):
rect = patches.Rectangle([int(box[0]), int(box[1])], int(box[2] - box[0]), int(box[3] - box[1]), linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
ax.text(int(box[0]), int(box[1]), str(round(score, 2)) + " " + str(class_name), color='white', fontsize=6, bbox=dict(facecolor='red', alpha=0.5))
plt.savefig('image_result/cat_and_remote_with_bboxes_zero_shot.jpeg')
结果
输出图像
致谢
此适配工作灵感来源于@philschmid在philschmid/clip-zero-shot-image-classification上的成果。