Graph definition forces input to have batchsize number of images
hannah-rae opened this issue · 1 comments
hannah-rae commented
Suppose you want to restore the trained/saved model and test new inputs on. The way the model is written makes it inflexible to the number of inputs passed in, you must always pass in a number of examples equal to the batchsize the model was trained on. You can't for example run the model on just one test image, which would be nice.
hannah-rae commented
Fixed this by replacing self.batchsize in the model definition with tf.shape(self.images)[0] and passing an extra argument to conv_transpose():
def __init__(self):
self.writer = tf.summary.FileWriter('./log')
self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
self.n_samples = self.mnist.train.num_examples
self.n_hidden = 500
self.n_z = 20
self.batchsize = 100
self.images = tf.placeholder(tf.float32, [None, 784], name='input')
image_matrix = tf.reshape(self.images,[-1, 28, 28, 1], name='reshaped_input')
z_mean, z_stddev = self.recognition(image_matrix)
# Sample from Gaussian distribution
samples = tf.random_normal([tf.shape(self.images)[0], self.n_z], 0, 1, dtype=tf.float32)
guessed_z = z_mean + (z_stddev * samples)
self.generated_images = self.generation(guessed_z)
generated_flat = tf.reshape(self.generated_images, [tf.shape(self.images)[0], 28*28])
self.generation_loss = -tf.reduce_sum(self.images * tf.log(1e-6 + generated_flat) + (1-self.images) * tf.log(1e-6 + 1 - generated_flat),1)
self.latent_loss = 0.5 * tf.reduce_sum(tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev)) - 1,1)
self.cost = tf.reduce_mean(self.generation_loss + self.latent_loss)
self.optimizer = tf.train.AdamOptimizer(0.001).minimize(self.cost)
# encoder
def recognition(self, input_images):
with tf.variable_scope("recognition"):
h1 = lrelu(conv2d(input_images, 1, 16, "d_h1")) # 28x28x1 -> 14x14x16
h2 = lrelu(conv2d(h1, 16, 32, "d_h2")) # 14x14x16 -> 7x7x32
h2_flat = tf.reshape(h2,[tf.shape(self.images)[0], 7*7*32])
w_mean = dense(h2_flat, 7*7*32, self.n_z, "w_mean")
w_stddev = dense(h2_flat, 7*7*32, self.n_z, "w_stddev")
return w_mean, w_stddev
# decoder
def generation(self, z):
with tf.variable_scope("generation"):
z_develop = dense(z, self.n_z, 7*7*32, scope='z_matrix')
z_matrix = tf.nn.relu(tf.reshape(z_develop, [tf.shape(self.images)[0], 7, 7, 32]))
h1 = tf.nn.relu(conv_transpose(z_matrix, 32, [tf.shape(self.images)[0], 14, 14, 16], name="g_h1"))
h2 = conv_transpose(h1, 16, [tf.shape(self.images)[0], 28, 28, 1], name="g_h2")
h2 = tf.nn.sigmoid(h2)
return h2
def conv_transpose(x, prev_size, outputShape, name):
with tf.variable_scope(name):
# h, w, out, in
w = tf.get_variable("w",[5,5, outputShape[-1], prev_size], initializer=tf.truncated_normal_initializer(stddev=0.02))
b = tf.get_variable("b",[outputShape[-1]], initializer=tf.constant_initializer(0.0))
convt = tf.nn.conv2d_transpose(x, w, output_shape=tf.stack(outputShape), strides=[1,2,2,1])
return convt