Skip to content

Commit

Permalink
Fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 29, 2019
1 parent 4cb7d02 commit 9dc340e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
17 changes: 15 additions & 2 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,17 @@ def determine_dims(datasets, result_coord_names):
both_data_and_coords = result_coord_names & noncoord_names
if both_data_and_coords:
raise ValueError(
"%r is a coordinate in some datasets and not in others."
% list(both_data_and_coords)[0] # preserve previous behaviour
"%r is a coordinate in some datasets but not others."
% list(both_data_and_coords)[0] # preserve format of error message
)
dim_names, result_coords = determine_dims(datasets, result_coord_names)
# we don't want the concat dimension in the result dataset yet
result_coords.pop(dim, None)

# case where concat dimension is a coordinate but not a dimension
if dim in result_coord_names and dim not in dim_names:
datasets = [ds.expand_dims(dim) for ds in datasets]

# determine which variables to concatentate
concat_over, equals = _calc_concat_over(
datasets, dim, dim_names, data_vars, coords, compat
Expand All @@ -335,6 +339,10 @@ def determine_dims(datasets, result_coord_names):
if variables_to_merge:
to_merge = []
for ds in datasets:
if variables_to_merge - set(ds.variables):
raise ValueError(
"Encountered unexpected variables %r" % list(variables_to_merge)[0]
)
to_merge.append(ds.reset_coords()[list(variables_to_merge)])
# TODO: Provide equals as an argument and thread that down to merge.unique_variable
result_vars = merge_variables(
Expand Down Expand Up @@ -397,6 +405,11 @@ def ensure_common_dims(vars):
result = result.set_coords(result_coord_names)
result.encoding = result_encoding

# TODO: avoid this
unlabeled_dims = dim_names - result_coord_names
result = result.drop(unlabeled_dims, errors="ignore")

# need to pass test when provided dim is a DataArray
if coord is not None:
# add concat dimension last to ensure that its in the final Dataset
result[coord.name] = coord
Expand Down
27 changes: 17 additions & 10 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import pytest

from xarray import DataArray, Dataset, Variable, concat
from xarray.core import dtypes

from xarray.core import dtypes, merge
from . import (
InaccessibleArray,
assert_array_equal,
Expand Down Expand Up @@ -36,11 +35,15 @@ def test_concat_compat():
coords={"x": [0, 1], "y": [1], "z": [-1, -2], "q": [0]},
)

result = concat([ds1, ds2], dim="y", compat="equals")
result = concat([ds1, ds2], dim="y", data_vars="minimal", compat="broadcast_equals")
assert_equal(ds2.no_x_y, result.no_x_y.transpose())

for var in ["has_x", "no_x_y", "const1"]:
for var in ["has_x", "no_x_y"]:
assert "y" not in result[var]

with raises_regex(ValueError, "'q' is not a dimension in all datasets"):
concat([ds1, ds2], dim="q", data_vars="all", compat="broadcast_equals")


class TestConcatDataset:
@pytest.fixture
Expand Down Expand Up @@ -116,7 +119,7 @@ def test_concat_coords(self):
actual = concat(objs, dim="x", coords=coords)
assert_identical(expected, actual)
for coords in ["minimal", []]:
with raises_regex(ValueError, "not equal across"):
with raises_regex(merge.MergeError, "conflicting values"):
concat(objs, dim="x", coords=coords)

def test_concat_constant_index(self):
Expand All @@ -127,8 +130,10 @@ def test_concat_constant_index(self):
for mode in ["different", "all", ["foo"]]:
actual = concat([ds1, ds2], "y", data_vars=mode)
assert_identical(expected, actual)
with raises_regex(ValueError, "not equal across datasets"):
concat([ds1, ds2], "y", data_vars="minimal")
with raises_regex(merge.MergeError, "conflicting values"):
# previously dim="y", and raised error which makes no sense.
# "foo" has dimension "y" so minimal should concatenate it?
concat([ds1, ds2], "new_dim", data_vars="minimal")

def test_concat_size0(self):
data = create_test_data()
Expand Down Expand Up @@ -372,10 +377,10 @@ def setUp(self):

def test_concat_sensible_compat_errors(self):

with raises_regex(ValueError, "'only_x' is not equal across datasets."):
with raises_regex(merge.MergeError, "conflicting values"):
concat(self.dsets, data_vars="sensible", dim="y")

with raises_regex(ValueError, "'coord1' is not equal across datasets."):
with raises_regex(merge.MergeError, "conflicting values"):
concat(self.dsets, coords="sensible", dim="y")

@pytest.mark.parametrize("concat_dim", ["y", "new_dim"])
Expand Down Expand Up @@ -431,7 +436,9 @@ def test_compat_override(self, data_vars, coords):

def test_compat_override_different_error(self):
with raises_regex(ValueError, "Cannot specify both .*='different'"):
concat(self.dsets, data_vars="different", compat="override")
concat(
self.dsets, dim="concat_dim", data_vars="different", compat="override"
)


class TestConcatDataArray:
Expand Down

0 comments on commit 9dc340e

Please sign in to comment.