library_name: transformers
tags:
- 乳腺X光摄影
- 癌症
- 乳腺癌
- 放射学
- 乳腺密度
license: apache-2.0
base_model:
- timm/tf_efficientnetv2_s.in21k_ft_in1k
pipeline_tag: image-classification
这是一个基于筛查性乳腺X光摄影预测乳腺癌和乳腺密度的集成模型。
该模型使用3个基础CNN网络(tf_efficientnetv2_s
主干),对每张提供的图像(即CC和MLO视图)进行推理。
集成中每个网络采用不同分辨率:2048×1024、1920×1280和1536×1536。
最终输出会在提供的视图和神经网络之间进行平均。
模型也支持单视图(图像)推理,但性能会有所下降。
首先在"数字乳腺X光筛查数据库精选子集"(CBIS-DDSM)上预训练了一个分类-分割混合模型。该数据集包含胶片乳腺X光检查(非数字化)及良恶性肿块和钙化的ROI标注。
随后使用RSNA乳腺X光筛查乳腺癌检测挑战赛数据进一步训练。数据按80%/10%/10%划分为训练集/验证集/测试集,评估在10%的测试集上进行。该过程重复3次以更好评估模型性能,提供的权重来自第一次数据划分。
训练中采用指数移动平均技术提升性能。
注意:模型训练使用裁剪后的图像,建议推理前先裁剪。裁剪模型见:https://huggingface.co/ianpan/mammo-crop
主要评估指标是受试者工作特征曲线下面积(AUC/AUROC)。以下是3次划分的平均值和标准差:
划分1: 0.9464
划分2: 0.9467
划分3: 0.9422
均值(标准差): 0.9451 (0.002)
作为筛查测试,高灵敏度至关重要。我们在不同灵敏度下的特异性表现如下(3次划分平均值):
灵敏度98.1%时:特异性65.4% ±7.2%,阈值0.0072±0.0021
灵敏度94.3%时:特异性78.7% ±0.9%,阈值0.0127±0.0011
灵敏度90.5%时:特异性84.8% ±2.7%,阈值0.0184±0.0027
使用示例:
import cv2
import torch
from transformers import AutoModel
def 裁剪乳腺图(img, model, device):
img_shape = torch.tensor([img.shape[:2]]).to(device)
x = model.preprocess(img)
x = torch.from_numpy(x).expand(1, 1, -1, -1).float().to(device)
with torch.inference_mode():
coords = model(x, img_shape)
coords = coords[0].cpu().numpy()
x, y, w, h = coords
return img[y: y + h, x: x + w]
device = "cuda:0"
crop_model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)
model = AutoModel.from_pretrained("ianpan/mammoscreen", trust_remote_code=True)
model = model.eval().to(device)
cc_img = cv2.imread("mammo_cc.png", cv2.IMREAD_GRAYSCALE)
mlo_img = cv2.imread("mammo_mlo.png", cv2.IMREAD_GRAYSCALE)
cc_img = 裁剪乳腺图(cc_img, crop_model, device)
mlo_img = 裁剪乳腺图(mlo_img, crop_model, device)
with torch.inference_mode():
output = model({"cc": cc_img, "mlo": mlo_img}, device=device)
注意:模型在forward
函数内部完成数据预处理。output
是包含cancer
和density
两个键的字典。output['cancer']
是形状(N,1)的张量,output['density']
是形状(N,4)的张量。获取密度分类可使用output['density'].argmax(1)
。单次检查时N=1。
也可单独访问每个神经网络model.net{i}
,但需在forward
外手动预处理:
input_dict = model.net0.preprocess({"cc": cc_img, "mlo": mlo_img}, device=device)
with torch.inference_mode():
out = model.net0(input_dict)
支持批量推理。为每个乳房构建字典并传入字典列表。例如对2名患者(pt1
,pt2
)的双乳检查:
cc_images = ["rt_pt1_cc.png", "lt_pt1_cc.png", "rt_pt2_cc.png", "lt_pt2_cc.png"]
mlo_images = ["rt_pt1_mlo.png", "lt_pt1_mlo.png", "rt_pt2_mlo.png", "lt_pt2_mlo.png"]
cc_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in cc_images]
mlo_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in mlo_images]
cc_images = [裁剪乳腺图(_, crop_model, device) for _ in cc_images]
mlo_images = [裁剪乳腺图(_, crop_model, device) for _ in mlo_images]
input_dict = [{"cc": cc_img, "mlo": mlo_img} for cc_img, mlo_img in zip(cc_images, mlo_images)]
with torch.inference_mode():
output = model(input_dict, device=device)
注意:从DICOM转为8位PNG/JPEG时,需使用pydicom.pixels.apply_voi_lut
处理像素值。安装pydicom
后可直接加载DICOM图像:
img = model.load_image_from_dicom(dicom文件路径)