patrick-kidger/jaxtyping

numpy structured dtype support

alexfanqi opened this issue · 1 comments

Hey,

This is a really useful library that saves me a lot of debugging time. Thanks for maintaining this all along!

I am wondering if it is possible to support numpy's structured array? https://numpy.org/doc/stable/user/basics.rec.html#structured-arrays I mainly use it to store multiple labels for a sample.

I did a small hacking to get it to work, but am unsure if this is safe.

@@ -166,6 +166,9 @@ class _MetaAbstractArray(type):
         if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
             # JAX, numpy
             dtype = obj.dtype.type.__name__
+            # struct in numpy
+            if dtype == 'void' and obj.dtype is not np.dtype('V'):
+                dtype = str(obj.dtype)
         elif hasattr(obj.dtype, "as_numpy_dtype"):
             # TensorFlow
             dtype = obj.dtype.as_numpy_dtype.__name__

declare new AbstractDtype

annotation_t = np.dtype([('finger_count', np.uint8), ('lightness', np.int16), ('finger split', bool)])
class AnnotationT(AbstractDtype):
    dtypes = str(annotation_t)

assert(isinstance(np.array([(1, 1, False)], dtype=annotation_t), AnnotationT)) # pass

Something like this looks reasonable to me! I'd be happy to take a PR adding support for this.