license: apache-2.0
tags:
SAM模型卡(ViT基础版)
Segment Anything Model (SAM) 详细架构图
目录
- 摘要
- 模型详情
- 使用方式
- 引用
摘要
原始代码库链接
Segment Anything Model (SAM) 能够通过输入提示(如点或框)生成高质量的对象掩码,并可为图像中所有对象创建分割掩码。该模型在包含1100万张图像和11亿个掩码的数据集上训练,在多种分割任务上展现出强大的零样本性能。
论文摘要指出:
我们推出Segment Anything(SA)项目:包含新任务、模型和数据集的全新图像分割方案。通过高效模型的数据收集循环,我们构建了迄今为止最大规模的分割数据集(远超现有水平),包含11M经过授权且尊重隐私的图像及超过10亿掩码。该模型设计为可提示式,能够零样本迁移到新图像分布和任务。我们在多项任务上评估其能力,发现其零样本性能令人印象深刻——常与全监督结果媲美甚至更优。我们将发布Segment Anything模型(SAM)及对应数据集(SA-1B,含10亿掩码和1100万图像)以促进计算机视觉基础模型研究,详见https://segment-anything.com。
免责声明:本模型卡内容由Hugging Face团队编写,部分内容复制自原始SAM模型卡。
模型详情
SAM模型由3个模块组成:
VisionEncoder
:基于VIT的图像编码器。通过图像块注意力计算嵌入表示,使用相对位置编码。
PromptEncoder
:为点和边界框生成嵌入表示
MaskDecoder
:双向Transformer,执行图像嵌入与点嵌入的交叉注意力(->)以及点嵌入与图像嵌入的交叉注意力。输出将馈入:
Neck
:基于MaskDecoder
生成的上下文掩码预测最终输出掩码。
使用方式
提示式掩码生成
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
除生成掩码的基本参数外,您还可以传入:对象大致位置的2D坐标、包围对象的边界框(格式应为左上和右下点的x,y坐标)、分割掩码。截至本文撰写时,官方模型暂不支持文本输入(参见官方仓库说明)。更多细节可参考演示完整流程的示例笔记本。
自动掩码生成
该模型可实现"零样本"风格的自动分割掩码生成。模型会自动使用1024个网格点作为提示输入。
以下代码片段展示如何轻松运行自动掩码生成流程(支持任何设备!只需调整points_per_batch
参数):
from transformers import pipeline
generator = pipeline("mask-generation", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
可视化生成结果:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
引用
如需使用本模型,请引用以下文献:
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}