NotFoundError: Graph execution error: TPU
innat opened this issue · 8 comments
While trying to run the following code on tpu-vm, it didn't work.
tf: 2.15
keras: 3.0.5
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
strategy = tf.distribute.TPUStrategy(tpu)
def get_compiled_model():
# Make a simple 2-layer densely-connected neural network.
inputs = keras.Input(shape=(784,))
x = keras.layers.Dense(256, activation="relu")(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return model
def get_dataset():
batch_size = 32
num_val_samples = 10000
# Return the MNIST dataset in the form of a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocess the data (these are Numpy arrays)
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
# Reserve num_val_samples samples for validation
x_val = x_train[-num_val_samples:]
y_val = y_train[-num_val_samples:]
x_train = x_train[:-num_val_samples]
y_train = y_train[:-num_val_samples]
return (
tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
)
with strategy.scope():
model_ = get_compiled_model()
train_dataset, val_dataset, test_dataset = get_dataset()
model_.fit(train_dataset, epochs=2, validation_data=val_dataset)
---------------------------------------------------------------------------
NotFoundError Traceback (most recent call last)
Cell In[5], line 1
----> 1 model_.fit(train_dataset, epochs=2, validation_data=val_dataset)
File /usr/local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
120 filtered_tb = _process_traceback_frames(e.__traceback__)
121 # To get the full stack trace, call:
122 # `keras.config.disable_traceback_filtering()`
--> 123 raise e.with_traceback(filtered_tb) from None
124 finally:
125 del filtered_tb
File /usr/local/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
51 try:
52 ctx.ensure_initialized()
---> 53 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
54 inputs, attrs, num_outputs)
55 except core._NotOkStatusException as e:
56 if name is not None:
NotFoundError: Graph execution error:
Detected at node TPUReplicate/_compile/_9074053372847989778/_4 defined at (most recent call last):
<stack traces unavailable>
Hi @djherbis, Could you please provide any information regarding this issue? Is there any blockers to use tpu-vm at the moment?
@innat Could you share a public notebook with the complete code? That makes it a bit easier to debug, thanks!
Hey, have you confirmed that Keras is using Tensorflow under the hood?
I took a quick try at this, I switched to tf-cpu, removed the TPU VM + tensorflow related code, and switched to the Keras backend to JAX and then I think it works?
I don't fully get your points. However, I was able to run keras
with all backend (tf, torch, jax) on cpu and gpu. But as shown in the above gist, for tpu-vm it didn't.
I have run the above gist again with keras+tensorflow
and keras+jax
setup for tpu. And both fail to run the program.
I meant when I ran it as Jax without tensorflow on tpuvm then it worked:
https://www.kaggle.com/code/herbison/keras-jax-tpu-vm-model-build-test
Its not too uncommon for something to work on CPU/GPU and not tpu since the actual underlying systems are different.
If possible using the Jax example might be a path forward.
Ah, I see.
I als tried following without installing tf-cpu, didn't work though.
tf.config.set_visible_devices([], "TPU")
import keras, jax
devices = jax.devices("tpu")
data_parallel = keras.distribution.DataParallel(devices=devices)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[5], line 4
1 import keras, jax
----> 4 data_parallel = keras.distribution.DataParallel(devices=devices)
5 keras.distribution.set_distribution(data_parallel)
File /usr/local/lib/python3.10/site-packages/keras/src/distribution/distribution_lib.py:400, in DataParallel.__init__(self, device_mesh, devices)
398 self._batch_dim_name = self.device_mesh.axis_names[0]
399 # Those following attributes might get convert to public methods.
--> 400 self._num_process = distribution_lib.num_processes()
401 self._process_id = distribution_lib.process_id()
402 self._is_multi_process = self._num_process > 1
AttributeError: module 'keras.src.backend.tensorflow.distribution_lib' has no attribute 'num_processes'
Yeah, its impossible to use tensorflow (TPU) install with JAX or Pytorch, and since Keras is calling tensorflow here, thats loading the TPU twice (once for JAX, once for tensorflow) which breaks things.
Installing tensorflow-cpu, and then using JAX (TPU) works though.