vit base patch16 224 in21k finetuned cifar10
aarakiVIT-BASE-PATCH16-224-IN21K-FINETUNED-CIFAR10
Introduction
This model is a fine-tuned version of google/vit-base-patch16-224-in21k
on the CIFAR-10 dataset. It is used for image classification tasks and has achieved a loss of 0.2564 and an accuracy of 0.9788 on the evaluation set.
Architecture
The model is based on the Vision Transformer (ViT) architecture, which uses a transformer-based approach for image classification tasks.
Training
The training procedure utilized the following hyperparameters:
- Learning Rate: 5e-05
- Training Batch Size: 32
- Evaluation Batch Size: 32
- Seed: 42
- Gradient Accumulation Steps: 4
- Total Training Batch Size: 128
- Optimizer: Adam with betas=(0.9, 0.999) and epsilon=1e-08
- LR Scheduler Type: Linear
- LR Scheduler Warmup Ratio: 0.1
- Number of Epochs: 1
Training achieved a loss of 0.4291 at epoch 1, step 390, with a validation loss of 0.2564 and an accuracy of 0.9788.
The model was trained using the following framework versions:
- Transformers: 4.17.0
- PyTorch: 1.10.0+cu111
- Datasets: 2.0.0
- Tokenizers: 0.11.6
Guide: Running Locally
To run this model locally, follow these basic steps:
-
Install the required libraries:
pip install transformers==4.17.0 torch==1.10.0+cu111 datasets==2.0.0 tokenizers==0.11.6
-
Load the model and use it for inference:
from transformers import ViTForImageClassification, ViTFeatureExtractor from PIL import Image import requests # Load model and feature extractor model = ViTForImageClassification.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10") feature_extractor = ViTFeatureExtractor.from_pretrained("aaraki/vit-base-patch16-224-in21k-finetuned-cifar10") # Load and preprocess image url = "https://example.com/path/to/your/image.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = feature_extractor(images=image, return_tensors="pt") # Perform inference outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() print("Predicted class index:", predicted_class_idx)
-
Optionally, use cloud GPUs such as those provided by AWS, Google Cloud, or Azure for faster inference if running on large datasets or requiring real-time performance.
License
This model is licensed under the Apache-2.0 License.