Correct way to integrate tf2jax output with a hk.Module
rdilip opened this issue · 1 comments
I'm looking at the tf2jax project, and the ability to take TensorFlow pretrained modules and convert them to haiku would be a really useful functionality, since there aren't a lot of available Haiku checkpoints. A typical application is something like
import tf2jax
import tensorflow as tf
import jax.numpy as jnp
jax_func, jax_params = tf2jax.convert(tf.function(tf.keras.applications.resnet50.ResNet50()), jnp.zeros((1, 224, 224, 3)))
So now I have a function and parameters to do what I want, but I need to insert them into a Haiku module. How should I do this? I'm hoping for some way to eventually be able to
class MyModule(hk.Module):
def __call__(self, x):
x = ResNet50Jax()(x)
x = # some other module specific stuff
return x
that I can then proceed with hk.transform
as usual. I wasn't able to find an obvious way to do this. Any thoughts?
More broadly, is it a bad idea to rely on tf2jax for checkpoints, versus perhaps making the model directly in Haiku and manually copying over weights from PyTorch/tensorflow?
Hi @rdilip, here is an example of integrating Haiku and tf2jax: https://colab.research.google.com/gist/tomhennigan/5a6a264bccbbe8ecac1b475ad8049c72/example-of-using-tf2jax-with-haiku.ipynb
tf2jax does introduce quite a bit of complexity/indirection (for example if you want to fine tune the model and take gradients through the tf code). I think it might be worth trying to take some pre-existing checkpoint and adapting it to work with Haiku.
Alternatively, if you have access to GPU/TPUs for training then we provide a training script to train a resnet50 model on imagenet.