dome272/Diffusion-Models-pytorch

model generating bad random images

NITHISHM2410 opened this issue · 2 comments

I trained my diffusion model in tensorflow based on this implementation and and after training for 450 epochs(on landscape dataset) ,my loss was around 0.015 (mse) and I generated a few samples and generated ones were very bad or random. Below are the generated images for 1000 time steps .

I just want to know is this a training issue , does my model need more training to further reduce the loss (currently : 0.015) OR is the problem in sampling technique.Can anyone help me please?

image

Hi, I encountered the same issue.

def sample(self, model, n):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

Make sure you use the same variable (I am talking about the variable x here) name in these lines:

x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
predicted_noise = model(x, t)
noise = torch.randn_like(x)
noise = torch.zeros_like(x)
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
x = (x.clamp(-1, 1) + 1) / 2
x = (x * 255).type(torch.uint8)
return x

If you don't, then you will get noise instead of meaningful images.