eriklindernoren/Keras-GAN

[Pix2Pix] Use fit_generator to speed up training process

wave-transmitter opened this issue · 0 comments

Hello GAN-comrades,

in this implementation train_on_batch is used to train the network, which seems the best option since it is required a two-step process of training separately generator and discriminator during an iteration.

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Condition on B and generate a translated version
            fake_A = self.generator.predict(imgs_B)

            # Train the discriminators (original images = real / generated = Fake)
            d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
            d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # -----------------
            #  Train Generator
            # -----------------

            # Train the generators
            g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])

One the other hand, fit_generator allows to speed up the process and deal with the CPU bottleneck of data preprocessing via the workers argument. I was wondering if fit_generator could be used somehow in our case.

To be honest I cannot imagine how this can be done, since generator and discriminator must be alternating trained per batch. Any ideas or tips how to implement a fit_generator approach or somehow employ more cpu workers for training?