Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implemented Jax front-end unwrap function #26345

Closed
wants to merge 15 commits into from

Conversation

rohitkg83
Copy link
Contributor

@rohitkg83 rohitkg83 commented Sep 30, 2023

Implemented Jax numpy unwrap function and associated tests.

#25767

Close #25767

Checklist

  • Did you add a function?
  • Did you add the tests?
  • Did you run your tests and are your tests passing?
  • Did pre-commit not fail on any check?
  • Did you follow the steps we provided?

Socials

@ivy-leaves ivy-leaves added the JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist label Sep 30, 2023
)
@to_ivy_arrays_and_back
def unwrap(p, discont=None, axis=-1, period=2 * ivy.pi):
p = ivy.asarray(p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for the explicit conversion happening here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bipinKrishnan - I implemented this as per the JAX source code here

def unwrap(p, discont=None, axis=-1, period=2 * ivy.pi):
p = ivy.asarray(p)
dtype_str = str(p.dtype)
if "int" in dtype_str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of multiple lines of code and if-elif statements, we could use a dict mapping similar to jax here.

Copy link
Contributor Author

@rohitkg83 rohitkg83 Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bipinKrishnan Done, replaced with dict mapping

@@ -148,6 +148,30 @@ def _get_dtype_input_and_vectors(draw):
return dtype, vec1, vec2


# unwrap
@st.composite
def _get_dtype_input_vector_axis(draw):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helper functions should be in the respective helper modules in tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bipinKrishnan - While I was adding the function to helper module, I found an existing function that I can use. So I have used that existing function in the test instead so new helper function isn't required.

period,
):
dtype, x, axis = dtype_x_axis
assume(not np.any(np.isclose(x[0], 0)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled while generating the values, so you could remove this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bipinKrishnan - Done this has been removed now.

@github-actions
Copy link
Contributor

github-actions bot commented Oct 5, 2023

Thank you for this PR, here is the CI results:


This pull request does not result in any additional test failures. Congratulations!

@bipinKrishnan
Copy link
Contributor

There's a key error issue when testing locally, can you rebase on top of main so that the CI picks up the newly added tests?

@rohitkg83 rohitkg83 force-pushed the jax-numpy-unwrap branch 3 times, most recently from c5c354d to 980ceba Compare October 17, 2023 10:08
@rohitkg83
Copy link
Contributor Author

@bipinKrishnan - I did some blunders during rebasing this from main. So I have created a new branch on top of latest main branch code. The PR is now available here #27050. So I am closing this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist
Projects
None yet
Development

Successfully merging this pull request may close these issues.

unwrap
3 participants