指标:
- 准确率
任务标签: 图像分类
基础模型: google/vit-base-patch16-384
模型索引:
- 名称: AdamCodd/vit-base-nsfw-detector
结果:
- 任务:
类型: 图像分类
名称: 图像分类
指标:
- 类型: 准确率
值: 0.9654
名称: 准确率
- 类型: AUC
值: 0.9948
- 类型: 损失值
值: 0.0937
名称: 损失值
许可证: apache-2.0
标签:
- transformers.js
- transformers
- nlp
vit-base-nsfw-detector
该模型是基于约25,000张图像(绘画、照片等)对vit-base-patch16-384进行微调的版本。在评估集上取得了以下结果:
最新动态 [07/30]:我专门为稳定扩散使用创建了一个新的ViT模型来检测NSFW/SFW图像(原因见下方免责声明):AdamCodd/vit-nsfw-stable-diffusion。
免责声明:该模型并非为生成图像设计!所用数据集中不包含任何生成图像,且对生成图像的表现显著较差,这需要另一个专门针对生成图像训练的ViT模型。以下是该模型在生成图像上的实际评分供参考:
- 损失值: 0.3682 (↑ 292.95%)
- 准确率: 0.8600 (↓ 10.91%)
- F1分数: 0.8654
- AUC: 0.9376 (↓ 5.75%)
- 精确率: 0.8350
- 召回率: 0.8980
模型描述
Vision Transformer (ViT) 是一种类似BERT的Transformer编码器模型,通过监督学习在大量图像(ImageNet-21k,分辨率为224x224像素)上进行预训练。随后,模型在更高分辨率(384x384)的ImageNet(又称ILSVRC2012)数据集上进行了微调,该数据集包含100万张图像和1,000个类别。
用途与限制
模型分为两类:SFW(安全)和NSFW(不安全)。模型训练时较为保守,因此会将“性感”图像归类为NSFW。也就是说,如果图像展示过多皮肤或暴露,将被分类为NSFW。这是正常现象。
本地图像使用方法:
from transformers import pipeline
from PIL import Image
img = Image.open("<图像文件路径>")
predict = pipeline("image-classification", model="AdamCodd/vit-base-nsfw-detector")
predict(img)
远程图像使用方法:
from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("预测类别:", model.config.id2label[predicted_class_idx])
使用Transformers.js(原生JavaScript):
import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.1';
env.allowLocalModels = false;
const classifier = await pipeline('image-classification', 'AdamCodd/vit-base-nsfw-detector');
async function classifyImage(url) {
try {
const response = await fetch(url);
if (!response.ok) throw new Error('加载图像失败');
const blob = await response.blob();
const image = new Image();
const imagePromise = new Promise((resolve, reject) => {
image.onload = () => resolve(image);
image.onerror = reject;
image.src = URL.createObjectURL(blob);
});
const img = await imagePromise;
const classificationResults = await classifier([img.src]);
console.log('预测类别: ', classificationResults[0].label);
} catch (error) {
console.error('分类图像时出错:', error);
}
}
classifyImage('https://example.com/path/to/image.jpg');
该模型已在多种图像(写实、3D、绘画)上进行训练,但仍不完美,某些图像可能会被错误分类为NSFW。此外,请注意在transformers.js管道中使用量化ONNX模型会略微降低模型的准确性。您可以在此处找到使用Transformers.js的该模型简易实现。
训练与评估数据
需要更多信息
训练过程
训练超参数
训练过程中使用了以下超参数:
- 学习率: 3e-05
- 训练批次大小: 32
- 评估批次大小: 32
- 随机种子: 42
- 优化器: Adam,参数为betas=(0.9,0.999)和epsilon=1e-08
- 训练轮数: 1
训练结果
- 验证损失: 0.0937
- 准确率: 0.9654
- AUC: 0.9948
混淆矩阵(评估集):
[1076 37]
[ 60 1627]
框架版本
- Transformers 4.36.2
- Evaluate 0.4.1
如果您想支持我,可以点击这里。