Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed training not working (batch size calculation) #1630

Open
natbprice opened this issue May 16, 2024 · 7 comments
Open

Distributed training not working (batch size calculation) #1630

natbprice opened this issue May 16, 2024 · 7 comments
Assignees

Comments

@natbprice
Copy link

Describe the bug
This is an issue I am having with keras-nlp, but I am not sure if it can be solved here or should be reported under keras or tensorflow.

Currently, the batch size is not calculated correctly when performing multi-worker distributed training with JAX backend:

Traceback (most recent call last):
  File "mycode.py", line 293, in <module>
    history = classifier.fit(
  File "/usr/local/lib/python3.10/dist-packages/keras_nlp/src/utils/pipeline_model.py", line 194, in fit
    return super().fit(
  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.10/dist-packages/keras/src/distribution/distribution_lib.py", line 467, in distribute_dataset
    raise ValueError(
ValueError: The batch size of the input dataset is unknown. Please config the batch size for the input dataset, e.g via `dataset.batch(batch_size)`

To Reproduce
Run (multi-worker?) distributed training with JAX backend.

The issue seems to stem from

where mapping a preprocessor over the dataset leads to failure at https://github.com/keras-team/keras/blob/3105247028bb0a7e6d2f05f5daa44c9cfafd3e67/keras/src/distribution/distribution_lib.py#L465

Here is minimal example where tensorflow.python.data.experimental.ops.distribute.compute_batch_size() returns -1 after mapping:

import tensorflow as tf
from tensorflow.python.data.experimental.ops import distribute as tf_data_distribute
from keras_nlp.src.utils.keras_utils import pack_x_y_sample_weight

ds = tf.data.Dataset.range(8)
ds = ds.batch(3)

print(f"True batch size (before): {len(list(ds.as_numpy_iterator()))}")
print(f"Calculated batch size (before): {tf_data_distribute.compute_batch_size(ds)}")

ds = ds.map(pack_x_y_sample_weight, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

print(f"True batch size (after): {len(list(ds.as_numpy_iterator()))}")
print(f"Calculated batch size (after): {tf_data_distribute.compute_batch_size(ds)}")

Expected behavior
A batched tf.data.Dataset() object is recognized as being batched.

@natbprice
Copy link
Author

Here is Colab with reproducible example:

https://colab.research.google.com/drive/1aCJVUNfro68fek-o0i_7Iojl6Qtix6NK?usp=sharing

@natbprice
Copy link
Author

I can reproduce the error using just keras, so maybe I should open issue there? Or maybe it should be fixed in tensorflow? But the documentation for tensorflow.python.data.experimental.ops.distribute.compute_batch_size() describes its limitations so not sure it is technically a bug in tensorflow.

https://colab.research.google.com/drive/1IxVNDcNoIK4SiX2wuDQKfqR_6Or9P40I#scrollTo=0Hf6qJOxXsqI

@natbprice natbprice changed the title Distributed batch size not calculated correctly Distributed training not working (batch size calculation) Jun 3, 2024
@natbprice
Copy link
Author

Hi @SuryanarayanaY, the related ticket in keras was closed with the recommendation that this be fixed in keras-nlp. Per @hertschuh: "One should simply apply batch_size after the map and not in _convert_inputs_to_dataset".

I can't quite figure out the best way for this to work with keras-nlp API. In particular, it seems like there are several combinations of (1) distribution strategy, (2) input types (e.g., tf.data.Dataset, NumPy arrays), and (3) batching (e.g., pre-batched dataset, explicit batch_size).

Currently, in _convert_inputs_to_dataset it will raise an error if you attempt to pass a tf.data.Dataset with explicit batch_size argument. It also looks like there is error handling to prevent you from passing unbatched inputs, but the string matching on the error message may be oudated and not functioning.

@hertschuh
Copy link
Contributor

@natbprice ,

Sorry for the delay, I'm still working on this. It turned out to be more complex to fix than I expected.

@hertschuh
Copy link
Contributor

@natbprice ,

I experimented with a few things, but I could not find a fix in keras-nlp that would work in all cases.

However, I do have an easy workaround: ds.batch(8, drop_remainder=True). By doing this, the dataset knows that the first dimension, the batch size, is always 8. Then, it can infer the first dimension of the result of other operations like map.

If you don't do drop_remainder=True, it thinks the last batch may be incomplete. And while you can still retrieve the batch size right after batching, it doesn't propagate through other operations like map.

If you're concerned about not using the last few examples, you can shuffle, or repeat the dataset before batching.

@natbprice
Copy link
Author

natbprice commented Jul 2, 2024

@hertschuh do you have a working example? This doesn't seem to work for me.

https://colab.research.google.com/drive/1aCJVUNfro68fek-o0i_7Iojl6Qtix6NK?usp=sharing#scrollTo=0Hf6qJOxXsqI

Edit:

It seems like the workaround works outside of keras-nlp. Maybe there is something specific to keras-nlp that still needs to be resolved?

https://colab.research.google.com/drive/1IxVNDcNoIK4SiX2wuDQKfqR_6Or9P40I#scrollTo=0Hf6qJOxXsqI

@natbprice
Copy link
Author

@SuryanarayanaY can we reopen this issue please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants