zsyzzsoft/co-mod-gan

Issue with run_generator.py

Opened this issue · 3 comments

I have trained my own dataset by using the dataset_tools/create_from_images.py and when running the run_generator.py, I get the issue

when using fake = Gs.run(latent, None, real, mask, truncation_psi=truncation)[0]
Error:
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict, tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found. (0) Invalid argument: input must be 4-dimensional[1,512,256] [[{{node Gs/_Run/Gs/G_synthesis/E_256x256/FromRGB/Conv2D}}]] [[Gs/_Run/Gs/images_out/_1407]] (1) Invalid argument: input must be 4-dimensional[1,512,256] [[{{node Gs/_Run/Gs/G_synthesis/E_256x256/FromRGB/Conv2D}}]] 0 successful operations. 0 derived errors ignored.

when using fake = Gs.run(latent, None, real[np.newaxis], mask, truncation_psi=truncation)[0]
Error:
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict, tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found. (0) Invalid argument: ConcatOp : Ranks of all input tensors should match: shape[0] = [1,256,256] vs. shape[1] = [1,3,256,256] [[{{node Gs/_Run/Gs/G_synthesis/concat}}]] [[Gs/_Run/Gs/images_out/_1407]] (1) Invalid argument: ConcatOp : Ranks of all input tensors should match: shape[0] = [1,256,256] vs. shape[1] = [1,3,256,256] [[{{node Gs/_Run/Gs/G_synthesis/concat}}]] 0 successful operations. 0 derived errors ignored.

Details:
Real image shape : [3,256,256]
Mask shape : [1,256,256]
Dynamic range applied on real image

Please provide your code if you have modified it

Please provide your code if you have modified it

Here is my changed run_generator script

import argparse
import numpy as np
import PIL.Image
from numpy.testing._private.utils import import_nose

from dnnlib import tflib
from training import misc
import os
import glob
import cv2
import numpy as np
from dataset_tools.create_from_images import create_from_image
import tensorflow as tf
from tensorflow.keras.preprocessing.image import img_to_array
from PIL import Image

if not os.path.exists('co-mod-gan_imgs2'):
    os.mkdir('co-mod-gan_imgs2')

def create_from_images(checkpoint, image, mask, truncation=None):
    real_img = PIL.Image.open(image).convert('RGB')
    real_img = real_img.resize((256, 256))
    f_name = image.split('/')[-1]
    real_img.save('co-mod-gan_imgs2/o_'+f_name)
    real = np.array(real_img).transpose([2, 0, 1])
    real = misc.adjust_dynamic_range(real, [0, 255], [-1, 1])
    #real_save = Image.fromarray(real)
    #real_save.save('co-mod-gan_imgs2/n_'+f_name)
    mask = Image.fromarray(mask).convert('1')
    mask.save('co-mod-gan_imgs2/m_'+f_name)
    mask = np.asarray(mask, dtype=np.float32)[np.newaxis]
    print(mask.shape)
    
    tflib.init_tf()
    _, _, Gs = misc.load_pkl(checkpoint)
    print('loaded network from pkl...')
    latent = np.random.randn(1, *Gs.input_shape[1:])
    #print((latent.shape, real[np.newaxis].shape, mask[np.newaxis].shape))
    fake = Gs.run(latent, None, real[np.newaxis], mask, truncation_psi=truncation)[0]
    fake = misc.adjust_dynamic_range(fake, [-1, 1], [0, 255])
    fake = fake.clip(0, 255).astype(np.uint8).transpose([1, 2, 0])
    fake = PIL.Image.fromarray(fake).convert('RGB')
    fake.save('co-mod-gan_imgs2/out_'+f_name)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--checkpoint', help='Network checkpoint path', required=True)
    parser.add_argument('-i', '--image', help='Original image path', required=True)
    parser.add_argument('-m', '--mask', help='Mask path', required=True)
    parser.add_argument('-t', '--truncation', help='Truncation psi for the trade-off between quality and diversity. Defaults to 1.', default=None)
    args = parser.parse_args()
    create_from_images(**vars(args))

def apply_predefined():
    pkl_path = '/home/sulugodu/ma_gan/MA/Anonymization/co-mod-gan/results/00016-co-mod-gan-celeba-1gpu/network-snapshot-010815.pkl'
    data_path = '/home/sulugodu/ma_gan/MA/Anonymization/Datasets/faces_bbox/images/test'
    imgs = glob.glob(data_path+'/*')[:10]
    for i, img in enumerate(imgs):
        mask_path = img.replace('Datasets/faces_bbox/images/test', 'u2net_seg_imgs')
        mask = PIL.Image.open(mask_path).resize((256, 256)).convert('RGB')
        mask = cv2.bitwise_not(np.array(mask))
        mask = cv2.threshold(mask, 15, 255, cv2.THRESH_BINARY)
        #print(mask[1].shape)
        #cv2.imwrite('co-mod-gan_imgs2/mask_'+mask_path.split('/')[-1]+'.jpg', mask[1])
        create_from_images(pkl_path, img, mask[1])

if __name__ == "__main__":
    #main()
    apply_predefined()

You need to make the image's shape to [1, 3, 256, 256] and the mask's shape to [1, 1, 256, 256]. In your code you should call img[np.newaxis] and mask[1][np.newaxis]