
how to train your inpainting model using my own dataset??

dreamlychina opened this issue · 1 comments

Thanks for sharing this amazing work,I want to train your inpainting model using my own dataset, could you show me any training script and how to prepare the data at your convenience?

from imagen_pytorch import Unet, Imagen, ImagenTrainer
from import Dataset


unets for unconditional imagen

unet = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 1,
layer_attns = (False, False, False, True),
layer_cross_attns = False

imagen, which contains the unet above

imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = unet,
image_sizes = 256,
timesteps = 1000

trainer = ImagenTrainer(
imagen = imagen,
split_valid_from_train = True # whether to split the validation dataset from the training

instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training

dataset = Dataset('/content/drive/MyDrive/unconditional_generation/dataset_256', image_size = 256)

trainer.add_train_dataset(dataset, batch_size = 16)

working training loop

for i in range(20000):
loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
print(f'loss: {loss}')

if not (i % 50):
    valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
    print(f'valid loss: {valid_loss}')

if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
    images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
    images[0].save(f'{output_path}/{i // 100}.png')

This is the training code for your custom dataset .