vit base patch16 224 in21k
googleIntroduction
The Vision Transformer (ViT) is a transformer-based model designed for image recognition tasks. Pre-trained on the extensive ImageNet-21k dataset, the ViT model processes images as sequences of fixed-size patches, allowing it to learn meaningful image representations. This model is suitable for raw image classification and can be fine-tuned for various vision-related tasks.
Architecture
The ViT model utilizes a transformer encoder architecture similar to BERT. It processes images by dividing them into 16x16 pixel patches, which are embedded linearly and supplemented with a [CLS] token for classification purposes. Absolute position embeddings are added before these sequences are input into the transformer's layers. The model does not include fine-tuned heads but retains a pre-trained pooler for downstream tasks.
Training
The ViT model is pre-trained on ImageNet-21k, featuring 14 million images across 21,843 classes. Preprocessing involves resizing images to 224x224 pixels and normalizing them across RGB channels. Training was conducted on TPUv3 hardware with a batch size of 4096 and a learning rate warm-up for 10k steps. Gradient clipping was used to enhance performance.
Guide: Running Locally
To run the ViT model locally, follow these steps:
-
Install Required Libraries: Ensure you have the
transformers
library installed.pip install transformers
-
Load and Process Image: Use PyTorch or JAX/Flax to load and process images.
- PyTorch Example:
from transformers import ViTImageProcessor, ViTModel from PIL import Image import requests url = 'http://images.cocodataset.org/val2017/000000039769.jpg' image = Image.open(requests.get(url, stream=True).raw) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) last_hidden_states = outputs.last_hidden_state
- JAX/Flax Example:
from transformers import ViTImageProcessor, FlaxViTModel from PIL import Image import requests url = 'http://images.cocodataset.org/val2017/000000039769.jpg' image = Image.open(requests.get(url, stream=True).raw) processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k') inputs = processor(images=image, return_tensors="np") outputs = model(**inputs) last_hidden_states = outputs.last_hidden_state
- PyTorch Example:
-
Cloud GPUs: For intensive tasks, consider using cloud GPU services like AWS, Google Cloud, or Azure to speed up processing.
License
The ViT model is released under the Apache 2.0 license, allowing for both personal and commercial use.