-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dtypes.result_type: ignore weak_type for single argument #6000
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something weird about this special-casing is that these print different things:
from jax import config, dtypes, lax
config.update('jax_enable_x64', True)
x = lax.convert_element_type(2, 'int32', weak_type=True)
print(dtypes.result_type(x)) # int32
print(dtypes.result_type(x, x)) # int64
An alternative would be to require dtypes.result_type
always take at least two arguments, and have a separate function for getting the dtype of a single argument. That way dtypes.result_type
could be used for promotion and would factor in weak type, while the dtype-getting function would ignore weak type (since the caller is asking for just a dtype). That seems like a good solution for internal use cases for callers of dtypes.result_type
, and we could separately consider how we want jnp.result_type
to behave. WDYT?
I guess overall it seems like we could be more explicit with this weak-type-related logic.
""" | ||
Returns either | ||
- a numpy dtype, indicating weak_type=False | ||
- a Python type (int, float, complex), indicating weak_type=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For both this function and the one above, maybe it'd be better to encode this directly, by making this function return a (dtype, bool)
pair? That seems like it would reduce risk of surprising the caller, as the caller would have to explicitly handle the weak type tag. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's probably a good idea, given how this surprised us!
That said, these return values are indices to the type promotion lattice, so the weak versions of various ints will all have to be normalized to a single identifier somewhere before they're used in the lattice code, and it's not clear where that would be done if not here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How hard would it be to make the lattice operate directly on (dtype, bool)
pairs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We couldn't use multiple aliases of each weak type in the lattice, because they are unorderable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could normalize all (np.[u]int*, True)
entries to a single weak int indicator within the lattice, but that's essentially what this function does currently.
#6068 has the better fix. |
Within the promotion table, all weak dtypes are canonicalized to
jnp.dtype(int)
,jnp.dtype(float)
, orjnp.dtype(complex)
. This means, e.g., that if you have a weakly-typed int32 in X64 mode, you get some strange results:This is particularly surprising becuase
dtypes.result_type
is used throughout the codebase to find the dtype of single arguments.This change special-cases the single-argument version of
dtypes.result_type
to return the dtype of the value.