A clean implementation for DEIS and iPNDM
import torch as th
from th_deis import DisVPSDE, get_sampler
vpsde = DisVPSDE(discrete_alpha) # assume t_start is 0, t_end=len(discrete_alpha) - 1
def eps_fn(x, scalar_t):
vec_t = (th.ones(x.shape[0])).float().to(x) * scalar_t
with th.no_grad():
return eps_model(x, vec_t)
# ! some model need vec_t shift 1 :(
# ! check trianing setting of your model
# return eps_model(x, vec_t - 1)
sampler_fn = get_sampler(
vpsde,
num_step,
eps_fn,
order=3, # deis support 0,1,2,3, iPNDM will ignore the arg
method="deis", # support deis or iPNDM
)
sample = sampler_fn(noise)
Based on PNDM codebase.
# ! make sure you download checkpoint and modify path in run.sh
cd demo/dis_celeba
bash run.sh
Not tested yet for the torch! See Jax version for tested usage
from th_deis import CntVPSDE, get_sampler
vpsde = CntVPSDE(alpha_fn, t_start, t_end)
sampler_fn = get_sampler(
vpsde,
num_step,
eps_fn,
order=3, # deis support 0,1,2,3, iPNDM will ignore the arg
method="deis", # support deis or iPNDM
)
from jax_deis import CntVPSDE, get_sampler
vpsde = CntVPSDE(alpha_fn, t_start, t_end)
sampler_fn = get_sampler(
vpsde,
num_step,
eps_fn,
order=3, # deis support 0,1,2,3, iPNDM will ignore the arg
method="deis", # support deis or iPNDM
)
Based on score_sde
# ! make sure you download checkpoint and modify path in deis.ipynb
cd demo/cnt_cifar
jupyter lab
# have fun with deis.ipynb
@article{zhang2022fast,
title={Fast Sampling of Diffusion Models with Exponential Integrator},
author={Zhang, Qinsheng and Chen, Yongxin},
journal={arXiv preprint arXiv:2204.13902},
year={2022}
}