/diffusion_transformer_from_scratch

从0到1手写基于mnist手写数字数据集的diffusion transformer模型复现

Primary LanguagePython

Diffusion_transformer_from_scratch

Introduction

Diffusion Transformers trained on MNIST dataset

用transformer-backbone来替换unet-backbone,用于实现stable diffusion扩散模型

Preliminary

  • 扩散模型的训练与推理过程

diffusion process

Architecture

diffusion transformer architecture

key

    def modulate(x, shift, scale):
        r"""
        Perform dit block shift and scale
        Args:
            x:      torch.tensor, [b, L, c]
            shift:  torch.tensor, [b, c]
            scale:  torch.tensor, [b, c]
        Return:
            torch.tensor, [b, L, c]
        """
        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

Loss

loss

Inference

loss

Todo

Acknowledgements