Diffusion Transformers trained on MNIST dataset
用transformer-backbone来替换unet-backbone,用于实现stable diffusion扩散模型
- 扩散模型的训练与推理过程
def modulate(x, shift, scale):
r"""
Perform dit block shift and scale
Args:
x: torch.tensor, [b, L, c]
shift: torch.tensor, [b, c]
scale: torch.tensor, [b, c]
Return:
torch.tensor, [b, L, c]
"""
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)