Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize dask array equality checks. #3453

Merged
merged 11 commits into from
Nov 5, 2019
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ Bug fixes
but cloudpickle isn't (:issue:`3401`) by `Rhys Doyle <https://github.com/rdoyle45>`_
- Fix grouping over variables with NaNs. (:issue:`2383`, :pull:`3406`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Use dask names to compare dask objects prior to comparing values after computation.
(:issue:`3068`, :issue:`3311`, :issue:`3454`, :pull:`3453`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
By `Anderson Banihirwe <https://github.com/andersy005>`_.
- Fix :py:meth:`xarray.core.groupby.DataArrayGroupBy.reduce` and
Expand Down
56 changes: 37 additions & 19 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import dtypes, utils
from .alignment import align
from .duck_array_ops import lazy_array_equiv
from .merge import _VALID_COMPAT, unique_variable
from .variable import IndexVariable, Variable, as_variable
from .variable import concat as concat_vars
Expand Down Expand Up @@ -189,26 +190,43 @@ def process_subset_opt(opt, subset):
# all nonindexes that are not the same in each dataset
for k in getattr(datasets[0], subset):
if k not in concat_over:
# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk
# while keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not getattr(v_lhs, compat)(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
equals[k] = None
variables = [ds.variables[k] for ds in datasets]
# first check without comparing values i.e. no computes
for var in variables[1:]:
equals[k] = getattr(variables[0], compat)(
var, equiv=lazy_array_equiv
)
if equals[k] is not True:
# exit early if we know these are not equal or that
# equality cannot be determined i.e. one or all of
# the variables wraps a numpy array
break
else:
equals[k] = True

if equals[k] is False:
concat_over.add(k)

elif equals[k] is None:
# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk
# while keeping the RAM footprint low.
v_lhs = datasets[0].variables[k].load()
# We'll need to know later on if variables are equal.
computed = []
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not getattr(v_lhs, compat)(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
break
else:
equals[k] = True

elif opt == "all":
concat_over.update(
Expand Down
62 changes: 47 additions & 15 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,27 +174,57 @@ def as_shared_dtype(scalars_or_arrays):
return [x.astype(out_type, copy=False) for x in arrays]


def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
"""Like np.allclose, but also allows values to be NaN in both arrays
def lazy_array_equiv(arr1, arr2):
"""Like array_equal, but doesn't actually compare values.
Returns True when arr1, arr2 identical or their dask names are equal.
Returns False when shapes are not equal.
Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays;
or their dask names are not equal
"""
if arr1 is arr2:
return True
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
if (
dask_array
and isinstance(arr1, dask_array.Array)
and isinstance(arr2, dask_array.Array)
):
# GH3068
if arr1.name == arr2.name:
return True
else:
return None
return None


def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
"""Like np.allclose, but also allows values to be NaN in both arrays
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
lazy_equiv = lazy_array_equiv(arr1, arr2)
if lazy_equiv is None:
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
else:
return lazy_equiv

dcherian marked this conversation as resolved.
Show resolved Hide resolved

def array_equiv(arr1, arr2):
"""Like np.array_equal, but also allows values to be NaN in both arrays
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
return bool(flag_array.all())
lazy_equiv = lazy_array_equiv(arr1, arr2)
if lazy_equiv is None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
return bool(flag_array.all())
else:
return lazy_equiv


def array_notnull_equiv(arr1, arr2):
Expand All @@ -203,12 +233,14 @@ def array_notnull_equiv(arr1, arr2):
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
return bool(flag_array.all())
lazy_equiv = lazy_array_equiv(arr1, arr2)
if lazy_equiv is None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
return bool(flag_array.all())
else:
return lazy_equiv


def count(data, axis=None):
Expand Down
19 changes: 14 additions & 5 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from . import dtypes, pdcompat
from .alignment import deep_align
from .duck_array_ops import lazy_array_equiv
from .utils import Frozen, dict_equiv
from .variable import Variable, as_variable, assert_unique_multiindex_level_names

Expand Down Expand Up @@ -123,16 +124,24 @@ def unique_variable(
combine_method = "fillna"

if equals is None:
out = out.compute()
# first check without comparing values i.e. no computes
for var in variables[1:]:
equals = getattr(out, compat)(var)
if not equals:
equals = getattr(out, compat)(var, equiv=lazy_array_equiv)
if equals is not True:
break

if equals is None:
# now compare values with minimum number of computes
out = out.compute()
for var in variables[1:]:
equals = getattr(out, compat)(var)
if not equals:
break

if not equals:
raise MergeError(
"conflicting values for variable {!r} on objects to be combined. "
"You can skip this check by specifying compat='override'.".format(name)
f"conflicting values for variable {name!r} on objects to be combined. "
"You can skip this check by specifying compat='override'."
)

if combine_method:
Expand Down
14 changes: 9 additions & 5 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,9 @@ def transpose(self, *dims) -> "Variable":
dims = self.dims[::-1]
dims = tuple(infix_dims(dims, self.dims))
axes = self.get_axis_num(dims)
if len(dims) < 2: # no need to transpose if only one dimension
if len(dims) < 2 or dims == self.dims:
# no need to transpose if only one dimension
# or dims are in same order
return self.copy(deep=False)

data = as_indexable(self._data).transpose(axes)
Expand Down Expand Up @@ -1595,22 +1597,24 @@ def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv):
return False
return self.equals(other, equiv=equiv)

def identical(self, other):
def identical(self, other, equiv=duck_array_ops.array_equiv):
"""Like equals, but also checks attributes.
"""
try:
return utils.dict_equiv(self.attrs, other.attrs) and self.equals(other)
return utils.dict_equiv(self.attrs, other.attrs) and self.equals(
other, equiv=equiv
)
except (TypeError, AttributeError):
return False

def no_conflicts(self, other):
def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv):
"""True if the intersection of two Variable's non-null data is
equal; otherwise false.

Variables can thus still be equal if there are locations where either,
or both, contain NaN values.
"""
return self.broadcast_equals(other, equiv=duck_array_ops.array_notnull_equiv)
return self.broadcast_equals(other, equiv=equiv)

def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
"""Compute the qth quantile of the data along the specified dimension.
Expand Down
108 changes: 107 additions & 1 deletion xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
raises_regex,
requires_scipy_or_netCDF4,
)
from ..core.duck_array_ops import lazy_array_equiv
from .test_backends import create_tmp_file

dask = pytest.importorskip("dask")
Expand Down Expand Up @@ -428,7 +429,53 @@ def test_concat_loads_variables(self):
out.compute()
assert kernel_call_count == 24

# Finally, test that riginals are unaltered
# Finally, test that originals are unaltered
assert ds1["d"].data is d1
assert ds1["c"].data is c1
assert ds2["d"].data is d2
assert ds2["c"].data is c2
assert ds3["d"].data is d3
assert ds3["c"].data is c3

# now check that concat() is correctly using dask name equality to skip loads
out = xr.concat(
[ds1, ds1, ds1], dim="n", data_vars="different", coords="different"
)
assert kernel_call_count == 24
# variables are not loaded in the output
assert isinstance(out["d"].data, dask.array.Array)
assert isinstance(out["c"].data, dask.array.Array)

out = xr.concat(
[ds1, ds1, ds1], dim="n", data_vars=[], coords=[], compat="identical"
)
assert kernel_call_count == 24
# variables are not loaded in the output
assert isinstance(out["d"].data, dask.array.Array)
assert isinstance(out["c"].data, dask.array.Array)

out = xr.concat(
[ds1, ds2.compute(), ds3],
dim="n",
data_vars="all",
coords="different",
compat="identical",
)
# c1,c3 must be computed for comparison since c2 is numpy;
# d2 is computed too
assert kernel_call_count == 28

out = xr.concat(
[ds1, ds2.compute(), ds3],
dim="n",
data_vars="all",
coords="all",
compat="identical",
)
# no extra computes
assert kernel_call_count == 30

# Finally, test that originals are unaltered
assert ds1["d"].data is d1
assert ds1["c"].data is c1
assert ds2["d"].data is d2
Expand Down Expand Up @@ -1142,6 +1189,19 @@ def test_make_meta(map_ds):
assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim


def test_identical_coords_no_computes():
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
a = xr.DataArray(
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
)
b = xr.DataArray(
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
)
with raise_if_dask_computes():
c = a + b
assert_identical(c, a)


@pytest.mark.parametrize(
"obj", [make_da(), make_da().compute(), make_ds(), make_ds().compute()]
)
Expand Down Expand Up @@ -1229,3 +1289,49 @@ def test_normalize_token_with_backend(map_ds):
map_ds.to_netcdf(tmp_file)
read = xr.open_dataset(tmp_file)
assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read)


@pytest.mark.parametrize(
"compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
)
def test_lazy_array_equiv_variables(compat):
var1 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
var2 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
var3 = xr.Variable(("y", "x"), da.zeros((20, 10), chunks=2))

with raise_if_dask_computes():
assert getattr(var1, compat)(var2, equiv=lazy_array_equiv)
# values are actually equal, but we don't know that till we compute, return None
with raise_if_dask_computes():
assert getattr(var1, compat)(var2 / 2, equiv=lazy_array_equiv) is None

# shapes are not equal, return False without computes
with raise_if_dask_computes():
assert getattr(var1, compat)(var3, equiv=lazy_array_equiv) is False

# if one or both arrays are numpy, return None
assert getattr(var1, compat)(var2.compute(), equiv=lazy_array_equiv) is None
assert (
getattr(var1.compute(), compat)(var2.compute(), equiv=lazy_array_equiv) is None
)

with raise_if_dask_computes():
assert getattr(var1, compat)(var2.transpose("y", "x"))


@pytest.mark.parametrize(
"compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
)
def test_lazy_array_equiv_merge(compat):
da1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
da2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
da3 = xr.DataArray(da.ones((20, 10), chunks=2), dims=("y", "x"))

with raise_if_dask_computes():
xr.merge([da1, da2], compat=compat)
# shapes are not equal; no computes necessary
with raise_if_dask_computes(max_computes=0):
with pytest.raises(ValueError):
xr.merge([da1, da3], compat=compat)
with raise_if_dask_computes(max_computes=2):
xr.merge([da1, da2 / 2], compat=compat)