许可协议:apache-2.0
标签:
SlimSAM模型卡片(SAM压缩版=Segment Anything)
SlimSAM概览及其与替代方案的差异。
目录
- 摘要
- 模型详情
- 使用方法
- 引用
摘要
SlimSAM是Segment Anything (SAM)模型的压缩(剪枝)版本,能够根据点或框等输入提示生成高质量的对象掩码。
论文摘要指出:
Segment Anything Model (SAM)庞大的模型规模和苛刻的计算需求使其难以部署在资源受限的设备上。现有的SAM压缩方法通常需要从头训练新网络,面临压缩成本与模型性能之间的艰难权衡。为解决这一问题,本文提出SlimSAM,一种新颖的SAM压缩方法,以极低的训练成本实现卓越性能。这是通过统一的剪枝-蒸馏框架高效复用预训练SAM实现的。为增强从原始SAM的知识继承,我们采用创新的交替瘦身策略,将压缩过程划分为渐进步骤。不同于先前的剪枝技术,我们以交替方式精细剪枝并蒸馏解耦的模型结构。此外,还提出了一种无标签剪枝标准,使剪枝目标与优化目标对齐,从而提升剪枝后的蒸馏效果。SlimSAM在显著提升性能的同时,训练成本比现有方法低10倍以上。即使与原始SAM-H相比,SlimSAM在参数降至仅0.9%(570万)、MACs降至0.8%(210亿)且仅需SAM训练数据0.1%(1万)的情况下,仍能接近其性能。
原始仓库链接
免责声明:本模型卡片内容由Hugging Face团队撰写,部分内容复制自原始SAM模型卡片。
模型详情
SAM模型由3个模块组成:
VisionEncoder
:基于VIT的图像编码器。通过图像块注意力计算图像嵌入,使用相对位置嵌入。
PromptEncoder
:为点和边界框生成嵌入。
MaskDecoder
:双向变换器,执行图像嵌入与点嵌入之间的交叉注意力(->)及反向操作。输出送入Neck
模块。
Neck
:基于MaskDecoder
产生的上下文掩码预测输出掩码。
使用方法
提示掩码生成
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("nielsr/slimsam-77-uniform")
processor = SamProcessor.from_pretrained("nielsr/slimsam-77-uniform")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
生成掩码时,除其他参数外,可传递对象大致位置的2D坐标、包围对象的边界框(格式应为左上和右下点坐标)或分割掩码。截至撰写时,官方模型不支持文本输入(参见官方仓库)。更多细节请参考展示使用流程及可视化示例的笔记本。
自动掩码生成
该模型可用于“零样本”方式生成分割掩码,给定输入图像。模型自动以1024个点的网格作为提示输入。
以下代码片段演示了如何轻松运行自动掩码生成管道(适用于任何设备!只需调整points_per_batch
参数):
from transformers import pipeline
generator = pipeline(task="mask-generation", model="nielsr/slimsam-77-uniform", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
展示结果图像:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
引用
若使用此模型,请引用以下BibTeX条目。
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}
@misc{chen202301,
title={0.1% Data Makes Segment Anything Slim},
author={Zigeng Chen and Gongfan Fang and Xinyin Ma and Xinchao Wang},
year={2023},
eprint={2312.05284},
archivePrefix={arXiv},
primaryClass={cs.CV}
}