Wrong in tf.get_collection
CuriousCat-7 opened this issue · 0 comments
CuriousCat-7 commented
Line 255 in 2cb1b2b
In offical doc, tf.get_collection use re.match. So, the code dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
will get dvars = "discriminator/", "discriminator_pair/", "discriminator_xx/", "discriminator_zz/"
No need to write train_disc_op = opt.minimize(disc_loss, var_list=dvars + dvars_xz + dvars_xx + dvars_zz)
but train_disc_op = opt.minimize(disc_loss, var_list=dvars)
. Or you should change the name of "discriminator" to "discriminator_orig"
I know it will not cause any problems in the code, but it misleads others
Best, regard