Image 2 Image Translation
JigneshChowdary opened this issue ยท 19 comments
Hi, can you provide the training code for performing paired Image to Image translation.
Hi, they have already provided the paired img2img translation! If you want to train a general img2img translation, give an option --cond-x1
.
Thank you.
Thanks, what should be the corrupt then?
If you have already paired dataset, you can give corrupt option mixture
, then pass corrupt process. For more information, please check this issue #2. It will be helpful!
Thanks, but if we use corrupt. It is taking a random method from "JPEG_5 = 0
JPEG_10 = 1
BLUR_UNI = 2
BLUR_GAUSS = 3
SR4X_POOL = 4
SR4X_BICUBIC = 5
INPAINT_CENTER = 6
INPAINT_FREE1020 = 7
INPAINT_FREE2030 = 8"? see mixture.py.
Yeah, but it can be ignored.
In train.py, you can see
if opt.corrupt == "mixture":
import corruption.mixture as mix
train_dataset = mix.MixtureCorruptDatasetTrain(opt, train_dataset)
val_dataset = mix.MixtureCorruptDatasetVal(opt, val_dataset)
mixture
method is only used in MixtureCorruptDatasetTrain
and MixtureCorruptDatasetVal
classes.
Just pass this code, then the random method you mentioned will be not applied. And below,
corrupt_method = build_corruption(opt, log)
build_corruption fuction contains mixture
but gives nothing.
Thanks, can we do the same for unpaired dataset?
The authors of the paper said that their method is not available for unpaired dataset as one of limitations. If you want to train with unpaired dataset, you can try Dual Diffusion Implicit Bridge
code.
Thanks.
@cychoi97 sorry to bother, did you get the translation results and visualize them? I already have paired datasets, and the folder structure is as follows. Where folder A contains imgs needed to be translated(input), folder B contains imgs has been translated.
dataset/
โโโ train/
โ โโโ A/
โ โโโ B/
โโโ val/
โ โโโ A/
โ โโโ B/
I convert the dataset/ folder to lmdb through this method and get 'train_faster_imagefolder.lmdb.pt' 'train.lmdb' 'val_faster_imagefolder.lmdb.pt' 'val.lmdb' files.
When executing sample.py, I found the program has taken files under both folder A and B as input(i.e corrupt_img), while I just want A to be input and B to be output.
Do you have any idea on how to solve this problem?
@Badw0lf613 Hi. Unfortunately, I trained using medical imaging to be protected for patient privacy info, so I can't show the results immediately. Also, I didn't use lmdb because I didn't get used to it. I made my custom dataloader using torch.utils.data.Dataset
inspired from stylegan2-ADA. But whatever dataloader you used, it doesn't matter. Dataloader returns clean_img, corrupt_img and label when using mixture
as a corrupt method.
In sample.py,
def compute_batch(ckpt_opt, corrupt_type, corrupt_method, out):
if "inpaint" in corrupt_type:
clean_img, y, mask = out
corrupt_img = clean_img * (1. - mask) + mask
x1 = clean_img * (1. - mask) + mask * torch.randn_like(clean_img)
elif corrupt_type == "mixture":
clean_img, corrupt_img, y = out
x1 = corrupt_img.to(opt.device)
mask = None
else:
clean_img, y = out
mask = None
corrupt_img = corrupt_method(clean_img.to(opt.device))
x1 = corrupt_img.to(opt.device)
cond = x1.detach() if ckpt_opt.cond_x1 else None
if ckpt_opt.add_x1_noise: # only for decolor
x1 = x1 + torch.randn_like(x1)
return corrupt_img, x1, mask, cond, y
compute_batch doesnt' return clean_img at all. Even though you use another corrupt_method, clean_img is only used for creating corrupt_img. So, you don't have to worry about it, I guess.
Thanks for the info so far @cychoi97! I am currently working on writing a custom dataloader too for I2SB. Could you share a little bit more about which sections you changed and altered for a folder of images in a pix2pix format to be loaded correctly?
The training code relies on the build_lmdb_dataset function, which is of course the thing we want to avoid.
@jurriandoornbos Sorry for late replay!
I changed `train.py' as following:
# before
train_dataset = imagenet.build_lmdb_dataset(opt, log, train=True)
val_dataset = imagenet.build_lmdb_dataset(opt, log, train=False)
# after (your own dataset)
train_dataset = dataset.Dataset(opt, log, train=True)
val_dataset = dataset.Dataset(opt, log, train=False)
and just pass this code if you already have a corrupted image:
if opt.corrupt == "mixture":
import corruption.mixture as mix
train_dataset = mix.MixtureCorruptDatasetTrain(opt, train_dataset)
val_dataset = mix.MixtureCorruptDatasetVal(opt, val_dataset)
Also, I added 1 line code in network.py
as following:
# before
kwargs["use_fp16"] = use_fp16
# after
kwargs["use_fp16"] = use_fp16
kwargs["in_channels"] = 4 # change in_channels with your input (input_channel + condition_channel)
In case of using --cond-x1
for img2img translation, input will be concatenated with condition x1.
If you use 3 channels image, it will be 6 channels as in_channels.
I hope it will be helpful!
Thanks for these details! Could you perhaps also share what the code of dataset.Dataset
looks like in your case?
Additionally, I understand that mix.MixtureCorruptDatasetTrain|Test
returns clean_img
, corrupt_img
and y
. In the context of style-transfer/img2img what does clean_img
and y
mean, as I am used to styleA and styleB from pix2pix. corrupt_img
is created by the mixture
function.
In my case, I already have corrupt image, so I make dataset code as following:
import os
import cv2
import pydicom
import numpy as np
from PIL import Image
from pathlib import Path
from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms as T
EXTENSION = ['.jpg', '.jpeg', '.png', '.tiff', '.dcm', '.dicom', '.nii']
class Dataset(Dataset):
def __init__(
self,
opt,
log,
train
):
super().__init__()
self.dataset_dir = opt.dataset_dir / ('train' if train else 'test')
self.corrupt_dir = self.dataset_dir / 'corrupt'
self.clean_dir = self.dataset_dir / 'clean'
self.image_size = opt.image_size
if os.path.isdir(self.corrupt_dir):
self.corrupt_fnames = {os.path.relpath(os.path.join(root, fname), start=self.corrupt_dir) for root, _dirs, files in os.walk(self.corrupt_dir) for fname in files}
else:
raise IOError('corrupt path must point to a directory')
if os.path.isdir(self.clean_dir):
self.clean_fnames = {os.path.relpath(os.path.join(root, fname), start=self.clean_dir) for root, _dirs, files in os.walk(self.clean_dir) for fname in files}
else:
raise IOError('clean path must point to a directory')
self.corrupt_image_fnames = sorted(fname for fname in self.corrupt_fnames if self._file_ext(fname) in EXTENSION)
if len(self.corrupt_image_fnames) == 0:
raise IOError('No corrupt image files found in the specified path')
self.clean_image_fnames = sorted(fname for fname in self.clean_fnames if self._file_ext(fname) in EXTENSION)
if len(self.clean_image_fnames) == 0:
raise IOError('No clean image files found in the specified path')
self.transform = T.Compose([
T.ToTensor(),
T.Lambda(lambda t: (t * 2) - 1) # [0,1] --> [-1, 1]
])
log.info(f"[Dataset] Built Imagenet dataset {self.corrupt_dir=}, size={len(self.corrupt_image_fnames)}!")
log.info(f"[Dataset] Built Imagenet dataset {self.clean_dir=}, size={len(self.clean_image_fnames)}!")
@staticmethod
def _file_ext(fname):
return os.path.splitext(fname)[1].lower()
def _open_file(self, fname):
return open(os.path.join(self.dataset_dir, fname), 'rb')
def _padding(self, img):
if img.shape[0] != img.shape[1]:
if img.shape[0] > img.shape[1]:
padding = np.zeros((img.shape[0], (img.shape[0] - img.shape[1]) // 2), np.float32)
img = np.concatenate([padding, img, padding], 1)
elif img.shape[0] < img.shape[1]:
padding = np.zeros(((img.shape[1] - img.shape[0]) // 2, img.shape[1]), np.float32)
img = np.concatenate([padding, img, padding], 0)
else:
pass
return img
def _resize(self, img):
if img.shape != (self.image_size, self.image_size):
img = cv2.resize(img, (self.image_size, self.image_size), cv2.INTER_AREA)
return img
def _clip_and_normalize(self, img, min, max):
img = np.clip(img, min, max)
img = (img - min) / (max - min)
return img
def _CT_preprocess(self, dcm, img, window_width=None, window_level=None):
intercept = dcm.RescaleIntercept
slope = dcm.RescaleSlope
img = img * slope + intercept
if window_width is not None and window_level is not None:
min = window_level - (window_width / 2.0)
max = window_level + (window_width / 2.0)
else: # 12 bits
min = -1024.0
max = 3071.0
img = self._padding(img)
img = self._resize(img)
img = self._clip_and_normalize(img, min, max)
return img
def __len__(self):
return len(self.clean_image_fnames)
def __getitem__(self, index):
corrupt_fname = self.corrupt_image_fnames[index]
clean_fname = self.clean_image_fnames[index]
with self._open_file(os.path.join('corrupt', corrupt_fname)) as f:
if self._file_ext(corrupt_fname) == '.dcm' or '.dicom':
corrupt_dcm = pydicom.read_file(f, force=True)
corrupt_img = corrupt_dcm.pixel_array.astype(np.float32)
corrupt_img = self._CT_preprocess(corrupt_dcm, corrupt_img, window_width=924, window_level=-562)
else: # jpg, jpeg, tiff, png, etc.
corrupt_img = np.array(Image.open(f))
with self._open_file(os.path.join('clean', clean_fname)) as f:
if self._file_ext(clean_fname) == '.dcm' or '.dicom':
clean_dcm = pydicom.read_file(f, force=True)
clean_img = clean_dcm.pixel_array.astype(np.float32)
clean_img = self._CT_preprocess(clean_dcm, clean_img, window_width=50, window_level=25)
else: # jpg, jpeg, tiff, png, etc.
clean_img = np.array(Image.open(f))
corrupt_img = self.transform(corrupt_img[:,:,np.newaxis])
clean_img = self.transform(clean_img[:,:,np.newaxis])
return clean_img , corrupt_img , clean_img
and I just return y
as clean_img
, because I didn't use y
at all. y
is label and used for accuracy evaluation with resnet. clean_img
is "target" domain and corrupt_img
is "source" domain in img2img translation.
Also, I used pydicom for medical data, so please consider it when you preprocess your data. Thanks!
You are a saint! This helps me out tremendously.
Getting closer step by step with your help! Although I am facing a new issue when the model tries to evaluate after the first iter.
It fails an assertion error
in runner.py
line 283: assert y.shape == (batch,)
Oh, I'm sorry not to tell you about it. I didn't calculate accuracy, so I just passed the following code:
# line 283
assert y.shape == (batch,)
# line 289~292
def log_accuracy(tag, img):
pred = self.resnet(img.to(opt.device)) # input range [-1,1]
accu = self.accuracy(pred, y.to(opt.device))
self.writer.add_scalar(it, tag, accu)
# line 313~316
log.info("Logging accuracies ...")
log_accuracy("accuracy/clean", img_clean)
log_accuracy("accuracy/corrupt", img_corrupt)
log_accuracy("accuracy/recon", img_recon)
If you want to caculate accuracy, you must give correct label to y
!
@jurriandoornbos
Hello, I am doing the same work as you, which is also based on the pix2pix data format. I would like to ask you about the experimental results based on this method. Are the results of the experiment good?
Looking forward to your answer. Thank you.
OK, thank you~