patrick-kidger/jaxtyping

`jaxtyped` Annotation fails

dxm447 opened this issue · 1 comments

I am building a code using cupy, and jaxtyping for type-hinting to calculate the Laplacian of Gaussian of a function. Here is my code:

import cupy as cp
import numpy as np
import cupyx.scipy.ndimage as csnd
import cucim.skimage.exposure as cexpose
import cupyx.scipy.signal as csisig
from typing import Tuple
from jaxtyping import Float, jaxtyped
from beartype import beartype as typechecker


@jaxtyped(typechecker=typechecker)
def laplacian_gaussian(
    image: Float[cp.ndarray, "dim1 dim2"],
    standard_deviation: int = 3,
    hist_stretch: bool = True,
    sampling: Float = 1,
) -> Tuple[
    Float[cp.ndarray, "dim3 dim4"],
    Float[cp.ndarray, "dim3 dim4"],
]:
    image: Float[cp.ndarray, "dim1 dim2"] = cp.asarray(image.astype(cp.float64))
    if sampling != 1:
        sampled_image: Float[cp.ndarray, "dim3 dim4"] = csnd.zoom(image, sampling)
    else:
        sampled_image: Float[cp.ndarray, "dim3 dim4"] = cp.copy(image)
    if hist_stretch:
        sampled_image: Float[cp.ndarray, "dim3 dim4"] = cexpose.equalize_hist(
            sampled_image
        )
    gauss_image: Float[cp.ndarray, "dim3 dim4"] = csnd.gaussian_filter(
        sampled_image, standard_deviation
    )
    positive_laplacian: Float[cp.ndarray, "3 3"] = cp.asarray(
        (
            (0.0, 1.0, 0.0),
            (1.0, -4.0, 1.0),
            (0.0, 1.0, 0.0),
        ),
        dtype=np.float64,
    )
    negative_laplacian: Float[cp.ndarray, "3 3"] = cp.asarray(
        (
            (0.0, -1.0, 0.0),
            (-1.0, 4.0, -1.0),
            (0.0, -1.0, 0.0),
        ),
        dtype=np.float64,
    )
    positive_filtered: Float[cp.ndarray, "dim3 dim4"] = csisig.convolve2d(
        gauss_image, positive_laplacian, mode="same", boundary="symm", fillvalue=0
    )
    negative_filtered: Float[cp.ndarray, "dim3 dim4"] = csisig.convolve2d(
        gauss_image, negative_laplacian, mode="same", boundary="symm", fillvalue=0
    )
    return (positive_filtered, negative_filtered)

Calling this raises the following error:
AnnotationError: Do not use isinstance(x, jaxtyping.Float). If you want to check just the dtype of an array, then use jaxtyping.Float[jnp.ndarray, "..."].

The error is from:

File ~/anaconda3/envs/arm/lib/python3.10/site-packages/jaxtyping/_array_types.py:561, in _MetaAbstractDtype.instancecheck(cls, obj)

    560 def __instancecheck__(cls, obj: Any) -> NoReturn:
--> 561     raise AnnotationError(
    562         f"Do not use `isinstance(x, jaxtyping.{cls.__name__})`. If you want to "
    563         "check just the dtype of an array, then use "
    564         f'`jaxtyping.{cls.__name__}[jnp.ndarray, "..."]`.'
    565     )

Closing as the error message is informative.