许可协议: MIT
语言:

SegVol是一款通用且交互式的医学体数据图像分割模型。该模型支持通过点提示、框提示和文本提示进行体积分割。通过在9万例未标注的CT扫描数据和6千例标注CT数据上进行训练,这一基础模型能够支持超过200种解剖结构的识别分割。
论文与代码已开源。
关键词: 3D医学SAM、体数据图像分割
快速开始
环境配置
conda create -n segvol_transformers python=3.8
conda activate segvol_transformers
需安装pytorch v1.11.0或更高版本。通过以下命令安装关键依赖:
pip install 'monai[all]==0.9.0'
pip install einops==0.6.1
pip install transformers==4.18.0
pip install matplotlib
测试脚本
from transformers import AutoModel, AutoTokenizer
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=True)
model.model.text_encoder.tokenizer = clip_tokenizer
model.eval()
model.to(device)
print('模型加载完成')
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'
categories = ["肝脏", "肾脏", "脾脏", "胰腺"]
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
data_item = model.processor.zoom_transform(ct_npy, gt_npy)
data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \
data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device)
cls_idx = 0
text_prompt = [categories[cls_idx]]
point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)
bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)
print('提示生成完成')
logits_mask = model.forward_test(image=data_item['image'],
zoomed_image=data_item['zoom_out_image'],
bbox_prompt_group=[bbox_prompt, bbox_prompt_map],
text_prompt=text_prompt,
use_zoom=False
)
dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx])
print(dice)
save_path='./Case_preds_00001.nii.gz'
model.processor.save_preds(ct_path, save_path, logits_mask[0][0],
start_coord=data_item['foreground_start_coord'],
end_coord=data_item['foreground_end_coord'])
print('完成')
训练脚本
from transformers import AutoModel, AutoTokenizer
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=False)
model.model.text_encoder.tokenizer = clip_tokenizer
model.train()
model.to(device)
print('模型加载完成')
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'
categories = ["肝脏", "肾脏", "脾脏", "胰腺"]
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
data_item = model.processor.train_transform(ct_npy, gt_npy)
image, gt3D = data_item["image"].unsqueeze(0).to(device), data_item["label"].unsqueeze(0).to(device)
loss_step_avg = 0
for cls_idx in range(len(categories)):
organs_cls = categories[cls_idx]
labels_cls = gt3D[:, cls_idx]
print(image.shape, organs_cls, labels_cls.shape)
loss = model.forward_train(image, train_organs=organs_cls, train_labels=labels_cls)
loss_step_avg += loss.item()
loss.backward()
loss_step_avg /= len(categories)
print(f'平均损失 {loss_step_avg}')
model.save_pretrained('./ckpt')