vit base patch16 224 in21k finetuned cifar10

aaraki

VIT-BASE-PATCH16-224-IN21K-FINETUNED-CIFAR10

Introduction

This model is a fine-tuned version of google/vit-base-patch16-224-in21k on the CIFAR-10 dataset. It is used for image classification tasks and has achieved a loss of 0.2564 and an accuracy of 0.9788 on the evaluation set.

Architecture

The model is based on the Vision Transformer (ViT) architecture, which uses a transformer-based approach for image classification tasks.

Training

The training procedure utilized the following hyperparameters:

  • Learning Rate: 5e-05
  • Training Batch Size: 32
  • Evaluation Batch Size: 32
  • Seed: 42
  • Gradient Accumulation Steps: 4
  • Total Training Batch Size: 128
  • Optimizer: Adam with betas=(0.9, 0.999) and epsilon=1e-08
  • LR Scheduler Type: Linear
  • LR Scheduler Warmup Ratio: 0.1
  • Number of Epochs: 1

Training achieved a loss of 0.4291 at epoch 1, step 390, with a validation loss of 0.2564 and an accuracy of 0.9788.

The model was trained using the following framework versions:

  • Transformers: 4.17.0
  • PyTorch: 1.10.0+cu111
  • Datasets: 2.0.0
  • Tokenizers: 0.11.6

Guide: Running Locally

To run this model locally, follow these basic steps:

  1. Install the required libraries:

    pip install transformers==4.17.0 torch==1.10.0+cu111 datasets==2.0.0 tokenizers==0.11.6
    
  2. Load the model and use it for inference:

    from transformers import ViTForImageClassification, ViTFeatureExtractor
    from PIL import Image
    import requests
    
    # Load model and feature extractor
    model = ViTForImageClassification.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10")
    feature_extractor = ViTFeatureExtractor.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10")
    
    # Load and preprocess image
    url = "https://example.com/path/to/your/image.jpg"
    image = Image.open(requests.get(url, stream=True).raw)
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # Perform inference
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    
    print("Predicted class index:", predicted_class_idx)
    
  3. Optionally, use cloud GPUs such as those provided by AWS, Google Cloud, or Azure for faster inference if running on large datasets or requiring real-time performance.

License

This model is licensed under the Apache-2.0 License.

More Related APIs in Image Classification