conditioned prior

nousr

Introduction

A text-conditioned diffusion prior allows translation between two embedding spaces. This is particularly useful when dealing with models like CLIP, where image and text embeddings occupy distinct spaces. The prior facilitates generating images from text more effectively by bridging these embeddings.

Architecture

The model uses a diffusion prior network with a dimensionality of 768, depth of 24, and specific dropout rates for attention and feed-forward layers. It integrates with an OpenAI CLIP adapter and is trained over 1000 timesteps using the l2 loss type. The architecture is designed to condition on text encodings effectively.

Training

Overview

Training the prior involves using precomputed embeddings to enhance efficiency. Preparing a dataset in the format expected by the EmbeddingReader is crucial. Once the dataset is ready, training can proceed with the Trainer base class.

Dataset

Precomputed embeddings are recommended for efficiency. Tools like img2dataset and clip_retrieval can be used to generate embeddings from images, which are then used in the dataloader for training the prior.

Guide: Running Locally

  1. Installation: Ensure you have PyTorch and the required dependencies installed. Clone the repository from DALLE2-pytorch.

  2. Loading Checkpoints:

    import torch
    from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
    from dalle2_pytorch.trainer import DiffusionPriorTrainer
    
    def load_diffusion_model(dprior_path):
        prior_network = DiffusionPriorNetwork(
            dim=768,
            depth=24,
            dim_head=64,
            heads=32,
            normformer=True,
            attn_dropout=5e-2,
            ff_dropout=5e-2,
            num_time_embeds=1,
            num_image_embeds=1,
            num_text_embeds=1,
            num_timesteps=1000,
            ff_mult=4
        )
    
        diffusion_prior = DiffusionPrior(
            net=prior_network,
            clip=OpenAIClipAdapter("ViT-L/14"),
            image_embed_dim=768,
            timesteps=1000,
            cond_drop_prob=0.1,
            loss_type="l2",
            condition_on_text_encodings=True,
        )
    
        trainer = DiffusionPriorTrainer(
            diffusion_prior=diffusion_prior,
            lr=1.1e-4,
            wd=6.02e-2,
            max_grad_norm=0.5,
            amp=False,
            group_wd_params=True,
            use_ema=True,
            device=device,
            accelerator=None,
        )
    
        trainer.load(dprior_path)
        return trainer
    
  3. Sampling:

    tokenized_text = clip.tokenize("<your amazing prompt>")
    predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
    
  4. Cloud GPUs: For efficient execution, consider using cloud-based GPUs like those offered by AWS, GCP, or Azure.

License

This project is licensed under the MIT License.

More Related APIs