diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c92fcb956b1..36452db8b7b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -50,8 +50,8 @@ ) from .dataset import Dataset, split_indexes from .formatting import format_item -from .indexes import Indexes, default_indexes -from .merge import PANDAS_TYPES +from .indexes import Indexes, propagate_indexes, default_indexes +from .merge import PANDAS_TYPES, _extract_indexes_from_coords from .options import OPTIONS from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs from .variable import ( @@ -367,6 +367,9 @@ def __init__( data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) variable = Variable(dims, data, attrs, encoding, fastpath=True) + indexes = dict( + _extract_indexes_from_coords(coords) + ) # needed for to_dataset # These fully describe a DataArray self._variable = variable @@ -400,6 +403,7 @@ def _replace_maybe_drop_dims( ) -> "DataArray": if variable.dims == self.dims and variable.shape == self.shape: coords = self._coords.copy() + indexes = self._indexes elif variable.dims == self.dims: # Shape has changed (e.g. from reduce(..., keepdims=True) new_sizes = dict(zip(self.dims, variable.shape)) @@ -408,12 +412,19 @@ def _replace_maybe_drop_dims( for k, v in self._coords.items() if v.shape == tuple(new_sizes[d] for d in v.dims) } + changed_dims = [ + k for k in variable.dims if variable.sizes[k] != self.sizes[k] + ] + indexes = propagate_indexes(self._indexes, exclude=changed_dims) else: allowed_dims = set(variable.dims) coords = { k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims } - return self._replace(variable, coords, name) + indexes = propagate_indexes( + self._indexes, exclude=(set(self.dims) - allowed_dims) + ) + return self._replace(variable, coords, name, indexes=indexes) def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray": if not len(indexes): @@ -444,19 +455,21 @@ def _from_temp_dataset( return self._replace(variable, coords, name, indexes=indexes) def _to_dataset_split(self, dim: Hashable) -> Dataset: + """ splits dataarray along dimension 'dim' """ + def subset(dim, label): array = self.loc[{dim: label}] - if dim in array.coords: - del array.coords[dim] array.attrs = {} - return array + return as_variable(array) variables = {label: subset(dim, label) for label in self.get_index(dim)} - - coords = self.coords.to_dataset() - if dim in coords: - del coords[dim] - return Dataset(variables, coords, self.attrs) + variables.update({k: v for k, v in self._coords.items() if k != dim}) + indexes = propagate_indexes(self._indexes, exclude=dim) + coord_names = set(self._coords) - set([dim]) + dataset = Dataset._from_vars_and_coord_names( + variables, coord_names, indexes=indexes, attrs=self.attrs + ) + return dataset def _to_dataset_whole( self, name: Hashable = None, shallow_copy: bool = True @@ -480,8 +493,12 @@ def _to_dataset_whole( if shallow_copy: for k in variables: variables[k] = variables[k].copy(deep=False) + indexes = self._indexes + coord_names = set(self._coords) - dataset = Dataset._from_vars_and_coord_names(variables, coord_names) + dataset = Dataset._from_vars_and_coord_names( + variables, coord_names, indexes=indexes + ) return dataset def to_dataset(self, dim: Hashable = None, *, name: Hashable = None) -> Dataset: @@ -927,7 +944,8 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray": """ variable = self.variable.copy(deep=deep, data=data) coords = {k: v.copy(deep=deep) for k, v in self._coords.items()} - return self._replace(variable, coords) + indexes = self._indexes + return self._replace(variable, coords, indexes=indexes) def __copy__(self) -> "DataArray": return self.copy(deep=False) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 206f2f55b3c..089cda06b0d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -59,7 +59,13 @@ remap_label_indexers, ) from .duck_array_ops import datetime_to_numeric -from .indexes import Indexes, default_indexes, isel_variable_and_index, roll_index +from .indexes import ( + Indexes, + default_indexes, + isel_variable_and_index, + propagate_indexes, + roll_index, +) from .merge import ( dataset_merge_method, dataset_update_method, @@ -872,8 +878,12 @@ def _construct_direct( return obj @classmethod - def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None): - return cls._construct_direct(variables, coord_names, attrs=attrs) + def _from_vars_and_coord_names( + cls, variables, coord_names, indexes=None, attrs=None + ): + return cls._construct_direct( + variables, coord_names, indexes=indexes, attrs=attrs + ) def _replace( self, @@ -4375,10 +4385,13 @@ def to_array(self, dim="variable", name=None): coords = dict(self.coords) coords[dim] = list(self.data_vars) + indexes = propagate_indexes(self._indexes) dims = (dim,) + broadcast_vars[0].dims - return DataArray(data, coords, dims, attrs=self.attrs, name=name) + return DataArray( + data, coords, dims, attrs=self.attrs, name=name, indexes=indexes + ) def _to_dataframe(self, ordered_dims): columns = [k for k in self.variables if k not in self.dims] diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ec752721781..7e872c74d72 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -10,6 +10,7 @@ from .common import ImplementsArrayReduce, ImplementsDatasetReduce from .concat import concat from .formatting import format_array_flat +from .indexes import propagate_indexes from .options import _get_keep_attrs from .pycompat import integer_types from .utils import ( @@ -529,7 +530,7 @@ def _maybe_unstack(self, obj): for dim in self._inserted_dims: if dim in obj.coords: del obj.coords[dim] - del obj.indexes[dim] + obj._indexes = propagate_indexes(obj._indexes, exclude=self._inserted_dims) return obj def fillna(self, value): @@ -786,7 +787,8 @@ def _combine(self, applied, restore_coord_dims=False, shortcut=False): combined = self._restore_dim_order(combined) if coord is not None: if shortcut: - combined._coords[coord.name] = as_variable(coord) + coord_var = as_variable(coord) + combined._coords[coord.name] = coord_var else: combined.coords[coord.name] = coord combined = self._maybe_restore_empty_groups(combined) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 1574f4f18df..8337a0f082a 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -5,6 +5,7 @@ import pandas as pd from . import formatting +from .utils import is_scalar from .variable import Variable @@ -35,9 +36,6 @@ def __contains__(self, key): def __getitem__(self, key): return self._indexes[key] - def __delitem__(self, key): - del self._indexes[key] - def __repr__(self): return formatting.indexes_repr(self) @@ -100,3 +98,22 @@ def roll_index(index: pd.Index, count: int, axis: int = 0) -> pd.Index: return index[-count:].append(index[:-count]) else: return index[:] + + +def propagate_indexes( + indexes: Optional[Dict[Hashable, pd.Index]], exclude: Optional[Any] = None +) -> Optional[Dict[Hashable, pd.Index]]: + """ Creates new indexes dict from existing dict optionally excluding some dimensions. + """ + if exclude is None: + exclude = () + + if is_scalar(exclude): + exclude = (exclude,) + + if indexes is not None: + new_indexes = {k: v for k, v in indexes.items() if k not in exclude} + else: + new_indexes = None # type: ignore + + return new_indexes diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 4c3553c867e..b5397525a22 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -14,6 +14,7 @@ from xarray.convert import from_cdms2 from xarray.core import dtypes from xarray.core.common import full_like +from xarray.core.indexes import propagate_indexes from xarray.tests import ( LooseVersion, ReturnItem, @@ -1239,6 +1240,7 @@ def test_coords(self): assert expected == actual del da.coords["x"] + da._indexes = propagate_indexes(da._indexes, exclude="x") expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo") assert_identical(da, expected)