Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dtypes.result_type: ignore weak_type for single argument #6000

Closed
wants to merge 1 commit into from

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Mar 10, 2021

Within the promotion table, all weak dtypes are canonicalized to jnp.dtype(int), jnp.dtype(float), or jnp.dtype(complex). This means, e.g., that if you have a weakly-typed int32 in X64 mode, you get some strange results:

from jax import config, dtypes, lax                                                                                                                                     
config.update('jax_enable_x64', True)                                                                                                                                   

x = lax.convert_element_type(2, 'int32', weak_type=True)                                                                                                    
print(dtypes.result_type(x))
# int64

This is particularly surprising becuase dtypes.result_type is used throughout the codebase to find the dtype of single arguments.

This change special-cases the single-argument version of dtypes.result_type to return the dtype of the value.

@jakevdp jakevdp requested a review from mattjj March 10, 2021 00:33
@google-cla google-cla bot added the cla: yes label Mar 10, 2021
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something weird about this special-casing is that these print different things:

from jax import config, dtypes, lax
config.update('jax_enable_x64', True)

x = lax.convert_element_type(2, 'int32', weak_type=True)
print(dtypes.result_type(x))     # int32
print(dtypes.result_type(x, x))  # int64

An alternative would be to require dtypes.result_type always take at least two arguments, and have a separate function for getting the dtype of a single argument. That way dtypes.result_type could be used for promotion and would factor in weak type, while the dtype-getting function would ignore weak type (since the caller is asking for just a dtype). That seems like a good solution for internal use cases for callers of dtypes.result_type, and we could separately consider how we want jnp.result_type to behave. WDYT?

I guess overall it seems like we could be more explicit with this weak-type-related logic.

"""
Returns either
- a numpy dtype, indicating weak_type=False
- a Python type (int, float, complex), indicating weak_type=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both this function and the one above, maybe it'd be better to encode this directly, by making this function return a (dtype, bool) pair? That seems like it would reduce risk of surprising the caller, as the caller would have to explicitly handle the weak type tag. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably a good idea, given how this surprised us!

That said, these return values are indices to the type promotion lattice, so the weak versions of various ints will all have to be normalized to a single identifier somewhere before they're used in the lattice code, and it's not clear where that would be done if not here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to make the lattice operate directly on (dtype, bool) pairs?

Copy link
Collaborator Author

@jakevdp jakevdp Mar 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We couldn't use multiple aliases of each weak type in the lattice, because they are unorderable.

Copy link
Collaborator Author

@jakevdp jakevdp Mar 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could normalize all (np.[u]int*, True) entries to a single weak int indicator within the lattice, but that's essentially what this function does currently.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Mar 15, 2021

#6068 has the better fix.

@jakevdp jakevdp closed this Mar 15, 2021
@jakevdp jakevdp deleted the unary-result-type branch March 15, 2021 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants