许可协议:apache-2.0
标签:
SlimSAM模型卡片(SAM压缩版=万物分割)
SlimSAM概览及其与替代方案的差异。
目录
- 摘要
- 模型详情
- 使用方式
- 引用
摘要
SlimSAM是Segment Anything (SAM)模型的压缩(剪枝)版本,能够根据点或框等输入提示生成高质量的对象掩码。
研究论文摘要指出:
万物分割模型(SAM)庞大的参数量和高计算需求使其难以部署在资源受限的设备上。现有的SAM压缩方法通常需要从头训练新网络,面临压缩成本与模型性能之间的艰难权衡。为解决这一问题,本文提出SlimSAM,一种新颖的SAM压缩方法,以极低的训练成本实现卓越性能。该方法通过统一的剪枝-蒸馏框架高效复用预训练SAM。为增强原始SAM的知识继承,我们采用创新的交替瘦身策略,将压缩过程分解为渐进步骤。不同于以往剪枝技术,我们精细地以交替方式对解耦的模型结构进行剪枝与蒸馏。此外,还提出了一种无标签剪枝标准,使剪枝目标与优化目标对齐,从而提升剪枝后的蒸馏效果。SlimSAM在性能显著提升的同时,训练成本比现有方法降低10倍以上。即使与原始SAM-H相比,SlimSAM在参数量仅0.9%(570万)、MAC运算量0.8%(210亿)、训练数据量0.1%(1万)的情况下,仍能达到接近的性能。
原始仓库链接
免责声明:本模型卡片内容由Hugging Face团队撰写,部分内容复制自原始SAM模型卡片。
模型详情
SAM模型由3个模块组成:
VisionEncoder
:基于VIT的图像编码器。通过图像分块的注意力计算图像嵌入,使用相对位置编码。
PromptEncoder
:为点和边界框生成嵌入表示
MaskDecoder
:双向transformer,执行图像嵌入与点嵌入的交叉注意力(->)以及点嵌入与图像嵌入的交叉注意力(<-)。输出馈入
Neck
:基于MaskDecoder
生成的上下文掩码预测输出掩码。
使用方式
提示式掩码生成
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("nielsr/slimsam-50-uniform")
processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
生成掩码时,除其他参数外,可传递:感兴趣对象的近似二维位置坐标、包围对象的边界框(格式应为边界框左上右下点的x,y坐标)、分割掩码。截至撰写时,根据官方仓库,官方模型暂不支持文本输入。更多细节可参考展示使用流程的示例笔记本。
自动掩码生成
该模型可用于"零样本"方式生成分割掩码。模型会自动接收1024个网格点作为提示。
以下代码片段演示了如何轻松运行自动掩码生成流程(可在任何设备上运行!只需调整points_per_batch
参数):
from transformers import pipeline
generator = pipeline(task="mask-generation", model="nielsr/slimsam-50-uniform", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
展示生成结果:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
引用
若使用本模型,请引用以下文献:
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}
@misc{chen202301,
title={0.1% Data Makes Segment Anything Slim},
author={Zigeng Chen and Gongfan Fang and Xinyin Ma and Xinchao Wang},
year={2023},
eprint={2312.05284},
archivePrefix={arXiv},
primaryClass={cs.CV}
}