Vectorize the edges and attributes of the networks
LegrandNico opened this issue · 0 comments
The data format we are currently using to represent attributes
and edges
is suboptimal regarding JAX transformation. Per JAX standard, PyTrees are only accessible at compile time and cannot be indexed using Tracer, that are accessible at run time. For this reason, update functions need to set the node_idx
and edges
variables as static arguments, which cache a new function for each node separately. This makes us lose the advantages provided by the modularity of the implementation, and large models will definitely benefit from having a uniquely cached update function.
The solution I see would be:
- [x ] Use a dictionary of arrays to store the
edges
using a connectivity matrix representation. - [x ] Use a dictionary of arrays for each node parameter.
Update: The current status is that it is (very) difficult to write readable update functions that can pass messages with a dynamically valued number of nodes without using something like Dynamic shapes. It is under development in JAX but not yet available. Until such a feature is available it seems unreasonable to try to move the code to this implementation. We have a working example for the two-level binary HGF and the total execution time is longer than the default implementation, so it is unclear if we would really benefit from this, besides compilation time.