NVIDIA/JAX-Toolbox

Use the unit testing framework instead of the MGMN testing framework for `_test_te.yaml`

yhtang opened this issue · 1 comments

yhtang commented

Currently, the TE tests as implemented in _test_te.yaml consists of two parts:

  1. unit testing on V100 only
  2. multi-GPU testing on A100 via SLURM jobs.

This seems to make the testing setup unnecessarily complex.

We have a V100/A100 unit testing framework as exemplified in _test_jax.yaml, which allows the same unit testing/multi-GPU test logic to be matrices over GPU types as well as scaling from 1-8 GPUs.

@terrykong @ashors1 would you be able to refactor the TE to follow the JAX unit testing framework?

Completed in #636