keras-team/keras-nlp

Distributed batch size not calculated correctly

Opened this issue · 2 comments

Describe the bug
This is an issue I am having with keras-nlp, but I am not sure if it can be solved here or should be reported under keras or tensorflow.

Currently, the batch size is not calculated correctly when performing multi-worker distributed training with JAX backend:

Traceback (most recent call last):
  File "mycode.py", line 293, in <module>
    history = classifier.fit(
  File "/usr/local/lib/python3.10/dist-packages/keras_nlp/src/utils/pipeline_model.py", line 194, in fit
    return super().fit(
  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.10/dist-packages/keras/src/distribution/distribution_lib.py", line 467, in distribute_dataset
    raise ValueError(
ValueError: The batch size of the input dataset is unknown. Please config the batch size for the input dataset, e.g via `dataset.batch(batch_size)`

To Reproduce
Run (multi-worker?) distributed training with JAX backend.

The issue seems to stem from

where mapping a preprocessor over the dataset leads to failure at https://github.com/keras-team/keras/blob/3105247028bb0a7e6d2f05f5daa44c9cfafd3e67/keras/src/distribution/distribution_lib.py#L465

Here is minimal example where tensorflow.python.data.experimental.ops.distribute.compute_batch_size() returns -1 after mapping:

import tensorflow as tf
from tensorflow.python.data.experimental.ops import distribute as tf_data_distribute
from keras_nlp.src.utils.keras_utils import pack_x_y_sample_weight

ds = tf.data.Dataset.range(8)
ds = ds.batch(3)

print(f"True batch size (before): {len(list(ds.as_numpy_iterator()))}")
print(f"Calculated batch size (before): {tf_data_distribute.compute_batch_size(ds)}")

ds = ds.map(pack_x_y_sample_weight, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

print(f"True batch size (after): {len(list(ds.as_numpy_iterator()))}")
print(f"Calculated batch size (after): {tf_data_distribute.compute_batch_size(ds)}")

Expected behavior
A batched tf.data.Dataset() object is recognized as being batched.

I can reproduce the error using just keras, so maybe I should open issue there? Or maybe it should be fixed in tensorflow? But the documentation for tensorflow.python.data.experimental.ops.distribute.compute_batch_size() describes its limitations so not sure it is technically a bug in tensorflow.

https://colab.research.google.com/drive/1IxVNDcNoIK4SiX2wuDQKfqR_6Or9P40I#scrollTo=0Hf6qJOxXsqI