This package contains the core implementation of the AlphaGrad reinforcement learning algorithm.
The following packages are required to successfully run the algorithms.
- jax GPU installation from https://github.com/google/jax
- mctx package for tree search from https://github.com/google-deepmind/mctx
- optax
- equinox
- numpy, scipy, matplotlib
- distrax from https://github.com/google-deepmind/distraxd
- flashbax from https://instadeepai.github.io/flashbax/
- tqdm
Finally, the Graphax package needs to be installed as well since it contains the
implementation of the sample task. The package is contained in the ZIP file of
this project.
It can be installed using
pip install -e .
in the root directory of the package. Similarly, the AlphaGrad package itself has to be installed by executingpip install -e .
in the root directory of this package.
To start a run of the RL algorithm, use the following command.
CUDA_VISIBLE_DEVICES=0,1,2,3 vertex_A0.py --task RoeFlux_1d --name test --seed 123
The config subfolder contains .yaml files to configure the hyperparameters
of the experiments.
Similarly, you can run the experiments of separate_models_vertex_ppo.py
and
vertex_A0_joint.py
.
Note that it is necessary to set up wandb
to log the experiments.
Use --wandb disabled
to deactivate it.
The project structure is described in the following section:
- alphagrad
-
src
- alphagrad
- alphazero
- environment_interaction.py Contains the implementation of the Monte-Carlo Tree Search.
- vertex_A0.py Run this script for a single task experiment with AlphaZero.
- vertex_A0_joint.py Run this script for a single task experiment with AlphaZero.
- eval This folder contains the evaluation of the elimination orders found by the RL algorithm. Also contains the reward curves and Jupyter notebooks used to create the figures from the paper.
- ppo
- runtime_vertex_ppo.py Run this script for a single task experiment with PPO and actual runtime as reward for the model. Not tested yet.
- separate_models_vertex_ppo.py Run this script for a single task experiment with PPO where We use a separate model for policy and value networks.
- vertex_ppo_joint.py Run this script for joint experiments with PPO. This script is decomissioned.
- vertex_ppo.py Run this script for a single task experiment with PPO where we use a joint model for policy and value functions.
- transformer Implementation of transformer model used in this work.
- vertexgame
- interpreter This folder contains the functions that trace the python function to create the computational graph representation used as the state for the RL algorithm.
- codegeneration This folder contains the source code that was used to generate the random functions f and g.
- transforms This folder contains a set of transformations similar to image augumentations in computer vision.
- vertex_game.py Implementation of the VertexGame reinforcement learning game using the number of multiplications as a reward.
- runtime_game.py Implementation of the VertexGame reinforcement learning game using the runtime as a reward.
- core.py Core implementation of the environment dynamics model of cross-country elimination with sparsity types, Jacobian shapes etc.
- alphazero
- alphagrad
-
docs
-
tests
-
The eval subfolder contains the code used to evaluate the elimination orders.
For every experiment, the folder contains an appropriate subfolder with a
.ipynb
notebook.
The graphax.jacve(f, order=order, argnums=argnums)
command computes the Jacobian
with Graphax for a given elimination order order
. The syntax is similar to
the syntax of jax.jacfwd
and jax.jacrev
.
Runtime performances are tested with the graphax.perf
package.