[Pix2Pix] Use fit_generator to speed up training process
wave-transmitter opened this issue · 0 comments
wave-transmitter commented
Hello GAN-comrades,
in this implementation train_on_batch
is used to train the network, which seems the best option since it is required a two-step process of training separately generator and discriminator during an iteration.
# ---------------------
# Train Discriminator
# ---------------------
# Condition on B and generate a translated version
fake_A = self.generator.predict(imgs_B)
# Train the discriminators (original images = real / generated = Fake)
d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# -----------------
# Train Generator
# -----------------
# Train the generators
g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])
One the other hand, fit_generator
allows to speed up the process and deal with the CPU bottleneck of data preprocessing via the workers
argument. I was wondering if fit_generator
could be used somehow in our case.
To be honest I cannot imagine how this can be done, since generator and discriminator must be alternating trained per batch. Any ideas or tips how to implement a fit_generator
approach or somehow employ more cpu workers for training?