This is a simple pytorch implementation of Your Diffusion Model is Secretly a Zero-Shot Classifier.
model.py is a minimal implementation of a conditional diffusion model with the ability of Bayesian Inference by Monte Carlo sampling. During training, it learns to generate MNIST digits conditioned on a class label. During inference, it samples pairs of
The conditioning roughly follows the method described in Classifier-Free Diffusion Guidance (also used in ImageGen). The model infuses timestep embeddings
At training time,
Increasing
The basic idea of a diffusion classifier is bayesian inference, that is,
$$ p_\theta(\mathbf{c}_i\mid\mathbf{x})=\frac{p(\mathbf{c}i)\ p\theta(\mathbf{x}\mid\mathbf{c}_i)}{\sum_j p(\mathbf{c}j)\ p\theta(\mathbf{x}\mid\mathbf{c}_j))} $$
A uniform prior over
$$ \begin{aligned} p_\theta(\mathbf{c}i \mid \mathbf{x}) & \approx \frac{\exp {-\mathbb{E}{t, \epsilon}[|\epsilon-\epsilon_\theta(\mathbf{x}t, \mathbf{c}i)|^2]+C}}{\sum_j \exp {-\mathbb{E}{t, \epsilon}[|\epsilon-\epsilon\theta(\mathbf{x}t, \mathbf{c}j)|^2]+C}} \ & =\frac{\exp {-\mathbb{E}{t, \epsilon}[|\epsilon-\epsilon\theta(\mathbf{x}t, \mathbf{c}i)|^2]}}{\sum_j \exp {-\mathbb{E}{t, \epsilon}[|\epsilon-\epsilon\theta(\mathbf{x}_t, \mathbf{c}_j)|^2]}} \end{aligned} $$
which can be estimated by Monte Carlo sampling (see Your Diffusion Model is Secretly a Zero-Shot Classifier).
During bayesian inference, we no longer drop
We trained the conditioned diffusion model for 50 epochs, and performed sampling with
Sample | Acc | Time |
---|---|---|
20 | 98.27 | ~14 min |
50 | 98.98 | ~35 min |
100 | 99.23 | ~70 min |