hk.BatchNorm with jax.vmap
Opened this issue · 0 comments
reemabdelrazek30 commented
Is there any workaround that I can perform in order to use jax.vmap with hk.BatchNorm. should I use hk.vmap instead? should I write a custom batchNorm?
Opened this issue · 0 comments
Is there any workaround that I can perform in order to use jax.vmap with hk.BatchNorm. should I use hk.vmap instead? should I write a custom batchNorm?