JAX reimplementation of Temporal Fusion Transformer


Unfortunately, despite my preference for jax + flax I had to abandon this project in favour of (successful) existing and maintained torch implementations (*sigh).


References


Build .sif image

srun --partition=cpu-2h --ntasks-per-node=2 --pty bash
apptainer build --fakeroot images/image.sif images/image.def