`jaxtyped` Annotation fails
dxm447 opened this issue · 1 comments
dxm447 commented
I am building a code using cupy
, and jaxtyping
for type-hinting to calculate the Laplacian of Gaussian of a function. Here is my code:
import cupy as cp
import numpy as np
import cupyx.scipy.ndimage as csnd
import cucim.skimage.exposure as cexpose
import cupyx.scipy.signal as csisig
from typing import Tuple
from jaxtyping import Float, jaxtyped
from beartype import beartype as typechecker
@jaxtyped(typechecker=typechecker)
def laplacian_gaussian(
image: Float[cp.ndarray, "dim1 dim2"],
standard_deviation: int = 3,
hist_stretch: bool = True,
sampling: Float = 1,
) -> Tuple[
Float[cp.ndarray, "dim3 dim4"],
Float[cp.ndarray, "dim3 dim4"],
]:
image: Float[cp.ndarray, "dim1 dim2"] = cp.asarray(image.astype(cp.float64))
if sampling != 1:
sampled_image: Float[cp.ndarray, "dim3 dim4"] = csnd.zoom(image, sampling)
else:
sampled_image: Float[cp.ndarray, "dim3 dim4"] = cp.copy(image)
if hist_stretch:
sampled_image: Float[cp.ndarray, "dim3 dim4"] = cexpose.equalize_hist(
sampled_image
)
gauss_image: Float[cp.ndarray, "dim3 dim4"] = csnd.gaussian_filter(
sampled_image, standard_deviation
)
positive_laplacian: Float[cp.ndarray, "3 3"] = cp.asarray(
(
(0.0, 1.0, 0.0),
(1.0, -4.0, 1.0),
(0.0, 1.0, 0.0),
),
dtype=np.float64,
)
negative_laplacian: Float[cp.ndarray, "3 3"] = cp.asarray(
(
(0.0, -1.0, 0.0),
(-1.0, 4.0, -1.0),
(0.0, -1.0, 0.0),
),
dtype=np.float64,
)
positive_filtered: Float[cp.ndarray, "dim3 dim4"] = csisig.convolve2d(
gauss_image, positive_laplacian, mode="same", boundary="symm", fillvalue=0
)
negative_filtered: Float[cp.ndarray, "dim3 dim4"] = csisig.convolve2d(
gauss_image, negative_laplacian, mode="same", boundary="symm", fillvalue=0
)
return (positive_filtered, negative_filtered)
Calling this raises the following error:
AnnotationError: Do not use isinstance(x, jaxtyping.Float)
. If you want to check just the dtype of an array, then use jaxtyping.Float[jnp.ndarray, "..."]
.
The error is from:
File ~/anaconda3/envs/arm/lib/python3.10/site-packages/jaxtyping/_array_types.py:561, in _MetaAbstractDtype.instancecheck(cls, obj)
560 def __instancecheck__(cls, obj: Any) -> NoReturn:
--> 561 raise AnnotationError(
562 f"Do not use `isinstance(x, jaxtyping.{cls.__name__})`. If you want to "
563 "check just the dtype of an array, then use "
564 f'`jaxtyping.{cls.__name__}[jnp.ndarray, "..."]`.'
565 )
patrick-kidger commented
Closing as the error message is informative.