license: apache-2.0
pipeline_tag: image-classification
模型卡片:用于NSFW图像分类的微调视觉变换器(ViT)
模型描述
**微调视觉变换器(ViT)**是变换器编码器架构的一个变体,类似于BERT,但已适应图像分类任务。这个名为"google/vit-base-patch16-224-in21k"的特定模型,是在ImageNet-21k数据集上以监督方式预训练的,利用了大量的图像集合。预训练数据集中的图像被调整为224x224像素的分辨率,使其适用于广泛的图像识别任务。
在训练阶段,对超参数设置进行了细致的关注,以确保模型性能的最优化。模型以精心选择的16批次大小进行微调。这一选择不仅平衡了计算效率,还使模型能够有效地处理和学习多样化的图像。
为了促进这一微调过程,采用了5e-5的学习率。学习率是一个关键的调优参数,决定了训练过程中对模型参数调整的幅度。在这种情况下,选择5e-5的学习率是为了在快速收敛和稳定优化之间取得平衡,从而使得模型不仅学习迅速,还能在整个训练过程中稳步提升其能力。
这一训练阶段使用了一个包含80,000张图像的专有数据集执行,每张图像都具有高度的可变性。数据集经过精心策划,包含两个不同的类别,即"normal"(正常)和"nsfw"(不适宜工作场所)。这种多样性使模型能够掌握细微的视觉模式,使其具备准确区分安全和露骨内容的能力。
这一细致训练过程的总体目标是让模型深入理解视觉线索,确保其在处理NSFW图像分类这一特定任务时的鲁棒性和能力。最终得到的模型能够为内容安全和审核做出重要贡献,同时保持最高的准确性和可靠性标准。
预期用途与限制
预期用途
- NSFW图像分类:该模型的主要预期用途是对NSFW(不适宜工作场所)图像进行分类。它已针对此目的进行了微调,适用于在各种应用中过滤露骨或不适当的内容。
使用方法
以下是使用该模型对图像进行分类的方法(基于2个类别:normal和nsfw):
# 使用pipeline作为高级辅助工具
from PIL import Image
from transformers import pipeline
img = Image.open("<图片文件路径>")
classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
classifier(img)
# 直接加载模型
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor
img = Image.open("<图片文件路径>")
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
with torch.no_grad():
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_label = logits.argmax(-1).item()
model.config.id2label[predicted_label]
运行YOLO版本
import os
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import onnxruntime as ort
import json
# 使用YOLOv9模型进行预测
def predict_with_yolov9(image_path, model_path, labels_path, input_size):
"""
使用转换后的YOLOv9模型对单张图像进行推理。
参数:
image_path (str): 输入图像文件路径。
model_path (str): ONNX模型文件路径。
labels_path (str): 包含类别标签的JSON文件路径。
input_size (tuple): 模型的预期输入尺寸(高度,宽度)。
返回:
str: 预测的类别标签。
PIL.Image.Image: 原始加载的图像。
"""
def load_json(file_path):
with open(file_path, "r") as f:
return json.load(f)
# 加载标签
labels = load_json(labels_path)
# 预处理图像
original_image = Image.open(image_path).convert("RGB")
image_resized = original_image.resize(input_size, Image.Resampling.BILINEAR)
image_np = np.array(image_resized, dtype=np.float32) / 255.0
image_np = np.transpose(image_np, (2, 0, 1)) # [C, H, W]
input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)
# 加载YOLOv9模型
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 运行推理
outputs = session.run([output_name], {input_name: input_tensor})
predictions = outputs[0]
# 后处理预测结果
predicted_index = np.argmax(predictions)
predicted_label = labels[str(predicted_index)]
return predicted_label, original_image
# 显示单张图像的预测结果
def display_single_prediction(image_path, model_path, labels_path, input_size):
"""
预测单张图像的类别并显示图像及其预测结果。
参数:
image_path (str): 输入图像文件路径。
model_path (str): ONNX模型文件路径。
labels_path (str): 包含类别标签的JSON文件路径。
input_size (tuple): 模型的预期输入尺寸(高度,宽度)。
"""
try:
# 运行预测
prediction, img = predict_with_yolov9(image_path, model_path, labels_path, input_size)
# 显示图像和预测结果
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(img)
ax.set_title(f"预测结果: {prediction}", fontsize=14)
ax.axis("off")
plt.tight_layout()
plt.show()
except FileNotFoundError:
print(f"错误: 图像文件未找到: {image_path}")
except Exception as e:
print(f"发生错误: {e}")
# --- 主执行部分 ---
# 路径和参数 - **请修改这些**
single_image_path = "path/to/your/single_image.jpg" # <--- 替换为实际图像文件路径
model_path = "path/to/your/yolov9_model.onnx" # <--- 替换为实际ONNX模型路径
labels_path = "path/to/your/labels.json" # <--- 替换为实际标签JSON文件路径
input_size = (224, 224) # 标准输入尺寸,根据模型调整
# 检查图像文件是否存在(可选但推荐)
if os.path.exists(single_image_path):
# 运行预测并显示结果
display_single_prediction(single_image_path, model_path, labels_path, input_size)
else:
print(f"错误: 指定的图像文件不存在: {single_image_path}")
限制
- 特定任务微调:虽然该模型擅长NSFW图像分类,但在应用于其他任务时性能可能会有所不同。
- 有兴趣将该模型用于其他任务的用户应探索模型中心提供的微调版本以获得最佳结果。
训练数据
模型的训练数据包括一个包含约80,000张图像的专有数据集。该数据集具有高度的可变性,包含两个不同的类别:"normal"和"nsfw"。在这一数据上的训练过程旨在使模型能够有效区分安全和露骨内容。
训练统计
- 'eval_loss': 0.07463177293539047,
- 'eval_accuracy': 0.980375,
- 'eval_runtime': 304.9846,
- 'eval_samples_per_second': 52.462,
- 'eval_steps_per_second': 3.279
注意:在使用该模型时,必须负责任且符合道德规范,特别是在涉及潜在敏感内容的实际应用中,应遵守内容指南和相关法规。
有关模型微调和使用的更多详情,请参阅模型的文档和模型中心。
参考文献
免责声明:模型的性能可能受其微调数据的质量和代表性的影响。鼓励用户评估模型对其特定应用和数据集的适用性。