Error when using Distrax: `Subscripted generics cannot be used with class and instance checks`
Closed this issue · 3 comments
mhinne commented
Hey there,
I'm trying to use Distrax together with e.g. GPJax and BlackJax. However, every call to a Distrax distribution leads to the following error: Subscripted generics cannot be used with class and instance checks
.
Here is a MWE, run on Colab:
!pip install distrax
import distrax as dx
import jax.random as jrnd
key = jrnd.PRNGKey(42)
normal = dx.Normal(loc=0., scale=1.0).sample(seed=key)
print(normal)
which returns:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-25-753b0aebd6e0>](https://localhost:8080/#) in <module>
1 key = jrnd.PRNGKey(42)
2
----> 3 normal = dx.Normal(loc=0., scale=1.0).sample(seed=key)
4 print(normal)
3 frames
[/usr/lib/python3.8/typing.py](https://localhost:8080/#) in __subclasscheck__(self, cls)
775 if cls._special:
776 return issubclass(cls.__origin__, self.__origin__)
--> 777 raise TypeError("Subscripted generics cannot be used with"
778 " class and instance checks")
779
TypeError: Subscripted generics cannot be used with class and instance checks
Is there a way to avoid this error?
Thanks!
hbq1 commented
Hi @mhinne, thanks for raising this issue. What versions of python and libraries do you use? Could you also share a full stack trace for this error?
mhinne commented
Thanks for the prompt reply! Here they are:
The versions:
Last updated: Wed Feb 08 2023
Python implementation: CPython
Python version : 3.8.10
IPython version : 7.9.0
numpy : 1.21.6
matplotlib: 3.2.2
distrax : 0.1.2
jax : 0.3.25
Watermark: 2.3.1
And the stack trace:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-35-753b0aebd6e0>](https://localhost:8080/#) in <module>
1 key = jrnd.PRNGKey(42)
2
----> 3 normal = dx.Normal(loc=0., scale=1.0).sample(seed=key)
4 print(normal)
3 frames
[/usr/local/lib/python3.8/dist-packages/distrax/_src/distributions/normal.py](https://localhost:8080/#) in __init__(self, loc, scale)
47 """
48 super().__init__()
---> 49 self._loc = conversion.as_float_array(loc)
50 self._scale = conversion.as_float_array(scale)
51 self._batch_shape = jax.lax.broadcast_shapes(
[/usr/local/lib/python3.8/dist-packages/distrax/_src/utils/conversion.py](https://localhost:8080/#) in as_float_array(x)
136 An array with floating-point dtype.
137 """
--> 138 if not isinstance(x, Array):
139 x = jnp.asarray(x)
140 if jnp.issubdtype(x.dtype, jnp.floating):
[/usr/lib/python3.8/typing.py](https://localhost:8080/#) in __instancecheck__(self, obj)
767
768 def __instancecheck__(self, obj):
--> 769 return self.__subclasscheck__(type(obj))
770
771 def __subclasscheck__(self, cls):
[/usr/lib/python3.8/typing.py](https://localhost:8080/#) in __subclasscheck__(self, cls)
775 if cls._special:
776 return issubclass(cls.__origin__, self.__origin__)
--> 777 raise TypeError("Subscripted generics cannot be used with"
778 " class and instance checks")
779
TypeError: Subscripted generics cannot be used with class and instance checks
hbq1 commented
Thanks for the additional info! I believe it should be fixed now.