FinnBehrendt/patched-Diffusion-Models-UAD

Training time

fujistoo opened this issue · 8 comments

What is the estimated training time? It seems that the bbox is pretty time-consuming both during training and testing. Also wanted to make sure that during testing, only one image gets processed at a time?

During training, there shouldnt be much difference. During testing, the reconstruction of individual images will indeed take 4 times longer (~40% increased inference time overall). Right now, only one image is processed at a time and the code is not runtime optimized! Also the patching could be done in parallel to speed up the inference.

Thanks! I made some changes myself for non-medical images, but wasn't sure which part went wrong. All the reconstructed images did not get denoised...?

Resulting image comes from reco=reco_patched.clone() after 99-epochs.

image

That is very hard to tell without knowing the changes you made. 99 epochs should be enough to get a somewhat meaningful reconstruction for the IXI data set. What parts did you change?

# DDPM_2D_patched.py
import torch

import numpy as np
import torchio as tio
import torch.optim as optim
import pytorch_lightning as pl

from typing import Any, List
from src.models.diffusionmodules.cond_DDPM import GaussianDiffusion
from src.models.diffusionmodules.OpenAI_Unet import UNetModel as OpenAI_UNet
from src.utils.diffusionmodules.patch_sampling import BoxSampler
from src.utils.diffusionmodules.generate_noise import gen_noise
from src.utils.diffusionmodules.utils_eval import _test_step, _test_end, get_eval_dictionary, get_eval_metrics_dictionary

import lightning as L 
# metrics 
from sklearn.metrics import confusion_matrix, roc_curve, accuracy_score, precision_recall_fscore_support, auc, precision_recall_curve, average_precision_score

from customized.pre_processing import Tiler 

import wandb

def compute_roc(predictions, labels):
    _fpr, _tpr, _ = roc_curve(labels.astype(int), predictions,pos_label=1)
    roc_auc = auc(_fpr, _tpr)
    return roc_auc, _fpr, _tpr, _


def compute_prc(predictions, labels):
    precisions, recalls, thresholds = precision_recall_curve(labels.astype(int), predictions)
    auprc = average_precision_score(labels.astype(int), predictions)
    return auprc, precisions, recalls, thresholds   

class DDPM_2D(L.LightningModule):
    def __init__(self,cfg,prefix=None):
        super().__init__()
        
        self.cfg = cfg

        # Modell
        image_size = (int(cfg.get('image_size',400)),)*2 # default 400 or from config. 

        model = OpenAI_UNet(
            image_size =  image_size,
            in_channels = 3,
            model_channels = cfg.get('unet_dim',64),
            out_channels = 3,
            num_res_blocks = cfg.get('num_res_blocks',3),
            # attention_resolutions = (int(cfg.imageDim[0])/int(32),int(cfg.imageDim[0])/int(16),int(cfg.imageDim[0])/int(8)),
            attention_resolutions = tuple(cfg.get('att_res',[int(image_size[0]/32),int(image_size[0]/16), int(image_size[0]/8)])), # 32, 16, 8
            dropout=cfg.get('dropout_unet',0), # default is 0.1
            channel_mult=cfg.get('dim_mults',[1, 2, 4, 8]),
            conv_resample=True,
            dims=2,
            num_classes=None,
            use_checkpoint=True,
            use_fp16=True,
            num_heads=cfg.get('num_heads',1),
            num_head_channels=64,
            num_heads_upsample=-1,
            use_scale_shift_norm=True,
            resblock_updown=True,
            use_new_attention_order=True,
            use_spatial_transformer=False,    
            transformer_depth=1,              
        )
        model.convert_to_fp16()
        

        timesteps = cfg.get('timesteps',1000)
        self.test_timesteps = cfg.get('test_timesteps',150) 
        sampling_timesteps = cfg.get('sampling_timesteps',self.test_timesteps)
    
        self.diffusion = GaussianDiffusion(
            model,
            image_size = image_size, # only important when sampling
            timesteps = timesteps,   # number of steps
            sampling_timesteps = sampling_timesteps,
            objective = cfg.get('objective','pred_x0'), # pred_noise or pred_x0
            channels = 1,
            loss_type = cfg.get('loss','l1'),    # L1 or L2
            p2_loss_weight_gamma = cfg.get('p2_gamma',0),
            inpaint = cfg.get('inpaint',False),
            cfg=cfg
        )
        
        self.boxes = BoxSampler(cfg) # initialize box sampler

        self.prefix = prefix
        
        self.save_hyperparameters()


    def forward(self):
        return None


    def training_step(self, batch, batch_idx: int):
        # process batch
        input = batch["image"]

        # generate bboxes for DDPM 
        if self.cfg.get('grid_boxes',True): # sample boxes from a grid
            bbox = torch.zeros([input.shape[0],4],dtype=int)
            bboxes = self.boxes.sample_grid(input)
            ind = torch.randint(0,bboxes.shape[1],(input.shape[0],))
            for j in range(input.shape[0]):
                bbox[j] = bboxes[j,ind[j]]
            bbox = bbox.unsqueeze(-1)
        else: # sample boxes randomly
            bbox = self.boxes.sample_single_box(input)

        # generate noise
        if self.cfg.get('noisetype') is not None:
            noise = gen_noise(self.cfg, input.shape).to(self.device)
        else: 
            noise = None
        # reconstruct
        loss, reco = self.diffusion(input, box=bbox,noise=noise)

        self.log(f'train/loss', loss, prog_bar=False, on_step=False, on_epoch=True, batch_size=input.shape[0],sync_dist=True)
        return {"loss": loss}
    
    def validation_step(self, batch: Any, batch_idx: int):
        # input = batch['vol'][tio.DATA].squeeze(-1) 
        input = batch["image"]
        # generate bboxes for DDPM 
        if self.cfg.get('grid_boxes',False): # sample boxes from a grid
            bbox = torch.zeros([input.shape[0],4],dtype=int)
            bboxes = self.boxes.sample_grid(input)
            ind = torch.randint(0,bboxes.shape[1],(input.shape[0],))
            for j in range(input.shape[0]):
                bbox[j] = bboxes[j,ind[j]]
            bbox = bbox.unsqueeze(-1)
        else:  # sample boxes randomly
            bbox = self.boxes.sample_single_box(input)

        # generate noise
        if self.cfg.get('noisetype') is not None:
            noise = gen_noise(self.cfg, input.shape).to(self.device)
        else: 
            noise = None

        loss, reco = self.diffusion(input, box=bbox, noise=noise)


        self.log(f'val/loss_comb', loss, prog_bar=False, on_step=False, on_epoch=True, batch_size=input.shape[0],sync_dist=True)
        return {"loss": loss}

    def on_test_start(self):
        self.metrics_dict = get_eval_metrics_dictionary()
        # self.eval_dict = get_eval_dictionary()
        # self.inds = []
        # self.latentSpace_slice = []
        # self.new_size = [160,190,160]
        # self.diffs_list = []
        # self.seg_list = []
        # if not hasattr(self,'threshold'):
            # self.threshold = {}

    def test_step(self, batch: Any, batch_idx: int):
        # self.dataset = batch['Dataset']
        input = batch["image"] # 1chw 
        mask = batch["mask"].unsqueeze(1).expand(-1,3,-1,-1) # 1hw


        # if self.cfg.get('num_eval_slices', input.size(4)) != input.size(4):
            # num_slices = self.cfg.get('num_eval_slices', input.size(4))  # number of center slices to evaluate. If not set, the whole Volume is evaluated
            # start_slice = int((input.size(4) - num_slices) / 2)
            # input = input[...,start_slice:start_slice+num_slices]
        #     # data_orig = data_orig[...,start_slice:start_slice+num_slices] 
        #     # data_seg = data_seg[...,start_slice:start_slice+num_slices]
        #     # data_mask = data_mask[...,start_slice:start_slice+num_slices]
        #     ind_offset = start_slice
        # else: 
        #     ind_offset = 0 

        # final_volume = torch.zeros([input.size(2), input.size(3), input.size(4)], device = self.device)

        # reorder depth to batch dimension
        assert input.shape[0] == 1, "Batch size must be 1"
        # input = input.squeeze(0).permute(3,0,1,2) # [B,C,H,W,D] -> [D,C,H,W]
        # input = input.squeeze(0).permute(1,2,0)
        
        # latentSpace.append(torch.tensor([0],dtype=float).repeat(input.shape[0])) # dummy latent space 

        # generate bboxes for DDPM 
        bbox = self.boxes.sample_grid(input)
        reco_patched = torch.zeros_like(input)

        
        # generate noise
        if self.cfg.get('noisetype') is not None:
            noise = gen_noise(self.cfg, input.shape).to(self.device)
        else: 
            noise = None

        # use tiler 
        # tiles = self.tiler.tile(input)
        # loss, reco = self.diffusion(tiles, box=bbox, noise=noise)

        # over 4 tiles 
        for k in range(bbox.shape[1]):
            box = bbox[:,k]
            # reconstruct
            loss_diff, reco = self.diffusion(input,t=self.test_timesteps-1, box=box,noise=noise)

            if reco.shape[1] == 2:
                reco = reco[:,0:1,:,:]
    
            for j in range(reco_patched.shape[0]): 
                if self.cfg.get('overlap',False): # cut out the overlap region
                    grid = self.boxes.sample_grid_cut(input)
                    box_cut = grid[:,k]
                    if self.cfg.get('agg_overlap','cut') == 'cut': # cut out the overlap region
                        reco_patched[j,:,box_cut[j,1]:box_cut[j,3],box_cut[j,0]:box_cut[j,2]] = reco[j,:,box_cut[j,1]:box_cut[j,3],box_cut[j,0]:box_cut[j,2]]
                    elif self.cfg.get('agg_overlap','cut') == 'avg': # average the overlap region
                        reco_patched[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]] = reco_patched[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]] + reco[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]]
                else:
                    reco_patched[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]] = reco[j,:,box[j,1]:box[j,3],box[j,0]:box[j,2]]


            if self.cfg.get('overlap',False) and self.cfg.get('agg_overlap','cut') == 'avg': # average the intersection of all patches
                mask = torch.zeros_like(reco_patched) 
                # create mask 
                for k in range(bbox.shape[1]):
                    box = bbox[:,k]
                    for l in range(mask.shape[0]):
                        mask[l,:,box[l,1]:box[l,3],box[l,0]:box[l,2]] = mask[l,:,box[l,1]:box[l,3],box[l,0]:box[l,2]] + 1
                # divide by the mask to average the intersection of all patches
                reco_patched = reco_patched/mask

            reco = reco_patched.clone()

        recon = reco.clone().squeeze().permute(1,2,0)
        input = input.squeeze().permute(1,2,0)
        diff = torch.abs(input - recon)
        
        recon = recon.cpu().numpy() # 1,3,h,w
        recon = (recon*255).astype("uint8")

        AUC, _fpr, _tpr, _threshs = compute_roc(diff.cpu().flatten(), np.array(mask[0].cpu().flatten()).astype(bool))
        AUPRC, _fpr, _tpr, _threshs = compute_prc(diff.cpu().flatten(), np.array(mask[0].cpu().flatten()).astype(bool))

        self.metrics_dict['AUROC'].append(AUC)
        self.metrics_dict['AUPRC'].append(AUPRC)
        self.logger.experiment.log({"test/recon": wandb.Image(recon)})

        
        # AnomalyScoreComb.append(loss_diff.cpu())
        # AnomalyScoreReg.append(AnomalyScoreComb) # dummy
        # AnomalyScoreReco.append(AnomalyScoreComb) # dummy

        # # reassamble the reconstruction volume
        # final_volume = reco.clone().squeeze()
        # final_volume = final_volume.permute(1,2,0) # to HxWxD
       

        # # average across slices to get volume-based scores
        # self.latentSpace_slice.extend(latentSpace)
        # self.eval_dict['latentSpace'].append(torch.mean(torch.stack(latentSpace),0))

        # AnomalyScoreComb_vol = np.mean(AnomalyScoreComb) 
        # AnomalyScoreReg_vol = AnomalyScoreComb_vol # dummy
        # AnomalyScoreReco_vol = AnomalyScoreComb_vol # dummy

        # self.eval_dict['AnomalyScoreRegPerVol'].append(AnomalyScoreReg_vol)


        # if not self.cfg.get('use_postprocessed_score', True):
        #     self.eval_dict['AnomalyScoreRecoPerVol'].append(AnomalyScoreReco_vol)
        #     self.eval_dict['AnomalyScoreCombPerVol'].append(AnomalyScoreComb_vol)


        # final_volume = final_volume.unsqueeze(0)
        # final_volume = final_volume.unsqueeze(0)

        # # calculate metrics
        # _test_step(self, final_volume, data_orig, data_seg, data_mask, batch_idx, ID, label) 

           
    # def on_test_end(self) :
    #     # calculate metrics
    #     _test_end(self) # everything that is independent of the model choice 
    def on_test_end(self):
        self.metrics_dict['AUROC'] = np.mean(self.metrics_dict['AUROC'])
        self.metrics_dict['AUPRC'] = np.mean(self.metrics_dict['AUPRC'])
        print(self.metrics_dict['AUROC'])
        print(self.metrics_dict['AUPRC'])

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.cfg.lr)
    
    def update_prefix(self, prefix):
        self.prefix = prefix 

This is the DDPM_2D_patched.py. The changes I have made so far are image_size. I didn't rescale it. And during test_step the num_eval_slices part is commented out so the entire image gets passed in normally, because the images I deal with are common bchw, unlike scalar medical images with an additional dimension. The underlying UNet and GaussianDiffusion classes and configs are left untouched. The entire idea was to have the codebase accommodate to non-scalar ("normal") images.

at first glance, i can not see an error. Does the Training work? i.e. is the loss decreasing?

Yeah it did, which is why the resulting image baffles me a bit. But it would make sense commenting out the num_eval_slices part, right? Since that is more scalar-specific section to process different dimensions? (haven't had any encounters with scalar images yet, newbie on that)

Have you checked, if the model checkpoint gets loaded properly (given you are reevaluating)? Also maybe debug in the evaluation and look at the input and output directly at the reconstruction step.

closing this (stale)