Small typo in the example
netrunner-exe opened this issue · 1 comments
netrunner-exe commented
Hello @felixrosberg! I decided to test the example provided in the README.md and found one typo in the import - from network.layers import AdaptiveAttention, AdaIn
instead of AdaIN
. I also have small suggestions for solving the points that I encountered, I have indicated in more detail in the comments in my example.
from networks.layers import AdaIN, AdaptiveAttention
from tensorflow_addons.layers import InstanceNormalization
from tensorflow.keras.models import load_model
from PIL import Image
import numpy as np
import cv2
# To hide "WARNING:root:The given value for groups will be overwritten."
import logging
logging.getLogger().setLevel(logging.ERROR)
# To hide very long tensorflow log like:
# Model: "model"
# __________________________________________________________________________________________________
# Layer (type) Output Shape Param # Connected to
# ==================================================================================================
# input_1 (InputLayer) [(None, 256, 256, 3 0 []
#
# Can be added directly to networks/layers.py
import tensorflow as tf
tf.keras.utils.disable_interactive_logging()
# Add compile=False to hide
# "WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually."
model = load_model("path/to/model.h5", compile=False, custom_objects={"AdaIN": AdaIN, "AdaptiveAttention": AdaptiveAttention, "InstanceNormalization": InstanceNormalization})
arcface = load_model("path/to/arcface.h5", compile=False)
# target and source images need to be properly cropeed and aligned
target = np.asarray(Image.open("path/to/target_face.png").resize((256, 256)))
source = np.asarray(Image.open("path/to/source_face.png").resize((112, 112)))
source_z = arcface(np.expand_dims(source / 255.0, axis=0))
face_swap = model([np.expand_dims((target - 127.5) / 127.5, axis=0), source_z]).numpy()
face_swap = (face_swap[0] + 1) / 2
face_swap = np.clip(face_swap * 255, 0, 255).astype('uint8')
cv2.imwrite("./swapped_face.png", cv2.cvtColor(face_swap, cv2.COLOR_BGR2RGB))
felixrosberg commented
Hello @netrunner-exe
Thank you very much, I will update the README with this!