Enabling PyTorch as backend for the GAN training step
Opened this issue · 0 comments
mbarbetti commented
Starting from the v0.2.0 release PIDGAN is compatible with the new multi-backend Keras 3.
Keras 3 is a full rewrite of Keras that enables you to run your Keras workflows on top of either JAX, TensorFlow, or PyTorch, and that unlocks brand new large-scale model training and deployment capabilities.
At the moment, training GAN models is only possible by using the TensorFlow backend. For example, if we look at lines 173-183 of the Keras3-based GAN class, we have
def train_step(self, *args, **kwargs):
if keras.backend.backend() == "tensorflow":
return self._tf_train_step(*args, **kwargs)
elif keras.backend.backend() == "torch":
raise NotImplementedError("`train_step()` not implemented for the PyTorch backend")
elif keras.backend.backend() == "jax":
raise NotImplementedError("`train_step()` not implemented for the Jax backend")
The goal of this issue is to implement the train_step()
also for the PyTorch backend. In addition to the "plain" training step, also the Lipschitz regularization functions should be adapted to rely on the PyTorch backend.