Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax devicearray sum #19723

Closed
wants to merge 0 commits into from
Closed

Conversation

VictorOdede
Copy link
Contributor

Close #19418

@ivy-leaves ivy-leaves added the JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist label Jul 20, 2023
Copy link
Contributor

@fnhirwa fnhirwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @VictorOdede

The changes are good, some requested changes are below

init_tree="jax.numpy.array",
method_name="sum",
dtype_and_x=helpers.dtype_values_axis(
available_dtypes=["int64"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests should also cover all supported dtypes

@VictorOdede VictorOdede requested a review from fnhirwa July 31, 2023 08:30
@VictorOdede
Copy link
Contributor Author

Hey @hirwa-nshuti , I've added tests to include all numeric dtypes

Copy link
Contributor

@fnhirwa fnhirwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey

Some other changes needed😊

@@ -138,6 +138,41 @@ def sort(self, axis=-1, order=None):
order=order,
)

def sum(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The promotion here will be done with the same approach as the one mentioned on the prod PR.

@VictorOdede
Copy link
Contributor Author

Hey @hirwa-nshuti , the most recent changes to main are causing all Jax tests to fail with AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'. Not sure what's causing this problem

@fnhirwa
Copy link
Contributor

fnhirwa commented Aug 2, 2023

Hey @hirwa-nshuti , the most recent changes to main are causing all Jax tests to fail with AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'. Not sure what's causing this problem

Working on a fix for this issue😊
Just an update you can now run the tests successfully again

@VictorOdede
Copy link
Contributor Author

Hey @hirwa-nshuti , the most recent changes to main are causing all Jax tests to fail with AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'. Not sure what's causing this problem

Working on a fix for this issue😊 Just an update you can now run the tests successfully again

Thanks for the prompt response. Will run the tests again 😊

@VictorOdede VictorOdede closed this Aug 2, 2023
@VictorOdede VictorOdede force-pushed the jax_devicearray_sum branch 2 times, most recently from 4077bc2 to b8fb6f5 Compare August 2, 2023 23:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sum
3 participants