Allow changing the default dtypes
asmeurer opened this issue · 1 comments
See https://data-apis.org/array-api/latest/API_specification/data_types.html#default-data-types and https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.default_dtypes.html#array_api.info.default_dtypes.
We should add flags to the set_array_api_strict_flags
to configure these away from the NumPy defaults.
One concern here is that some instances of moving from float64 to float32, we might have to just downcast the result from NumPy, meaning the computation will still happen in float64, producing a result that could be different from a library that actually does everything in float32. This should likely be worked around wherever possible by downcasting the input before computing rather than the output.
We could also add behavior to emulate missing dtypes. This would require rewriting the existing code little bit, so I'm only really included to implement this if people ask for it. It would help map to libraries like pytorch, but at the same time, people will just test against those libraries so it isn't strictly necessary for array-api-strict to be the provider of this behavior.