iris-hep/NSBI-workflow-tutorial

Make the JAX model more vectorized, avoid using dictionaries / pytrees for faster JIT-based computations

Opened this issue · 0 comments

The nsbi_common_utils.model currently deploys a JAX-based model leveraging the use of just-in-time (JIT) compilation. Currently dictionaries with process:data structure is used. This can be replaced by stacked arrays with appropriate bookkeeping for faster computations.

The current model can be found at:

nsbi_common_utils.model: link

which reads the workspace and builds a model with dictionaries. Replace this model building with jax.numpy stacked arrays that can be computed in parallel.