/PaintMind

Fast and controllable text-to-iamge model.

Primary LanguagePythonApache License 2.0Apache-2.0

PaintMind

Badge 1 Badge 2 Badge 3
- 2023/5/8
Release the new vit-s-vqgan.
- 2023/4/19  
Note: Hi, I am preparing for a new release with better vitvqgan and generate model, pretrained weights 
from old version will not be available.

Inspired by MaskGIT: Masked Generative Image Transformer and Muse: Text-To-Image Generation via Masked Generative Transformers.
The code is still under testing, feel free to improve and test the code and add more interesting things.

Features

  • xformers support, accelerate both for training and inference
  • accelerate support, for mixed precision and multi-gpu training

To Do

  • test end-to-end, stage1 vqtokenizer and stage2 text-to-image
  • final test and release the stable version V1

Installation

To install the latest version:

pip install git+https://github.com/Qiyuan-Ge/PaintMind.git

Recommended installation:

  • xformers can accelerate both for training and inference
  • requires PyTorch 2.0.0 on Windows or PyTorch 1.12.1, 1.13.1 or 2.0.0 on Linux
pip install -U xformers

Import

import paintmind as pm

Stage1: Reconstruction

Play with Colab Notebook.

Usage

if you set 'pretrained=True', the code will then try to downlaod the pretrained weights.

import paintmind as pm

img = Image.open(img_path).convert('RGB')
img = pm.stage1_transform(is_train=False)(img)
# load pretrained vit-vqgan
model = pm.create_model(arch='vqgan', version='vit-s-vqgan', pretrained=True)
# encode image to latent
z, _, _ = model.encode(img.unsqueeze(0))
# decode latent to image
rec = model.decode(z).squeeze(0)
rec = torch.clamp(rec, -1., 1.)

You could also download the weights of the pretrained vit-vqgan to local from https://huggingface.co/RootYuan/vit-s-vqgan.
To load the pretrained weights from local:

model = pm.create_model(arch='vqgan', version='vit-s-vqgan', pretrained=True, checkpoint_path='your/model/path')

Training

import paintmind as pm
from paintmind.utils import datasets

data_path = 'your/data/path'
transform = pm.stage1_transform(img_size=256, is_train=True, scale=0.66)
dataset = datasets.ImageNet(root=data_path, transform=transform)
# or your own dataset, the output format should be image: torch.Tensor or (image: torch.Tensor, _)

model = pm.create_model(arch='vqgan', version='vit-s-vqgan', pretrained=False)

trainer = pm.VQGANTrainer(
    vqvae                    = model,
    dataset                  = dataset,
    num_epoch                = 100,
    valid_size               = 64,
    lr                       = 1e-4,
    lr_min                   = 5e-5,
    warmup_steps             = 50000,
    warmup_lr_init           = 1e-6,
    decay_steps              = 100000,
    batch_size               = 16,
    num_workers              = 2,
    pin_memory               = True,
    grad_accum_steps         = 8,
    mixed_precision          = 'bf16',
    max_grad_norm            = 1.0,
    save_every               = 5000,
    sample_every             = 5000,
    result_folder            = "your/result/folder",
    log_dir                  = "your/log/dir",
)
trainer.train()

Performance

Below was the reconstruction ability of the vit-s-vqgan after training on 3M images with batchsize 16 and constant learning rate for 200000 steps. Because of limited time and computing resource, I only train the model for one eopch. The results was quite good, but the human face(especially the eyes) still need to be improved. By trying other techniques(warmup, cosine lr decay, larger batchsize, add more faces...). I'll release a better version in the future.

1.

pm.reconstruction(img_path='https://cdn.pixabay.com/photo/2014/10/22/15/47/squirrel-498139_960_720.jpg')

2.

pm.reconstruction(img_path='https://cdn.pixabay.com/photo/2017/04/09/10/44/sea-shells-2215408_960_720.jpg')

3.

pm.reconstruction(img_path='https://cdn.pixabay.com/photo/2015/06/19/21/24/avenue-815297_960_720.jpg')

4.

pm.reconstruction(img_path='https://cdn.pixabay.com/photo/2017/03/30/18/17/girl-2189247_960_720.jpg')

5.

pm.reconstruction(img_path='https://cdn.pixabay.com/photo/2017/10/28/07/47/woman-2896389_960_720.jpg')

Stage2: Text2Image

Not finish yet, but the code is ready.

Training

import paintmind as pm
from paintmind.utils import datasets

data_path = 'your/data/path'
transform = pm.stage2_transform(img_size=256, is_train=True, scale=0.8)
dataset = datasets.CoCo(root=data_path, transform=transform) 
# or your own dataset, the output format should be (image: torch.Tensor, caption: str)

# load pretrained weights I upload to huggingface, not finish yet
model = pm.create_pipeline_for_train(version='paintmindv1', stage1_pretrained=True)
# or load your pretrained weights
model = pm.create_pipeline_for_train(version='paintmindv1', stage1_pretrained=True, stage1_checkpoint_path='your/pretrained/vitvqgan')

trainer = pm.PaintMindTrainer(
    model                       = model,
    dataset                     = dataset,
    num_epoch                   = 40,
    valid_size                  = 64,
    optim                       = 'adamw',
    lr                          = 1e-4,
    lr_min                      = 1e-5,
    warmup_steps                = 10000,
    weight_decay                = 0.05,
    warmup_lr_init              = 1e-6,
    decay_steps                 = 80000,
    batch_size                  = 16,
    num_workers                 = 2,
    pin_memory                  = True,
    grad_accum_steps            = 8,
    mixed_precision             = 'bf16',
    max_grad_norm               = 1.0,
    save_every                  = 5000,
    sample_every                = 5000,
    result_folder               = "your/result/folder",
    log_dir                     = "your/log/dir",
    )
trainer.train()

Acknowledgments

  • taming-transformers
  • maskgit
  • muse-maskgit-pytorch
  • img2dataset
  • accelerate
  • xformers