模型简介
模型特点
模型能力
使用案例
库名称:transformers 流水线标签:掩码生成 许可证:apache-2.0 标签:
- 视觉
高质量分割万物模型(SAM-HQ)模型卡
SAM-HQ与原版SAM模型的架构对比,展示了HQ输出令牌和全局-局部特征融合组件。
目录
摘要
SAM-HQ(高质量分割万物)是Segment Anything Model(SAM)的增强版本,能够从点或框等输入提示生成更高质量的对象掩码。虽然SAM在1100万张图像和11亿个掩码的数据集上进行了训练,但其掩码预测质量在许多情况下仍显不足,尤其是在处理具有复杂结构的对象时。SAM-HQ以最少的额外参数和计算成本解决了这些限制。
该模型擅长生成高质量的分割掩码,即使对于具有复杂边界和细微结构的对象,原版SAM模型常常难以处理的情况也能应对。SAM-HQ保留了SAM原有的可提示设计、效率和零样本泛化能力,同时显著提高了掩码质量。
模型详情
SAM-HQ在原版SAM架构的基础上进行了两项关键创新,同时保留了SAM的预训练权重:
-
高质量输出令牌:一个可学习的令牌,注入到SAM的掩码解码器中,负责预测高质量掩码。与SAM原有的输出令牌不同,该令牌及其相关的MLP层专门训练用于生成高度准确的分割掩码。
-
全局-局部特征融合:SAM-HQ不仅将HQ输出令牌应用于掩码解码器特征,还首先将这些特征与早期和最终的ViT特征融合,以改善掩码细节。这结合了高级语义上下文和低级边界信息,实现更准确的分割。
SAM-HQ在精心挑选的44K精细掩码(HQSeg-44K)数据集上进行了训练,这些数据来自多个来源,具有极其准确的标注。训练过程仅需8个GPU上4小时,相比原版SAM模型增加了不到0.5%的参数。
该模型已在10个不同的分割数据集上进行了评估,涵盖多种下游任务,其中8个数据集采用零样本迁移协议进行评估。结果表明,SAM-HQ能够生成比原版SAM模型显著更好的掩码,同时保持其零样本泛化能力。
SAM-HQ解决了原版SAM模型的两个关键问题:
- 粗糙的掩码边界,常常忽略细微对象结构
- 在具有挑战性的情况下出现错误预测、断裂掩码或较大误差
这些改进使SAM-HQ特别适用于需要高度准确图像掩码的应用,如自动化标注和图像/视频编辑任务。
使用方法
提示掩码生成
from PIL import Image
import requests
from transformers import SamHQModel, SamHQProcessor
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-large")
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-large")
img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_boxes = [[[306, 132, 925, 893]]] # 图像的边界框
inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
在生成掩码的其他参数中,您可以传递对象大致位置的2D坐标、包围对象的边界框(格式应为边界框左上角和右下角的x、y坐标)、分割掩码。在撰写本文时,官方模型不支持传递文本作为输入,根据官方仓库。更多详情,请参考此笔记本,其中展示了如何使用模型的逐步指南,并附有可视化示例!
自动掩码生成
该模型可用于在给定输入图像的情况下以“零样本”方式生成分割掩码。模型会自动提示一个包含1024
个点的网格,所有这些点都会输入到模型中。
流水线专为自动掩码生成而设计。以下代码片段展示了如何轻松运行它(在任何设备上!只需提供适当的points_per_batch
参数)
from transformers import pipeline
generator = pipeline("mask-generation", model="syscv-community/sam-hq-vit-large", device=0, points_per_batch=256)
image_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
outputs = generator(image_url, points_per_batch=256)
现在显示图像:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
完整示例与可视化
import numpy as np
import matplotlib.pyplot as plt
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
def show_boxes_on_image(raw_image, boxes):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
for box in boxes:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
def show_points_on_image(raw_image, input_points, input_labels=None):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
input_points = np.array(input_points)
if input_labels is None:
labels = np.ones_like(input_points[:, 0])
else:
labels = np.array(input_labels)
show_points(input_points, labels, plt.gca())
plt.axis('on')
plt.show()
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
input_points = np.array(input_points)
if input_labels is None:
labels = np.ones_like(input_points[:, 0])
else:
labels = np.array(input_labels)
show_points(input_points, labels, plt.gca())
for box in boxes:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
input_points = np.array(input_points)
if input_labels is None:
labels = np.ones_like(input_points[:, 0])
else:
labels = np.array(input_labels)
show_points(input_points, labels, plt.gca())
for box in boxes:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_masks_on_image(raw_image, masks, scores):
if len(masks.shape) == 4:
masks = masks.squeeze()
if scores.shape[0] == 1:
scores = scores.squeeze()
nb_predictions = scores.shape[-1]
fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
for i, (mask, score) in enumerate(zip(masks, scores)):
mask = mask.cpu().detach()
axes[i].imshow(np.array(raw_image))
show_mask(mask, axes[i])
axes[i].title.set_text(f"掩码 {i+1}, 分数: {score.item():.3f}")
axes[i].axis("off")
plt.show()
def show_masks_on_single_image(raw_image, masks, scores):
if len(masks.shape) == 4:
masks = masks.squeeze()
if scores.shape[0] == 1:
scores = scores.squeeze()
# 如果图像尚未转换为numpy数组,则进行转换
image_np = np.array(raw_image)
# 创建图形
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(image_np)
# 在同一图像上叠加所有掩码
for i, (mask, score) in enumerate(zip(masks, scores)):
mask = mask.cpu().detach().numpy() # 转换为NumPy
show_mask(mask, ax) # 假设`show_mask`正确叠加掩码
ax.set_title(f"叠加掩码与分数")
ax.axis("off")
plt.show()
import torch
from transformers import SamHQModel, SamHQProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-large").to(device)
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-large")
from PIL import Image
import requests
img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
plt.imshow(raw_image)
inputs = processor(raw_image, return_tensors="pt").to(device)
image_embeddings, intermediate_embeddings = model.get_image_embeddings(inputs["pixel_values"])
input_boxes = [[[306, 132, 925, 893]]]
show_boxes_on_image(raw_image, input_boxes[0])
inputs.pop("pixel_values", None)
inputs.update({"image_embeddings": image_embeddings})
inputs.update({"intermediate_embeddings": intermediate_embeddings})
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
show_masks_on_single_image(raw_image, masks[0], scores)
show_masks_on_image(raw_image, masks[0], scores)
引用
@misc{ke2023segmenthighquality,
title={高质量分割万物},
author={柯磊、叶明桥、Martin Danelljan、刘一凡、戴宇荣、唐志强、Fisher Yu},
year={2023},
eprint={2306.01567},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2306.01567},
}











