fashn-AI/tryondiffusion

training viton hd dataset using test_tryon_imagen_trainer.py

Closed this issue · 9 comments

i am using viton hd dataset as a sythentic dataset to train using test_tryon_imagen_trainer.py pyton file. After training it is giving noise images as output images. Can anyone suggest why this is happening and how to train the diffusion model with vitonhd dataset.

Hi @Teja414 if you share more about your code maybe I can help

First, right off the bat - Did you change the number of training iterations? number of sampling steps? The parameters in the example scripts are just for quick testing, not full-fledged training.

Another thing to mention is that this architecture is not optimal for VITON-HD, it is meant to take garment from another person (thus there is a garment_pose input), you will need to encode something in-place of the garment pose (e.g. all 0.5) or remove that part of the network.

this the test_tryon_imagen_trainer.py code that i have updated and used for training jst sample of 10. and for sample of 10 i have attaching the output.

image_0

import torch
from torch.utils.data import DataLoader, Dataset
import os
import json
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

from tryondiffusion import TryOnImagen, get_unet_by_name
from tryon_imagen_trainer import TryOnImagenTrainer

TRAIN_UNET_NUMBER = 1
BASE_UNET_IMAGE_SIZE = (128, 128)
SR_UNET_IMAGE_SIZE = (256, 256)
BATCH_SIZE = 10
GRADIENT_ACCUMULATION_STEPS = 2
NUM_ITERATIONS = 20
TIMESTEPS = (2, 2)

Define the transformation to resize images to 128x128

image_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])

Define your dataset class

class SyntheticTryonDataset(Dataset):
def init(self, data_dir, num_samples, image_size, pose_size=(18, 2)):
self.data_dir = data_dir
self.num_samples = num_samples
self.image_size = image_size
self.pose_size = pose_size

    # Get the list of image filenames
    self.person_image_files = sorted(os.listdir(os.path.join(data_dir, 'person_image')))
    self.ca_image_files = sorted(os.listdir(os.path.join(data_dir, 'ca_image')))
    self.garment_image_files = sorted(os.listdir(os.path.join(data_dir, 'garment_image')))
    self.person_pose_files = sorted(os.listdir(os.path.join(data_dir, 'person_pose')))
    self.garment_pose_files = sorted(os.listdir(os.path.join(data_dir, 'garment_pose')))

def __len__(self):
    return self.num_samples

def __getitem__(self, idx):
    person_image_filename = self.person_image_files[idx]
    ca_image_filename = self.ca_image_files[idx]
    garment_image_filename = self.garment_image_files[idx]
    person_pose_filename = self.person_pose_files[idx]
    garment_pose_filename = self.garment_pose_files[idx]

    person_image_path = os.path.join(self.data_dir, 'person_image', person_image_filename)
    ca_image_path = os.path.join(self.data_dir, 'ca_image', ca_image_filename)
    garment_image_path = os.path.join(self.data_dir, 'garment_image', garment_image_filename)
    person_pose_path = os.path.join(self.data_dir, 'person_pose', person_pose_filename)
    garment_pose_path = os.path.join(self.data_dir, 'garment_pose', garment_pose_filename)

    person_image = Image.open(person_image_path)  # Load person image from the path
    ca_image = Image.open(ca_image_path)  # Load ca image from the path
    garment_image = Image.open(garment_image_path)  # Load garment image from the path

    # Apply image transformation to resize to 128x128
    person_image = image_transform(person_image)
    ca_image = image_transform(ca_image)
    garment_image = image_transform(garment_image)

    with open(person_pose_path, 'r') as f:
        pose_data = json.load(f)
        long_pose = pose_data.get("long")[:self.pose_size[0]]  # Extract first 18 points from "long" key
        person_pose = torch.tensor(long_pose).view(self.pose_size[0], self.pose_size[1])

    garment_pose = torch.randn(*self.pose_size)  # Placeholder for garment pose data

    sample = {
        "person_images": person_image,
        "ca_images": ca_image,
        "garment_images": garment_image,
        "person_poses": person_pose,
        "garment_poses": garment_pose,
    }

    return sample

def tryondiffusion_collate_fn(batch):
return {
"person_images": torch.stack([item["person_images"] for item in batch]),
"ca_images": torch.stack([item["ca_images"] for item in batch]),
"garment_images": torch.stack([item["garment_images"] for item in batch]),
"person_poses": torch.stack([item["person_poses"] for item in batch]),
"garment_poses": torch.stack([item["garment_poses"] for item in batch]),
}

def main():
print("Instantiating the dataset and dataloader...")
dataset = SyntheticTryonDataset(
data_dir="/home/try_on_diffusion/tryondiffusion/data/syn_data", #path to the data folder
num_samples=10, # Adjust this according to your dataset size
image_size=SR_UNET_IMAGE_SIZE if TRAIN_UNET_NUMBER == 2 else BASE_UNET_IMAGE_SIZE
)
train_dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=tryondiffusion_collate_fn,
)
validation_dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=False,
collate_fn=tryondiffusion_collate_fn,
)
print("Checking the dataset and dataloader...")
for i, batch in enumerate(train_dataloader):
print(f"Batch {i+1}:")
# Print the shapes of each tensor in the batch
for k, v in batch.items():
if isinstance(v, torch.Tensor):
print(f"{k}: {v.shape}")
else:
print(f"{k}: List of {len(v)} tensors with shapes: {v[0].shape}")

    # Print and display the images (assuming they are PIL Images)
    for j in range(len(batch["person_images"])):
        person_image = batch["person_images"][j].permute(1, 2, 0)  # Transpose the image data
        ca_image = batch["ca_images"][j].permute(1, 2, 0)  # Transpose the image data
        garment_image = batch["garment_images"][j].permute(1, 2, 0)  # Transpose the image data
        
        # Display the images using Matplotlib
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 3, 1)
        plt.imshow(person_image)
        plt.title("Person Image")
        plt.axis("off")
        
        plt.subplot(1, 3, 2)
        plt.imshow(ca_image)
        plt.title("CA Image")
        plt.axis("off")
        
        plt.subplot(1, 3, 3)
        plt.imshow(garment_image)
        plt.title("Garment Image")
        plt.axis("off")
        
        plt.show()
    
    # Break after printing the first few batches
    if i >= 2:
        break

# Instantiate the unets
print("Instantiating U-Nets...")
base_unet = get_unet_by_name("base")
sr_unet = get_unet_by_name("sr")

# Instantiate the Imagen model
imagen = TryOnImagen(
    unets=(base_unet, sr_unet),
    image_sizes=(BASE_UNET_IMAGE_SIZE, SR_UNET_IMAGE_SIZE),
    timesteps=TIMESTEPS,
)

print("Instantiating the trainer...")
trainer = TryOnImagenTrainer(
    imagen=imagen,
    max_grad_norm=1.0,
    accelerate_cpu=True,
    accelerate_gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
)

trainer.add_train_dataloader(train_dataloader)
trainer.add_valid_dataloader(validation_dataloader)

print("Starting training loop...")
# training loop
for i in range(NUM_ITERATIONS):
    # TRAINING
    loss = trainer.train_step(unet_number=TRAIN_UNET_NUMBER)
    print(f"loss: {loss}")
    valid_loss = trainer.valid_step(unet_number=TRAIN_UNET_NUMBER)
    print(f"valid loss: {valid_loss}")

# SAMPLING
print("Starting sampling loop...")
validation_sample = next(trainer.valid_dl_iter)
_ = validation_sample.pop("person_images")
imagen_sample_kwargs = dict(
    **validation_sample,
    batch_size=BATCH_SIZE,
    cond_scale=2.0,
    start_at_unet_number=1,
    return_all_unet_outputs=True,
    return_pil_images=True,
    use_tqdm=True,
    use_one_unet_in_gpu=True,
)
import os
from PIL import Image

images = trainer.sample(**imagen_sample_kwargs)  # returns List[Image]
assert len(images) == 2
assert len(images[0]) == BATCH_SIZE and len(images[1]) == BATCH_SIZE

# Create the output folder if it doesn't exist
output_folder = 'output/'
os.makedirs(output_folder, exist_ok=True)

for i, unet_output in enumerate(images):
    for j, image in enumerate(unet_output):
    # Save the image with a unique filename
        image.save(os.path.join(output_folder, f'image_{i * BATCH_SIZE + j}.jpg'))

print("Images saved to 'output/' folder.")

for unet_output in images:
    for image in unet_output:
        image.show()

if name == "main":
# python ./examples/test_tryon_imagen_trainer.py
main()

It takes a lot more time to train to stop seeing noise

  • Sample for 256 timesteps in the 1st UNet, 128 timesteps if you also use the 2nd UNet
  • You need much more data samples (minimum 10k)
  • Train for at least 10k iterations (to stop seeing noise) with batch size 256
  • Use GPU (CPU is used there just for testing)

can you please let me know how you have got the cloth landmark keypoints. can you guide me to do that task?

The pose keypoints are taken from the person wearing the cloth. This architecture is not suited as-is to work with VITON-HD where the garment images are flat clothing-only photos.

You can modify the architecture to remove any part that is related to garment pose, or find another way to extract landmarks from garment-only pictures.

@Teja414 closing this issue as you have opened a new issue.