Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(jax-backend): fix failing test for sum #27087

Merged
merged 8 commits into from
Nov 3, 2023
Merged

Conversation

Aaryan562
Copy link
Contributor

fix failing test for sum

@Aaryan562 Aaryan562 requested a review from AnnaTz October 21, 2023 05:01
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR Compliance Checks

Thank you for your Pull Request! We have run several checks on this pull request in order to make sure it's suitable for merging into this project. The results are listed in the following section.

Issue Reference

In order to be considered for merging, the pull request description must refer to a specific issue number. This is described in our contributing guide and our PR template.
This check is looking for a phrase similar to: "Fixes #XYZ" or "Resolves #XYZ" where XYZ is the issue number that this PR is meant to address.

@github-actions
Copy link
Contributor

Thank you for this PR, here is the CI results:


This pull request does not result in any additional test failures. Congratulations!

@AnnaTz
Copy link
Contributor

AnnaTz commented Oct 23, 2023

Hi @Aaryan562, I'm not sure we should declare complex as unsupported for the jax backend in general, because the error you are trying to fix only occurs when test_gradients=True.

@vedpatwardhan, what would you suggest in this case?
Fot the record, I ran the test and the full error is this:

ivy.utils.exceptions.IvyBackendException: jax: execute_with_gradients: grad requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex128. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.

@vedpatwardhan
Copy link
Contributor

@vedpatwardhan, what would you suggest in this case? Fot the record, I ran the test and the full error is this:

ivy.utils.exceptions.IvyBackendException: jax: execute_with_gradients: grad requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex128. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.

Ideally we shouldn't be running gradient tests when the input is a complex dtype. We also have a check to avoid that from happening, could you please look into why the check isn't working for this test @Aaryan562? Thanks 😄

@Aaryan562
Copy link
Contributor Author

@vedpatwardhan, what would you suggest in this case? Fot the record, I ran the test and the full error is this:

ivy.utils.exceptions.IvyBackendException: jax: execute_with_gradients: grad requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex128. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, use jax.vjp directly.

Ideally we shouldn't be running gradient tests when the input is a complex dtype. We also have a check to avoid that from happening, could you please look into why the check isn't working for this test @Aaryan562? Thanks 😄

@vedpatwardhan @AnnaTz The paddle version has been changed to 2.5.2, which might create error in the paddle backend. The jax backend for the complex datatypes has been resolved. Kindly check!

Copy link
Contributor

@vedpatwardhan vedpatwardhan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Aaryan562, just requested a final change, thanks 😄

@@ -342,6 +342,9 @@ def test_sum(
if "torch" in backend_fw:
assume(not test_flags.as_variable[0])
assume(not test_flags.test_gradients)
if "jax" in backend_fw and castable_dtype in ["complex32", "complex64"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be "complex32", "complex128" instead, what do you think?

Copy link
Contributor Author

@Aaryan562 Aaryan562 Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be "complex32", "complex128" instead, what do you think?

Ivy support complex64 and complex128 so maybe we should cater to that?

@@ -167,7 +166,10 @@ def std(
return _std(x, axis, correction, keepdims).cast(x.dtype)


@with_unsupported_dtypes({"2.5.2 and below": ("int8", "uint8")}, backend_version)
@with_supported_dtypes(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the paddle backend gave this error so changed its decorator ivy.utils.exceptions.IvyBackendException: paddle: execute_with_gradients: paddle: nested_map: paddle: mean: (NotFound) The kernel with key (CPU, Undefined(AnyLayout), int32) of kernel mean is not registered. Selected wrong DataType int32. Paddle support following DataTypes: float64, complex128, float32, complex64, bool
I dont know why mean is getting triggered even when we are dealing with sum function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean is actually applied for the gradient tests here, so both mean and sum are called when running the gradient tests for sum. We should probably we fixing mean first if there's an issue with mean. Thanks @Aaryan562 😄
Also, does paddle not natively support int dtypes for its mean and sum functions (cc @MahmoudAshraf97)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[int8, int16, uint8] are not supported for mean and [int8, uint8] for sum

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh okay, ig then int32 and int64 can also be marked as supported @Aaryan562 😄
(probably specifying the unsupported dtypes is more convenient that specifying the supported dtypes for this function)

Copy link
Contributor Author

@Aaryan562 Aaryan562 Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ivy.utils.exceptions.IvyBackendException: paddle: execute_with_gradients: paddle: nested_map: paddle: mean: (NotFound) The kernel with key (CPU, Undefined(AnyLayout), int32) of kernel mean is not registered. Selected wrong DataType int32. Paddle support following DataTypes: float64, complex128, float32, complex64, bool

@vedpatwardhan But then when supporting int32 and int64 i am getting the above error.

Copy link
Contributor Author

@Aaryan562 Aaryan562 Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I further debugged I got to know it was when my test_gradients=True. So I changed my test_sum to make it test_gradients=False and all the test cases passed succesfully. But what could be the reason for it?

@ivy-leaves ivy-leaves added the PaddlePaddle Backend Developing the Paddle Paddle Backend. label Oct 30, 2023
@@ -167,7 +166,10 @@ def std(
return _std(x, axis, correction, keepdims).cast(x.dtype)


@with_unsupported_dtypes({"2.5.2 and below": ("int8", "uint8")}, backend_version)
@with_supported_dtypes(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean is actually applied for the gradient tests here, so both mean and sum are called when running the gradient tests for sum. We should probably we fixing mean first if there's an issue with mean. Thanks @Aaryan562 😄
Also, does paddle not natively support int dtypes for its mean and sum functions (cc @MahmoudAshraf97)?

Copy link
Contributor

@vedpatwardhan vedpatwardhan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! Feel free to merge, thanks @Aaryan562 😄

@Aaryan562 Aaryan562 merged commit 7af38de into ivy-llc:main Nov 3, 2023
176 of 273 checks passed
@Aaryan562 Aaryan562 deleted the sum_stat branch November 11, 2023 09:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
PaddlePaddle Backend Developing the Paddle Paddle Backend.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants