orthax
is a Python package for working with orthogonal (and other) polynomials in JAX.
It largely seeks to replicate the functionality of the numpy.polynomial
package,
through there are some API differences due to limitations of JAX, primarily that
trailing zeros are not automatically trimmed from series, so you should do that
manually if it becomes a concern.
For full details of various options see the Documentation
orthax is installable with pip
:
pip install orthax