kvfrans/variational-autoencoder

Graph definition forces input to have batchsize number of images

hannah-rae opened this issue · 1 comments

Suppose you want to restore the trained/saved model and test new inputs on. The way the model is written makes it inflexible to the number of inputs passed in, you must always pass in a number of examples equal to the batchsize the model was trained on. You can't for example run the model on just one test image, which would be nice.

Fixed this by replacing self.batchsize in the model definition with tf.shape(self.images)[0] and passing an extra argument to conv_transpose():

    def __init__(self):
        self.writer = tf.summary.FileWriter('./log')

        self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
        self.n_samples = self.mnist.train.num_examples

        self.n_hidden = 500
        self.n_z = 20
        self.batchsize = 100

        self.images = tf.placeholder(tf.float32, [None, 784], name='input')
        image_matrix = tf.reshape(self.images,[-1, 28, 28, 1], name='reshaped_input')
        z_mean, z_stddev = self.recognition(image_matrix)

        # Sample from Gaussian distribution
        samples = tf.random_normal([tf.shape(self.images)[0], self.n_z], 0, 1, dtype=tf.float32)
        guessed_z = z_mean + (z_stddev * samples)

        self.generated_images = self.generation(guessed_z)
        generated_flat = tf.reshape(self.generated_images, [tf.shape(self.images)[0], 28*28])

        self.generation_loss = -tf.reduce_sum(self.images * tf.log(1e-6 + generated_flat) + (1-self.images) * tf.log(1e-6 + 1 - generated_flat),1)

        self.latent_loss = 0.5 * tf.reduce_sum(tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev)) - 1,1)
        self.cost = tf.reduce_mean(self.generation_loss + self.latent_loss)
        self.optimizer = tf.train.AdamOptimizer(0.001).minimize(self.cost)


    # encoder
    def recognition(self, input_images):
        with tf.variable_scope("recognition"):
            h1 = lrelu(conv2d(input_images, 1, 16, "d_h1")) # 28x28x1 -> 14x14x16
            h2 = lrelu(conv2d(h1, 16, 32, "d_h2")) # 14x14x16 -> 7x7x32
            h2_flat = tf.reshape(h2,[tf.shape(self.images)[0], 7*7*32])

            w_mean = dense(h2_flat, 7*7*32, self.n_z, "w_mean")
            w_stddev = dense(h2_flat, 7*7*32, self.n_z, "w_stddev")

        return w_mean, w_stddev

    # decoder
    def generation(self, z):
        with tf.variable_scope("generation"):
            z_develop = dense(z, self.n_z, 7*7*32, scope='z_matrix')
            z_matrix = tf.nn.relu(tf.reshape(z_develop, [tf.shape(self.images)[0], 7, 7, 32]))
            h1 = tf.nn.relu(conv_transpose(z_matrix, 32, [tf.shape(self.images)[0], 14, 14, 16], name="g_h1"))
            h2 = conv_transpose(h1, 16, [tf.shape(self.images)[0], 28, 28, 1], name="g_h2")
            h2 = tf.nn.sigmoid(h2)

        return h2
def conv_transpose(x, prev_size, outputShape, name):
    with tf.variable_scope(name):
        # h, w, out, in
        w = tf.get_variable("w",[5,5, outputShape[-1], prev_size], initializer=tf.truncated_normal_initializer(stddev=0.02))
        b = tf.get_variable("b",[outputShape[-1]], initializer=tf.constant_initializer(0.0))
        convt = tf.nn.conv2d_transpose(x, w, output_shape=tf.stack(outputShape), strides=[1,2,2,1])
        return convt