数据集:
- ILSVRC/imagenet-1k
许可证: 其他
许可证名称: nvclv1
许可证链接: LICENSE
库名称: transformers
管道标签: 图像分类
标签:
- 图像特征提取
MambaVision:混合Mamba-Transformer视觉骨干网络
模型概述
我们开发了首个结合Mamba和Transformer优势的计算机视觉混合模型。具体而言,我们的核心贡献包括重新设计Mamba公式以增强其对视觉特征的高效建模能力。此外,我们还对将视觉Transformer(ViT)与Mamba集成的可行性进行了全面消融研究。结果表明,在Mamba架构的最后一层加入多个自注意力模块,能显著提升其捕捉长距离空间依赖关系的建模能力。基于这些发现,我们推出了一系列具有层次化架构的MambaVision模型,以满足不同的设计需求。
模型性能
MambaVision在Top-1准确率和吞吐量方面达到了新的SOTA Pareto前沿,展现出强大的性能。
模型使用
强烈建议通过以下命令安装MambaVision所需环境:
pip install mambavision
对于每个模型,我们提供了用于图像分类和特征提取的两种变体,只需一行代码即可导入。
图像分类
以下示例展示了如何使用MambaVision进行图像分类。
给定来自COCO数据集验证集的输入图像:
以下代码片段可用于图像分类:
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("预测类别:", model.config.id2label[predicted_class_idx])
预测标签为棕熊(brown bear, bruin, Ursus arctos)
。
特征提取
MambaVision还可作为通用特征提取器使用。
具体而言,我们可以提取模型每个阶段(共4个阶段)的输出,以及最终经过平均池化并展平的特征。
以下代码片段可用于特征提取:
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
out_avg_pool, features = model(inputs)
print("平均池化特征尺寸:", out_avg_pool.size())
print("提取特征的阶段数量:", len(features))
print("第1阶段提取特征尺寸:", features[0].size())
print("第4阶段提取特征尺寸:", features[3].size())
许可证:
NVIDIA源代码许可证-NC