Multi-process batch size not calculated correctly
natbprice opened this issue · 6 comments
Describe the bug
I opened a related issue in keras-nlp, but I believe the issue is likely best addressed in keras. See related issue: keras-team/keras-nlp#1630
Currently, the batch size is not calculated correctly when performing multi-process distributed training with JAX backend if the dataset has been pre-processed with a mapping function.
ValueError Traceback (most recent call last)
[<ipython-input-9-639e39591e79>](https://localhost:8080/#) in <cell line: 14>()
12
13 model.compile(loss="mse")
---> 14 model.fit(ds, epochs=3)
1 frames
[/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
[/usr/local/lib/python3.10/dist-packages/keras/src/distribution/distribution_lib.py](https://localhost:8080/#) in distribute_dataset(self, dataset)
465 batch_size = tf_data_distribute.compute_batch_size(dataset)
466 if batch_size.numpy() < 0:
--> 467 raise ValueError(
468 "The batch size of the input dataset is "
469 "unknown. Please config the batch size for "
ValueError: The batch size of the input dataset is unknown. Please config the batch size for the input dataset, e.g via `dataset.batch(batch_size)`
To Reproduce
See https://colab.research.google.com/drive/1IxVNDcNoIK4SiX2wuDQKfqR_6Or9P40I?usp=sharing
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import tensorflow as tf
import numpy as np
print(f"keras", keras.__version__)
print(f"tf", tf.__version__)
data_parallel = keras.distribution.DataParallel()
# Mock multi-process environment
data_parallel._is_multi_process = True
keras.distribution.set_distribution(data_parallel)
inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
ds = tf.data.Dataset.from_tensor_slices((inputs, labels)).batch(16)
ds = ds.map(lambda x,y: (x,y))
inputs = keras.layers.Input(shape=(28, 28, 1))
y = keras.layers.Flatten()(inputs)
y = keras.layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = keras.layers.Dropout(0.4)(y)
y = keras.layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)
model.compile(loss="mse")
model.fit(ds, epochs=3)
Expected behavior
A batched tf.data.Dataset() object is recognized as being batched.
Would you like to help us fix it?
I would like to try to fix this if it is not too complex. Maybe we can just replace call to tensorflow.python.data.experimental.ops.distribute.compute_batch_size()
with dataset._input_dataset._batch_size
?
Hi @natbprice ,
Thanks for reporting. I have tested the code snippet and reproduced the reported behaviour. Attached gist for reference.
Thanks for the report and the investigation. After looking into it in details, I came to the conclusion that this works as expected.
I saw your proposed fix:
if isinstance(ds, _MapDataset) or isinstance(ds, _ParallelMapDataset):
return ds._input_dataset._batch_size
But that's the batch size of the input dataset. The issue is that there is no constraint on what the function passed to map
is allowed to do, therefore there is no guarantee that what comes out of map
has the same batch size as what came in.
Now, why does this only happen when using multi-process distribution? That's because Keras is able to train with an unknown batch size in the normal case and only tries to determine the batch size if distribution is turned on.
What's the fix? Well, the standard pattern I've seen used is to batch last, after map
, shuffle
etc.
ds = tf.data.Dataset.from_tensor_slices((inputs, labels))
ds = ds.map(lambda x,y: (x,y))
ds = ds.batch(16)
Let me know if you have further questions.
@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.
Sorry, if I created extra work. I guess I should have not opened related issue here.
@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.
Sorry, if I created extra work. I guess I should have not opened related issue here.
Yes, I think the fix should be in keras-nlp. One should simply apply batch_size
after the map
and not in _convert_inputs_to_dataset
. Do you want me to follow up in the keras-nlp bug?
@hertschuh if you don't mind following up in keras-nlp, that would be great! I think I understand the solution you are proposing, but I can't quite figure out the best way for keras-nlp API to function. In particular, it seems like there are several combinations of (1) distribution strategy, (2) input types (e.g., tf.data.Dataset, NumPy arrays), and (3) batching (e.g., pre-batched dataset, explicit batch_size
).
Currently, in _convert_inputs_to_dataset
it will raise an error if you attempt to pass a tf.data.Dataset
with explicit batch_size
argument. It also looks like there is error handling to prevent you from passing unbatched inputs, but the string matching on the error message may be oudated and not functioning.