w86763777/pytorch-ddpm

About memory usage

Arksyd96 opened this issue · 3 comments

Hello, having issues with memory usage.
Is it normal that even with 48Go VRAM i cannot run the reverse process for generation with a small batch of 2 ?
What are you specs ?

No, that is abnormal.
To train CIFAR-10, an 11G VRAM like the 2080 Ti is sufficient. However, if you use a larger model, the VRAM requirements may increase.

Yeah problem fixed. Actually i'm training on 1x128x128 BraTS images and i forgot to put a torch.no_grad(): during reverse process.

However, i still have an issue with the reverse process. During training, the MSE is well optimized, but it only generates noise. Here's my sampling code if you want to give it a look and tell me if its ok :

    def q_mean_variance(self, x_0, x_t, t):
        posterior_mean = (
            self.posterior_mean_c1[t, None, None, None].to(device) * x_0 + 
            self.posterior_mean_c2[t, None, None, None].to(device) * x_t
        )
        posterior_log_var = self.posterior_log_var[t, None, None, None]
        return posterior_mean, posterior_log_var
    
    def p_mean_variance(self, x_t, t):
        model_logvar = torch.log(torch.cat([self.posterior_var[1: 2], self.betas[1:]])).to(device)
        model_logvar = model_logvar[t, None, None, None]

        eps = self.model(x_t, t.to(device))
        x_0 = self.predict_x_start_from_eps(x_t, t, eps)
        model_mean, _ = self.q_mean_variance(x_0, x_t, t)

        return model_mean, model_logvar
    
    def predict_x_start_from_eps(self, x_t, t, eps):
        return (
            torch.sqrt(1. - self.alpha_prods[t, None, None, None].to(device)) * x_t +
            torch.sqrt(1. / self.alpha_prods[t, None, None, None].to(device) - 1.) * eps
        )

    def forward(self, x_T):
        x_t = x_T
        for timestep in reversed(range(self.T)):
            t = torch.full((x_T.shape[0],), fill_value=timestep, dtype=torch.long)
            mean, logvar = self.p_mean_variance(x_t, t)
            if timestep > 0:
                noise = torch.randn_like(x_T)
            else:
                noise = 0
            x_t = mean + torch.exp(0.5 * logvar) * noise
        x_0 = x_t
        return torch.clip(x_0, -1, 1)

Apologies for the delayed response.

To the best of my recollection, you do not need to update the GaussianDiffusionTrainer and GaussianDiffusionSampler when training with images of different sizes. These components are capable of adapting to different image dimensions.

However, you will need to modify the model and data-related code, including the UNet, dataset, and dataloader, to accommodate the new image sizes.