Minimal code piece to reproduce our Pytorch XLA error

On a multi-GPU machine:

export CUDA_VISIBLE_DEVICES=0,1
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=2 python xla_entry.py

The program will fail at line 42 of main batch = next(iter(dataloader))

By removing xm.is_master_ordinal() at line 7 of main, the program runs fine.