license: apache-2.0
tags:
Segment Anything Model (SAM) - ViT超大版(ViT-H)模型卡片
Segment Anything Model (SAM) 详细架构图
目录
- 摘要
- 模型详情
- 使用方式
- 引用
摘要
原始代码库链接
Segment Anything Model (SAM) 能够根据输入提示(如点或框)生成高质量的对象掩码,并可为图像中的所有对象生成掩码。该模型在包含1100万张图像和11亿个掩码的数据集上训练,在各种分割任务上展现出强大的零样本性能。论文摘要指出:
我们推出了Segment Anything (SA)项目:这是一个用于图像分割的新任务、模型和数据集。通过在数据收集循环中使用我们的高效模型,我们构建了迄今为止最大的分割数据集(远超以往),包含11M经过许可且尊重隐私的图像上的超过10亿个掩码。该模型设计为可提示式训练,因此能够零样本迁移到新的图像分布和任务。我们在多项任务上评估其能力,发现其零样本性能令人印象深刻——通常与之前完全监督的结果相当甚至更优。我们将发布Segment Anything Model (SAM)和相应的数据集(SA-1B),包含10亿掩码和1100万图像,以促进计算机视觉基础模型的研究。
免责声明:本模型卡片内容由Hugging Face团队编写,部分内容复制自原始SAM模型卡片。
模型详情
SAM模型由3个模块组成:
VisionEncoder
:基于VIT的图像编码器。通过图像块注意力计算图像嵌入,使用相对位置编码。
PromptEncoder
:为点和边界框生成嵌入表示
MaskDecoder
:双向transformer,在图像嵌入和点嵌入之间执行交叉注意力(->),以及在点嵌入和图像嵌入之间执行交叉注意力(<-)。输出被送入
Neck
:基于MaskDecoder
生成的上下文掩码预测输出掩码。
使用方式
提示式掩码生成
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
除生成掩码的其他参数外,您可以传递:
- 目标物体大致位置的2D坐标
- 包围目标物体的边界框(格式应为边界框左上和右下点的x,y坐标)
- 分割掩码
根据官方代码库,当前官方模型不支持文本输入。
自动掩码生成
该模型可用于"零样本"方式生成分割掩码。模型会自动使用1024个点组成的网格作为提示。
以下代码片段展示如何轻松运行自动掩码生成流程(可在任何设备上运行!只需调整points_per_batch
参数):
from transformers import pipeline
generator = pipeline("mask-generation", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
展示结果图像:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
引用
如果使用本模型,请引用以下文献:
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}