Include a factor of x8
Opened this issue · 1 comments
EmanuelCastanho commented
Hi,
I am trying to include on SR4RS a factor of 8. For this, I changed the following:
Inside constants.py
:
factors = [1, 2, 4, 8]
Inside network.py
, on the discriminator
function:
def discriminator(hr_images, scope, dim):
"""
Discriminator
"""
conv_lrelu = partial(conv, activation_fn=lrelu)
def _combine(x, newdim, name, z=None):
x = conv_lrelu(x, newdim, 1, 1, name)
y = x if z is None else tf.concat([x, z], axis=-1)
return minibatch_stddev_layer(y)
def _conv_downsample(x, dim, ksize, name):
y = conv2d_downscale2d(x, dim, ksize, name=name)
y = lrelu(y)
return y
with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
with tf.compat.v1.variable_scope("res_8x"):
net = _combine(hr_images[1], newdim=dim, name="from_input")
net = conv_lrelu(net, dim, 3, 1, "conv1")
net = conv_lrelu(net, dim, 3, 1, "conv2")
net = conv_lrelu(net, dim, 3, 1, "conv3")
net = _conv_downsample(net, dim, 3, "conv_down")
with tf.compat.v1.variable_scope("res_4x"):
net = _combine(hr_images[2], newdim=dim, name="from_input", z=net)
dim *= 2
net = conv_lrelu(net, dim, 3, 1, "conv1")
net = conv_lrelu(net, dim, 3, 1, "conv2")
net = conv_lrelu(net, dim, 3, 1, "conv3")
net = _conv_downsample(net, dim, 3, "conv_down")
with tf.compat.v1.variable_scope("res_2x"):
net = _combine(hr_images[4], newdim=dim, name="from_input", z=net)
dim *= 2
net = conv_lrelu(net, dim, 3, 1, "conv1")
net = conv_lrelu(net, dim, 3, 1, "conv2")
net = conv_lrelu(net, dim, 3, 1, "conv3")
net = _conv_downsample(net, dim, 3, "conv_down")
with tf.compat.v1.variable_scope("res_1x"):
net = _combine(hr_images[8], newdim=dim, name="from_input", z=net)
dim *= 2
net = conv_lrelu(net, dim, 3, 1, "conv")
net = _conv_downsample(net, dim, 3, "conv_down")
with tf.compat.v1.variable_scope("bn"):
dim *= 2
net = conv_lrelu(net, dim, 3, 1, "conv1")
net = _conv_downsample(net, dim, 3, "conv_down1")
net = minibatch_stddev_layer(net)
# dense
dim *= 2
net = conv_lrelu(net, dim, 1, 1, "dense1")
net = conv(net, 1, 1, 1, "dense2")
net = tf.reduce_mean(net, axis=[1, 2])
return net
Inside network.py
, on the generator
function:
def generator(lr_image, scope, nchannels, nresblocks, dim):
"""
Generator
"""
hr_images = dict()
def conv_upsample(x, dim, ksize, name):
y = upscale2d_conv2d(x, dim, ksize, name)
y = blur2d(y)
y = lrelu(y)
y = pixel_norm(y)
return y
def _residule_block(x, dim, name):
with tf.compat.v1.variable_scope(name):
y = conv(x, dim, 3, 1, "conv1")
y = lrelu(y)
y = pixel_norm(y)
y = conv(y, dim, 3, 1, "conv2")
y = pixel_norm(y)
return y + x
def conv_bn(x, dim, ksize, name):
y = conv(x, dim, ksize, 1, name)
y = lrelu(y)
y = pixel_norm(y)
return y
def _make_output(net, factor):
hr_images[factor] = conv(net, nchannels, 1, 1, "output")
with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
with tf.compat.v1.variable_scope("encoder"):
net = lrelu(conv(lr_image, dim, 9, 1, "conv1_9x9"))
conv1 = net
for i in range(nresblocks):
net = _residule_block(net, dim=dim, name="ResBlock{}".format(i))
with tf.compat.v1.variable_scope("res_1x"):
net = conv(net, dim, 3, 1, "conv1")
net = pixel_norm(net)
net += conv1
_make_output(net, factor=8)
with tf.compat.v1.variable_scope("res_2x"):
net = conv_upsample(net, 4 * dim, 3, "conv_upsample")
net = conv_bn(net, 4 * dim, 3, "conv1")
net = conv_bn(net, 4 * dim, 3, "conv2")
net = conv_bn(net, 4 * dim, 5, "conv3")
_make_output(net, factor=4)
with tf.compat.v1.variable_scope("res_4x"):
net = conv_upsample(net, 4 * dim, 3, "conv_upsample")
net = conv_bn(net, 4 * dim, 3, "conv1")
net = conv_bn(net, 4 * dim, 3, "conv2")
net = conv_bn(net, 4 * dim, 9, "conv3")
_make_output(net, factor=2)
with tf.compat.v1.variable_scope("res_8x"):
net = conv_upsample(net, 4 * dim, 3, "conv_upsample")
net = conv_bn(net, 4 * dim, 3, "conv1")
net = conv_bn(net, 4 * dim, 3, "conv2")
net = conv_bn(net, 4 * dim, 9, "conv3")
_make_output(net, factor=1)
return hr_images
Any ideas or suggestions?
remicres commented
Hi @EmanuelCastanho, I guess you're on the right track. You also have to modify the main program to feed the various resampled inputs.