brainpy/BrainPy

How to view the computational graph in BrainPy

CloudyDory opened this issue · 1 comments

Hi, is there a way to view the computational graph when training BrainPy models by backpropagation? I have found some tutorials on viewing computational graph for JAX functions (https://bnikolic.co.uk/blog/python/jax/2022/02/22/jax-outputgraph-rev.html), but I am not sure how to do it on BrainPy, for both jitted and un-jitted functions.

Thanks!

@CloudyDory Thanks for the question. Actually, this is almost the same as the examples you have linked here.

The following is my example to visualize the computation graph of a LIF neuron model.

import brainpy as bp
import brainpy.math as bm

hh = bp.dyn.LifRef(10)

def run_fun(inputs):
  return bm.for_loop(hh.step_run, (np.arange(inputs.shape[0]), inputs))

z = jax.xla_computation(run_fun)(np.random.uniform(2., 6., 10000))
with open("lif.dot", "w") as f:
  f.write(z.as_hlo_dot_graph())

Then, call the following command in the terminal:

dot lif.dot  -Tpng > lif.png