
"dadapt_adamw, combine, prodigy" occur "TypeError: 'type' object is not subscriptable."

I installed Optax using pip on Ubuntu 20.0.4 and Python 3.8.7.
TypeError                                 Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 import optax

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/, in <module>
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
     18 from optax import losses
     19 from optax import monte_carlo

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/, in <module>
     19 from optax.contrib.complex_valued import split_real_and_imaginary
     20 from optax.contrib.complex_valued import SplitRealAndImaginaryState
---> 21 from optax.contrib.dadapt_adamw import dadapt_adamw
     22 from optax.contrib.dadapt_adamw import DAdaptAdamWState
     23 from optax.contrib.mechanic import MechanicState

File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/, in <module>
     39   numerator_weighted: chex.Array  # shape=(), dtype=jnp.float32.
     40   count: chex.Array  # shape=(), dtype=jnp.int32.
     43 def dadapt_adamw(
     44     learning_rate: base.ScalarOrSchedule = 1.0,
---> 45     betas: tuple[float, float] = (0.9, 0.999),
     46     eps: float = 1e-8,
     47     estim_lr0: float = 1e-6,
     48     weight_decay: float = 0.,
     49 ) -> base.GradientTransformation:
     50   """Learning rate free AdamW by D-Adaptation.
     52   Adapts the baseline learning rate of AdamW automatically by estimating the
     69     A `GradientTransformation` object.
     70   """
     72   def init_fn(params: base.Params) -> DAdaptAdamWState:

TypeError: 'type' object is not subscriptable

[1] Open the 3 files. Those location is on your optax package's. Below pathes are in my case.
File1: ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/
File2: ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/_src/
File3: ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/

[2] Find some word tuple[ exactly at the 3 files.

[3] Replace tuple[ into Tuple[! Because those code-expression are using from typing!

[4] In File2 (, please find from typing import Callable, NamedTuple, Union, Mapping, Hashable, and add Tuple at the end of this line.

[5] Go back your .py | .ipynb file, and restart kernel. Then you can import optax.

I hope your code work very well. ^^


Hi @sbb2002 , and thanks for the report. Optax requires python >= 3.9, so I suspect the error might come from there