license: apache-2.0
tags:
分段任意模型(SAM)卡片 - ViT大模型(ViT-L)版本
分段任意模型(SAM)的详细架构图
目录
- 摘要
- 模型详情
- 使用方式
- 引用文献
摘要
原始代码库链接
**分段任意模型(SAM)**能够通过输入提示点或边界框生成高质量物体掩膜,并可为图像中所有物体创建掩膜。该模型在包含1100万张图像和11亿个掩膜的数据集上训练,在各类分割任务中展现出强大的零样本性能。
论文摘要指出:
我们推出Segment Anything(SA)项目:包含全新任务、模型和数据集的全新图像分割方案。通过在数据收集循环中使用高效模型,我们构建了迄今为止最大规模的分割数据集(远超现有),包含基于1100万张经过授权且尊重隐私的图像的10亿以上掩膜。该模型专为提示式设计训练,可实现对新图像分布和任务的零样本迁移。我们在多项任务上评估其能力,发现其零样本性能令人印象深刻——往往能与先前全监督结果媲美甚至更优。我们将发布Segment Anything模型(SAM)及对应数据集(SA-1B)——包含10亿掩膜和1100万图像,访问地址https://segment-anything.com,以促进计算机视觉基础模型研究。
免责声明:本模型卡内容由Hugging Face团队编写,部分内容复制自原始SAM模型卡。
模型详情
SAM模型由3个模块组成:
视觉编码器(VisionEncoder)
:基于VIT架构的图像编码器。通过对图像分块进行注意力计算生成图像嵌入,采用相对位置编码。
提示编码器(PromptEncoder)
:为提示点和边界框生成嵌入表示
掩膜解码器(MaskDecoder)
:双向Transformer架构,在图像嵌入与提示点嵌入之间执行交叉注意力计算(双向),输出结果传递至
颈部网络(Neck)
:根据掩膜解码器生成的上下文掩膜预测最终输出掩膜。
使用方式
提示式掩膜生成
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("facebook/sam-vit-large")
processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
除生成掩膜的基本参数外,您可以传入:目标物体大致位置的2D坐标、包围目标物体的边界框(格式应为边界框左上右下点的x,y坐标)、分割掩膜。截至本文撰写时,根据官方代码库说明,官方模型暂不支持文本输入。更多细节请参考演示笔记本,其中包含可视化示例的使用教程!
自动掩膜生成
该模型可实现"零样本"风格的自动图像分割掩膜生成。模型会自动接收1024个网格点作为提示输入。
以下代码片段展示了自动掩膜生成管道的简易使用方法(可在任何设备运行!只需调整合适的points_per_batch参数):
from transformers import pipeline
generator = pipeline("mask-generation", device = 0, points_per_batch = 256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch = 256)
展示生成结果的代码:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
引用文献
若使用本模型,请引用以下BibTeX条目:
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}