Google Colab Error: optax is throwing an attribute error.
prajjwalgeek opened this issue · 2 comments
prajjwalgeek commented
Optax Throws attribute error when using the attached Google Colab Inference Demo
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[<ipython-input-14-a22d9a83aa66>](https://localhost:8080/#) in <module>()
4 from jax.experimental import maps
5 import numpy as np
----> 6 import optax
7 import transformers
8
6 frames
[/usr/local/lib/python3.7/dist-packages/optax/__init__.py](https://localhost:8080/#) in <module>()
15 """Optax: composable gradient processing and optimization, in JAX."""
16
---> 17 from optax._src.alias import adabelief
18 from optax._src.alias import adafactor
19 from optax._src.alias import adagrad
[/usr/local/lib/python3.7/dist-packages/optax/_src/alias.py](https://localhost:8080/#) in <module>()
19 import jax.numpy as jnp
20
---> 21 from optax._src import base
22 from optax._src import clipping
23 from optax._src import combine
[/usr/local/lib/python3.7/dist-packages/optax/_src/base.py](https://localhost:8080/#) in <module>()
16
17 from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
---> 18 import chex
19
20 # pylint:disable=no-value-for-parameter
[/usr/local/lib/python3.7/dist-packages/chex/__init__.py](https://localhost:8080/#) in <module>()
15 """Chex: Testing made fun, in JAX!"""
16
---> 17 from chex._src.asserts import assert_axis_dimension
18 from chex._src.asserts import assert_axis_dimension_comparator
19 from chex._src.asserts import assert_axis_dimension_gt
[/usr/local/lib/python3.7/dist-packages/chex/_src/asserts.py](https://localhost:8080/#) in <module>()
24 from unittest import mock
25
---> 26 from chex._src import asserts_internal as _ai
27 from chex._src import pytypes
28 import jax
[/usr/local/lib/python3.7/dist-packages/chex/_src/asserts_internal.py](https://localhost:8080/#) in <module>()
30
31 from absl import logging
---> 32 from chex._src import pytypes
33 import jax
34 import jax.numpy as jnp
[/usr/local/lib/python3.7/dist-packages/chex/_src/pytypes.py](https://localhost:8080/#) in <module>()
34 Scalar = Union[float, int]
35 Numeric = Union[Array, Scalar]
---> 36 PRNGKey = jax.random.KeyArray
37 PyTreeDef = type(jax.tree_structure(None))
38 Shape = jax.core.Shape
AttributeError: module 'jax.random' has no attribute 'KeyArray'
neverix commented
I think I finally figured it out.
-
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0 chex==0.0.6 jaxlib==0.3.7
#@title Patch 1
%%file /usr/local/lib/python3.7/dist-packages/chex/_src/pytypes.py
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pytypes for arrays and scalars."""
from typing import Any, Iterable, Mapping, Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np
Array = jnp.ndarray
ArrayBatched = jax.interpreters.batching.BatchTracer
ArrayNumpy = np.ndarray
ArraySharded = jax.interpreters.pxla.ShardedDeviceArray
# Use this type for type annotation. For instance checking, use
# `isinstance(x, jax.DeviceArray)`.
# `jax.interpreters.xla._DeviceArray` appears in jax > 0.2.5
if hasattr(jax.interpreters.xla, '_DeviceArray'):
ArrayDevice = jax.interpreters.xla._DeviceArray # pylint:disable=protected-access
else:
ArrayDevice = jax.interpreters.xla.DeviceArray
Scalar = Union[float, int]
Numeric = Union[Array, Scalar]
PRNGKey = Array
Shape = Tuple[int, ...]
# CpuDevice = jax.lib.xla_extension.CpuDevice
GpuDevice = jax.lib.xla_extension.GpuDevice
TpuDevice = jax.lib.xla_extension.TpuDevice
Device = Union[GpuDevice, TpuDevice]
# As of 06/2020 pytype doesn't support recursive types (see b/109648354)
# pytype: disable=not-supported-yet
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
#@title Patch 2
%%file /usr/local/lib/python3.7/dist-packages/chex/__init__.py
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Chex: Testing made fun, in JAX!"""
from chex._src.asserts import assert_axis_dimension
from chex._src.asserts import assert_axis_dimension_gt
from chex._src.asserts import assert_devices_available
from chex._src.asserts import assert_equal
from chex._src.asserts import assert_equal_rank
from chex._src.asserts import assert_equal_shape
from chex._src.asserts import assert_equal_shape_prefix
from chex._src.asserts import assert_equal_shape_suffix
from chex._src.asserts import assert_exactly_one_is_none
from chex._src.asserts import assert_gpu_available
from chex._src.asserts import assert_is_broadcastable
from chex._src.asserts import assert_max_traces
from chex._src.asserts import assert_not_both_none
from chex._src.asserts import assert_numerical_grads
from chex._src.asserts import assert_rank
from chex._src.asserts import assert_scalar
from chex._src.asserts import assert_scalar_in
from chex._src.asserts import assert_scalar_negative
from chex._src.asserts import assert_scalar_non_negative
from chex._src.asserts import assert_scalar_positive
from chex._src.asserts import assert_shape
from chex._src.asserts import assert_tpu_available
from chex._src.asserts import assert_tree_all_close
from chex._src.asserts import assert_tree_all_equal_comparator
from chex._src.asserts import assert_tree_all_equal_shapes
from chex._src.asserts import assert_tree_all_equal_structs
from chex._src.asserts import assert_tree_all_finite
from chex._src.asserts import assert_tree_no_nones
from chex._src.asserts import assert_tree_shape_prefix
from chex._src.asserts import assert_type
from chex._src.asserts import clear_trace_counter
from chex._src.asserts import if_args_not_none
from chex._src.dataclass import dataclass
from chex._src.dataclass import mappable_dataclass
from chex._src.fake import fake_jit
from chex._src.fake import fake_pmap
from chex._src.fake import fake_pmap_and_jit
from chex._src.fake import set_n_cpu_devices
from chex._src.pytypes import Array
from chex._src.pytypes import ArrayBatched
from chex._src.pytypes import ArrayDevice
from chex._src.pytypes import ArrayNumpy
from chex._src.pytypes import ArraySharded
from chex._src.pytypes import ArrayTree
# from chex._src.pytypes import CpuDevice
from chex._src.pytypes import Device
from chex._src.pytypes import GpuDevice
from chex._src.pytypes import Numeric
from chex._src.pytypes import PRNGKey
from chex._src.pytypes import Scalar
from chex._src.pytypes import Shape
from chex._src.pytypes import TpuDevice
from chex._src.variants import all_variants
from chex._src.variants import ChexVariantType
from chex._src.variants import params_product
from chex._src.variants import TestCase
from chex._src.variants import variants
__version__ = "0.0.6"
__all__ = (
"all_variants",
"Array",
"ArrayBatched",
"ArrayDevice",
"ArrayNumpy",
"ArraySharded",
"ArrayTree",
"assert_axis_dimension",
"assert_axis_dimension_gt",
"assert_devices_available",
"assert_equal",
"assert_equal_rank",
"assert_equal_shape",
"assert_equal_shape_prefix",
"assert_equal_shape_suffix",
"assert_exactly_one_is_none",
"assert_gpu_available",
"assert_is_broadcastable",
"assert_max_traces",
"assert_not_both_none",
"assert_numerical_grads",
"assert_rank",
"assert_scalar",
"assert_scalar_in",
"assert_scalar_negative",
"assert_scalar_non_negative",
"assert_scalar_positive",
"assert_shape",
"assert_tpu_available",
"assert_tree_all_close",
"assert_tree_all_equal_comparator",
"assert_tree_all_equal_shapes",
"assert_tree_all_equal_structs",
"assert_tree_all_finite",
"assert_tree_no_nones",
"assert_tree_shape_prefix",
"assert_type",
"ChexVariantType",
"clear_trace_counter",
"CpuDevice",
"dataclass",
"Device",
"fake_jit",
"fake_pmap",
"fake_pmap_and_jit",
"GpuDevice",
"if_args_not_none",
"mappable_dataclass",
"Numeric",
"params_product",
"PRNGKey",
"Scalar",
"set_n_cpu_devices",
"Shape",
"TestCase",
"TpuDevice",
"variants",
)
With these few patches, it seems to work on Colab's TPU.