google-deepmind/dm-haiku

BatchNorm No value for 'average' ValueError, perhaps set an init function?

noahadhikari opened this issue · 2 comments

I currently have the following module and am trying to get a simple forward pass working.

class TopBranch(hk.Module):
    def __init__(self, name: Optional[str] = None):
        super().__init__(name=name)
        self.conv32_5_5 = hk.Conv2D(32, (5, 5), stride=(1, 3), name="conv32_5_5")
        self.bn = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9)
        
    def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray:
        x = self.conv32_5_5(x)
        x = self.bn(x, is_training=is_training)
        x = jax.nn.relu(x)
        return x

When I remove the batch normalization submodule, the code works fine, and is able to output something. However, when I run the code as defined above, I get the following error:
ValueError: No value for 'average' in 'posteriorgram_model/~/top_branch/~/batch_norm/~/mean_ema', perhaps set an init function?

I have tried my best to replicate example code and do not see where I'm going wrong.
Locally, I'm using dm-haiku=0.0.9, jax=0.3.25, jaxlib=0.3.25, on python=3.8.10, and have replicated this issue on Colab with the default package versions there.
Thank you for your guidance in advance!

Below is the stack trace (TopBranch is located inside PosteriorgramModel):

ValueError                                Traceback (most recent call last)
Cell In[25], line 15
     13 model = hk.transform(f)
     14 rng = jax.random.PRNGKey(0)
---> 15 params = model.init(rng, out, False)
     16 model.apply(params, rng=rng, x=out, is_training=False)
     17 print("hi")

File ~/.local/lib/python3.8/site-packages/haiku/_src/transform.py:114, in without_state.<locals>.init_fn(*args, **kwargs)
    113 def init_fn(*args, **kwargs):
--> 114   params, state = f.init(*args, **kwargs)
    115   if state:
    116     raise ValueError("If your transformed function uses `hk.{get,set}_state` "
    117                      "then use `hk.transform_with_state`.")

File ~/.local/lib/python3.8/site-packages/haiku/_src/transform.py:338, in transform_with_state.<locals>.init_fn(rng, *args, **kwargs)
    336 with base.new_context(rng=rng) as ctx:
    337   try:
--> 338     f(*args, **kwargs)
    339   except jax.errors.UnexpectedTracerError as e:
    340     raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

Cell In[25], line 12, in f(x, is_training)
     10 a = PosteriorgramModel()
     11 # a = hk.IdentityCore()
---> 12 return a(x, is_training)

File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:426, in wrap_method.<locals>.wrapped(self, *args, **kwargs)
    423   if method_name != "__call__":
    424     f = jax.named_call(f, name=method_name)
--> 426 out = f(*args, **kwargs)
    428 # Module names are set in the constructor. If `f` is the constructor then
    429 # its name will only be set **after** `f` has run. For methods other
    430 # than `__init__` we need the name before running in order to wrap their
    431 # execution with `named_call`.
    432 if module_name is None:

File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     72 @wraps(func)
     73 def inner(*args, **kwds):
     74     with self._recreate_cm():
---> 75         return func(*args, **kwds)

File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:272, in run_interceptors(bound_method, method_name, self, *args, **kwargs)
    270 """Runs any method interceptors or the original method."""
    271 if not interceptor_stack:
--> 272   return bound_method(*args, **kwargs)
    274 ctx = MethodContext(module=self,
    275                     method_name=method_name,
    276                     orig_method=bound_method)
    277 interceptor_stack_copy = interceptor_stack.clone()

File /mnt/c/Users/noaha/berkeley/berkeley-fa22/cs182/project/v3/new_model_in_jax.py:91, in PosteriorgramModel.__call__(self, audio, is_training)
     89 yp = self.yp_branch(processed, is_training)
     90 yn = self.yn_branch(yp)
---> 91 top = self.top_branch(processed, is_training)
     92 concat = jax.numpy.concatenate([top, yn], axis=-1) 
     93 yo = self.yo_branch(concat)

File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:426, in wrap_method.<locals>.wrapped(self, *args, **kwargs)
    423   if method_name != "__call__":
    424     f = jax.named_call(f, name=method_name)
--> 426 out = f(*args, **kwargs)
    428 # Module names are set in the constructor. If `f` is the constructor then
    429 # its name will only be set **after** `f` has run. For methods other
    430 # than `__init__` we need the name before running in order to wrap their
    431 # execution with `named_call`.
    432 if module_name is None:

File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     72 @wraps(func)
     73 def inner(*args, **kwds):
     74     with self._recreate_cm():
---> 75         return func(*args, **kwds)

File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:272, in run_interceptors(bound_method, method_name, self, *args, **kwargs)
    270 """Runs any method interceptors or the original method."""
    271 if not interceptor_stack:
--> 272   return bound_method(*args, **kwargs)
    274 ctx = MethodContext(module=self,
    275                     method_name=method_name,
    276                     orig_method=bound_method)
    277 interceptor_stack_copy = interceptor_stack.clone()

File /mnt/c/Users/noaha/berkeley/berkeley-fa22/cs182/project/v3/new_model_in_jax.py:30, in TopBranch.__call__(self, x, is_training)
     28 def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray:
     29     x = self.conv32_5_5(x)
---> 30     x = self.bn(x, is_training=is_training)
     31     x = jax.nn.relu(x)
     32     return x

File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:426, in wrap_method.<locals>.wrapped(self, *args, **kwargs)
    423   if method_name != "__call__":
    424     f = jax.named_call(f, name=method_name)
--> 426 out = f(*args, **kwargs)
    428 # Module names are set in the constructor. If `f` is the constructor then
    429 # its name will only be set **after** `f` has run. For methods other
    430 # than `__init__` we need the name before running in order to wrap their
    431 # execution with `named_call`.
    432 if module_name is None:

File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     72 @wraps(func)
     73 def inner(*args, **kwds):
     74     with self._recreate_cm():
---> 75         return func(*args, **kwds)

File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:272, in run_interceptors(bound_method, method_name, self, *args, **kwargs)
    270 """Runs any method interceptors or the original method."""
    271 if not interceptor_stack:
--> 272   return bound_method(*args, **kwargs)
    274 ctx = MethodContext(module=self,
    275                     method_name=method_name,
    276                     orig_method=bound_method)
    277 interceptor_stack_copy = interceptor_stack.clone()

File ~/.local/lib/python3.8/site-packages/haiku/_src/batch_norm.py:181, in BatchNorm.__call__(self, inputs, is_training, test_local_stats, scale, offset)
    179   var = mean_of_squares - jnp.square(mean)
    180 else:
--> 181   mean = self.mean_ema.average.astype(inputs.dtype)
    182   var = self.var_ema.average.astype(inputs.dtype)
    184 if is_training:

File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:426, in wrap_method.<locals>.wrapped(self, *args, **kwargs)
    423   if method_name != "__call__":
    424     f = jax.named_call(f, name=method_name)
--> 426 out = f(*args, **kwargs)
    428 # Module names are set in the constructor. If `f` is the constructor then
    429 # its name will only be set **after** `f` has run. For methods other
    430 # than `__init__` we need the name before running in order to wrap their
    431 # execution with `named_call`.
    432 if module_name is None:

File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     72 @wraps(func)
     73 def inner(*args, **kwds):
     74     with self._recreate_cm():
---> 75         return func(*args, **kwds)

File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     72 @wraps(func)
     73 def inner(*args, **kwds):
     74     with self._recreate_cm():
---> 75         return func(*args, **kwds)

File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:272, in run_interceptors(bound_method, method_name, self, *args, **kwargs)
    270 """Runs any method interceptors or the original method."""
    271 if not interceptor_stack:
--> 272   return bound_method(*args, **kwargs)
    274 ctx = MethodContext(module=self,
    275                     method_name=method_name,
    276                     orig_method=bound_method)
    277 interceptor_stack_copy = interceptor_stack.clone()

File ~/.local/lib/python3.8/site-packages/haiku/_src/moving_averages.py:137, in ExponentialMovingAverage.average(self)
    135 @property
    136 def average(self):
--> 137   return hk.get_state("average")

File ~/.local/lib/python3.8/site-packages/haiku/_src/base.py:448, in replaceable.<locals>.wrapped(*args, **kwargs)
    446 @functools.wraps(f)
    447 def wrapped(*args, **kwargs):
--> 448   return wrapped._current(*args, **kwargs)

File ~/.local/lib/python3.8/site-packages/haiku/_src/base.py:1099, in get_state(name, shape, dtype, init)
   1097 if value is None:
   1098   if init is None:
-> 1099     raise ValueError(f"No value for {name!r} in {bundle_name!r}, perhaps "
   1100                      "set an init function?")
   1101   if shape is None or dtype is None:
   1102     raise ValueError(f"Must provide shape and dtype to initialize {name!r} "
   1103                      f"in {bundle_name!r}.")

Resolved, I needed to set is_training=True in the init call for my module.

Can you show how you fixed it? I met the same question T_T.