Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement ReLu6 in the backends and frontends #10587

Merged
merged 13 commits into from
Feb 25, 2023
Merged

Reimplement ReLu6 in the backends and frontends #10587

merged 13 commits into from
Feb 25, 2023

Conversation

MahmoudAshraf97
Copy link
Contributor

closes #10567 , #10584, and #10586

some code is the modified files was reformatted by mistake by black which is irrelevant to the PR

@MahmoudAshraf97 MahmoudAshraf97 added PyTorch Frontend Developing the PyTorch Frontend, checklist triggered by commenting add_frontend_checklist JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist TensorFlow Frontend Developing the TensorFlow Frontend, checklist triggered by commenting add_frontend_checklist Function Reformatting Reformat all Ivy functions in accordance with the latest coding style in the contributor guide labels Feb 15, 2023
Copy link
Contributor

@CatB1t CatB1t left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @MahmoudAshraf97, great work! few points to address then we're good to merge!

  • We should move all the new implementations to the experimental subpackage in all backends, and also in Ivy.
  • We should also implement the test for the backend function.
  • You can skip the changes made the black or flake8 by passing --no-verify when comitting, another approach is to fix those linting issues on master branch, merging the master with your branch and you wouldn't have these extra unrelated linting commits in your PR.

Thanks!

@MahmoudAshraf97
Copy link
Contributor Author

Changes implemented, backend test implemented and passing for all backends

@ivy-leaves ivy-leaves added Array API Conform to the Array API Standard, created by The Consortium for Python Data API Standards Ivy API Experimental Run CI for testing API experimental/New feature or request labels Feb 17, 2023
Copy link
Contributor

@CatB1t CatB1t left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @MahmoudAshraf97 , just a few points to address, also if the tests for the function that was not using ivy.relu6 weren't passing before the refactor, please feel free to ignore those.

@@ -487,7 +487,7 @@ def Pow(*, x, y, name="Pow"):

@to_ivy_arrays_and_back
def Relu6(features, name="Relu6"):
return ivy.clip(features, 0, 6)
return ivy.relu6(features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The frontend tests fail here, could you have a look?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, althought the tests still fail with jax backend due to different return Dtype, this is a problem with all activation functions that use any Dtype other than float such as numerical
also sometimes fails with torch backend when float16 is used although the @with_unsupported_dtypes dtype is used

@@ -172,7 +172,7 @@ def threshold_(input, threshold, value):


def relu6(input, inplace=False):
ret = ivy.minimum(ivy.maximum(input, 0), 6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, this one also fails for JAX backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same problem as other two tests

@handle_test(
fn_tree="functional.ivy.experimental.relu6",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should ideally be using valid here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this to numeric since valid results in 'bool' datatype which is an invalid input, the result passes in all backends except jax where it always returns float64 instead of input datatype

@MahmoudAshraf97
Copy link
Contributor Author

So basically, jax backend returns float when int is passed, this causes all the tests to fail when int is passed

Copy link
Contributor

@CatB1t CatB1t left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, we should handle that in the logic of JAX backend for relu6 we should adhere the superset behavior, all backends should return the same dtype for various inputs. If that's the case in the frontends, we would handle it in a similar way but only implement that logic in the frontend.

@MahmoudAshraf97
Copy link
Contributor Author

jax backend passes all the tests now, torch backend float16 problem still persists although this dtype is explicitly set as non-supported
ivy/ivy/functional/backends/torch/experimental/activations.py

Copy link
Contributor

@CatB1t CatB1t left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the safety factor to be more sensible, and also updated the get_dtype to have the value numeric, ideally, all Ivy API tests should have get_dtypes("valid").

JAX backend starts to fail with integer inputs, looking at the minimal case: 0.5 != 0.0, I'd assume this has to do with casting, we should handle this in the backend for integer inputs. feel free to merge after getting this one working! thanks.

@MahmoudAshraf97
Copy link
Contributor Author

The error was caused by jax backend gradients at boundary conditions, it has no relation with input dtype, this was fixed through a custom gradient function and opened a PR in jax repo to fix the issue, but for now all the tests are passing with no issues and will proceed to merge

@MahmoudAshraf97 MahmoudAshraf97 merged commit e7df138 into ivy-llc:master Feb 25, 2023
@MahmoudAshraf97 MahmoudAshraf97 deleted the tensorflow.nn.relu6 branch February 25, 2023 19:25
vedpatwardhan pushed a commit to vedpatwardhan/ivy that referenced this pull request Feb 26, 2023
Co-authored-by: CatB1t <skytedits@gmail.com>
vedpatwardhan pushed a commit to vedpatwardhan/ivy that referenced this pull request Feb 26, 2023
Co-authored-by: CatB1t <skytedits@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Array API Conform to the Array API Standard, created by The Consortium for Python Data API Standards Function Reformatting Reformat all Ivy functions in accordance with the latest coding style in the contributor guide Ivy API Experimental Run CI for testing API experimental/New feature or request Ivy Functional API JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist PyTorch Frontend Developing the PyTorch Frontend, checklist triggered by commenting add_frontend_checklist TensorFlow Frontend Developing the TensorFlow Frontend, checklist triggered by commenting add_frontend_checklist
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Relu6 [Bug]: torch frontend inconsistent gradient behaviour of torch.nn.relu6 tensorflow.nn.relu6
3 participants