google-deepmind/dm-haiku

[Minor bug] error check in hk.transform_with_state

albertfgu opened this issue · 1 comments

On this line: https://github.com/deepmind/dm-haiku/blob/6f2769e8c8dd35b3fc0e66905c877debea7d525f/haiku/_src/transform.py#L441

the check should be if state is None instead of if state. The current version will trigger if a state is passed in that is empty (i.e. state={}) which leads to a confusing incorrect error message with the wrong method signature

ValueError: Apply must be called with an RNG as the second argument, the required signature is: `apply(params, rng, *a, **k)`

Thanks for the report @albertfgu , I'll get a fix in for this. Here is a minimal reproducer:

>>> f = hk.transform_with_state(lambda: None)
>>> f.apply({}, {}, 'a')
ValueError: Apply must be called with an RNG as the second argument, the required signature is: `apply(params, rng, *a, **k)`. The object was of type <class 'str'>: a