Multiple functions sharing some gradients
ludvigk opened this issue · 1 comments
Hi,
I was wondering if there is a way to track gradients for multiple functions that partially share some computation. As an example, in GANs, the generator and discriminator loss both depend on the output of the discriminator, hence there is no need to calculate this part twice for each update. In Tensorflow, this can be done with multiple gradient tapes.
This example is taken from the Tensorflow tutorials page https://www.tensorflow.org/tutorials/generative/dcgan.
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
I just realized that in a simple case like this one can just take the gradient of discriminator output separately, and use the chain rule to avoid extra computation. It won't be pretty when the functions are more entangled, but I suppose it's the best option.
I just want to add that the chain rule method doesn't work when the shared computation does not result in a scalar. I am wondering if there is a more general way of doing these calculations.