vit base nsfw detector
AdamCoddIntroduction
The VIT-BASE-NSFW-DETECTOR is a fine-tuned version of the Vision Transformer (ViT), specifically designed to classify images as Safe for Work (SFW) or Not Safe for Work (NSFW). It achieves high accuracy on a dataset of around 25,000 images, performing well in distinguishing between these two categories.
Architecture
The model is based on the Vision Transformer (ViT) architecture, which is a transformer encoder pretrained on ImageNet-21k and fine-tuned on ImageNet. The model operates at a resolution of 384x384 pixels and utilizes a BERT-like transformer encoder.
Training
Training and Evaluation Data
The model was trained on a diverse set of images, including realistic, 3D, and drawings, to enhance its classification capabilities.
Training Procedure
- Learning Rate: 3e-05
- Train Batch Size: 32
- Eval Batch Size: 32
- Seed: 42
- Optimizer: Adam with betas=(0.9, 0.999) and epsilon=1e-08
- Number of Epochs: 1
Training Results
The model achieved the following results on the evaluation set:
- Loss: 0.0937
- Accuracy: 0.9654
- AUC: 0.9948
Guide: Running Locally
Local Image Classification
from transformers import pipeline
from PIL import Image
img = Image.open("<path_to_image_file>")
predict = pipeline("image-classification", model="AdamCodd/vit-base-nsfw-detector")
predict(img)
Distant Image Classification
from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
Using Transformers.js
import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.1';
env.allowLocalModels = false;
const classifier = await pipeline('image-classification', 'AdamCodd/vit-base-nsfw-detector');
async function classifyImage(url) {
try {
const response = await fetch(url);
if (!response.ok) throw new Error('Failed to load image');
const blob = await response.blob();
const image = new Image();
const imagePromise = new Promise((resolve, reject) => {
image.onload = () => resolve(image);
image.onerror = reject;
image.src = URL.createObjectURL(blob);
});
const img = await imagePromise;
const classificationResults = await classifier([img.src]);
console.log('Predicted class: ', classificationResults[0].label);
} catch (error) {
console.error('Error classifying image:', error);
}
}
classifyImage('https://example.com/path/to/image.jpg');
Cloud GPUs
For enhanced performance and faster processing, consider using cloud GPU services such as AWS EC2, Google Cloud Platform, or Azure.
License
The model is licensed under the Apache License 2.0.