Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent preprocessing output on all backends #1777

Merged

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Aug 15, 2024

Old behavior:

  • On TF backend, raggeds and strings returned as tf tensors.
  • On Jax/Torch backends, raggeds and strings returned as lists.
  • Preprocessing functions outside of __call__, like tokenize(), detokenize(), generate_preprocess(), will always return tf tensors on all backends.

This made it hard to write backend agnostic code. TF shows up in random places, and if you are flipping from tf -> jax or vice versa you have to switch between handling tensors and lists.

New behavior:

  • On all backends for all preprocessing functions, raggeds and strings are returned as lists.
  • Inside a tf.data call or tf compiled function, preprocessing layers always output tf.tensors.

This requires a little complexity to avoid over converting back and forth from tf -> python in nested calls, but thankfully we can hide most of that complexity in a decorator.

@mattdangerw mattdangerw force-pushed the consistent-preprocessing-outputs branch 7 times, most recently from 2008b84 to 5de16cd Compare August 16, 2024 03:04
@mattdangerw mattdangerw changed the title [DRAFT] Consistent preprocessing output on all backends Consistent preprocessing output on all backends Aug 16, 2024
@mattdangerw mattdangerw marked this pull request as ready for review August 16, 2024 03:05
Old behavior:
- On TF backend, raggeds and strings returned as tf tensors.
- On Jax/Torch bakcnes, raggeds and strings returned as lists.
- Preprocessing functions outside of `call`, like `tokenize()`,
  `detokenize()`, `generate_preprocess()`, will always return
  tf tensors.

This made it hard to write backend agnostic code. TF shows up in
random places, and if you are flipping from tf -> jax or vice versa
you have to switch between handling tensors and lists.

New behavior:
- On all backends for all functions, raggeds and strings are returned
  as lists.
- Inside a `tf.data` call or tf compiled function, preprocessing layers
  always output tf.tensors.

This requires a little complexity to avoid over converting back and
forth for tf -> python, but thankfully we can hide most of that
complexity in a decorator.
@mattdangerw mattdangerw force-pushed the consistent-preprocessing-outputs branch from c582212 to 331f6a1 Compare August 16, 2024 20:56
Copy link
Member

@SamanehSaadat SamanehSaadat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks, Matt! Just left a couple of nit comments!

keras_nlp/src/models/bart/bart_preprocessor.py Outdated Show resolved Hide resolved
keras_nlp/src/utils/tensor_utils.py Outdated Show resolved Hide resolved
@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Aug 19, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Aug 19, 2024
@mattdangerw mattdangerw merged commit 180c7ec into keras-team:master Aug 19, 2024
10 checks passed
pkgoogle pushed a commit to pkgoogle/keras-hub that referenced this pull request Aug 22, 2024
* Consistent preprocessing output on all backends

Old behavior:
- On TF backend, raggeds and strings returned as tf tensors.
- On Jax/Torch bakcnes, raggeds and strings returned as lists.
- Preprocessing functions outside of `call`, like `tokenize()`,
  `detokenize()`, `generate_preprocess()`, will always return
  tf tensors.

This made it hard to write backend agnostic code. TF shows up in
random places, and if you are flipping from tf -> jax or vice versa
you have to switch between handling tensors and lists.

New behavior:
- On all backends for all functions, raggeds and strings are returned
  as lists.
- Inside a `tf.data` call or tf compiled function, preprocessing layers
  always output tf.tensors.

This requires a little complexity to avoid over converting back and
forth for tf -> python, but thankfully we can hide most of that
complexity in a decorator.

* Rename preprocessing_function -> tf_preprocessing_function

* address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants