Skip to content

Commit

Permalink
Merge pull request #6068 from jakevdp:fix-result-type
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 363282198
  • Loading branch information
jax authors committed Mar 16, 2021
2 parents 6e1cd39 + b0c5fba commit 0a84db5
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 27 deletions.
12 changes: 5 additions & 7 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,20 +272,18 @@ def _promote_dtypes(*args):
if len(args) < 2:
return args
else:
to_dtype_raw = dtypes._result_type_raw(*args)
weak_type = to_dtype_raw in set(dtypes._weak_types)
to_dtype = dtypes.canonicalize_dtype(to_dtype_raw)
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax.convert_element_type(x, to_dtype, weak_type) for x in args]

def _promote_dtypes_inexact(*args):
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
to_dtype_raw = dtypes._result_type_raw(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype_raw)
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = _to_inexact_dtype(to_dtype)
weak_type = (to_dtype == to_dtype_inexact
and to_dtype_raw in set(dtypes._weak_types))
weak_type = (weak_type and to_dtype == to_dtype_inexact)
return [lax.convert_element_type(x, to_dtype_inexact, weak_type) for x in args]

def _to_inexact_dtype(dtype):
Expand Down
48 changes: 31 additions & 17 deletions jax/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,13 @@ def dtype_real(typ):
np.dtype('complex128'),
] + _weak_types # type: ignore[operator]

def _jax_type(value):
"""Return the jax type for a value or type."""
# Note: `x in _weak_types` can return false positives due to dtype comparator overloading.
if any(value is typ for typ in _weak_types):
return value
dtype_ = dtype(value)
if is_weakly_typed(value):
pytype = type(dtype_.type(0).item())
if pytype in _weak_types:
return pytype
return dtype_
def _jax_type(dtype, weak_type):
"""Return the jax type for a dtype and weak type."""
return type(dtype.type(0).item()) if (weak_type and dtype != bool) else dtype

def _dtype_and_weaktype(value):
"""Return a (dtype, weak_type) tuple for the given input."""
return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)

def _type_promotion_lattice():
"""
Expand Down Expand Up @@ -264,6 +260,14 @@ def _make_lattice_upper_bounds():

@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence.
def _least_upper_bound(*nodes):
"""Compute the least upper bound of a set of nodes.
Args:
nodes: sequence of entries from _jax_types
Returns:
the _jax_type representing the least upper bound of the input nodes
on the promotion lattice.
"""
# This function computes the least upper bound of a set of nodes N within a partially
# ordered set defined by the lattice generated above.
# Given a partially ordered set S, let the set of upper bounds of n ∈ S be
Expand Down Expand Up @@ -323,13 +327,23 @@ def dtype(x):
return python_scalar_dtypes[type(x)]
return np.result_type(x)

def _result_type_raw(*args):
if len(args) == 1:
return _jax_type(args[0])
return _least_upper_bound(*{_jax_type(arg) for arg in args})
def _lattice_result_type(*args):
dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
if len(dtypes) == 1:
return dtypes[0], weak_types[0]

# If all inputs are weakly typed, we compute the bound of the strongly-typed
# counterparts and apply the weak type at the end. This avoids returning the
# incorrect result with non-canonical weak types (e.g. weak int16).
if all(weak_types):
result_type = _least_upper_bound(*{_jax_type(dtype, False) for dtype in dtypes})
return dtype(result_type), True
else:
result_type = _least_upper_bound(*{_jax_type(d, w) for d, w in zip(dtypes, weak_types)})
return dtype(result_type), any(result_type is t for t in _weak_types)

def result_type(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
"""Convenience function to apply JAX argument dtype promotion."""
if len(args) == 0:
raise ValueError("at least one array or dtype is required")
return canonicalize_dtype(_result_type_raw(*args))
return canonicalize_dtype(_lattice_result_type(*args)[0])
32 changes: 29 additions & 3 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import jax
from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax.interpreters import xla
Expand All @@ -34,7 +35,7 @@
bool_dtypes = [np.dtype('bool')]

signed_dtypes = [np.dtype('int8'), np.dtype('int16'), np.dtype('int32'),
np.dtype('int64'), np.dtype('longlong'), np.dtype('intc')]
np.dtype('int64')]

unsigned_dtypes = [np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
np.dtype('uint64')]
Expand Down Expand Up @@ -210,7 +211,7 @@ class TestPromotionTables(jtu.JaxTestCase):
"jaxtype": jaxtype}
for jaxtype in dtypes._jax_types)
def testJaxTypeFromType(self, jaxtype):
self.assertIs(dtypes._jax_type(jaxtype), jaxtype)
self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(jaxtype)), jaxtype)

@parameterized.named_parameters(
{"testcase_name": "_jaxtype={}".format(jaxtype),
Expand All @@ -221,7 +222,7 @@ def testJaxTypeFromVal(self, jaxtype):
val = jaxtype(0)
except TypeError:
val = jaxtype.type(0)
self.assertIs(dtypes._jax_type(val), jaxtype)
self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(val)), jaxtype)

@jtu.ignore_warning(category=UserWarning,
message="Explicitly requested dtype.*")
Expand Down Expand Up @@ -327,5 +328,30 @@ def testBinaryPromotionJitInvariance(self, xtype, ytype, xfun, yfun):
args_maker = lambda: [xtype(1), ytype(1)]
self._CompileAndCheck(f, args_maker, check_dtypes=True)

@parameterized.named_parameters(
{"testcase_name": "_dtype={}_weak_type={}".format(dtype, weak_type),
"dtype": dtype, "weak_type": weak_type}
for dtype in all_dtypes
for weak_type in [True, False]
)
def testUnaryPromotion(self, dtype, weak_type):
# Regression test for https://github.com/google/jax/issues/6051
x = lax.convert_element_type(0, dtype, weak_type=weak_type)
y = jnp.array(0, dtype=dtypes.result_type(x))
assert x.dtype == y.dtype

@parameterized.named_parameters(
{"testcase_name": "_dtype={}_weak_type={}".format(dtype, weak_type),
"dtype": dtype, "weak_type": weak_type}
for dtype in all_dtypes
for weak_type in [True, False]
)
def testBinaryNonPromotion(self, dtype, weak_type):
# Regression test for https://github.com/google/jax/issues/6051
x = lax.convert_element_type(0, dtype, weak_type=weak_type)
y = (x + x)
assert x.dtype == y.dtype
assert dtypes.is_weakly_typed(y) == dtypes.is_weakly_typed(x)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 0a84db5

Please sign in to comment.