google-deepmind/dm-haiku

hk.BatchNorm with jax.vmap

Opened this issue · 0 comments

Is there any workaround that I can perform in order to use jax.vmap with hk.BatchNorm. should I use hk.vmap instead? should I write a custom batchNorm?