Setup
VM Setup:
We've run all the code on a GCP Ubuntu 20.04 VM with 8 V100 GPU, Cuda 11.7, CuDNN 8.5 setup script can be seen here : https://gist.githubusercontent.com/sourabh2k15/9dbadd0f5ca35568ca210ee4cb3b19c1/raw/be988210438f91eda85fc608a9c71744a4c8af89/vm_setup.sh
we simply wget above script and run to install required elements.
Code Setup
git clone https://github.com/sourabh2k15/conformer-diffs3.git
cd conformer-diffs pip3 install -e .
cd conformer_diffs
Data Setup
- download real batch using :
wget https://transfer.sh/oanN8q/sharded_padded_batch.npz
this will give you one sharded batch of real librispeech data with size 256, this batch has input_paddings and target paddings includes.
dimensions:
inputs - (8, 32, 320000)
input_paddings - (8, 32, 320000)
targets - (8, 32, 256)
target_paddings - (8, 32, 256)
- create jax and pytorch checkpoints by running the following script
cd pytorch
python3 dump_model_weights.py
this should dump same model weights for JAX and PyTorch in ckpts directory inside torch/
Reproduction
JAX
Command to run the jax example :
python3 jax/jax_e2e.py
PyTorch
Command to run the torch example :
torchrun --standalone --nnodes 1 --nproc_per_node 8 pytorch/torch_e2e.py
Results:
JAX and PyTorch loss logs can be found here :
https://gist.github.com/sourabh2k15/7f363358801ec1e8ac770c61c7de5203