keras-team/keras

Multi-process batch size not calculated correctly

natbprice opened this issue · 6 comments

Describe the bug
I opened a related issue in keras-nlp, but I believe the issue is likely best addressed in keras. See related issue: keras-team/keras-nlp#1630

Currently, the batch size is not calculated correctly when performing multi-process distributed training with JAX backend if the dataset has been pre-processed with a mapping function.

ValueError                                Traceback (most recent call last)
[<ipython-input-9-639e39591e79>](https://localhost:8080/#) in <cell line: 14>()
     12 
     13 model.compile(loss="mse")
---> 14 model.fit(ds, epochs=3)

1 frames
[/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

[/usr/local/lib/python3.10/dist-packages/keras/src/distribution/distribution_lib.py](https://localhost:8080/#) in distribute_dataset(self, dataset)
    465         batch_size = tf_data_distribute.compute_batch_size(dataset)
    466         if batch_size.numpy() < 0:
--> 467             raise ValueError(
    468                 "The batch size of the input dataset is "
    469                 "unknown. Please config the batch size for "

ValueError: The batch size of the input dataset is unknown. Please config the batch size for the input dataset, e.g via `dataset.batch(batch_size)`

To Reproduce
See https://colab.research.google.com/drive/1IxVNDcNoIK4SiX2wuDQKfqR_6Or9P40I?usp=sharing

import os

os.environ['KERAS_BACKEND'] = 'jax'

import keras
import tensorflow as tf
import numpy as np

print(f"keras", keras.__version__)
print(f"tf", tf.__version__)

data_parallel = keras.distribution.DataParallel()

# Mock multi-process environment
data_parallel._is_multi_process = True

keras.distribution.set_distribution(data_parallel)

inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
ds = tf.data.Dataset.from_tensor_slices((inputs, labels)).batch(16)
ds = ds.map(lambda x,y: (x,y))

inputs = keras.layers.Input(shape=(28, 28, 1))
y = keras.layers.Flatten()(inputs)
y = keras.layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = keras.layers.Dropout(0.4)(y)
y = keras.layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)

model.compile(loss="mse")
model.fit(ds, epochs=3)

Expected behavior
A batched tf.data.Dataset() object is recognized as being batched.

Would you like to help us fix it?
I would like to try to fix this if it is not too complex. Maybe we can just replace call to tensorflow.python.data.experimental.ops.distribute.compute_batch_size() with dataset._input_dataset._batch_size?

Hi @natbprice ,

Thanks for reporting. I have tested the code snippet and reproduced the reported behaviour. Attached gist for reference.

@natbprice ,

Thanks for the report and the investigation. After looking into it in details, I came to the conclusion that this works as expected.

I saw your proposed fix:

  if isinstance(ds, _MapDataset) or isinstance(ds, _ParallelMapDataset):
    return ds._input_dataset._batch_size

But that's the batch size of the input dataset. The issue is that there is no constraint on what the function passed to map is allowed to do, therefore there is no guarantee that what comes out of map has the same batch size as what came in.

Now, why does this only happen when using multi-process distribution? That's because Keras is able to train with an unknown batch size in the normal case and only tries to determine the batch size if distribution is turned on.

What's the fix? Well, the standard pattern I've seen used is to batch last, after map, shuffle etc.

ds = tf.data.Dataset.from_tensor_slices((inputs, labels))
ds = ds.map(lambda x,y: (x,y))
ds = ds.batch(16)

Let me know if you have further questions.

Are you satisfied with the resolution of your issue?
Yes
No

@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.

Sorry, if I created extra work. I guess I should have not opened related issue here.

@hertschuh thanks for investigating this. Based on your conclusion, it sounds like this issue should instead be resolved in keras-team/keras-nlp#1630? In that case, a preprocessor is being mapped over the data internally so there doesn't appear to be an easy workaround.

Sorry, if I created extra work. I guess I should have not opened related issue here.

@natbprice ,

Yes, I think the fix should be in keras-nlp. One should simply apply batch_size after the map and not in _convert_inputs_to_dataset. Do you want me to follow up in the keras-nlp bug?

@hertschuh if you don't mind following up in keras-nlp, that would be great! I think I understand the solution you are proposing, but I can't quite figure out the best way for keras-nlp API to function. In particular, it seems like there are several combinations of (1) distribution strategy, (2) input types (e.g., tf.data.Dataset, NumPy arrays), and (3) batching (e.g., pre-batched dataset, explicit batch_size).

Currently, in _convert_inputs_to_dataset it will raise an error if you attempt to pass a tf.data.Dataset with explicit batch_size argument. It also looks like there is error handling to prevent you from passing unbatched inputs, but the string matching on the error message may be oudated and not functioning.