ChunyuanLI/ALICE

Wrong in tf.get_collection

CuriousCat-7 opened this issue · 0 comments

dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")

In offical doc, tf.get_collection use re.match. So, the code dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
will get dvars = "discriminator/", "discriminator_pair/", "discriminator_xx/", "discriminator_zz/"

No need to write train_disc_op = opt.minimize(disc_loss, var_list=dvars + dvars_xz + dvars_xx + dvars_zz) but train_disc_op = opt.minimize(disc_loss, var_list=dvars). Or you should change the name of "discriminator" to "discriminator_orig"

I know it will not cause any problems in the code, but it misleads others

Best, regard