conditioned prior
nousrIntroduction
A text-conditioned diffusion prior allows translation between two embedding spaces. This is particularly useful when dealing with models like CLIP, where image and text embeddings occupy distinct spaces. The prior facilitates generating images from text more effectively by bridging these embeddings.
Architecture
The model uses a diffusion prior network with a dimensionality of 768, depth of 24, and specific dropout rates for attention and feed-forward layers. It integrates with an OpenAI CLIP adapter and is trained over 1000 timesteps using the l2 loss type. The architecture is designed to condition on text encodings effectively.
Training
Overview
Training the prior involves using precomputed embeddings to enhance efficiency. Preparing a dataset in the format expected by the EmbeddingReader
is crucial. Once the dataset is ready, training can proceed with the Trainer base class.
Dataset
Precomputed embeddings are recommended for efficiency. Tools like img2dataset
and clip_retrieval
can be used to generate embeddings from images, which are then used in the dataloader for training the prior.
Guide: Running Locally
-
Installation: Ensure you have PyTorch and the required dependencies installed. Clone the repository from DALLE2-pytorch.
-
Loading Checkpoints:
import torch from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch.trainer import DiffusionPriorTrainer def load_diffusion_model(dprior_path): prior_network = DiffusionPriorNetwork( dim=768, depth=24, dim_head=64, heads=32, normformer=True, attn_dropout=5e-2, ff_dropout=5e-2, num_time_embeds=1, num_image_embeds=1, num_text_embeds=1, num_timesteps=1000, ff_mult=4 ) diffusion_prior = DiffusionPrior( net=prior_network, clip=OpenAIClipAdapter("ViT-L/14"), image_embed_dim=768, timesteps=1000, cond_drop_prob=0.1, loss_type="l2", condition_on_text_encodings=True, ) trainer = DiffusionPriorTrainer( diffusion_prior=diffusion_prior, lr=1.1e-4, wd=6.02e-2, max_grad_norm=0.5, amp=False, group_wd_params=True, use_ema=True, device=device, accelerator=None, ) trainer.load(dprior_path) return trainer
-
Sampling:
tokenized_text = clip.tokenize("<your amazing prompt>") predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
-
Cloud GPUs: For efficient execution, consider using cloud-based GPUs like those offered by AWS, GCP, or Azure.
License
This project is licensed under the MIT License.