T D D
RED-AIGCIntroduction
Target-Driven Distillation (TDD) introduces three innovative techniques that enhance training efficiency and flexibility compared to traditional consistency distillation methods:
- Target Timestep Selection: TDD employs a strategic selection of target timesteps from predefined equidistant denoising schedules (e.g., 4-8 steps), incorporating stochastic offsets for non-deterministic sampling.
- Decoupled Guidance: During training, TDD uses decoupled guidance, allowing for post-tuning on guidance scales during inference. It achieves this by replacing some text prompts with unconditional (empty) prompts, aligning with standard CFG training processes.
- Flexible Sampling Options: TDD supports non-equidistant sampling and x0 clipping, providing a more flexible and accurate image sampling method.
Architecture
The TDD architecture involves a training process that incorporates target timestep selection and decoupled guidance. During inference, it can optionally utilize non-equidistant denoising schedules, enhancing image complexity and clarity compared to other methods.
Training
TDD's training strategy focuses on selecting effective target timesteps and utilizing decoupled guidance to improve efficiency and allow for flexible post-training adjustments. This approach results in higher-quality image generation, surpassing other consistency distillation methods in both complexity and clarity.
Guide: Running Locally
To run the TDD model locally, follow these steps:
- Setup Environment: Ensure you have Python installed and set up a virtual environment.
- Install Dependencies: Run
pip install opencv-python transformers accelerate torch diffusers
. - Download Model Weights: Use the
hf_hub_download
function fromhuggingface_hub
to download the necessary model weights. - Load and Run the Model:
- For FLUX:
from huggingface_hub import hf_hub_download from diffusers import FluxPipeline pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipe.load_lora_weights(hf_hub_download("RED-AIGC/TDD", "TDD-FLUX.1-dev-lora-beta.safetensors")) pipe.fuse_lora(lora_scale=0.125) pipe.to("cuda") image_flux = pipe( prompt=[prompt], generator=torch.Generator().manual_seed(3413), num_inference_steps=8, guidance_scale=2.0, height=1024, width=1024, max_sequence_length=256 ).images[0]
- For SDXL:
import torch import diffusers from diffusers import StableDiffusionXLPipeline from tdd_scheduler import TDDScheduler device = "cuda" tdd_lora_path = "tdd_lora/sdxl_tdd_lora_weights.safetensors" pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16").to(device) pipe.scheduler = TDDSchedulerPlus.from_config(pipe.scheduler.config) pipe.load_lora_weights(tdd_lora_path, adapter_name="accelerate") pipe.fuse_lora() prompt = "A photo of a cat made of water." image = pipe( prompt=prompt, num_inference_steps=4, guidance_scale=1.7, eta=0.2, generator=torch.Generator(device=device).manual_seed(546237), ).images[0] image.save("tdd.png")
- For FLUX:
For optimal performance, consider using a cloud GPU service such as AWS, Google Cloud, or Azure.
License
The TDD model is licensed under the Apache-2.0 License.