google-deepmind/optax

"dadapt_adamw, combine, prodigy" occur "TypeError: 'type' object is not subscriptable."

sbb2002 opened this issue · 1 comments

Hello. :)
I'm optax user in South Korea.

I was making some model using flax.
But when I import optax, it occurs 3 errors!
So I read some msgs, and solve those!

Therefore, I feel to share those tips for other optax-users, and request to modify those error for optax dev-team, honestly.

*Case:
I installed Optax using pip on Ubuntu 20.0.4 and Python 3.8.7.
But python can't import Optax.
The detail error-messages are below...

[1st error]

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 import optax

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/__init__.py:17, in <module>
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
     18 from optax import losses
     19 from optax import monte_carlo

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/__init__.py:21, in <module>
     19 from optax.contrib.complex_valued import split_real_and_imaginary
     20 from optax.contrib.complex_valued import SplitRealAndImaginaryState
---> 21 from optax.contrib.dadapt_adamw import dadapt_adamw
     22 from optax.contrib.dadapt_adamw import DAdaptAdamWState
     23 from optax.contrib.mechanic import MechanicState

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/dadapt_adamw.py:45, in <module>
     39   numerator_weighted: chex.Array  # shape=(), dtype=jnp.float32.
     40   count: chex.Array  # shape=(), dtype=jnp.int32.
     43 def dadapt_adamw(
     44     learning_rate: base.ScalarOrSchedule = 1.0,
---> 45     betas: tuple[float, float] = (0.9, 0.999),
     46     eps: float = 1e-8,
     47     estim_lr0: float = 1e-6,
     48     weight_decay: float = 0.,
     49 ) -> base.GradientTransformation:
     50   """Learning rate free AdamW by D-Adaptation.
51
     52   Adapts the baseline learning rate of AdamW automatically by estimating the
   (...)
     69     A `GradientTransformation` object.
     70   """
     72   def init_fn(params: base.Params) -> DAdaptAdamWState:

TypeError: 'type' object is not subscriptable

[2nd error]

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [2], in <cell line: 1>()
----> 1 import optax

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/__init__.py:17, in <module>
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
     18 from optax import losses
     19 from optax import monte_carlo

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/__init__.py:25, in <module>
     23 from optax.contrib.mechanic import MechanicState
     24 from optax.contrib.mechanic import mechanize
---> 25 from optax.contrib.privacy import differentially_private_aggregate
     26 from optax.contrib.privacy import DifferentiallyPrivateAggregateState
     27 from optax.contrib.privacy import dpsgd

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/privacy.py:23, in <module>
     21 from optax._src import base
     22 from optax._src import clipping
---> 23 from optax._src import combine
     24 from optax._src import transform
     27 # pylint:disable=no-value-for-parameter

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/_src/combine.py:72, in <module>
     63   # We opt to always return the GradientTransformationExtraArgs type here,
     64   # instead of selecting the return type based on the arguments, since it works
     65   # much better with the currently available type checkers. It also means that
     66   # users will not get unexpected signature errors if they remove all of the
     67   # transformations in a chain accepting extra args.
     68   return base.GradientTransformationExtraArgs(init_fn, update_fn)
     71 def named_chain(
---> 72     *transforms: tuple[str, base.GradientTransformation]
     73 ) -> base.GradientTransformationExtraArgs:
     74   """Chains optax gradient transformations.
75
     76   The `transforms` are `(name, transformation)` pairs, constituted of a string
   (...)
    102     A single (init_fn, update_fn) tuple. 
    103   """
    105   names = [name for name, _ in transforms]

TypeError: 'type' object is not subscriptable

[3rd error]

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [3], in <cell line: 1>()
----> 1 import optax

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/__init__.py:17, in <module>
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
     18 from optax import losses
     19 from optax import monte_carlo

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/__init__.py:28, in <module>
     26 from optax.contrib.privacy import DifferentiallyPrivateAggregateState
     27 from optax.contrib.privacy import dpsgd
---> 28 from optax.contrib.prodigy import prodigy
     29 from optax.contrib.prodigy import ProdigyState
     30 from optax.contrib.reduce_on_plateau import reduce_on_plateau

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/prodigy.py:48, in <module>
     42   numerator_weighted: chex.Array  # shape=(), dtype=jnp.float32.
     43   count: chex.Array  # shape=(), dtype=int32.
     46 def prodigy(
     47     learning_rate: base.ScalarOrSchedule = 0.1,
---> 48     betas: tuple[float, float] = (0.9, 0.999),
     49     beta3: Optional[float] = None,
     50     eps: float = 1e-8,
     51     estim_lr0: float = 1e-6,
     52     estim_lr_coef: float = 1.0,
     53     weight_decay: float = 0.0,
     54 ) -> base.GradientTransformation:
     55   """Learning rate free AdamW with Prodigy.
56
     57   Implementation of the Prodigy method from "Prodigy: An Expeditiously
   (...)
     78     A `GradientTransformation` object.
     79   """
     80   beta1, beta2 = betas

TypeError: 'type' object is not subscriptable

*Solution:
[1] Open the 3 files. Those location is on your optax package's. Below pathes are in my case.
File1: ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/dadapt_adamw.py
File2: ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/_src/combine.py
File3: ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/prodigy.py

[2] Find some word tuple[ exactly at the 3 files.

[3] Replace tuple[ into Tuple[! Because those code-expression are using from typing!

[4] In File2 (combine.py), please find from typing import Callable, NamedTuple, Union, Mapping, Hashable, and add Tuple at the end of this line.

[5] Go back your .py | .ipynb file, and restart kernel. Then you can import optax.

I hope your code work very well. ^^

Thanks!

Hi @sbb2002 , and thanks for the report. Optax requires python >= 3.9, so I suspect the error might come from there