-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add affine_transform
op to all backends
#477
Conversation
Thanks for the PR. It this really necessary? Couldn't you do the same with |
Do you mean that For 1st: If the future plan involves implementing KPLs using pure keras-core, the For 2nd: |
What's the closest numpy/scipy op? Does this match https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html ? |
The challenge is that the absence of In these 3 backends, JAX is the only one that can directly utilize numpy/scipy-like ops because it has For TF, the closest op would be For Torch, the closest op is a combination of |
Usage: import matplotlib.cbook as cbook
import matplotlib.pyplot as plt
import numpy as np
from keras_core.backend.jax.image import affine as jax_affine
from keras_core.backend.tensorflow.image import affine as tf_affine
from keras_core.backend.torch.image import affine as torch_affine
with cbook.get_sample_data("grace_hopper.jpg") as image_file:
image = plt.imread(image_file)
transform = np.array([1.2, 0.3, -100, 0.5, 1.2, -150, 0, 0])
jax_affined = jax_affine(image, transform)
tf_affined = tf_affine(image, transform)
torch_affined = torch_affine(image, transform)
fig, ax_dict = plt.subplot_mosaic([["A", "B", "C", "D"]], figsize=(12, 4))
ax_dict["A"].set_title("Original")
ax_dict["A"].imshow(image)
ax_dict["B"].set_title("JAX Affined")
ax_dict["B"].imshow(jax_affined)
ax_dict["C"].set_title("TF Affined")
ax_dict["C"].imshow(tf_affined)
ax_dict["D"].set_title("Torch Affined")
ax_dict["D"].imshow(torch_affined)
fig.tight_layout(h_pad=0.1, w_pad=0.1)
plt.savefig("affine.png") The visualization: |
@mihirparadkar is working on adding this exact op to Keras -- let's make sure he sees this. My understanding was that torch's Also, if we have |
I believe Besides:
As a user, I think it would be more convenient to have
As a developer, |
Inversely, if we have |
For common 2D image processing, I believe that A use case for medical image processing: |
Understood. Let's add both |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work 👍
keras_core/backend/jax/image.py
Outdated
) | ||
|
||
|
||
def affine( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call it affine_transform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
method="bilinear", | ||
fill_mode="constant", | ||
fill_value=0, | ||
data_format="channels_last", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function signature looks good!
keras_core/backend/jax/image.py
Outdated
if method not in AFFINE_METHODS.keys(): | ||
raise ValueError( | ||
"Invalid value for argument `method`. Expected of one " | ||
f"{AFFINE_METHODS.keys()}. Received: method={method}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will print "dict_keys()". Let's cast to set()
first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done for all backends.
keras_core/ops/image.py
Outdated
fill_value=0, | ||
data_format="channels_last", | ||
): | ||
# TODO: add docstring |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added. I mainly borrowed the docstring from
keras_core/backend/torch/image.py
Outdated
return img | ||
|
||
|
||
def affine( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you able to add it for the NumPy backend as well? Fine to depend on scipy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
Hi @fchollet Usage: import matplotlib.cbook as cbook
import matplotlib.pyplot as plt
import numpy as np
from keras_core.backend.jax.image import (
affine_transform as jax_affine_transform,
)
from keras_core.backend.numpy.image import (
affine_transform as np_affine_transform,
)
from keras_core.backend.tensorflow.image import (
affine_transform as tf_affine_transform,
)
from keras_core.backend.torch.image import (
affine_transform as torch_affine_transform,
)
with cbook.get_sample_data("grace_hopper.jpg") as image_file:
image = plt.imread(image_file)
transform = np.array([1.2, 0.3, -100, 0.5, 1.2, -150, 0, 0])
jax_affined = jax_affine_transform(image, transform)
np_affined = np_affine_transform(image, transform)
tf_affined = tf_affine_transform(image, transform)
torch_affined = torch_affine_transform(image, transform)
fig, ax_dict = plt.subplot_mosaic([["A", "B", "C", "D", "E"]], figsize=(12, 4))
ax_dict["A"].set_title("Original")
ax_dict["A"].imshow(image)
ax_dict["B"].set_title("JAX Affined")
ax_dict["B"].imshow(jax_affined)
ax_dict["C"].set_title("Numpy Affined")
ax_dict["C"].imshow(np_affined)
ax_dict["D"].set_title("TF Affined")
ax_dict["D"].imshow(tf_affined)
ax_dict["E"].set_title("Torch Affined")
ax_dict["E"].imshow(torch_affined)
fig.tight_layout(h_pad=0.1, w_pad=0.1)
plt.savefig("affine.png") |
affine_transform
op to all backends
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent -- thanks for the polished docstring!
keras_core/ops/image_test.py
Outdated
("bilinear", "nearest", "channels_last"), | ||
("nearest", "nearest", "channels_last"), | ||
# ("bilinear", "wrap", "channels_last"), no wrap in torch | ||
# ("nearest", "wrap", "channels_last"), no wrap in torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use if backend.backend() == "torch": pytest.skip("wrap fill_mode not support in torch")
so that at least it's tested when supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet
When using wrap
, the output in TF shows inconsistency compared to JAX and NumPy.
What make this situation awkward is that the test sometimes passes but at other times it fails.
As a solution, I have opted to skip the test to prevent CI failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mismatched elements was about 0.6%
Mismatched elements: 45 / 7500 (0.6%) # one run for JAX
Mismatched elements: 48 / 7500 (0.64%) # one run for NumPy
keras_core/ops/image.py
Outdated
according to the given mode. Available methods are `"constant"`, | ||
`"nearest"` and `"reflect"`. Defaults to `"constant"`. | ||
fill_value: Value used for points outside the boundaries of the input if | ||
fill_mode=`"constant"`. Defaults to `0`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move backticks: fill_mode="constant"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
keras_core/ops/image.py
Outdated
gradients are not backpropagated into transformation parameters. | ||
Note that `c0` and `c1` are only effective when using TensorFlow | ||
backend and will be considered as `0` when using other backends. | ||
method: Interpolation method. Available methods are `"nearest"`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed that we have a slight API inconsistency: this function as well as resize
both use method
, but our preprocessing layers use interpolation
. Perhaps we should just switch everything (including resize
) to interpolation
since the preprocessing layers are much older and more widely used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have replaced method
with interpolation
in both affine_transform
and resize
.
Additionally, there have been some API changes in keras-core.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome -- thanks for the great contribution! 👍
* Add affine op * Sync import convention * Use `np.random.random` * Refactor jax implementation * Fix * Address fchollet's comments * Update docstring * Fix test * Replace method with interpolation * Replace method with interpolation * Replace method with interpolation * Update test
Related to #460
This PR introduces the
affine
operation in 3 backends. This op can be useful when implementing KPLs using pure keras-core, such as RandomRotation, RandomTranslation, etc.It is worth noting that the implementations of
F.affine_grid
andF.grid_sample
in the torch backend differ significantly from TF and JAX. As a result, the correctness tests are expected to fail.Hope this helps!