/transformer_latent_diffusion

Text to Image Latent Diffusion using a Transformer core

Primary LanguagePythonMIT LicenseMIT

Transformer Latent Diffusion

Self Contained Text to Image Latent Diffusion using a Transformer core in PyTorch.

Try with own inputs: Open In Colab

Below are some random examples (at 256 resolution) from a 100MM model trained from scratch for 260k iterations (about 32 hours on 1 A100):

image

Clip interpolation Examples:

a photo of a cat → an anime drawing of a super saiyan cat, artstation:

image

a cute great gray owl → starry night by van gogh:

image

Note that the model has not converged yet and could use more training.

High(er) Resolution:

By upsampling the positional encoding the model can also generate 512 or 1024 px images with minimal fine-tuning. See below for some examples of model fine-tuned on 100k extra 512 px images and 30k 1024 px images for about 2 hours on an A100. The images do sometimes lack global coherence at 1024 px - more to come here:

image image

Intro:

The main goal of this repo is to build an accessible diffusion model in PyTorch that is:

  • fast (close to real time generation)
  • small (~100MM params)
  • reasonably good (of course not SOTA)
  • can be trained in a reasonable amount of time on a single GPU (under 50 hours on an A100 or equivalent).
  • simple self-contained codebase (model + train loop is about ~400 lines of PyTorch with little dependencies)
  • uses ~ 1 million images with a focus on data quality over quantity with code provided for downloading and processing the data

Table of Contents:

Codebase:

The code is written in pure PyTorch with as few dependencies as possible.

  • transformer_blocks.py - basic transformer building blocks relevant to the transformer denoiser
  • denoiser.py - the architecture of the denoiser transformer
  • train.py. The train loop uses accelerate so its training can scale to multiple GPUs if needed.
  • diffusion.py. Class to generate image from noise using reverse diffusion. Short (~60 lines) and self-contained.
  • data.py. Data utils to download images/text and process necessary features for the diffusion model.

Usage:

If you have your own dataset of URLs + captions, the process to train a model on the data consists of two steps:

  1. Use train.download_and_process_data to obtain the latent and text encodings as numpy files. See Open In Colab for a notebook example downloading and processing 2000 images from this HuggingFace dataset.

  2. use the train.main function in an accelerate notebook_launcher - see Open In Colab for a colab notebook that trains a model on 100k images from scratch. Note that this downloads already pre-preprocessed latents and embeddings from here but you could just use whatever .npy files you had saved from step 1.

Install and Dependencies:

To install the package and dependencies run:

pip install git+https://github.com/apapiu/transformer_latent_diffusion.git

  • PyTorch numpy einops for model building
  • wandb tqdm for logging + progress bars
  • accelerate for train loop and multi-GPU support
  • img2dataset webdataset torchvision for data downloading and image processing
  • diffusers clip for pretrained VAE and CLIP text model

Basic Inference code:

from tld.configs import LTDConfig, DenoiserConfig, TrainConfig
from tld.diffusion import DiffusionTransformer

denoiser_cfg = DenoiserConfig(n_channels=4) #configure your model here.
cfg = LTDConfig(denoiser_cfg=denoiser_cfg)

diffusion_transformer = DiffusionTransformer(cfg)

out = diffusion_transformer.generate_image_from_text(prompt="a cute cat")

Basic Training code:

from tld.train import main
from tld.configs import ModelConfig, DataConfig

data_config = DataConfig(
    latent_path="latents.npy", text_emb_path="text_emb.npy", val_path="val_emb.npy"
)

model_cfg = ModelConfig(
    data_config=data_config,
    train_config=TrainConfig(n_epoch=100, save_model=False, compile=False, use_wandb=False),
)

main(model_cfg)

#OR in a notebook ot run the training process on 2 GPUs:
#notebook_launcher(main, model_cfg, num_processes=2)

Tests:

The tests in test_diffuser.py are a good place to start understanding the code. You can run all tests by running pytest -s.

Github Actions:

I have some github action configured to run tests, check linting, and build some docker images - if you're just exploring the code you can comment these out or delete the .github/workflows folder.

Configs:

Configs are in tld/configs.py in the form of dataclasses. The default values can always be overwritten. For example: DenoiserConfig(n_layers=16) keeps all defaults except for n_layers. You can also save your configs as JSON and load them in like so: DenoiserConfig(**json.load(file))

Speed:

I try to speed up training and inference as much as possible by:

  • using mixed precision for training + [sdpa]
  • precompute all latent and text embeddings
  • using float16 precision for inference
  • using [sdpa] for the flash attention 2 + torch.compile() on pyttorch 2.0+
  • use a highly performant sampler (DPM-Solver++(2M)) that gets good results in ~ 15 steps.

The time to generate a batch of 36 images (15 iterations) on a:

  • T4: ~ 3.5 seconds
  • A100: ~ 0.6 seconds In fact on an A100 the vae becomes the bottleneck even though it is only used once.

Codebases used for inspiration:

Examples:

More examples generated with the 100MM model - click the photo to see the prompt and other params like cfg and seed: image image image image image image image

Outpainting model:

I also fine-tuned an outpaing model on top of the original 101MM model. I had to modify the original input conv2d patch to 8 channel and initialize the mask channels parameters to zero. The rest of the architecture remained the same.

Below I apply the outpainting model repatedly to generate a somewhat consistent scenery based on the prompt "a cyberpunk marketplace":

image

Data Processing:

In data.py, I have some helper functions to process images and captions. The flow is as follows:

  • Use img2dataset to download images from a dataframe containing URLs and captions.
  • Use CLIP to encode the prompts and the VAE to encode images to latents on a web2dataset data generator.
  • Save the latents and text embedding for future training.

There are two advantages to this approach. One is that the VAE encoding is somewhat expensive, so doing it every epoch would affect training times. The other is that we can discard the images after processing. For 3*256*256 images, the latent dimension is 4*32*32, so every latent is around 4KB (when quantized in uint8; see here). This means that 1 million latents will be "only" 4GB in size, which is easy to handle even in RAM. Storing the raw images would have been 48x larger in size.

Architecture:

See here for the denoiser class.

The denoiser model is a Transformer-based model based on the archirtecture in DiT and Pixart-Alpha, albeit with quite a few modifications and simplifications. Using a Transformer as the denoiser is different from most diffusion models in that most other models used a CNN-based U-NET as the denoising backbone. I decided to use a Transformer for a few reasons. One was I just wanted to experiment and learn how to build and train Transformers from the ground up. Secondly, Transformers are fast both to train and to do inference on, and they will benefit most from future advances (both in hardware and in software) in performance.

Transformers are not natively built for spatial data and at first I found a lot of the outputs to be very "patchy". To remediy that I added a depth-wise convolution in the FFN layer of the transformer (this was introduced in the Local ViT paper. This allows the model to mix pixels that are close to each other with very little added compute cost.

Img+Text+Noise Encoding:

The image latent inputs are 4*32*32 and we use a patch size of 2 to build 256 flattened 4*2*2=16 dimensional input "pixels". These are then projected into the embed dimensions are are fed through the transformer blocks.

The text and noise conditioning is very simple - we concatenate a pooled CLIP text embedding (ViT/L14 - 768-dimensional) and the sinusoidal noise embedding and feed it as input in the cross-attention layer in each transformer block. No unpooled CLIP embeddings are used.

Training:

The base model is 101MM parameters and has 12 layers and embedding dimension = 768. I train it with a batch size of 256 on a A100 and learning rate of 3e-4. I used 1000 steps for warmup. Due to computational contraints I did not do any ablations for this configuration.

Train and Diffusion Setup:

We train a denoising transformer that takes the following three inputs:

  • noise_level (sampled from 0 to 1 with more values concentrated close to 0 - I use a beta distribution)
  • Image latent (x) corrupted with a level of random noise
    • For a given noise_level between 0 and 1, the corruption is as follows:
      • x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal(0, 1)
  • CLIP embeddings of a text prompt
    • You can think of this as a numerical representation of a text prompt.
    • We use the pooled text embedding here (768 dimensional for ViT/L14)

The output is a prediction of the denoised image latent - call it f(x_noisy).

The model is trained to minimize the mean squared error |f(x_noisy) - x| between the prediction and actual image (you can also use absolute error here). Note that I don't reparameterize the loss in terms of the noise here to keep things simple.

Using this model, we then iteratively generate an image from random noise as follows:

for i in range(len(self.noise_levels) - 1):

  curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]

  # Predict original denoised image:
  x0_pred = predict_x_zero(new_img, label, curr_noise)

  # New image at next_noise level is a weighted average of old image and predicted x0:
  new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise

The predict_x_zero method uses classifier free guidance by combining the conditional and unconditional prediction: x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional

A bit of math: The approach above falls within the VDM parametrization see 3.1 in Kingma et al.:

$$z_t = \alpha_t x + \sigma_t \epsilon, \epsilon \sim \mathcal{N}(0,1)$$

Where $z_t$ is the noisy version of $x$ at time $t$.

Generally, $\alpha_t$ is chosen to be $\sqrt{1-\sigma_t^2}$ so that the process is variance preserving. Here, I chose $\alpha_t=1-\sigma_t$ so that we linearly interpolate between the image and random noise. Why? For one, it simplifies the updating equation quite a bit, and it's easier to understand what the noise to signal ratio will look like. I also found that the model produces sharper images faster. The updating equation above is the DDIM model for this parametrization, which simplifies to a simple weighted average. Note that the DDIM model deterministically maps random normal noise to images - this has two benefits: we can interpolate in the random normal latent space, and it generally takes fewer steps to achieve decent image quality.

TODOS:

  • [] how to speed up generation even more - LCMs?
  • [] add script to compute FID
  • better config in the train file
  • faster sampling - DDPM