Provide a way to get error diagnostics out of isinstance checks
reinerp opened this issue · 1 comments
The assert isinstance(...)
pattern prints a mostly useless message, just "AssertionError" without explanation. Would it be possible to expose an assertIsInstance(x, ty)
API that prints expected versus actual, like we get for errors in the function arguments?
So this actually dovetails well with another feature I would like to add.
Beartype now supports checking for an __instancecheck_str__
method. (Beartype release notes, relevant jaxtyping discussion thread.)
Once this is added, then your use-case could be easily supported via assert isinstance(x, ty), ty.__instancecheck_str__(x)
.
This shouldn't be too much work to add. Discussing some jaxtyping internals briefly, the plan is basically to rewrite things from
class _MetaAbstractArray(type):
def __instancecheck__(cls, obj):
if something_bad:
return False
...
return True
to
class _MetaAbstractArray(type):
def __instancecheck__(cls, obj):
return cls.__instancecheck_str__(obj) != ""
def __instancecheck_str__(cls, obj):
if something_bad:
return "something bad!" + _exc_shape_info(get_shape_memo())
...
return ""
which would give both specifically how we failed the check (more than we get at the moment under any circumstances!) and all the extra information about the current values of bindings (via _exc_shape_info
).
I'd be happy to guide a pull request on this; else I'm hoping to get around to this myself in the near future.