Skip to content

Commit

Permalink
Merge non concatenated variables.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 29, 2019
1 parent 820463b commit f294cc6
Showing 1 changed file with 68 additions and 50 deletions.
118 changes: 68 additions & 50 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

from . import utils, dtypes
from .alignment import align
from .merge import (
determine_coords,
merge_variables,
expand_variable_dicts,
_VALID_COMPAT,
)
from .variable import IndexVariable, Variable, as_variable
from .variable import concat as concat_vars

Expand Down Expand Up @@ -65,7 +71,7 @@ def concat(
in addition to the 'minimal' coordinates.
compat : {'equals', 'identical', 'override'}, optional
String indicating how to compare non-concatenated variables and
dataset global attributes for potential conflicts.
dataset global attributes for potential conflicts. This is passed down to merge.
* 'equals' means that all variable values and dimensions must be the same;
* 'identical' means that variable attributes and global attributes
must also be equal.
Expand Down Expand Up @@ -145,7 +151,7 @@ def concat(
"`data_vars` and `coords` arguments"
)

if compat not in ["equals", "identical", "override", "no_conflicts"]:
if compat not in _VALID_COMPAT:
raise ValueError(
"compat=%r invalid: must be 'equals', 'identical or 'override'" % compat
)
Expand Down Expand Up @@ -186,22 +192,28 @@ def _calc_concat_dim_coord(dim):
return dim, coord


def _calc_concat_over(datasets, dim, data_vars, coords):
def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat):
"""
Determine which dataset variables need to be concatenated in the result,
and which can simply be taken from the first dataset.
"""
# Return values
concat_over = set()
equals = {}

if dim in datasets[0]:
if dim in dim_names:
concat_over_existing_dim = True
concat_over.add(dim)
else:
concat_over_existing_dim = False

for ds in datasets:
if concat_over_existing_dim:
if dim not in ds.dims:
# TODO: why did I do this
if dim in ds:
ds = ds.set_coords(dim)
else:
raise ValueError("%r is not a dimension in all datasets" % dim)
concat_over.update(k for k, v in ds.variables.items() if dim in v.dims)

def process_subset_opt(opt, subset):
Expand All @@ -225,7 +237,7 @@ def process_subset_opt(opt, subset):
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not v_lhs.equals(v_rhs):
if not getattr(v_lhs, compat)(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
Expand Down Expand Up @@ -291,68 +303,74 @@ def _dataset_concat(
*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value
)

concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)
# determine dimensional coordinate names and a dict mapping name to DataArray
def determine_dims(datasets, result_coord_names):
dims = set()
coords = dict()
for ds in datasets:
for dim in set(ds.dims) - dims:
if dim not in coords:
coords[dim] = ds.coords[dim].variable
dims = dims | set(ds.dims)
return dims, coords

result_coord_names, noncoord_names = determine_coords(datasets)
both_data_and_coords = result_coord_names & noncoord_names
if both_data_and_coords:
raise ValueError(
"%r is a coordinate in some datasets and not in others."
% list(both_data_and_coords)[0] # preserve previous behaviour
)
dim_names, result_coords = determine_dims(datasets, result_coord_names)
# we don't want the concat dimension in the result dataset yet
result_coords.pop(dim, None)

def insert_result_variable(k, v):
assert isinstance(v, Variable)
if k in datasets[0].coords:
result_coord_names.add(k)
result_vars[k] = v
# determine which variables to concatentate
concat_over, equals = _calc_concat_over(
datasets, dim, dim_names, data_vars, coords, compat
)

# determine which variables to merge
variables_to_merge = (result_coord_names | noncoord_names) - concat_over - dim_names
if variables_to_merge:
to_merge = []
for ds in datasets:
to_merge.append(ds.reset_coords()[list(variables_to_merge)])
# TODO: Provide equals as an argument and thread that down to merge.unique_variable
result_vars = merge_variables(
expand_variable_dicts(to_merge), priority_vars=None, compat=compat
)
else:
result_vars = OrderedDict()
result_vars.update(result_coords)

# create the new dataset and add constant variables
result_vars = OrderedDict()
result_coord_names = set(datasets[0].coords)
# assign attrs and encoding from first dataset
result_attrs = datasets[0].attrs
result_encoding = datasets[0].encoding

for k, v in datasets[0].variables.items():
if k not in concat_over:
insert_result_variable(k, v)
def insert_result_variable(k, v):
assert isinstance(v, Variable)
result_vars[k] = v

# check that global attributes and non-concatenated variables are fixed
# across all datasets
# check that global attributes are fixed across all datasets if necessary
for ds in datasets[1:]:
if compat == "identical" and not utils.dict_equiv(ds.attrs, result_attrs):
raise ValueError("Dataset global attributes are not equal.")
for k, v in ds.variables.items():
if k not in result_vars and k not in concat_over:
raise ValueError("Encountered unexpected variable %r" % k)
elif (k in result_coord_names) != (k in ds.coords):
raise ValueError(
"%r is a coordinate in some datasets but not others." % k
)
elif compat != "override" and k in result_vars and k != dim:
# Don't use Variable.identical as it internally invokes
# Variable.equals, and we may already know the answer
if compat == "identical" and not utils.dict_equiv(
v.attrs, result_vars[k].attrs
):
raise ValueError(
"Variable '%s' is not identical across datasets. "
"You can skip this check by specifying compat='override'." % k
)

# Proceed with equals()
try:
# May be populated when using the "different" method
is_equal = equals[k]
except KeyError:
result_vars[k].load()
is_equal = v.equals(result_vars[k])
if not is_equal:
raise ValueError(
"Variable '%s' is not equal across datasets. "
"You can skip this check by specifying compat='override'." % k
)

##############
# TODO: do this stuff earlier so we loop over datasets only once
#############
# we've already verified everything is consistent; now, calculate
# shared dimension sizes so we can expand the necessary variables
dim_lengths = [ds.dims.get(dim, 1) for ds in datasets]
# non_concat_dims = dim_names - concat_over
non_concat_dims = {}
for ds in datasets:
non_concat_dims.update(ds.dims)
non_concat_dims.pop(dim, None)

# seems like there should be a helper function for this. We would need to add
# an exclude kwarg to exclude comparing along concat_dim
def ensure_common_dims(vars):
# ensure each variable with the given name shares the same
# dimensions and the same shape for all of them except along the
Expand Down

0 comments on commit f294cc6

Please sign in to comment.