Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

propagate indexes in to_dataset, from_dataset #3519

Merged
merged 14 commits into from
Nov 22, 2019
44 changes: 31 additions & 13 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
)
from .dataset import Dataset, merge_indexes, split_indexes
from .formatting import format_item
from .indexes import Indexes, default_indexes
from .merge import PANDAS_TYPES
from .indexes import Indexes, copy_indexes, default_indexes
from .merge import PANDAS_TYPES, _extract_indexes_from_coords
from .options import OPTIONS
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
from .variable import (
Expand Down Expand Up @@ -367,6 +367,9 @@ def __init__(
data = as_compatible_data(data)
coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
variable = Variable(dims, data, attrs, encoding, fastpath=True)
indexes = dict(
_extract_indexes_from_coords(coords)
) # needed for to_dataset

# These fully describe a DataArray
self._variable = variable
Expand Down Expand Up @@ -401,6 +404,7 @@ def _replace_maybe_drop_dims(
) -> "DataArray":
if variable.dims == self.dims and variable.shape == self.shape:
coords = self._coords.copy()
indexes = copy_indexes(self._indexes)
elif variable.dims == self.dims:
# Shape has changed (e.g. from reduce(..., keepdims=True)
new_sizes = dict(zip(self.dims, variable.shape))
Expand All @@ -409,12 +413,19 @@ def _replace_maybe_drop_dims(
for k, v in self._coords.items()
if v.shape == tuple(new_sizes[d] for d in v.dims)
}
changed_dims = [
k for k in variable.dims if variable.sizes[k] != self.sizes[k]
]
indexes = copy_indexes(self._indexes, exclude=changed_dims)
else:
allowed_dims = set(variable.dims)
coords = {
k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims
}
return self._replace(variable, coords, name)
indexes = copy_indexes(
self._indexes, exclude=(set(self.dims) - allowed_dims)
)
return self._replace(variable, coords, name, indexes=indexes)

def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
if not len(indexes):
Expand Down Expand Up @@ -445,19 +456,21 @@ def _from_temp_dataset(
return self._replace(variable, coords, name, indexes=indexes)

def _to_dataset_split(self, dim: Hashable) -> Dataset:
""" splits dataarray along dimension 'dim' """

def subset(dim, label):
array = self.loc[{dim: label}]
if dim in array.coords:
del array.coords[dim]
array.attrs = {}
return array
return as_variable(array)

variables = {label: subset(dim, label) for label in self.get_index(dim)}

coords = self.coords.to_dataset()
if dim in coords:
del coords[dim]
return Dataset(variables, coords, self.attrs)
variables.update({k: v for k, v in self._coords.items() if k != dim})
indexes = copy_indexes(self._indexes, exclude=dim)
coord_names = set(self._coords) - set([dim])
dataset = Dataset._from_vars_and_coord_names(
variables, coord_names, indexes=indexes, attrs=self.attrs
)
return dataset

def _to_dataset_whole(
self, name: Hashable = None, shallow_copy: bool = True
Expand All @@ -481,8 +494,12 @@ def _to_dataset_whole(
if shallow_copy:
for k in variables:
variables[k] = variables[k].copy(deep=False)
indexes = copy_indexes(self._indexes, deep=(not shallow_copy))

coord_names = set(self._coords)
dataset = Dataset._from_vars_and_coord_names(variables, coord_names)
dataset = Dataset._from_vars_and_coord_names(
variables, coord_names, indexes=indexes
)
return dataset

def to_dataset(self, dim: Hashable = None, *, name: Hashable = None) -> Dataset:
Expand Down Expand Up @@ -926,7 +943,8 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
"""
variable = self.variable.copy(deep=deep, data=data)
coords = {k: v.copy(deep=deep) for k, v in self._coords.items()}
return self._replace(variable, coords)
indexes = copy_indexes(self._indexes, deep=deep)
return self._replace(variable, coords, indexes=indexes)

def __copy__(self) -> "DataArray":
return self.copy(deep=False)
Expand Down
21 changes: 17 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@
remap_label_indexers,
)
from .duck_array_ops import datetime_to_numeric
from .indexes import Indexes, default_indexes, isel_variable_and_index, roll_index
from .indexes import (
Indexes,
copy_indexes,
default_indexes,
isel_variable_and_index,
roll_index,
)
from .merge import (
dataset_merge_method,
dataset_update_method,
Expand Down Expand Up @@ -862,8 +868,12 @@ def _construct_direct(
return obj

@classmethod
def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None):
return cls._construct_direct(variables, coord_names, attrs=attrs)
def _from_vars_and_coord_names(
cls, variables, coord_names, indexes=None, attrs=None
):
return cls._construct_direct(
variables, coord_names, indexes=indexes, attrs=attrs
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this isn't part of this PR, so maybe for another day; does calling _from_vars_and_coord_names do anything different from calling _construct_direct?


def _replace(
self,
Expand Down Expand Up @@ -4308,10 +4318,13 @@ def to_array(self, dim="variable", name=None):

coords = dict(self.coords)
coords[dim] = list(self.data_vars)
indexes = copy_indexes(self._indexes)

dims = (dim,) + broadcast_vars[0].dims

return DataArray(data, coords, dims, attrs=self.attrs, name=name)
return DataArray(
data, coords, dims, attrs=self.attrs, name=name, indexes=indexes
)

def _to_dataframe(self, ordered_dims):
columns = [k for k in self.variables if k not in self.dims]
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
from .concat import concat
from .formatting import format_array_flat
from .indexes import copy_indexes
from .options import _get_keep_attrs
from .pycompat import integer_types
from .utils import (
Expand Down Expand Up @@ -529,7 +530,7 @@ def _maybe_unstack(self, obj):
for dim in self._inserted_dims:
if dim in obj.coords:
del obj.coords[dim]
del obj.indexes[dim]
obj._indexes = copy_indexes(obj._indexes, exclude=self._inserted_dims)
return obj

def fillna(self, value):
Expand Down Expand Up @@ -732,7 +733,8 @@ def _combine(self, applied, restore_coord_dims=False, shortcut=False):
combined = self._restore_dim_order(combined)
if coord is not None:
if shortcut:
combined._coords[coord.name] = as_variable(coord)
coord_var = as_variable(coord)
combined._coords[coord.name] = coord_var
else:
combined.coords[coord.name] = coord
combined = self._maybe_restore_empty_groups(combined)
Expand Down
25 changes: 22 additions & 3 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd

from . import formatting
from .utils import is_scalar
from .variable import Variable


Expand Down Expand Up @@ -35,9 +36,6 @@ def __contains__(self, key):
def __getitem__(self, key):
return self._indexes[key]

def __delitem__(self, key):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

undoes my previous change.

I think I understand the indexes model better now. We only use ._indexes dict internally and Indexes is intended for external use. Is that right?

del self._indexes[key]

def __repr__(self):
return formatting.indexes_repr(self)

Expand Down Expand Up @@ -100,3 +98,24 @@ def roll_index(index: pd.Index, count: int, axis: int = 0) -> pd.Index:
return index[-count:].append(index[:-count])
else:
return index[:]


def copy_indexes(
dcherian marked this conversation as resolved.
Show resolved Hide resolved
indexes: Optional[Dict[Hashable, pd.Index]],
deep: bool = True,
exclude: Optional[Any] = None,
) -> Optional[Dict[Hashable, pd.Index]]:
if exclude is None:
exclude = ()

if is_scalar(exclude):
exclude = (exclude,)

if indexes is not None:
new_indexes = {
k: v.copy(deep=deep) for k, v in indexes.items() if k not in exclude
}
else:
new_indexes = None # type: ignore

return new_indexes
2 changes: 2 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xarray.convert import from_cdms2
from xarray.core import dtypes
from xarray.core.common import full_like
from xarray.core.indexes import copy_indexes
from xarray.tests import (
LooseVersion,
ReturnItem,
Expand Down Expand Up @@ -1229,6 +1230,7 @@ def test_coords(self):
assert expected == actual

del da.coords["x"]
da._indexes = copy_indexes(da._indexes, exclude="x")
expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo")
assert_identical(da, expected)

Expand Down