flaport/sax

error with `flax` < 0.8

Closed this issue · 1 comments

Just a heads up, with jax==0.4.26 and flax==0.7.* I was getting errors when importing sax installed recently with pip install --upgrade sax

>>> import sax
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/homebrew/lib/python3.11/site-packages/sax/__init__.py", line 15, in <module>
    from flax.core.frozen_dict import FrozenDict as FrozenDict
  File "/opt/homebrew/lib/python3.11/site-packages/flax/__init__.py", line 19, in <module>
    from .configurations import (
  File "/opt/homebrew/lib/python3.11/site-packages/flax/configurations.py", line 93, in <module>
    flax_filter_frames = define_bool_state(
                         ^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/lib/python3.11/site-packages/flax/configurations.py", line 42, in define_bool_state
    return jax_config.define_bool_state('flax_' + name, default, help)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Config' object has no attribute 'define_bool_state'

It was fixed when pip install --upgrade flax installed 0.8.*.

Not sure if this is something you might need to know for setting requirements but thought I'd let you know / create a paper trail in case others see this, feel free to close

im having the same issue

@flaport