-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax devicearray sum #19723
Jax devicearray sum #19723
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @VictorOdede
The changes are good, some requested changes are below
init_tree="jax.numpy.array", | ||
method_name="sum", | ||
dtype_and_x=helpers.dtype_values_axis( | ||
available_dtypes=["int64"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests should also cover all supported dtypes
Hey @hirwa-nshuti , I've added tests to include all numeric dtypes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey
Some other changes needed😊
@@ -138,6 +138,41 @@ def sort(self, axis=-1, order=None): | |||
order=order, | |||
) | |||
|
|||
def sum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The promotion here will be done with the same approach as the one mentioned on the prod
PR.
Hey @hirwa-nshuti , the most recent changes to |
Working on a fix for this issue😊 |
Thanks for the prompt response. Will run the tests again 😊 |
4077bc2
to
b8fb6f5
Compare
Close #19418