库名称:transformers
许可证:mit
RobustSAM:在退化图像上稳健分割一切(CVPR 2024亮点)
ViT Base(ViT-B)版本模型卡


RobustSAM官方仓库:在退化图像上稳健分割一切
项目页面 | 论文 | 数据集
简介
分割一切模型(SAM)已成为图像分割领域的革命性方法,以其强大的零样本分割能力和灵活的提示系统备受赞誉。然而,其在图像质量退化时的表现面临挑战。为解决这一局限,我们提出了稳健分割一切模型(RobustSAM),在保持提示性和零样本泛化能力的同时,提升了SAM在低质量图像上的性能。
我们的方法利用预训练的SAM模型,仅需少量参数增加和计算需求。RobustSAM的额外参数可在8块GPU上30小时内完成优化,展示了其在典型研究实验室中的可行性和实用性。我们还引入了Robust-Seg数据集,包含68.8万张带不同退化类型的图像-掩码对,专为优化模型训练和评估而设计。在多种分割任务和数据集上的广泛实验证实了RobustSAM的卓越性能,尤其是在零样本条件下,凸显了其广泛实际应用的潜力。此外,我们的方法还能有效提升基于SAM的下游任务(如单图像去雾和去模糊)的性能。
免责声明:本模型卡内容由Hugging Face团队撰写,部分内容复制自原始SAM模型卡。
模型详情
RobustSAM模型由3个模块组成:
VisionEncoder
:基于VIT的图像编码器。通过图像块注意力计算图像嵌入,使用相对位置编码。
PromptEncoder
:为点和边界框生成嵌入。
MaskDecoder
:双向Transformer,在图像嵌入和点嵌入之间执行交叉注意力(->)及反向操作。输出送入
Neck
:基于MaskDecoder
生成的上下文掩码预测输出掩码。
使用方式
提示掩码生成
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForMaskGeneration
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-base")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
生成掩码时,除其他参数外,可传递物体大致位置的2D坐标、包围物体的边界框(格式应为边界框左上和右下点的x、y坐标)或分割掩码。截至撰写时,官方模型不支持文本输入(参见官方仓库)。
更多细节可参考此笔记本,其中展示了使用模型的完整流程及可视化示例!
自动掩码生成
该模型可用于“零样本”方式生成分割掩码,给定输入图像。模型会自动提示1024个网格点作为输入。
以下代码片段演示了如何轻松运行自动掩码生成管道(适用于任何设备!只需传递适当的points_per_batch
参数):
from transformers import pipeline
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-base", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
现在展示生成的掩码:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
视觉对比
引用
如果本工作对您有帮助,请考虑引用我们!
@inproceedings{chen2024robustsam,
title={RobustSAM: Segment Anything Robustly on Degraded Images},
author={Chen, Wei-Ting and Vong, Yu-Jiet and Kuo, Sy-Yen and Ma, Sizhou and Wang, Jian},
journal={CVPR},
year={2024}
}
致谢
我们感谢SAM的作者,我们的仓库基于其构建。