NVlabs/I2SB

Image 2 Image Translation

JigneshChowdary opened this issue ยท 19 comments

Hi, can you provide the training code for performing paired Image to Image translation.

Hi, they have already provided the paired img2img translation! If you want to train a general img2img translation, give an option --cond-x1.

Thank you.

Thanks, what should be the corrupt then?

If you have already paired dataset, you can give corrupt option mixture, then pass corrupt process. For more information, please check this issue #2. It will be helpful!

Thanks, but if we use corrupt. It is taking a random method from "JPEG_5 = 0
JPEG_10 = 1
BLUR_UNI = 2
BLUR_GAUSS = 3
SR4X_POOL = 4
SR4X_BICUBIC = 5
INPAINT_CENTER = 6
INPAINT_FREE1020 = 7
INPAINT_FREE2030 = 8"? see mixture.py.

Yeah, but it can be ignored.

In train.py, you can see

if opt.corrupt == "mixture":
    import corruption.mixture as mix
    train_dataset = mix.MixtureCorruptDatasetTrain(opt, train_dataset)
    val_dataset = mix.MixtureCorruptDatasetVal(opt, val_dataset)

mixture method is only used in MixtureCorruptDatasetTrain and MixtureCorruptDatasetVal classes.
Just pass this code, then the random method you mentioned will be not applied. And below,

corrupt_method = build_corruption(opt, log)

build_corruption fuction contains mixture but gives nothing.

Thanks, can we do the same for unpaired dataset?

The authors of the paper said that their method is not available for unpaired dataset as one of limitations. If you want to train with unpaired dataset, you can try Dual Diffusion Implicit Bridge code.

Thanks.

@cychoi97 sorry to bother, did you get the translation results and visualize them? I already have paired datasets, and the folder structure is as follows. Where folder A contains imgs needed to be translated(input), folder B contains imgs has been translated.
dataset/
โ”œโ”€โ”€ train/
โ”‚ โ”œโ”€โ”€ A/
โ”‚ โ””โ”€โ”€ B/
โ”œโ”€โ”€ val/
โ”‚ โ”œโ”€โ”€ A/
โ”‚ โ””โ”€โ”€ B/
I convert the dataset/ folder to lmdb through this method and get 'train_faster_imagefolder.lmdb.pt' 'train.lmdb' 'val_faster_imagefolder.lmdb.pt' 'val.lmdb' files.
When executing sample.py, I found the program has taken files under both folder A and B as input(i.e corrupt_img), while I just want A to be input and B to be output.
Do you have any idea on how to solve this problem?

@Badw0lf613 Hi. Unfortunately, I trained using medical imaging to be protected for patient privacy info, so I can't show the results immediately. Also, I didn't use lmdb because I didn't get used to it. I made my custom dataloader using torch.utils.data.Dataset inspired from stylegan2-ADA. But whatever dataloader you used, it doesn't matter. Dataloader returns clean_img, corrupt_img and label when using mixture as a corrupt method.

In sample.py,

def compute_batch(ckpt_opt, corrupt_type, corrupt_method, out):
    if "inpaint" in corrupt_type:
        clean_img, y, mask = out
        corrupt_img = clean_img * (1. - mask) + mask
        x1          = clean_img * (1. - mask) + mask * torch.randn_like(clean_img)
    elif corrupt_type == "mixture":
        clean_img, corrupt_img, y = out
        x1 = corrupt_img.to(opt.device)
        mask = None
    else:
        clean_img, y = out
        mask = None
        corrupt_img = corrupt_method(clean_img.to(opt.device))
        x1 = corrupt_img.to(opt.device)

    cond = x1.detach() if ckpt_opt.cond_x1 else None
    if ckpt_opt.add_x1_noise: # only for decolor
        x1 = x1 + torch.randn_like(x1)

    return corrupt_img, x1, mask, cond, y

compute_batch doesnt' return clean_img at all. Even though you use another corrupt_method, clean_img is only used for creating corrupt_img. So, you don't have to worry about it, I guess.

Thanks for the info so far @cychoi97! I am currently working on writing a custom dataloader too for I2SB. Could you share a little bit more about which sections you changed and altered for a folder of images in a pix2pix format to be loaded correctly?

The training code relies on the build_lmdb_dataset function, which is of course the thing we want to avoid.

@jurriandoornbos Sorry for late replay!

I changed `train.py' as following:

# before
train_dataset = imagenet.build_lmdb_dataset(opt, log, train=True)
val_dataset   = imagenet.build_lmdb_dataset(opt, log, train=False)

# after (your own dataset)
train_dataset = dataset.Dataset(opt, log, train=True)
val_dataset   = dataset.Dataset(opt, log, train=False)

and just pass this code if you already have a corrupted image:

if opt.corrupt == "mixture":
    import corruption.mixture as mix
    train_dataset = mix.MixtureCorruptDatasetTrain(opt, train_dataset)
    val_dataset = mix.MixtureCorruptDatasetVal(opt, val_dataset)

Also, I added 1 line code in network.py as following:

# before
kwargs["use_fp16"] = use_fp16

# after
kwargs["use_fp16"] = use_fp16
kwargs["in_channels"] = 4 # change in_channels with your input (input_channel + condition_channel)

In case of using --cond-x1 for img2img translation, input will be concatenated with condition x1.
If you use 3 channels image, it will be 6 channels as in_channels.

I hope it will be helpful!

Thanks for these details! Could you perhaps also share what the code of dataset.Dataset looks like in your case?

Additionally, I understand that mix.MixtureCorruptDatasetTrain|Test returns clean_img, corrupt_img and y. In the context of style-transfer/img2img what does clean_img and y mean, as I am used to styleA and styleB from pix2pix. corrupt_img is created by the mixture function.

In my case, I already have corrupt image, so I make dataset code as following:

import os
import cv2
import pydicom
import numpy as np

from PIL import Image
from pathlib import Path
from functools import partial

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms as T


EXTENSION = ['.jpg', '.jpeg', '.png', '.tiff', '.dcm', '.dicom', '.nii']


class Dataset(Dataset):
    def __init__(
        self,
        opt,
        log,
        train
    ):
        super().__init__()
        self.dataset_dir = opt.dataset_dir / ('train' if train else 'test')
        self.corrupt_dir = self.dataset_dir / 'corrupt'
        self.clean_dir = self.dataset_dir / 'clean'
        self.image_size = opt.image_size

        if os.path.isdir(self.corrupt_dir):
            self.corrupt_fnames = {os.path.relpath(os.path.join(root, fname), start=self.corrupt_dir) for root, _dirs, files in os.walk(self.corrupt_dir) for fname in files}
        else:
            raise IOError('corrupt path must point to a directory')
        
        if os.path.isdir(self.clean_dir):
            self.clean_fnames = {os.path.relpath(os.path.join(root, fname), start=self.clean_dir) for root, _dirs, files in os.walk(self.clean_dir) for fname in files}
        else:
            raise IOError('clean path must point to a directory')
        
        self.corrupt_image_fnames = sorted(fname for fname in self.corrupt_fnames if self._file_ext(fname) in EXTENSION)
        if len(self.corrupt_image_fnames) == 0:
            raise IOError('No corrupt image files found in the specified path')
        
        self.clean_image_fnames = sorted(fname for fname in self.clean_fnames if self._file_ext(fname) in EXTENSION)
        if len(self.clean_image_fnames) == 0:
            raise IOError('No clean image files found in the specified path')

        self.transform = T.Compose([
            T.ToTensor(),
            T.Lambda(lambda t: (t * 2) - 1) # [0,1] --> [-1, 1]
        ])

        log.info(f"[Dataset] Built Imagenet dataset {self.corrupt_dir=}, size={len(self.corrupt_image_fnames)}!")
        log.info(f"[Dataset] Built Imagenet dataset {self.clean_dir=}, size={len(self.clean_image_fnames)}!")

    @staticmethod
    def _file_ext(fname):
        return os.path.splitext(fname)[1].lower()
    
    def _open_file(self, fname):
        return open(os.path.join(self.dataset_dir, fname), 'rb')

    def _padding(self, img):
        if img.shape[0] != img.shape[1]:
            if img.shape[0] > img.shape[1]:
                padding = np.zeros((img.shape[0], (img.shape[0] - img.shape[1]) // 2), np.float32)
                img = np.concatenate([padding, img, padding], 1)
            elif img.shape[0] < img.shape[1]:
                padding = np.zeros(((img.shape[1] - img.shape[0]) // 2, img.shape[1]), np.float32)
                img = np.concatenate([padding, img, padding], 0)
        else:
            pass
        return img
    
    def _resize(self, img):
        if img.shape != (self.image_size, self.image_size):
            img = cv2.resize(img, (self.image_size, self.image_size), cv2.INTER_AREA)
        return img

    def _clip_and_normalize(self, img, min, max):
        img = np.clip(img, min, max)
        img = (img - min) / (max - min)
        return img
    
    def _CT_preprocess(self, dcm, img, window_width=None, window_level=None):
        intercept = dcm.RescaleIntercept
        slope = dcm.RescaleSlope
        img = img * slope + intercept

        if window_width is not None and window_level is not None:
            min = window_level - (window_width / 2.0)
            max = window_level + (window_width / 2.0)
        else: # 12 bits
            min = -1024.0
            max = 3071.0

        img = self._padding(img)
        img = self._resize(img)
        img = self._clip_and_normalize(img, min, max)
        return img

    def __len__(self):
        return len(self.clean_image_fnames)

    def __getitem__(self, index):
        corrupt_fname = self.corrupt_image_fnames[index]
        clean_fname = self.clean_image_fnames[index]

        with self._open_file(os.path.join('corrupt', corrupt_fname)) as f:
            if self._file_ext(corrupt_fname) == '.dcm' or '.dicom':
                corrupt_dcm = pydicom.read_file(f, force=True)
                corrupt_img = corrupt_dcm.pixel_array.astype(np.float32)
                corrupt_img = self._CT_preprocess(corrupt_dcm, corrupt_img, window_width=924, window_level=-562)
            else: # jpg, jpeg, tiff, png, etc.
                corrupt_img = np.array(Image.open(f))

        with self._open_file(os.path.join('clean', clean_fname)) as f:
            if self._file_ext(clean_fname) == '.dcm' or '.dicom':
                clean_dcm = pydicom.read_file(f, force=True)
                clean_img = clean_dcm.pixel_array.astype(np.float32)
                clean_img = self._CT_preprocess(clean_dcm, clean_img, window_width=50, window_level=25)
            else: # jpg, jpeg, tiff, png, etc.
                clean_img = np.array(Image.open(f))

        corrupt_img = self.transform(corrupt_img[:,:,np.newaxis])
        clean_img = self.transform(clean_img[:,:,np.newaxis])

        return clean_img , corrupt_img , clean_img

and I just return y as clean_img, because I didn't use y at all. y is label and used for accuracy evaluation with resnet. clean_img is "target" domain and corrupt_img is "source" domain in img2img translation.
Also, I used pydicom for medical data, so please consider it when you preprocess your data. Thanks!

You are a saint! This helps me out tremendously.

Getting closer step by step with your help! Although I am facing a new issue when the model tries to evaluate after the first iter.
It fails an assertion error in runner.py line 283: assert y.shape == (batch,)

Oh, I'm sorry not to tell you about it. I didn't calculate accuracy, so I just passed the following code:

# line 283
assert y.shape == (batch,)

# line 289~292
def log_accuracy(tag, img):
    pred = self.resnet(img.to(opt.device)) # input range [-1,1]
    accu = self.accuracy(pred, y.to(opt.device))
    self.writer.add_scalar(it, tag, accu)
    
# line 313~316
log.info("Logging accuracies ...")
log_accuracy("accuracy/clean",   img_clean)
log_accuracy("accuracy/corrupt", img_corrupt)
log_accuracy("accuracy/recon",   img_recon)

If you want to caculate accuracy, you must give correct label to y!

@jurriandoornbos
Hello, I am doing the same work as you, which is also based on the pix2pix data format. I would like to ask you about the experimental results based on this method. Are the results of the experiment good?
Looking forward to your answer. Thank you.

OK, thank you~