Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add affine_transform op to all backends #477

Merged
merged 18 commits into from
Jul 20, 2023

Conversation

james77777778
Copy link
Contributor

Related to #460

This PR introduces the affine operation in 3 backends. This op can be useful when implementing KPLs using pure keras-core, such as RandomRotation, RandomTranslation, etc.

It is worth noting that the implementations of F.affine_grid and F.grid_sample in the torch backend differ significantly from TF and JAX. As a result, the correctness tests are expected to fail.

Hope this helps!

@fchollet
Copy link
Contributor

Thanks for the PR. It this really necessary? Couldn't you do the same with einsum?

@james77777778
Copy link
Contributor Author

Thanks for the PR. It this really necessary? Couldn't you do the same with einsum?

Do you mean that affine is not part of the plan, or should I implement the torch's affine operation using einsum?

For 1st:
It is not obvious for me to directly implement affine using einsum.
(actually, I have not seen the implementation of affine in these backends utilizes einsum?)

If the future plan involves implementing KPLs using pure keras-core, the affine op becomes neccesary for many layers.

For 2nd:
I can try to implement torch's map_coordinate by einsum.

@fchollet
Copy link
Contributor

What's the closest numpy/scipy op? Does this match https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html ?

@james77777778
Copy link
Contributor Author

What's the closest numpy/scipy op? Does this match https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html ?

The challenge is that the absence of map_coordinates op in TF and Torch. And yes, I want to add affine_transform op to all backends!

In these 3 backends, JAX is the only one that can directly utilize numpy/scipy-like ops because it has map_coordinates op.

For TF, the closest op would be tf.raw_ops.ImageProjectiveTransformV3. This PR mainly follow the signature of tf.raw_ops.ImageProjectiveTransformV3 and the transform definition of [a0, a1, a2, b0, b1, b2, c0, c1]

For Torch, the closest op is a combination of tnn.affine_grid and tnn.grid_sample.

@james77777778
Copy link
Contributor Author

Usage:

import matplotlib.cbook as cbook
import matplotlib.pyplot as plt
import numpy as np

from keras_core.backend.jax.image import affine as jax_affine
from keras_core.backend.tensorflow.image import affine as tf_affine
from keras_core.backend.torch.image import affine as torch_affine

with cbook.get_sample_data("grace_hopper.jpg") as image_file:
    image = plt.imread(image_file)
transform = np.array([1.2, 0.3, -100, 0.5, 1.2, -150, 0, 0])

jax_affined = jax_affine(image, transform)
tf_affined = tf_affine(image, transform)
torch_affined = torch_affine(image, transform)

fig, ax_dict = plt.subplot_mosaic([["A", "B", "C", "D"]], figsize=(12, 4))
ax_dict["A"].set_title("Original")
ax_dict["A"].imshow(image)
ax_dict["B"].set_title("JAX Affined")
ax_dict["B"].imshow(jax_affined)
ax_dict["C"].set_title("TF Affined")
ax_dict["C"].imshow(tf_affined)
ax_dict["D"].set_title("Torch Affined")
ax_dict["D"].imshow(torch_affined)
fig.tight_layout(h_pad=0.1, w_pad=0.1)
plt.savefig("affine.png")

The visualization:

affine

@fchollet
Copy link
Contributor

The challenge is that the absence of map_coordinates op in TF and Torch

@mihirparadkar is working on adding this exact op to Keras -- let's make sure he sees this.

My understanding was that torch's grid_sample was equivalent -- is it not?

Also, if we have map_coordinates, would we still need affine?

@james77777778
Copy link
Contributor Author

james77777778 commented Jul 18, 2023

My understanding was that torch's grid_sample was equivalent -- is it not?

I believe grid_sample shares similar functionality but it is not precisely identical in terms of numerical behavior.
(I don't know where the differences come from, but the correctness tests failed.)

Besides:

Also, if we have map_coordinates, would we still need affine?

As a user, I think it would be more convenient to have affine or affine_transform in keras-core. It is not straightforward to apply affine transform using map_coordinates.
Otherwise there won't exist affine-like apis for these backends:

As a developer, affine is very useful to implement RandomRotation, RandomAffine, etc.

@fchollet
Copy link
Contributor

As a user, I think it would be more convenient to have affine or affine_transform in keras-core

Inversely, if we have affine_transform, do we still need map_coordinates?

@james77777778
Copy link
Contributor Author

Inversely, if we have affine_transform, do we still need map_coordinates?

For common 2D image processing, I believe that affine_transform is sufficient.
However, when it comes to the multidimensional image processing, map_coordinates might be the only option.

A use case for medical image processing:
https://bic-berkeley.github.io/psych-214-fall-2016/map_coordinates.html

@fchollet
Copy link
Contributor

Understood. Let's add both affine_transform (same API as scipy, though we don't need to implement all interpolation modes) and map_coordinates. @mihirparadkar is working on map_coordinates.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work 👍

)


def affine(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call it affine_transform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

method="bilinear",
fill_mode="constant",
fill_value=0,
data_format="channels_last",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature looks good!

if method not in AFFINE_METHODS.keys():
raise ValueError(
"Invalid value for argument `method`. Expected of one "
f"{AFFINE_METHODS.keys()}. Received: method={method}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will print "dict_keys()". Let's cast to set() first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for all backends.

fill_value=0,
data_format="channels_last",
):
# TODO: add docstring
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return img


def affine(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you able to add it for the NumPy backend as well? Fine to depend on scipy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

@james77777778
Copy link
Contributor Author

james77777778 commented Jul 20, 2023

Hi @fchollet
I have added affine_transform to NumPy backend and it passed the tests as expected.

Usage:

import matplotlib.cbook as cbook
import matplotlib.pyplot as plt
import numpy as np

from keras_core.backend.jax.image import (
    affine_transform as jax_affine_transform,
)
from keras_core.backend.numpy.image import (
    affine_transform as np_affine_transform,
)
from keras_core.backend.tensorflow.image import (
    affine_transform as tf_affine_transform,
)
from keras_core.backend.torch.image import (
    affine_transform as torch_affine_transform,
)

with cbook.get_sample_data("grace_hopper.jpg") as image_file:
    image = plt.imread(image_file)
transform = np.array([1.2, 0.3, -100, 0.5, 1.2, -150, 0, 0])

jax_affined = jax_affine_transform(image, transform)
np_affined = np_affine_transform(image, transform)
tf_affined = tf_affine_transform(image, transform)
torch_affined = torch_affine_transform(image, transform)

fig, ax_dict = plt.subplot_mosaic([["A", "B", "C", "D", "E"]], figsize=(12, 4))
ax_dict["A"].set_title("Original")
ax_dict["A"].imshow(image)
ax_dict["B"].set_title("JAX Affined")
ax_dict["B"].imshow(jax_affined)
ax_dict["C"].set_title("Numpy Affined")
ax_dict["C"].imshow(np_affined)
ax_dict["D"].set_title("TF Affined")
ax_dict["D"].imshow(tf_affined)
ax_dict["E"].set_title("Torch Affined")
ax_dict["E"].imshow(torch_affined)
fig.tight_layout(h_pad=0.1, w_pad=0.1)
plt.savefig("affine.png")

Visualization:
affine

@james77777778 james77777778 requested a review from fchollet July 20, 2023 03:07
@james77777778 james77777778 changed the title Add affine op to all backends Add affine_transform op to all backends Jul 20, 2023
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent -- thanks for the polished docstring!

("bilinear", "nearest", "channels_last"),
("nearest", "nearest", "channels_last"),
# ("bilinear", "wrap", "channels_last"), no wrap in torch
# ("nearest", "wrap", "channels_last"), no wrap in torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use if backend.backend() == "torch": pytest.skip("wrap fill_mode not support in torch") so that at least it's tested when supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet
When using wrap, the output in TF shows inconsistency compared to JAX and NumPy.
What make this situation awkward is that the test sometimes passes but at other times it fails.

As a solution, I have opted to skip the test to prevent CI failures.

Copy link
Contributor Author

@james77777778 james77777778 Jul 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mismatched elements was about 0.6%

Mismatched elements: 45 / 7500 (0.6%)  # one run for JAX
Mismatched elements: 48 / 7500 (0.64%)  # one run for NumPy

according to the given mode. Available methods are `"constant"`,
`"nearest"` and `"reflect"`. Defaults to `"constant"`.
fill_value: Value used for points outside the boundaries of the input if
fill_mode=`"constant"`. Defaults to `0`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move backticks: fill_mode="constant"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

gradients are not backpropagated into transformation parameters.
Note that `c0` and `c1` are only effective when using TensorFlow
backend and will be considered as `0` when using other backends.
method: Interpolation method. Available methods are `"nearest"`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that we have a slight API inconsistency: this function as well as resize both use method, but our preprocessing layers use interpolation. Perhaps we should just switch everything (including resize) to interpolation since the preprocessing layers are much older and more widely used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have replaced method with interpolation in both affine_transform and resize.
Additionally, there have been some API changes in keras-core.

@james77777778 james77777778 requested a review from fchollet July 20, 2023 07:08
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome -- thanks for the great contribution! 👍

@fchollet fchollet merged commit 2702048 into keras-team:main Jul 20, 2023
@james77777778 james77777778 deleted the add-affine branch July 21, 2023 02:58
adi-kmt pushed a commit to adi-kmt/keras-core that referenced this pull request Jul 21, 2023
* Add affine op

* Sync import convention

* Use `np.random.random`

* Refactor jax implementation

* Fix

* Address fchollet's comments

* Update docstring

* Fix test

* Replace method with interpolation

* Replace method with interpolation

* Replace method with interpolation

* Update test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants