🚀 ResNet50 v1.5 图像分类模型
ResNet50 v1.5 是一款用于图像分类的模型,它在原始 ResNet50 v1 模型基础上进行了改进,提升了一定的准确性,可利用 NVIDIA GPU 架构的 Tensor Cores 进行混合精度训练,还能部署在 NVIDIA Triton 推理服务器上进行推理。
🚀 快速开始
本模型可用于图像分类任务,下面将介绍如何使用预训练的 ResNet50 v1.5 模型对图像进行推理并展示结果。
✨ 主要特性
- 改进版本:ResNet50 v1.5 是 原始 ResNet50 v1 模型 的改进版本,在瓶颈块的下采样操作上与 v1 有所不同,使得其准确率比 v1 略高(约 0.5% top1),但性能略有下降(约 5% imgs/sec)。
- 混合精度训练:该模型使用 Volta、Turing 和 NVIDIA Ampere GPU 架构上的 Tensor Cores 进行混合精度训练,研究人员可以比不使用 Tensor Cores 时快 2 倍以上得到结果,同时体验混合精度训练的好处。
- 一致性测试:该模型针对每个 NGC 月度容器版本进行测试,以确保随着时间的推移保持一致的准确性和性能。
- 可部署性:ResNet50 v1.5 模型可以使用 TorchScript、ONNX Runtime 或 TensorRT 作为执行后端,部署在 NVIDIA Triton 推理服务器 上进行推理。
📦 安装指南
运行示例需要安装一些额外的 Python 包,用于图像预处理和可视化:
!pip install validators matplotlib
💻 使用示例
基础用法
以下是使用预训练的 ResNet50 v1.5 模型对图像进行推理的示例代码:
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')
加载在 IMAGENET 数据集上预训练的模型:
resnet50 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_convnets_processing_utils')
resnet50.eval().to(device)
准备样本输入数据:
uris = [
'http://images.cocodataset.org/test-stuff2017/000000024309.jpg',
'http://images.cocodataset.org/test-stuff2017/000000028117.jpg',
'http://images.cocodataset.org/test-stuff2017/000000006149.jpg',
'http://images.cocodataset.org/test-stuff2017/000000004954.jpg',
]
batch = torch.cat(
[utils.prepare_input_from_uri(uri) for uri in uris]
).to(device)
运行推理,使用 pick_n_best(predictions=output, n=topN)
辅助函数根据模型选择 N 个最可能的假设:
with torch.no_grad():
output = torch.nn.functional.softmax(resnet50(batch), dim=1)
results = utils.pick_n_best(predictions=output, n=5)
显示结果:
for uri, result in zip(uris, results):
img = Image.open(requests.get(uri, stream=True).raw)
img.thumbnail((256,256), Image.ANTIALIAS)
plt.imshow(img)
plt.show()
print(result)
📚 详细文档
有关模型输入和输出、训练方法、推理和性能的详细信息,请访问:
github
和/或 NGC
📄 许可证
本项目采用 Apache-2.0 许可证。
🔗 参考资料