库名称:transformers
标签:[]
基础模型:microsoft/conditional-detr-resnet-50
模型概述
该模型是基于DETR的目标检测模型,专为医学图像分析训练,包含4个类别:0:肺炎、1:正常、2:细菌性肺炎、3:病毒性肺炎。
模型描述
模型架构:DEtection TRansformers (DETR)
训练数据:基于标注医学图像的自定义数据集训练
用途:设计用于分析胸部X光图像,检测肺炎的存在及类型,或分类为正常。
使用示例
测试模型的示例代码:
import os
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import torch
from PIL import Image
import pandas as pd
folder_path = ""
processor = AutoImageProcessor.from_pretrained("0llheaven/CON-DETR-V5")
model = AutoModelForObjectDetection.from_pretrained("0llheaven/CON-DETR-V5")
results_list = []
for image_name in os.listdir(folder_path):
if image_name.endswith((".jpg", ".png", ".jpeg")):
image_path = os.path.join(folder_path, image_name)
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)
print(f"正在处理图像: {image_name}")
detected_any = False
for result in results:
scores = result["scores"]
labels = result["labels"]
boxes = result["boxes"]
filtered_data = [(score, label, box) for score, label, box in zip(scores, labels, boxes) if score > 0.5][:2]
for score, label, box in zip(scores, labels, boxes):
if score > 0.5:
if len(filtered_data) > 0:
detected_any = True
for score, label, box in filtered_data:
if label.item() == 0:
label_name = "肺炎"
elif label.item() == 1:
label_name = "正常"
elif label.item() == 2:
label_name = "细菌性肺炎"
else:
label_name = "病毒性肺炎"
xmin, ymin, xmax, ymax = [round(i, 2) for i in box.tolist()]
print(f" - 检测到 {label_name},置信度 {round(score.item(), 3)},位置 {xmin, ymin, xmax, ymax}")
results_list.append({
"图像名称": image_name,
"标签": label_name,
"xmin": xmin,
"ymin": ymin,
"xmax": xmax,
"ymax": ymax,
"置信度": round(score.item(), 3),
})
if not detected_any:
print(" - 未检测到目标")
results_list.append({
"图像名称": image_name,
"标签": "其他",
"xmin": 0,
"ymin": 0,
"xmax": 0,
"ymax": 0,
"置信度": 0,
})
results_df = pd.DataFrame(results_list)
print("\n最终结果:")
results_df.to_csv("testmodel.csv", index=False)