StyleGAN 2 doesn't work with Colab TPU despite successful TPU connection/init?
josephrocca opened this issue · 9 comments
When trying to run the StyleGAN 2 training code on Google Colab, I'm getting:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
But that's after confirming that the TPU is setup correctly:
Here's a minimal example: https://colab.research.google.com/gist/josephrocca/5e64c9906db96f27b583f0a577ef9b4a/debugging-matthias-wright-s-stylegan2-jax-tpu-not-detected.ipynb
If I set TF_CPP_MIN_LOG_LEVEL=0
, I get:
2021-10-08 16:05:10.421297: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-10-08 16:05:12.352286: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x55a8d14dddc0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2021-10-08 16:05:12.352348: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Interpreter, <undefined>
2021-10-08 16:05:12.358082: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:163] TfrtCpuClient created.
2021-10-08 16:05:12.371498: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-10-08 16:05:12.371542: I external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (4b26ff987b6b): /proc/driver/nvidia/version does not exist
2021-10-08 16:05:12.371984: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
2021-10-08 16:05:12.508418: I tensorflow/core/platform/cloud/google_auth_provider.cc:180] Attempting an empty bearer token since no token was retrieved from files, and GCE metadata check was skipped.
2021-10-08 16:05:12.545218: I tensorflow/core/platform/cloud/google_auth_provider.cc:180] Attempting an empty bearer token since no token was retrieved from files, and GCE metadata check was skipped.
2021-10-08 16:05:12.586517: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-10-08 16:05:12.586652: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2021-10-08 16:05:12.595212: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-10-08 16:05:12.595243: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (4b26ff987b6b): /proc/driver/nvidia/version does not exist
2021-10-08 16:05:12.595544: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-10-08 16:05:12.597344: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
Not sure if this problem is specific to the StyleGAN 2 training code, since I haven't tried any of the other models. I'm going to continue trying to debug this tommorrow - will update this post if I find out what's going on here.
Hi @josephrocca, thanks for pointing this out! Have you tried running a small Jax example (without flaxmodels
) on Colab with TPUs?
Hey @matthias-wright, yep, in general JAX works fine on Colab. For example, in the notebook I linked, if you replace the final !python3 main.py ...
notebook cell with:
def f(x):
return x ** 2.0
df = jax.jit(jax.grad(f))
df(5.0)
then it doesn't have any problems with detecting the TPU. The fact that jax.devices()
correctly returns the TPU cores makes me think that the stylegan 2 code is somehow using a different "instance" of jax
that hasn't been connected to the TPU devices via jax.tools.colab_tpu.setup_tpu()
, but being new to Python this is a wild guess because I'm not sure how Python's imports work, and e.g. whether it's even possible to have two different instances of the jax
package.
This might be a bit late but here's a working notebook for Colab running StyleGAN2 on the TPU. Setting these parameters is crucial to getting jax to connect to the TPU. Otherwise the TPU will not be utilized as jax cannot communicate with the backend
config.FLAGS.jax_xla_backend = "tpu_driver" .
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
Check that the TPU is being utilized by running:
from jax.lib import xla_bridge
xla_bridge.get_backend().platform
It should output 'tpu' if everything is working correctly.
Link to my Colab notebook : https://colab.research.google.com/drive/11JDP0xbKV249UB2kvIfn6_Y9WDewJkfw?usp=sharing#scrollTo=uGp2rhZP40NE
Hope this helps.
@JohnnyRacer I'd assume that those options are set automatically by this code:
# https://github.com/google/jax#pip-installation-google-cloud-tpu
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
There's no mention in the google/jax readme about having to set the flags you mentioned manually. Perhaps your approach is just another way to do what jax.tools.colab_tpu.setup_tpu()
does?
I think the issue I document in the original post is about something different, since it does appear to be connected to the TPU correctly (see screenshot in original post). I just checked, and xla_bridge.get_backend().platform
returns 'tpu' as expected.
@josephrocca josephrocca Yes the commands are simply to manually set the backend for jax, since I thought it was a problem with jax.tools.colab_tpu.setup_tpu()
not properly registering the TPU. After looking at the notebook again I think this isn't the issue and is something else like you mentioned.
@JohnnyRacer Thanks for taking the time to try to help here anyway! :)
Thanks @JohnnyRacer!
Hey @matthias-wright, this didn't actually solve the issue, to be clear (no worries if you wanted to close this anyway).
Anyone found a solution to this yet? I ran into the same issue