数据集:
- ILSVRC/imagenet-1k
许可证: 其他
许可证名称: nvclv1
许可证链接: LICENSE
任务类型: 图像分类
库名称: transformers
MambaVision:混合型Mamba-Transformer视觉骨干网络
代码仓库: https://github.com/NVlabs/MambaVision
模型概述
我们开发了首个融合Mamba与Transformer优势的计算机视觉混合模型。核心创新包括重构Mamba公式以增强其视觉特征建模能力,并通过系统实验验证了视觉Transformer(ViT)与Mamba结合的可行性。研究表明,在Mamba架构最后几层加入自注意力模块能显著提升长程空间依赖的建模能力。基于此发现,我们推出具有分层架构的MambaVision系列模型,满足不同设计需求。
模型性能
MambaVision在Top-1准确率与计算吞吐量方面创造了新的SOTA帕累托前沿。
使用说明
建议通过以下命令安装MambaVision所需环境:
pip install mambavision
每个模型均提供图像分类和特征提取两种变体,一行代码即可导入。
图像分类
以下示例展示MambaVision用于图像分类:
输入COCO数据集验证集的示例图像:
分类代码片段:
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L2-1K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_pct,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("预测类别:", model.config.id2label[predicted_class_idx])
输出结果为棕熊(学名:Ursus arctos)
。
特征提取
MambaVision也可作为通用特征提取器使用,支持提取:
特征提取代码片段:
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-L2-1K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_pct,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
out_avg_pool, features = model(inputs)
print("平均池特征尺寸:", out_avg_pool.size())
print("特征阶段数量:", len(features))
print("第一阶段特征尺寸:", features[0].size())
print("第四阶段特征尺寸:", features[3].size())
许可证
NVIDIA受限源代码许可证