bryandlee/FreezeG

how to generate pair data?

SpartacusIn21 opened this issue · 4 comments

Thank you for your nice work, I try to generate pair data by pretrained model(ffhq model and ffhq2met model are download from #3) by follow steps:

  1. get latent vector from picture by project.py:
    python projector.py --ckpt checkpoint/550000.pt ffhq_sample/sample/000000.png

  2. generate picture by generate.py and latent vecotor by step 1:
    python generate_pair_data.py --size 256 --ckpt checkpoint/face2met_10k.pt --project_latent_file 000000.pt

generate.py is modified as follows:

import argparse
import os

import torch
from torchvision import utils
from model import Generator
from tqdm import tqdm
import glob
import numpy as np
from PIL import Image

def make_image(tensor):
    return (
        tensor.detach()
            .clamp_(min=-1, max=1)
            .add(1)
            .div_(2)
            .mul(255)
            .type(torch.uint8)
            .permute(0, 2, 3, 1)
            .to("cpu")
            .numpy()
    )

def generate(args, g_ema, device, sample_z_s, mean_latent, sample_noise_s):

    with torch.no_grad():
        g_ema.eval()
        for k,v in tqdm(sample_z_s.items()):
            #sample_z = torch.randn(args.sample, args.latent, device=device)
            

            noise=sample_noise_s[k]
            sample, _ = g_ema(
                [v], truncation=args.truncation, truncation_latent=mean_latent, input_is_latent=True, noise=noise
            )

           
            utils.save_image(
                sample,
                f"metface_dir/{os.path.basename(k)}.png",
                nrow=1,
                normalize=True,
                range=(-1, 1),
            )




if __name__ == "__main__":
    device = "cuda"

    parser = argparse.ArgumentParser(description="Generate samples from the generator")

    parser.add_argument(
        "--size", type=int, default=1024, help="output image size of the generator"
    )
    parser.add_argument(
        "--sample",
        type=int,
        default=1,
        help="number of samples to be generated for each image",
    )
    parser.add_argument(
        "--pics", type=int, default=20, help="number of images to be generated"
    )
    parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
    parser.add_argument(
        "--truncation_mean",
        type=int,
        default=4096,
        help="number of vectors to calculate mean for the truncation",
    )
    parser.add_argument(
        "--ckpt",
        type=str,
        default="stylegan2-ffhq-config-f.pt",
        help="path to the model checkpoint",
    )
    parser.add_argument(
        "--channel_multiplier",
        type=int,
        default=2,
        help="channel multiplier of the generator. config-f = 2, else = 1",
    )
    parser.add_argument(
        "--project_latent_file",
        type=str,
        default="000000.pt",
        help="path to the latent file",)
    args = parser.parse_args()

    args.latent = 512
    args.n_mlp = 8

    #g_ema = Generator(
    #    args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
    #).to(device)
    g_ema = Generator(
        args.size, args.latent, args.n_mlp).to(device)
    checkpoint = torch.load(args.ckpt)

    g_ema.load_state_dict(checkpoint["g_ema"], strict=False)

    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(args.truncation_mean)
    else:
        mean_latent = None

    sample_z_s = {}
    sample_noise_s = {}
    result_file = torch.load(args.project_latent_file)
    for k,v in result_file.items():
        sample_z_s[k] = torch.unsqueeze(v["latent"], 0)
        sample_noise_s[k] = v["noise"]


    generate(args, g_ema, device, sample_z_s, mean_latent, sample_noise_s)

but the result i get is diffrent from your face2art samples:
image
imageimageimage
image
imageimageimage

what is my problem? thank you.

Hi, I added a script for the paired image generation. code

python generate_pair.py --sample 5

gives this:

sample

Thank for your answer, your generate_pair.py script can generate pair data perfectly. But if I want to generate pair data from a specific picture, what can I do to achieve this goal? As metioned in question, I get latent vector by step 1, then generate face2art picture with face2met_10k.pt model(set input_is_latent=True), but the art picture is something wrong in style,to verify the correctness of latent vector, I use it to generate picture with 550000.pt model and the result is correct. I cann't find the diffrence between your script and mine, and why my face2art picture is wrong in style?(from left to right:ffhq picture/project picture/face2art picture)
imageimageimage

I guess it's related to the domain of the latent code. Latent code obtained from the default projection method is not quite editable. You may want to use in-domain inversion methods or image2stylegan variants. Or you can try different latent codes for the finetuned layers as in gradio_app.py. In that way, you can explore the latent space and see what is going on.

Thanks for your advice, I will try!