"dadapt_adamw, combine, prodigy" occur "TypeError: 'type' object is not subscriptable."
sbb2002 opened this issue · 1 comments
Hello. :)
I'm optax
user in South Korea.
I was making some model using flax.
But when I import optax
, it occurs 3 errors!
So I read some msgs, and solve those!
Therefore, I feel to share those tips for other optax
-users, and request to modify those error for optax
dev-team, honestly.
*Case:
I installed Optax
using pip on Ubuntu 20.0.4 and Python 3.8.7.
But python can't import Optax
.
The detail error-messages are below...
[1st error]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 import optax
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/__init__.py:17, in <module>
1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ==============================================================================
15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
18 from optax import losses
19 from optax import monte_carlo
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/__init__.py:21, in <module>
19 from optax.contrib.complex_valued import split_real_and_imaginary
20 from optax.contrib.complex_valued import SplitRealAndImaginaryState
---> 21 from optax.contrib.dadapt_adamw import dadapt_adamw
22 from optax.contrib.dadapt_adamw import DAdaptAdamWState
23 from optax.contrib.mechanic import MechanicState
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/dadapt_adamw.py:45, in <module>
39 numerator_weighted: chex.Array # shape=(), dtype=jnp.float32.
40 count: chex.Array # shape=(), dtype=jnp.int32.
43 def dadapt_adamw(
44 learning_rate: base.ScalarOrSchedule = 1.0,
---> 45 betas: tuple[float, float] = (0.9, 0.999),
46 eps: float = 1e-8,
47 estim_lr0: float = 1e-6,
48 weight_decay: float = 0.,
49 ) -> base.GradientTransformation:
50 """Learning rate free AdamW by D-Adaptation.
51
52 Adapts the baseline learning rate of AdamW automatically by estimating the
(...)
69 A `GradientTransformation` object.
70 """
72 def init_fn(params: base.Params) -> DAdaptAdamWState:
TypeError: 'type' object is not subscriptable
[2nd error]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [2], in <cell line: 1>()
----> 1 import optax
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/__init__.py:17, in <module>
1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ==============================================================================
15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
18 from optax import losses
19 from optax import monte_carlo
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/__init__.py:25, in <module>
23 from optax.contrib.mechanic import MechanicState
24 from optax.contrib.mechanic import mechanize
---> 25 from optax.contrib.privacy import differentially_private_aggregate
26 from optax.contrib.privacy import DifferentiallyPrivateAggregateState
27 from optax.contrib.privacy import dpsgd
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/privacy.py:23, in <module>
21 from optax._src import base
22 from optax._src import clipping
---> 23 from optax._src import combine
24 from optax._src import transform
27 # pylint:disable=no-value-for-parameter
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/_src/combine.py:72, in <module>
63 # We opt to always return the GradientTransformationExtraArgs type here,
64 # instead of selecting the return type based on the arguments, since it works
65 # much better with the currently available type checkers. It also means that
66 # users will not get unexpected signature errors if they remove all of the
67 # transformations in a chain accepting extra args.
68 return base.GradientTransformationExtraArgs(init_fn, update_fn)
71 def named_chain(
---> 72 *transforms: tuple[str, base.GradientTransformation]
73 ) -> base.GradientTransformationExtraArgs:
74 """Chains optax gradient transformations.
75
76 The `transforms` are `(name, transformation)` pairs, constituted of a string
(...)
102 A single (init_fn, update_fn) tuple.
103 """
105 names = [name for name, _ in transforms]
TypeError: 'type' object is not subscriptable
[3rd error]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [3], in <cell line: 1>()
----> 1 import optax
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/__init__.py:17, in <module>
1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ==============================================================================
15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
18 from optax import losses
19 from optax import monte_carlo
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/__init__.py:28, in <module>
26 from optax.contrib.privacy import DifferentiallyPrivateAggregateState
27 from optax.contrib.privacy import dpsgd
---> 28 from optax.contrib.prodigy import prodigy
29 from optax.contrib.prodigy import ProdigyState
30 from optax.contrib.reduce_on_plateau import reduce_on_plateau
File ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/prodigy.py:48, in <module>
42 numerator_weighted: chex.Array # shape=(), dtype=jnp.float32.
43 count: chex.Array # shape=(), dtype=int32.
46 def prodigy(
47 learning_rate: base.ScalarOrSchedule = 0.1,
---> 48 betas: tuple[float, float] = (0.9, 0.999),
49 beta3: Optional[float] = None,
50 eps: float = 1e-8,
51 estim_lr0: float = 1e-6,
52 estim_lr_coef: float = 1.0,
53 weight_decay: float = 0.0,
54 ) -> base.GradientTransformation:
55 """Learning rate free AdamW with Prodigy.
56
57 Implementation of the Prodigy method from "Prodigy: An Expeditiously
(...)
78 A `GradientTransformation` object.
79 """
80 beta1, beta2 = betas
TypeError: 'type' object is not subscriptable
*Solution:
[1] Open the 3 files. Those location is on your optax
package's. Below pathes are in my case.
File1: ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/dadapt_adamw.py
File2: ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/_src/combine.py
File3: ~/anaconda3/envs/MyEnv/lib/python3.8/site-packages/optax/contrib/prodigy.py
[2] Find some word tuple[
exactly at the 3 files.
[3] Replace tuple[
into Tuple[
! Because those code-expression are using from typing
!
[4] In File2 (combine.py), please find from typing import Callable, NamedTuple, Union, Mapping, Hashable
, and add Tuple
at the end of this line.
[5] Go back your .py | .ipynb file, and restart kernel. Then you can import optax
.
I hope your code work very well. ^^
Thanks!