BatchNorm No value for 'average' ValueError, perhaps set an init function?
noahadhikari opened this issue · 2 comments
I currently have the following module and am trying to get a simple forward pass working.
class TopBranch(hk.Module):
def __init__(self, name: Optional[str] = None):
super().__init__(name=name)
self.conv32_5_5 = hk.Conv2D(32, (5, 5), stride=(1, 3), name="conv32_5_5")
self.bn = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9)
def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray:
x = self.conv32_5_5(x)
x = self.bn(x, is_training=is_training)
x = jax.nn.relu(x)
return x
When I remove the batch normalization submodule, the code works fine, and is able to output something. However, when I run the code as defined above, I get the following error:
ValueError: No value for 'average' in 'posteriorgram_model/~/top_branch/~/batch_norm/~/mean_ema', perhaps set an init function?
I have tried my best to replicate example code and do not see where I'm going wrong.
Locally, I'm using dm-haiku=0.0.9, jax=0.3.25, jaxlib=0.3.25, on python=3.8.10, and have replicated this issue on Colab with the default package versions there.
Thank you for your guidance in advance!
Below is the stack trace (TopBranch is located inside PosteriorgramModel):
ValueError Traceback (most recent call last)
Cell In[25], line 15
13 model = hk.transform(f)
14 rng = jax.random.PRNGKey(0)
---> 15 params = model.init(rng, out, False)
16 model.apply(params, rng=rng, x=out, is_training=False)
17 print("hi")
File ~/.local/lib/python3.8/site-packages/haiku/_src/transform.py:114, in without_state.<locals>.init_fn(*args, **kwargs)
113 def init_fn(*args, **kwargs):
--> 114 params, state = f.init(*args, **kwargs)
115 if state:
116 raise ValueError("If your transformed function uses `hk.{get,set}_state` "
117 "then use `hk.transform_with_state`.")
File ~/.local/lib/python3.8/site-packages/haiku/_src/transform.py:338, in transform_with_state.<locals>.init_fn(rng, *args, **kwargs)
336 with base.new_context(rng=rng) as ctx:
337 try:
--> 338 f(*args, **kwargs)
339 except jax.errors.UnexpectedTracerError as e:
340 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
Cell In[25], line 12, in f(x, is_training)
10 a = PosteriorgramModel()
11 # a = hk.IdentityCore()
---> 12 return a(x, is_training)
File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:426, in wrap_method.<locals>.wrapped(self, *args, **kwargs)
423 if method_name != "__call__":
424 f = jax.named_call(f, name=method_name)
--> 426 out = f(*args, **kwargs)
428 # Module names are set in the constructor. If `f` is the constructor then
429 # its name will only be set **after** `f` has run. For methods other
430 # than `__init__` we need the name before running in order to wrap their
431 # execution with `named_call`.
432 if module_name is None:
File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
72 @wraps(func)
73 def inner(*args, **kwds):
74 with self._recreate_cm():
---> 75 return func(*args, **kwds)
File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:272, in run_interceptors(bound_method, method_name, self, *args, **kwargs)
270 """Runs any method interceptors or the original method."""
271 if not interceptor_stack:
--> 272 return bound_method(*args, **kwargs)
274 ctx = MethodContext(module=self,
275 method_name=method_name,
276 orig_method=bound_method)
277 interceptor_stack_copy = interceptor_stack.clone()
File /mnt/c/Users/noaha/berkeley/berkeley-fa22/cs182/project/v3/new_model_in_jax.py:91, in PosteriorgramModel.__call__(self, audio, is_training)
89 yp = self.yp_branch(processed, is_training)
90 yn = self.yn_branch(yp)
---> 91 top = self.top_branch(processed, is_training)
92 concat = jax.numpy.concatenate([top, yn], axis=-1)
93 yo = self.yo_branch(concat)
File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:426, in wrap_method.<locals>.wrapped(self, *args, **kwargs)
423 if method_name != "__call__":
424 f = jax.named_call(f, name=method_name)
--> 426 out = f(*args, **kwargs)
428 # Module names are set in the constructor. If `f` is the constructor then
429 # its name will only be set **after** `f` has run. For methods other
430 # than `__init__` we need the name before running in order to wrap their
431 # execution with `named_call`.
432 if module_name is None:
File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
72 @wraps(func)
73 def inner(*args, **kwds):
74 with self._recreate_cm():
---> 75 return func(*args, **kwds)
File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:272, in run_interceptors(bound_method, method_name, self, *args, **kwargs)
270 """Runs any method interceptors or the original method."""
271 if not interceptor_stack:
--> 272 return bound_method(*args, **kwargs)
274 ctx = MethodContext(module=self,
275 method_name=method_name,
276 orig_method=bound_method)
277 interceptor_stack_copy = interceptor_stack.clone()
File /mnt/c/Users/noaha/berkeley/berkeley-fa22/cs182/project/v3/new_model_in_jax.py:30, in TopBranch.__call__(self, x, is_training)
28 def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray:
29 x = self.conv32_5_5(x)
---> 30 x = self.bn(x, is_training=is_training)
31 x = jax.nn.relu(x)
32 return x
File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:426, in wrap_method.<locals>.wrapped(self, *args, **kwargs)
423 if method_name != "__call__":
424 f = jax.named_call(f, name=method_name)
--> 426 out = f(*args, **kwargs)
428 # Module names are set in the constructor. If `f` is the constructor then
429 # its name will only be set **after** `f` has run. For methods other
430 # than `__init__` we need the name before running in order to wrap their
431 # execution with `named_call`.
432 if module_name is None:
File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
72 @wraps(func)
73 def inner(*args, **kwds):
74 with self._recreate_cm():
---> 75 return func(*args, **kwds)
File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:272, in run_interceptors(bound_method, method_name, self, *args, **kwargs)
270 """Runs any method interceptors or the original method."""
271 if not interceptor_stack:
--> 272 return bound_method(*args, **kwargs)
274 ctx = MethodContext(module=self,
275 method_name=method_name,
276 orig_method=bound_method)
277 interceptor_stack_copy = interceptor_stack.clone()
File ~/.local/lib/python3.8/site-packages/haiku/_src/batch_norm.py:181, in BatchNorm.__call__(self, inputs, is_training, test_local_stats, scale, offset)
179 var = mean_of_squares - jnp.square(mean)
180 else:
--> 181 mean = self.mean_ema.average.astype(inputs.dtype)
182 var = self.var_ema.average.astype(inputs.dtype)
184 if is_training:
File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:426, in wrap_method.<locals>.wrapped(self, *args, **kwargs)
423 if method_name != "__call__":
424 f = jax.named_call(f, name=method_name)
--> 426 out = f(*args, **kwargs)
428 # Module names are set in the constructor. If `f` is the constructor then
429 # its name will only be set **after** `f` has run. For methods other
430 # than `__init__` we need the name before running in order to wrap their
431 # execution with `named_call`.
432 if module_name is None:
File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
72 @wraps(func)
73 def inner(*args, **kwds):
74 with self._recreate_cm():
---> 75 return func(*args, **kwds)
File /usr/lib/python3.8/contextlib.py:75, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
72 @wraps(func)
73 def inner(*args, **kwds):
74 with self._recreate_cm():
---> 75 return func(*args, **kwds)
File ~/.local/lib/python3.8/site-packages/haiku/_src/module.py:272, in run_interceptors(bound_method, method_name, self, *args, **kwargs)
270 """Runs any method interceptors or the original method."""
271 if not interceptor_stack:
--> 272 return bound_method(*args, **kwargs)
274 ctx = MethodContext(module=self,
275 method_name=method_name,
276 orig_method=bound_method)
277 interceptor_stack_copy = interceptor_stack.clone()
File ~/.local/lib/python3.8/site-packages/haiku/_src/moving_averages.py:137, in ExponentialMovingAverage.average(self)
135 @property
136 def average(self):
--> 137 return hk.get_state("average")
File ~/.local/lib/python3.8/site-packages/haiku/_src/base.py:448, in replaceable.<locals>.wrapped(*args, **kwargs)
446 @functools.wraps(f)
447 def wrapped(*args, **kwargs):
--> 448 return wrapped._current(*args, **kwargs)
File ~/.local/lib/python3.8/site-packages/haiku/_src/base.py:1099, in get_state(name, shape, dtype, init)
1097 if value is None:
1098 if init is None:
-> 1099 raise ValueError(f"No value for {name!r} in {bundle_name!r}, perhaps "
1100 "set an init function?")
1101 if shape is None or dtype is None:
1102 raise ValueError(f"Must provide shape and dtype to initialize {name!r} "
1103 f"in {bundle_name!r}.")
Resolved, I needed to set is_training=True
in the init call for my module.
Can you show how you fixed it? I met the same question T_T.