-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add fori_loop op to all 3 backends #462
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. We can add this op. Please add a corresponding Keras op in ops/core.py
and add unit tests.
@@ -163,5 +163,13 @@ def while_loop( | |||
) | |||
|
|||
|
|||
def fori_loop(lower, upper, body_fun, init_val): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All backend functions should have the same signature across backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be fixed as of most recent commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
|
||
|
||
@keras_core_export("keras_core.ops.fori_loop") | ||
def fori_loop(lower, upper, body_fun, init_val): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is public-facing, so please add a complete docstring.
@@ -204,6 +204,15 @@ def body(x, y): | |||
self.assertAllClose(x, np.ones((2, 3)) * 6) | |||
self.assertAllClose(y, np.ones((3, 2)) * 6) | |||
|
|||
def test_fori_loop(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also test shape inference correctness.
used by pyenv to set a per-directory python version / virtualenv / etc
added shape inference test and a docstring, tests passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! Can you also add it to the numpy backend, now that we have one?
keras_core/ops/core.py
Outdated
@keras_core_export("keras_core.ops.fori_loop") | ||
def fori_loop(lower, upper, body_fun, init_val): | ||
""" | ||
For loop implementation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move one-liner to the line above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand...
Do you mean like this?
""" For loop implementation.
or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see that's how the other docstrings are written, will do
keras_core/ops/core.py
Outdated
Returns: | ||
The final state after the loop. | ||
|
||
Examples: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Singular (there's only one example)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do, note that this exact same typo appears on some other ops (e.g. stop_gradient
directly above)
keras_core/ops/core.py
Outdated
>>> keras_core.ops.fori_loop(lower, upper, body_fun, init_val) | ||
45 | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove blank line
keras_core/ops/core_test.py
Outdated
result = core.fori_loop(0, 10, body_fun, initial_value) | ||
self.assertAllClose(result, 45) | ||
|
||
def test_fori_loop_shape_inference(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the shape inference to the test class for shape inference (they're all together)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
the symbolic call for fori_loop could probably use some work, in particular right now it looks like the graph for the loop body is going to get completely ignored? I'm curious what you think the right way to handle it would be. |
fori_loop
is extremely useful for keeping compile times down injax
by defining single layers as loops with a repeated body (instead of unrolling & compiling separately for each block in a sequence).adding the primitive allows doing this in a low-level way, though eventually something more similar to flax's
scan
might be a nicer API for users to consume