Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fori_loop op to all 3 backends #462

Merged
merged 11 commits into from
Jul 21, 2023
Merged

Conversation

GallagherCommaJack
Copy link
Contributor

fori_loop is extremely useful for keeping compile times down in jax by defining single layers as loops with a repeated body (instead of unrolling & compiling separately for each block in a sequence).

adding the primitive allows doing this in a low-level way, though eventually something more similar to flax's scan might be a nicer API for users to consume

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. We can add this op. Please add a corresponding Keras op in ops/core.py and add unit tests.

@@ -163,5 +163,13 @@ def while_loop(
)


def fori_loop(lower, upper, body_fun, init_val):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All backend functions should have the same signature across backends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be fixed as of most recent commit

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!



@keras_core_export("keras_core.ops.fori_loop")
def fori_loop(lower, upper, body_fun, init_val):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is public-facing, so please add a complete docstring.

@@ -204,6 +204,15 @@ def body(x, y):
self.assertAllClose(x, np.ones((2, 3)) * 6)
self.assertAllClose(y, np.ones((3, 2)) * 6)

def test_fori_loop(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test shape inference correctness.

@GallagherCommaJack
Copy link
Contributor Author

added shape inference test and a docstring, tests passed

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! Can you also add it to the numpy backend, now that we have one?

@keras_core_export("keras_core.ops.fori_loop")
def fori_loop(lower, upper, body_fun, init_val):
"""
For loop implementation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move one-liner to the line above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand...
Do you mean like this?

""" For loop implementation.

or something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see that's how the other docstrings are written, will do

Returns:
The final state after the loop.

Examples:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Singular (there's only one example)

Copy link
Contributor Author

@GallagherCommaJack GallagherCommaJack Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do, note that this exact same typo appears on some other ops (e.g. stop_gradient directly above)

>>> keras_core.ops.fori_loop(lower, upper, body_fun, init_val)
45
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove blank line

result = core.fori_loop(0, 10, body_fun, initial_value)
self.assertAllClose(result, 45)

def test_fori_loop_shape_inference(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the shape inference to the test class for shape inference (they're all together)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet fchollet merged commit 8e9827e into keras-team:main Jul 21, 2023
@GallagherCommaJack
Copy link
Contributor Author

the symbolic call for fori_loop could probably use some work, in particular right now it looks like the graph for the loop body is going to get completely ignored? I'm curious what you think the right way to handle it would be.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants