google/jax

Running on multiple A100 gets stuck

Hatmm opened this issue · 13 comments

Hatmm commented

Running the following code on one A100 GPU card works fine. However, when switching to more than one the GPUs utilization goes to 100% but their power consumption is as if they were idling.

os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.devices
  from functools import partial
  from jax import lax, pmap
  import jax.numpy as jnp

  @partial(pmap, axis_name='i')
  def normalize(x):
        return x / lax.psum(x, 'i')

  print(normalize(jnp.arange(2.)))

image

I should add that the program cannot be manually killed anymore.

tldr: i suspect you are not setting CUDA_VISIBLE_DEVICES correctly. See my successful attempt below

I took a clean GCP GPU VM and installed the latest jax

pip install --upgrade pip

# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

i verified i have 8 GPUs

>>> import jax
>>> jax.devices()
[GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0), GpuDevice(id=4, process_index=0), GpuDevice(id=5, process_inde
x=0), GpuDevice(id=6, process_index=0), GpuDevice(id=7, process_index=0)]
zhangqiaorjc@skyewm-gpu-vm2:~$ cat issue_8475.py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
from functools import partial
from jax import lax, pmap
import jax.numpy as jnp
@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))

It worked

zhangqiaorjc@skyewm-gpu-vm2:~$ python3 issue_8475.py 
[0.         0.16666667 0.33333334 0.5   
Hatmm commented

Hi @zhangqiaorjc ,
I have tried your code on my environment ( 0.1.70+cuda110 ), I still have the same bug.
Please find attached a screenshot where I have used GPU 0,4,5,6. As you can see the GPU utilization is maximal while the power consumption is low. The program gets stuck and cannot be manually killed.

Screen Shot 2021-11-20 at 4 04 55 PM

Can you please (a) use the current jax and jaxlib (0.1.74) and (b) provide self-contained instructions to reproduce? What precisely did you run?

Hatmm commented

Hi !

(a) I will try this one and get back to you !
(b) this is the code I ran by calling python3 issue_8475.py

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
from functools import partial
from jax import lax, pmap
import jax.numpy as jnp
@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))

Thanks !

@Hatmm are you able to get (a) to work? basically use latest jax and jaxlib?

Hatmm commented

@zhangqiaorjc (a) did not solve the problem. I have tried running on 2 GPUs (devices 0 and 2) as you can see the program looks idled. I am using cudnn 8.4.2 and cuda 11.4 with jax 0.2.25
Screen Shot 2021-11-30 at 6 16 47 PM
Screen Shot 2021-11-30 at 6 16 12 PM
.

Hatmm commented

After installing jaxlib 0.1.74 https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.74+cuda11.cudnn82-cp38-none-manylinux2010_x86_64.whl

I get an error instead of idling
RuntimeError: INTERNAL: CudnnLegacyConvRunner cached across multiple StreamExecutors.: while running replica 1 and partition 0 of a replicated computation (other replicas may have failed as well).
I believe this issue is related to this one https://github.com/google/jax/issues/8654

The CudnnLegacyConvRunner issue was resolved in XLA and will be available in the next release (see #8654 (comment)).

Is this issue resolved? I have found the same problem using V100s with cuda=11.4, jax=0.3.1, jaxlib=0.3.0+cuda11.cudnn82, nvidia driver 470.94.

I find the program hangs (with 100% utilization) when using the inter-device communication operations (eg psum). When I avoid communication, the output is as expected but only if a regular numpy array is used as input.

from functools import partial
from jax import lax, pmap, vmap
import jax.numpy as jnp

@partial(pmap, axis_name='i')
def normalize(x):
  return x

print(normalize(jnp.array([10.,100.])))

Using a jnp.array returned [10,0] rather than [10, 100] as is returned by vmap or using a regular numpy array.

Using a regular numpy array does not solve the hanging problem.

@tlitfin I can't reproduce your problem. e.g., I just created a GCP VM with 4xV100, with CUDA 11.4, driver 470.103.01 and the same jax and jaxlib versions.

$ nvidia-smi
Mon Mar  7 13:23:10 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    32W / 300W |     80MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:00:05.0 Off |                    0 |
| N/A   37C    P0    43W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  Off  | 00000000:00:06.0 Off |                    0 |
| N/A   37C    P0    33W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  Off  | 00000000:00:07.0 Off |                    0 |
| N/A   36C    P0    42W / 300W |      4MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

The program works as expected for me and prints:

$ python t.py
[ 10. 100.]

I'm wondering if something isn't working with device->device copies on your machine.

Are you able to reproduce this problem on a cloud VM (i.e., a setup I could replicate exactly?)

What does nvidia-smi topo -m print?

@hawkinsp I am so sorry! I started a new session this morning and found that the problem was resolved. I was already using a fresh conda environment but I must have had an environment variable set from previous debugging that made the issue persist. I am unable to re-create the faulty environment today to isolate the cause but I suspect it was a conflict between cuda/jax versions in my environment with the system-wide install. I am sorry again for wasting your time.

Woohoo, sounds fixed! 🎉

As a follow up, I found that my problem returned and was not a simple environment conflict as I had suspected. The problem seems to depend on communication between specific GPUs allocated by our scheduling software. I attached a screenshot to illustrate.

pmap

This may be a hardware configuration problem rather than a jax issue but I am posting the info here for completeness.