vit base nsfw detector

AdamCodd

Introduction

The VIT-BASE-NSFW-DETECTOR is a fine-tuned version of the Vision Transformer (ViT), specifically designed to classify images as Safe for Work (SFW) or Not Safe for Work (NSFW). It achieves high accuracy on a dataset of around 25,000 images, performing well in distinguishing between these two categories.

Architecture

The model is based on the Vision Transformer (ViT) architecture, which is a transformer encoder pretrained on ImageNet-21k and fine-tuned on ImageNet. The model operates at a resolution of 384x384 pixels and utilizes a BERT-like transformer encoder.

Training

Training and Evaluation Data

The model was trained on a diverse set of images, including realistic, 3D, and drawings, to enhance its classification capabilities.

Training Procedure

  • Learning Rate: 3e-05
  • Train Batch Size: 32
  • Eval Batch Size: 32
  • Seed: 42
  • Optimizer: Adam with betas=(0.9, 0.999) and epsilon=1e-08
  • Number of Epochs: 1

Training Results

The model achieved the following results on the evaluation set:

  • Loss: 0.0937
  • Accuracy: 0.9654
  • AUC: 0.9948

Guide: Running Locally

Local Image Classification

from transformers import pipeline
from PIL import Image

img = Image.open("<path_to_image_file>")
predict = pipeline("image-classification", model="AdamCodd/vit-base-nsfw-detector")
predict(img)

Distant Image Classification

from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Using Transformers.js

import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.1';

env.allowLocalModels = false;

const classifier = await pipeline('image-classification', 'AdamCodd/vit-base-nsfw-detector');

async function classifyImage(url) {
  try {
    const response = await fetch(url);
    if (!response.ok) throw new Error('Failed to load image');

    const blob = await response.blob();
    const image = new Image();
    const imagePromise = new Promise((resolve, reject) => {
      image.onload = () => resolve(image);
      image.onerror = reject;
      image.src = URL.createObjectURL(blob);
    });

    const img = await imagePromise;
    const classificationResults = await classifier([img.src]);
    console.log('Predicted class: ', classificationResults[0].label);
  } catch (error) {
    console.error('Error classifying image:', error);
  }
}

classifyImage('https://example.com/path/to/image.jpg');

Cloud GPUs

For enhanced performance and faster processing, consider using cloud GPU services such as AWS EC2, Google Cloud Platform, or Azure.

License

The model is licensed under the Apache License 2.0.

More Related APIs in Image Classification