google-deepmind/distrax

Error when using Distrax: `Subscripted generics cannot be used with class and instance checks`

Closed this issue · 3 comments

Hey there,

I'm trying to use Distrax together with e.g. GPJax and BlackJax. However, every call to a Distrax distribution leads to the following error: Subscripted generics cannot be used with class and instance checks.

Here is a MWE, run on Colab:

!pip install distrax
import distrax as dx
import jax.random as jrnd

key = jrnd.PRNGKey(42)

normal = dx.Normal(loc=0., scale=1.0).sample(seed=key)
print(normal)

which returns:

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

[<ipython-input-25-753b0aebd6e0>](https://localhost:8080/#) in <module>
      1 key = jrnd.PRNGKey(42)
      2 
----> 3 normal = dx.Normal(loc=0., scale=1.0).sample(seed=key)
      4 print(normal)

3 frames

[/usr/lib/python3.8/typing.py](https://localhost:8080/#) in __subclasscheck__(self, cls)
    775             if cls._special:
    776                 return issubclass(cls.__origin__, self.__origin__)
--> 777         raise TypeError("Subscripted generics cannot be used with"
    778                         " class and instance checks")
    779 

TypeError: Subscripted generics cannot be used with class and instance checks

Is there a way to avoid this error?

Thanks!

hbq1 commented

Hi @mhinne, thanks for raising this issue. What versions of python and libraries do you use? Could you also share a full stack trace for this error?

Thanks for the prompt reply! Here they are:

The versions:

Last updated: Wed Feb 08 2023

Python implementation: CPython
Python version       : 3.8.10
IPython version      : 7.9.0

numpy     : 1.21.6
matplotlib: 3.2.2
distrax   : 0.1.2
jax       : 0.3.25

Watermark: 2.3.1

And the stack trace:

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

[<ipython-input-35-753b0aebd6e0>](https://localhost:8080/#) in <module>
      1 key = jrnd.PRNGKey(42)
      2 
----> 3 normal = dx.Normal(loc=0., scale=1.0).sample(seed=key)
      4 print(normal)

3 frames

[/usr/local/lib/python3.8/dist-packages/distrax/_src/distributions/normal.py](https://localhost:8080/#) in __init__(self, loc, scale)
     47     """
     48     super().__init__()
---> 49     self._loc = conversion.as_float_array(loc)
     50     self._scale = conversion.as_float_array(scale)
     51     self._batch_shape = jax.lax.broadcast_shapes(

[/usr/local/lib/python3.8/dist-packages/distrax/_src/utils/conversion.py](https://localhost:8080/#) in as_float_array(x)
    136     An array with floating-point dtype.
    137   """
--> 138   if not isinstance(x, Array):
    139     x = jnp.asarray(x)
    140   if jnp.issubdtype(x.dtype, jnp.floating):

[/usr/lib/python3.8/typing.py](https://localhost:8080/#) in __instancecheck__(self, obj)
    767 
    768     def __instancecheck__(self, obj):
--> 769         return self.__subclasscheck__(type(obj))
    770 
    771     def __subclasscheck__(self, cls):

[/usr/lib/python3.8/typing.py](https://localhost:8080/#) in __subclasscheck__(self, cls)
    775             if cls._special:
    776                 return issubclass(cls.__origin__, self.__origin__)
--> 777         raise TypeError("Subscripted generics cannot be used with"
    778                         " class and instance checks")
    779 

TypeError: Subscripted generics cannot be used with class and instance checks
hbq1 commented

Thanks for the additional info! I believe it should be fixed now.