roboflow/template-python

[Repo] Should we use types and check with mypi?

FrancescoSaverioZuppichini opened this issue · 3 comments

We can decide if it makes sense for us to force ourself to use types hints and check them with mypi

I don't have super strong opinions about this but sometimes it feels like they can get in the way of rapid development. And since they aren't enforced in any way, they are essentially just code comments which can be hard to keep up to date when used extensively.

On the flip side, I can see how it makes code a lot easier to read.

I feel the same. One thing that is kinda of not so nice is when you are using packages that are not mypi friendly, then you have to tell him to not check them etc.

Probably we can agree upon using them for function/methods parameters/returns to make it easier sharing code between us.

Looking forward to hear @Jacobsolawetz opinion.

Hi,

I jumped to this issue while reading the contributors' guidelines. I have been using jaxtyping with beartype for my codebase, which has the following advantages:

  • Supports NumPy, Torch, JAX.
  • Wide variety of enforcing checks to replace manual assertions.
  • Both type and shape are mentioned, making it instantly clear what the input and output shape and types should be.

Disadvantages:

  • Using @jaxtyped(typechecked=beartype) for every function. Some tricks could potentially avoid it.
  • I am not sure about the impact on the speed.

Minimum working example

Installations

pip install jaxtyping beartype

Example

from numpy import ndarray
from beartype import beartype
from jaxtyping import Float, Int, jaxtyped

# Current evaluate_detection_batch function in supervision
def evaluate_detection_batch(
        predictions: np.ndarray,
        targets: np.ndarray,
        num_classes: int,
        conf_threshold: float,
        iou_threshold: float,
    ) -> np.ndarray:
    ...

# How it'd look like after applying the type+shape checker
@jaxtyped(typechecker=beartype)
def evaluate_detection_batch(
        predictions: Float[ndarray, "n 6"],
        targets: Float[ndarray, "m 5"],
        num_classes: int,
        conf_threshold: float,
        iou_threshold: float,
    ) -> Float[ndarray, "{num_classes}+1 {num_classes}+1"]:
    ...

# Another indicative example
def iou(true_bbox: Float[ndarray, "n 4"], pred_bbox: Float[ndarray, "m 4"]) -> Float[ndarray, "n m"]:
    ...

An implemented class within a library with all functions jaxtyped

https://github.com/patel-zeel/garuda/blob/edcfb1943781301c9ab0d0a503fe47f8cbe6d94f/garuda/od.py#L176-L420