imagegpt large
openaiIntroduction
ImageGPT (Large-sized Model) is a transformer decoder model designed for image generation and feature extraction. It is pre-trained on the ImageNet-21k dataset and introduced in the paper "Generative Pretraining from Pixels" by Chen et al. The model is capable of predicting the next pixel value in a sequence of images, allowing it to generate images or extract features for downstream tasks.
Architecture
ImageGPT uses a GPT-like transformer decoder architecture, focusing on generating images by predicting pixel values one by one. The model operates on images sized at 32x32 pixels, converting them into a sequence of cluster values through color-clustering, reducing the input size for the transformer.
Training
Training Data
The model was pre-trained on the ImageNet-21k dataset, which consists of 14 million images across 21,843 classes.
Training Procedure
- Preprocessing: Images are resized to 32x32 pixels and normalized. Each pixel is converted to one of 512 possible cluster values through color-clustering.
- Pretraining: Detailed training procedures are discussed in section 3.4 of the paper "Generative Pretraining from Pixels."
Guide: Running Locally
To use ImageGPT for unconditional image generation, follow these steps:
-
Install the necessary libraries:
pip install torch transformers matplotlib numpy
-
Set up the model in a Python script:
from transformers import ImageGPTImageProcessor, ImageGPTForCausalImageModeling import torch import matplotlib.pyplot as plt import numpy as np processor = ImageGPTImageProcessor.from_pretrained('openai/imagegpt-large') model = ImageGPTForCausalImageModeling.from_pretrained('openai/imagegpt-large') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) batch_size = 8 context = torch.full((batch_size, 1), model.config.vocab_size - 1) # initialize with SOS token context = torch.tensor(context).to(device) output = model.generate(pixel_values=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40) clusters = processor.clusters n_px = processor.size samples = output[:, 1:].cpu().detach().numpy() samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples] f, axes = plt.subplots(1, batch_size, dpi=300) for img, ax in zip(samples_img, axes): ax.axis('off') ax.imshow(img)
-
Execute the script to generate and visualize images.
Cloud GPUs
For optimal performance and faster execution, consider using cloud-based GPU services such as AWS EC2, Google Cloud Platform, or Azure.
License
ImageGPT is released under the Apache-2.0 license.