kingoflolz/mesh-transformer-jax

Google Colab Error: optax is throwing an attribute error.

prajjwalgeek opened this issue · 2 comments

Optax Throws attribute error when using the attached Google Colab Inference Demo

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-14-a22d9a83aa66>](https://localhost:8080/#) in <module>()
      4 from jax.experimental import maps
      5 import numpy as np
----> 6 import optax
      7 import transformers
      8 

6 frames
[/usr/local/lib/python3.7/dist-packages/optax/__init__.py](https://localhost:8080/#) in <module>()
     15 """Optax: composable gradient processing and optimization, in JAX."""
     16 
---> 17 from optax._src.alias import adabelief
     18 from optax._src.alias import adafactor
     19 from optax._src.alias import adagrad

[/usr/local/lib/python3.7/dist-packages/optax/_src/alias.py](https://localhost:8080/#) in <module>()
     19 import jax.numpy as jnp
     20 
---> 21 from optax._src import base
     22 from optax._src import clipping
     23 from optax._src import combine

[/usr/local/lib/python3.7/dist-packages/optax/_src/base.py](https://localhost:8080/#) in <module>()
     16 
     17 from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
---> 18 import chex
     19 
     20 # pylint:disable=no-value-for-parameter

[/usr/local/lib/python3.7/dist-packages/chex/__init__.py](https://localhost:8080/#) in <module>()
     15 """Chex: Testing made fun, in JAX!"""
     16 
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

[/usr/local/lib/python3.7/dist-packages/chex/_src/asserts.py](https://localhost:8080/#) in <module>()
     24 from unittest import mock
     25 
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

[/usr/local/lib/python3.7/dist-packages/chex/_src/asserts_internal.py](https://localhost:8080/#) in <module>()
     30 
     31 from absl import logging
---> 32 from chex._src import pytypes
     33 import jax
     34 import jax.numpy as jnp

[/usr/local/lib/python3.7/dist-packages/chex/_src/pytypes.py](https://localhost:8080/#) in <module>()
     34 Scalar = Union[float, int]
     35 Numeric = Union[Array, Scalar]
---> 36 PRNGKey = jax.random.KeyArray
     37 PyTreeDef = type(jax.tree_structure(None))
     38 Shape = jax.core.Shape

AttributeError: module 'jax.random' has no attribute 'KeyArray'

I think I finally figured it out.

  1. !pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0 chex==0.0.6 jaxlib==0.3.7

#@title Patch 1
%%file /usr/local/lib/python3.7/dist-packages/chex/_src/pytypes.py
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pytypes for arrays and scalars."""

from typing import Any, Iterable, Mapping, Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np

Array = jnp.ndarray
ArrayBatched = jax.interpreters.batching.BatchTracer
ArrayNumpy = np.ndarray
ArraySharded = jax.interpreters.pxla.ShardedDeviceArray
# Use this type for type annotation. For instance checking,  use
# `isinstance(x, jax.DeviceArray)`.
# `jax.interpreters.xla._DeviceArray` appears in jax > 0.2.5
if hasattr(jax.interpreters.xla, '_DeviceArray'):
  ArrayDevice = jax.interpreters.xla._DeviceArray  # pylint:disable=protected-access
else:
  ArrayDevice = jax.interpreters.xla.DeviceArray

Scalar = Union[float, int]
Numeric = Union[Array, Scalar]
PRNGKey = Array
Shape = Tuple[int, ...]

# CpuDevice = jax.lib.xla_extension.CpuDevice
GpuDevice = jax.lib.xla_extension.GpuDevice
TpuDevice = jax.lib.xla_extension.TpuDevice
Device = Union[GpuDevice, TpuDevice]

# As of 06/2020 pytype doesn't support recursive types (see b/109648354)
# pytype: disable=not-supported-yet
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
#@title Patch 2
%%file /usr/local/lib/python3.7/dist-packages/chex/__init__.py
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Chex: Testing made fun, in JAX!"""

from chex._src.asserts import assert_axis_dimension
from chex._src.asserts import assert_axis_dimension_gt
from chex._src.asserts import assert_devices_available
from chex._src.asserts import assert_equal
from chex._src.asserts import assert_equal_rank
from chex._src.asserts import assert_equal_shape
from chex._src.asserts import assert_equal_shape_prefix
from chex._src.asserts import assert_equal_shape_suffix
from chex._src.asserts import assert_exactly_one_is_none
from chex._src.asserts import assert_gpu_available
from chex._src.asserts import assert_is_broadcastable
from chex._src.asserts import assert_max_traces
from chex._src.asserts import assert_not_both_none
from chex._src.asserts import assert_numerical_grads
from chex._src.asserts import assert_rank
from chex._src.asserts import assert_scalar
from chex._src.asserts import assert_scalar_in
from chex._src.asserts import assert_scalar_negative
from chex._src.asserts import assert_scalar_non_negative
from chex._src.asserts import assert_scalar_positive
from chex._src.asserts import assert_shape
from chex._src.asserts import assert_tpu_available
from chex._src.asserts import assert_tree_all_close
from chex._src.asserts import assert_tree_all_equal_comparator
from chex._src.asserts import assert_tree_all_equal_shapes
from chex._src.asserts import assert_tree_all_equal_structs
from chex._src.asserts import assert_tree_all_finite
from chex._src.asserts import assert_tree_no_nones
from chex._src.asserts import assert_tree_shape_prefix
from chex._src.asserts import assert_type
from chex._src.asserts import clear_trace_counter
from chex._src.asserts import if_args_not_none
from chex._src.dataclass import dataclass
from chex._src.dataclass import mappable_dataclass
from chex._src.fake import fake_jit
from chex._src.fake import fake_pmap
from chex._src.fake import fake_pmap_and_jit
from chex._src.fake import set_n_cpu_devices
from chex._src.pytypes import Array
from chex._src.pytypes import ArrayBatched
from chex._src.pytypes import ArrayDevice
from chex._src.pytypes import ArrayNumpy
from chex._src.pytypes import ArraySharded
from chex._src.pytypes import ArrayTree
# from chex._src.pytypes import CpuDevice
from chex._src.pytypes import Device
from chex._src.pytypes import GpuDevice
from chex._src.pytypes import Numeric
from chex._src.pytypes import PRNGKey
from chex._src.pytypes import Scalar
from chex._src.pytypes import Shape
from chex._src.pytypes import TpuDevice
from chex._src.variants import all_variants
from chex._src.variants import ChexVariantType
from chex._src.variants import params_product
from chex._src.variants import TestCase
from chex._src.variants import variants


__version__ = "0.0.6"

__all__ = (
    "all_variants",
    "Array",
    "ArrayBatched",
    "ArrayDevice",
    "ArrayNumpy",
    "ArraySharded",
    "ArrayTree",
    "assert_axis_dimension",
    "assert_axis_dimension_gt",
    "assert_devices_available",
    "assert_equal",
    "assert_equal_rank",
    "assert_equal_shape",
    "assert_equal_shape_prefix",
    "assert_equal_shape_suffix",
    "assert_exactly_one_is_none",
    "assert_gpu_available",
    "assert_is_broadcastable",
    "assert_max_traces",
    "assert_not_both_none",
    "assert_numerical_grads",
    "assert_rank",
    "assert_scalar",
    "assert_scalar_in",
    "assert_scalar_negative",
    "assert_scalar_non_negative",
    "assert_scalar_positive",
    "assert_shape",
    "assert_tpu_available",
    "assert_tree_all_close",
    "assert_tree_all_equal_comparator",
    "assert_tree_all_equal_shapes",
    "assert_tree_all_equal_structs",
    "assert_tree_all_finite",
    "assert_tree_no_nones",
    "assert_tree_shape_prefix",
    "assert_type",
    "ChexVariantType",
    "clear_trace_counter",
    "CpuDevice",
    "dataclass",
    "Device",
    "fake_jit",
    "fake_pmap",
    "fake_pmap_and_jit",
    "GpuDevice",
    "if_args_not_none",
    "mappable_dataclass",
    "Numeric",
    "params_product",
    "PRNGKey",
    "Scalar",
    "set_n_cpu_devices",
    "Shape",
    "TestCase",
    "TpuDevice",
    "variants",
)

With these few patches, it seems to work on Colab's TPU.