数据集:
- ILSVRC/imagenet-21k
许可证: other
许可证名称: nvclv1
许可证链接: LICENSE
管道标签: image-classification
库名称: transformers
MambaVision:混合曼巴-Transformer视觉骨干网络
项目主页
模型概述
我们开发了首个结合曼巴(Mamba)与Transformer优势的计算机视觉混合模型。核心创新包括:重新设计曼巴公式以增强视觉特征建模能力,并通过全面消融实验验证了视觉Transformer(ViT)与曼巴架构融合的可行性。结果表明,在曼巴架构最后几层加入自注意力模块能显著提升长距离空间依赖的建模能力。基于此,我们推出具有层级结构的MambaVision模型系列以满足不同设计需求。
模型性能
MambaVision-L3-512-21K在ImageNet-21K数据集预训练后,以512×512分辨率在ImageNet-1K上微调。
名称 |
Top1准确率(%) |
Top5准确率(%) |
参数量(M) |
计算量(G) |
分辨率 |
MambaVision-L3-512-21K |
88.1 |
98.6 |
739.6 |
489.1 |
512×512 |
该系列模型在Top1准确率与吞吐量方面实现了新的SOTA帕累托前沿。
使用指南
建议通过以下命令安装依赖:
pip install mambavision
图像分类示例
以下代码展示如何使用MambaVision进行图像分类(以COCO数据集图片为例):
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L3-512-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
transform = create_transform(input_size=(3,512,512), is_training=False,
mean=model.config.mean, std=model.config.std,
crop_mode=model.config.crop_mode, crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
print("预测类别:", model.config.id2label[outputs['logits'].argmax(-1).item()])
输出示例:棕熊(Ursus arctos)
特征提取示例
以下代码展示如何获取四阶段特征图及全局池化特征:
from transformers import AutoModel
model = AutoModel.from_pretrained("nvidia/MambaVision-L3-512-21K", trust_remote_code=True)
model.cuda().eval()
out_avg_pool, features = model(inputs)
print("全局池化特征尺寸:", out_avg_pool.size())
print("阶段特征数量:", len(features))
print("第一阶段特征尺寸:", features[0].size())
print("第四阶段特征尺寸:", features[3].size())
许可证
NVIDIA受限源代码许可证
预训练模型
ImageNet-21K
模型 | Top1(%) | Top5(%) | 参数量(M) | 计算量(G) | 分辨率 | HuggingFace | 下载 |
MambaVision-B-21K | 84.9 | 97.5 | 97.7 | 15.0 | 224×224 | 链接 | 模型 |
ImageNet-1K
模型 | Top1(%) | Top5(%) | 吞吐量(图/秒) | 分辨率 | 参数量(M) | 计算量(G) | HuggingFace | 下载 |
MambaVision-T | 82.3 | 96.2 | 6298 | 224×224 | 31.8 | 4.4 | 链接 | 模型 |
安装
除提供Dockerfile外,若已安装PyTorch,可通过以下命令安装依赖:
pip install -r requirements.txt