vlievin/biva-pytorch

Problem when sampling from the posterior

Closed this issue · 1 comments

Hello,

I would like to compare the quality of the reconstructed images with the quality of the generated images, in other terms, to compare the images generated by sampling via the prior and those generated by using the estimated posteriors.

Once BIVA is trained on Cifar10 dataset, I used the following code (starting after https://github.com/vlievin/biva-pytorch/blob/master/run_deepvae.py#L192)

def build_and_save_grid(data, filename, N=100):
     # make grid
    nrow = math.floor(math.sqrt(N))
    grid = make_grid(data, nrow=nrow)

    # normalize
    grid -= grid.min()
    grid /= grid.max()

    # save the raw image
    img = grid.data.permute(1, 2, 0).cpu().numpy()
    matplotlib.image.imsave(f"./output/{filename}.png", img)

load_model(ema.model, logdir)
           
with torch.no_grad():
    x = next(iter(train_loader)).to(opt.device)
    build_and_save_grid(x, "original")

    x_ = ema.model(x).get('x_')
    x_ = likelihood(logits=x_).sample()
    build_and_save_grid(x_, "reconstruction")

    x_ = ema.model.sample_from_prior(100).get('x_')
    x_ = likelihood(logits=x_).sample()
    build_and_save_grid(x_, "generation")

I was quite surprised to observe that the reconstructed images have a very poor quality compared to the generated images:
Reconstruction:
reconstruction
Generation:
generation

I assume that there is an error somewhere, but I couldn't find it. Do you have any idea?

Thanks,
Reuben

As we discussed outside this channel, this issue is related to the data-dependent initialisation, which modifies the parameters during the 1st forward pass. Hence, 1. instantiating the model, 2. loading the pretrained weights and then 3. performing a forward pass will alter the loaded weights.

Using the current code, one can avoid this issue by performing the steps in the order 1, 3, 2 and finally 3. The above snippet of code would become:

# perform data-dependent init before loading the weights
x = next(iter(train_loader)).to(opt.device)
ema.model(x)

# then load the weights
load_model(ema.model, logdir)
           
with torch.no_grad():
    x = next(iter(train_loader)).to(opt.device)
    build_and_save_grid(x, "original")

    x_ = ema.model(x).get('x_')
    x_ = likelihood(logits=x_).sample()
    build_and_save_grid(x_, "reconstruction")

    x_ = ema.model.sample_from_prior(100).get('x_')
    x_ = likelihood(logits=x_).sample()
    build_and_save_grid(x_, "generation")

I am a patch such as the boolean initialized in each layer is a module parameter, such as it can be saved and loaded automatically, hence solving this issue.