许可协议:apache-2.0
任务类型:图像分类
标签:
- 安全检测器
- 敏感内容过滤器
评估指标:
- 准确率
库名称:timm
SafeSearch v3.1 已发布
Google SafeSearch Mini V2 是一款超高精度的多类图像分类器,可准确检测敏感内容
Google SafeSearch Mini V2 的训练方式与 V1 不同,它采用了 InceptionResNetV2 架构,并使用了约 3,400,000 张 从互联网随机采集的图像数据集,其中部分数据通过数据增强生成。训练和验证数据来自 Google 图片、Reddit、Kaggle 和 Imgur,并由公司、Google SafeSearch 和内容审核员分类为安全或敏感内容。
在通过交叉熵损失训练模型 5 个周期后,对训练集和验证集进行评估,筛选出预测概率低于 0.90 的图像,对精选数据集进行必要修正后,再额外训练了 8 个周期。接着,我在模型可能难以分类的多种案例上测试,发现它会将棕色猫的毛发误判为人类皮肤。为提高准确率,我用 Kaggle 的 15 个额外数据集 对模型进行了 1 个周期的微调,最后用训练集和测试集的组合数据完成了最终周期的训练。最终在训练和验证数据上均达到了 97% 的准确率。
安全搜索过滤器不仅是社交媒体内容审核的强大工具,还可用于过滤数据集。与稳定扩散安全检测器相比,该模型具有显著优势——用户可节省 1.0 GB 的内存和磁盘空间。
PyTorch
pip install --upgrade torchvision
import torch, os
from torchvision import transforms
from PIL import Image
import urllib.request
import timm
image_path = "https://www.allaboutcats.ca/wp-content/uploads/sites/235/2022/03/shutterstock_320462102-2500-e1647917149997.jpg"
device = "cuda"
def preprocess_image(image_path):
transform = transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
if image_path.startswith('http://') or image_path.startswith('https://'):
import requests
from io import BytesIO
response = requests.get(image_path)
img = Image.open(BytesIO(response.content)).convert('RGB')
else:
img = Image.open(image_path).convert('RGB')
img = transform(img).unsqueeze(0)
img = img.cuda() if device.lower() == "cuda" else img.cpu()
return img
def eval():
model = timm.create_model("hf_hub:FredZhang7/google-safesearch-mini-v2", pretrained=True)
model.to(device)
img = preprocess_image(image_path)
with torch.no_grad():
out = model(img)
_, predicted = torch.max(out.data, 1)
classes = {
0: 'nsfw_gore',
1: 'nsfw_suggestive',
2: 'safe'
}
print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')
if __name__ == '__main__':
eval()