-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reimplement ReLu6 in the backends and frontends #10587
Reimplement ReLu6 in the backends and frontends #10587
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @MahmoudAshraf97, great work! few points to address then we're good to merge!
- We should move all the new implementations to the
experimental
subpackage in all backends, and also in Ivy. - We should also implement the test for the backend function.
- You can skip the changes made the
black
orflake8
by passing--no-verify
when comitting, another approach is to fix those linting issues on master branch, merging the master with your branch and you wouldn't have these extra unrelated linting commits in your PR.
Thanks!
Changes implemented, backend test implemented and passing for all backends |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @MahmoudAshraf97 , just a few points to address, also if the tests for the function that was not using ivy.relu6
weren't passing before the refactor, please feel free to ignore those.
@@ -487,7 +487,7 @@ def Pow(*, x, y, name="Pow"): | |||
|
|||
@to_ivy_arrays_and_back | |||
def Relu6(features, name="Relu6"): | |||
return ivy.clip(features, 0, 6) | |||
return ivy.relu6(features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The frontend tests fail here, could you have a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, althought the tests still fail with jax backend due to different return Dtype, this is a problem with all activation functions that use any Dtype other than float such as numerical
also sometimes fails with torch backend when float16 is used although the @with_unsupported_dtypes dtype is used
@@ -172,7 +172,7 @@ def threshold_(input, threshold, value): | |||
|
|||
|
|||
def relu6(input, inplace=False): | |||
ret = ivy.minimum(ivy.maximum(input, 0), 6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, this one also fails for JAX backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same problem as other two tests
@handle_test( | ||
fn_tree="functional.ivy.experimental.relu6", | ||
dtype_and_x=helpers.dtype_and_values( | ||
available_dtypes=helpers.get_dtypes("float"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should ideally be using valid
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed this to numeric since valid results in 'bool' datatype which is an invalid input, the result passes in all backends except jax where it always returns float64 instead of input datatype
So basically, jax backend returns float when int is passed, this causes all the tests to fail when int is passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, we should handle that in the logic of JAX backend for relu6
we should adhere the superset behavior, all backends should return the same dtype for various inputs. If that's the case in the frontends, we would handle it in a similar way but only implement that logic in the frontend.
jax backend passes all the tests now, torch backend float16 problem still persists although this dtype is explicitly set as non-supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the safety factor to be more sensible, and also updated the get_dtype
to have the value numeric
, ideally, all Ivy API tests should have get_dtypes("valid")
.
JAX backend starts to fail with integer inputs, looking at the minimal case: 0.5 != 0.0
, I'd assume this has to do with casting, we should handle this in the backend for integer inputs. feel free to merge after getting this one working! thanks.
The error was caused by jax backend gradients at boundary conditions, it has no relation with input dtype, this was fixed through a custom gradient function and opened a PR in jax repo to fix the issue, but for now all the tests are passing with no issues and will proceed to merge |
Co-authored-by: CatB1t <skytedits@gmail.com>
Co-authored-by: CatB1t <skytedits@gmail.com>
closes #10567 , #10584, and #10586
some code is the modified files was reformatted by mistake by black which is irrelevant to the PR