library_name: transformers
license: mit
RobustSAM:在退化图像上稳健分割任意对象(CVPR 2024亮点)
ViT大模型(ViT-L)版本模型卡

RobustSAM官方仓库:在退化图像上稳健分割任意对象
项目主页 | 论文 | 数据集
简介
分割任意模型(SAM)已成为图像分割领域的革命性方法,以其强大的零样本分割能力和灵活的提示系统备受赞誉。然而,其在质量退化图像上的表现面临挑战。针对这一局限,我们提出了稳健分割任意模型(RobustSAM),该模型在保持提示性和零样本泛化能力的同时,提升了SAM在低质量图像上的性能。
我们的方法利用预训练的SAM模型,仅需少量参数增加和计算需求。RobustSAM的额外参数可在8块GPU上30小时内完成优化,展现了其在典型研究实验室中的可行性和实用性。我们还引入了Robust-Seg数据集,包含68.8万组不同退化程度的图像-掩码对,专为优化模型训练和评估而设计。跨多种分割任务和数据集的广泛实验证实了RobustSAM的卓越性能,尤其是在零样本条件下,凸显了其广泛实际应用的潜力。此外,我们的方法还能有效提升基于SAM的下游任务(如单图像去雾和去模糊)的性能。
免责声明:本模型卡内容由Hugging Face团队编写,部分内容复制自原始SAM模型卡。
模型详情
RobustSAM模型由3个模块组成:
VisionEncoder
:基于VIT的图像编码器。通过图像块上的注意力计算图像嵌入,使用相对位置嵌入。
PromptEncoder
:为点和边界框生成嵌入。
MaskDecoder
:双向变换器,在图像嵌入和点嵌入之间执行交叉注意力(->)以及在点嵌入和图像嵌入之间执行交叉注意力。输出被馈送到。
Neck
:基于MaskDecoder
产生的上下文掩码预测输出掩码。
使用方法
提示掩码生成
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForMaskGeneration
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-large")
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-large")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
在生成掩码时,除了其他参数,你可以传递对象大致位置的2D坐标、包围对象的边界框(格式应为边界框右上和左下点的x、y坐标)或分割掩码。根据官方仓库,目前官方模型不支持文本输入。
更多细节请参考此笔记本,其中展示了如何使用模型的逐步指南及可视化示例!
自动掩码生成
该模型可用于“零样本”方式生成分割掩码,给定输入图像。模型会自动提示1024个点的网格,这些点全部输入模型。
该流程专为自动掩码生成设计。以下代码片段演示了如何轻松运行(在任何设备上!只需提供适当的points_per_batch
参数)
from transformers import pipeline
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-large", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
现在要在图像上显示生成的掩码:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
视觉对比
引用
如果本工作对您有帮助,请考虑引用我们!
@inproceedings{chen2024robustsam,
title={RobustSAM: Segment Anything Robustly on Degraded Images},
author={Chen, Wei-Ting and Vong, Yu-Jiet and Kuo, Sy-Yen and Ma, Sizhou and Wang, Jian},
journal={CVPR},
year={2024}
}
致谢
我们感谢SAM的作者,我们的仓库基于他们的工作。