diff --git a/podpac/core/interpolation/interpolation_manager.py b/podpac/core/interpolation/interpolation_manager.py index 0b22dded..7f0e2e92 100644 --- a/podpac/core/interpolation/interpolation_manager.py +++ b/podpac/core/interpolation/interpolation_manager.py @@ -6,6 +6,7 @@ from collections import OrderedDict from six import string_types import numpy as np +import xarray as xr import traitlets as tl from podpac.core import settings @@ -502,9 +503,42 @@ def select_coordinates(self, source_coordinates, eval_coordinates, index_type="n validate_crs=False, ) if index_type == "numpy": - selected_coords_idx2 = np.ix_(*[np.ravel(selected_coords_idx[k]) for k in source_coordinates.dims]) - elif index_type in ["slice", "xarray"]: - selected_coords_idx2 = tuple([selected_coords_idx[d] for d in source_coordinates.dims]) + npcoords = [] + has_stacked = False + for k in source_coordinates.dims: + # Deal with nD stacked source coords (marked by coords being in tuple) + if isinstance(selected_coords_idx[k], tuple): + has_stacked = True + npcoords.extend([sci for sci in selected_coords_idx[k]]) + else: + npcoords.append(selected_coords_idx[k]) + if has_stacked: + # When stacked coordinates are nD we cannot use the catchall of the next branch + selected_coords_idx2 = npcoords + else: + # This would not be needed if everything went as planned in + # interpolator.select_coordinates, but this is a catchall that works + # for 90% of the cases + selected_coords_idx2 = np.ix_(*[np.ravel(npc) for npc in npcoords]) + elif index_type == "xarray": + selected_coords_idx2 = [] + for i in selected_coords.dims: + # Deal with nD stacked source coords (marked by coords being in tuple) + if isinstance(selected_coords_idx[i], tuple): + selected_coords_idx2.extend([xr.DataArray(sci, dims=[i]) for sci in selected_coords_idx[i]]) + else: + selected_coords_idx2.append(selected_coords_idx[i]) + selected_coords_idx2 = tuple(selected_coords_idx2) + elif index_type == "slice": + selected_coords_idx2 = [] + for i in selected_coords.dims: + # Deal with nD stacked source coords (marked by coords being in tuple) + if isinstance(selected_coords_idx[i], tuple): + selected_coords_idx2.extend(selected_coords_idx[i]) + else: + selected_coords_idx2.append(selected_coords_idx[i]) + + selected_coords_idx2 = tuple(selected_coords_idx2) else: raise ValueError("Unknown index_type '%s'" % index_type) return selected_coords, tuple(selected_coords_idx2) diff --git a/podpac/core/interpolation/nearest_neighbor_interpolator.py b/podpac/core/interpolation/nearest_neighbor_interpolator.py index e2fcf601..069aaefa 100644 --- a/podpac/core/interpolation/nearest_neighbor_interpolator.py +++ b/podpac/core/interpolation/nearest_neighbor_interpolator.py @@ -95,7 +95,7 @@ def is_stacked(d): ] else: - bounds = {d: None for d in source_coordinates.udims} + bounds = None if self.remove_nan: # Eliminate nans from the source data. Note, this could turn a uniform griddted dataset into a stacked one @@ -110,16 +110,38 @@ def is_stacked(d): continue source = source_coordinates[d] if is_stacked(d): - bound = np.stack([bounds[dd] for dd in d.split("_")], axis=1) + if bounds is not None: + bound = np.stack([bounds[dd] for dd in d.split("_")], axis=1) + else: + bound = None index = self._get_stacked_index(d, source, eval_coordinates, bound) + + if len(source.shape) == 2: # Handle case of 2D-stacked coordinates + ncols = source.shape[1] + index1 = index // ncols + index1 = self._resize_stacked_index(index1, d, eval_coordinates) + # With nD stacked coordinates, there are 'n' indices in the tuple + # All of these need to get into the data_index, and in the right order + data_index.append(index1) # This is a hack + index = index % ncols # The second half can go through the usual machinery + elif len(source.shape) > 2: # Handle case of nD-stacked coordinates + raise NotImplementedError index = self._resize_stacked_index(index, d, eval_coordinates) elif source_coordinates[d].is_uniform: request = eval_coordinates[d] - index = self._get_uniform_index(d, source, request, bounds[d]) + if bounds is not None: + bound = bounds[d] + else: + bound = None + index = self._get_uniform_index(d, source, request, bound) index = self._resize_unstacked_index(index, d, eval_coordinates) else: # non-uniform coordinates... probably an optimization here request = eval_coordinates[d] - index = self._get_nonuniform_index(d, source, request, bounds[d]) + if bounds is not None: + bound = bounds[d] + else: + bound = None + index = self._get_nonuniform_index(d, source, request, bound) index = self._resize_unstacked_index(index, d, eval_coordinates) data_index.append(index) @@ -219,7 +241,8 @@ def _get_stacked_index(self, dim, source, request, bounds=None): scales = np.array([self._get_scale(d, time_source, time_request) for d in udims])[None, :] tol = np.linalg.norm((tols * scales).squeeze()) src_coords, req_coords_diag = _higher_precision_time_stack(source, request, udims) - ckdtree_source = cKDTree(src_coords.T * scales) + # We need to unwravel the nD stacked coordinates + ckdtree_source = cKDTree(src_coords.reshape(src_coords.shape[0], -1).T * scales) # if the udims are all stacked in the same stack as part of the request coordinates, then we're done. # Otherwise we have to evaluate each unstacked set of dimensions independently @@ -252,7 +275,13 @@ def _get_stacked_index(self, dim, source, request, bounds=None): if self.respect_bounds: if bounds is None: - bounds = [src_coords.min(0), src_coords.max(0)] + bounds = np.stack( + [ + src_coords.reshape(src_coords.shape[0], -1).T.min(0), + src_coords.reshape(src_coords.shape[0], -1).T.max(0), + ], + axis=1, + ) # Fix order of bounds bounds = bounds[:, [source.udims.index(dim) for dim in udims]] index[np.any((req_coords > bounds[1]), axis=1) | np.any((req_coords < bounds[0]), axis=1)] = -1 diff --git a/podpac/core/interpolation/selector.py b/podpac/core/interpolation/selector.py index 70ffb782..7cf3b484 100644 --- a/podpac/core/interpolation/selector.py +++ b/podpac/core/interpolation/selector.py @@ -100,8 +100,15 @@ def select(self, source_coords, request_coords, index_type="numpy"): for coord1d in source_coords._coords.values(): ci = self._select1d(coord1d, request_coords, index_type) ci = np.sort(np.unique(ci)) - if index_type == "slice": + if len(coord1d.shape) == 2: # Handle case of 2D-stacked coordinates + ncols = coord1d.shape[1] + ci = (ci // ncols, ci % ncols) + if index_type == "slice": + ci = tuple([_index2slice(cii) for cii in ci]) + elif index_type == "slice": ci = _index2slice(ci) + if len(coord1d.shape) == 3: # Handle case of 3D-stacked coordinates + raise NotImplementedError c = coord1d[ci] coords.append(c) coords_inds.append(ci) @@ -109,7 +116,9 @@ def select(self, source_coords, request_coords, index_type="numpy"): if index_type == "numpy": coords_inds = self._merge_indices(coords_inds, source_coords.dims, request_coords.dims) elif index_type == "xarray": - pass # unlike numpy, xarray assumes indexes are orthogonal by default, so the 1d coordinates are already correct + # unlike numpy, xarray assumes indexes are orthogonal by default, so the 1d coordinates are already correct + # unless there are tuple coordinates (nD stacked coords) but those are handled in interpolation_manager + pass return coords, tuple(coords_inds) def _select1d(self, source, request, index_type): @@ -127,11 +136,18 @@ def _select1d(self, source, request, index_type): def _merge_indices(self, indices, source_dims, request_dims): # For numpy to broadcast correctly, we have to reshape each of the indices reshape = np.ones(len(indices), int) + new_indices = [] for i in range(len(indices)): reshape[:] = 1 reshape[i] = -1 - indices[i] = indices[i].reshape(*reshape) - return tuple(indices) + if isinstance(indices[i], tuple): + # nD stacked coordinates + # This means the source has shape (N, M, ...) + # But the coordinates are stacked (i.e. lat_lon with shape N, M for the lon and lat parts) + new_indices.append(tuple([ind.reshape(*reshape) for ind in indices[i]])) + else: + new_indices.append(indices[i].reshape(*reshape)) + return tuple(new_indices) def _select_uniform(self, source, request, index_type): crds = request[source.name] @@ -186,7 +202,8 @@ def _select_stacked(self, source, request, index_type): inds = np.array([]) # Parts of the below code is duplicated in NearestNeighborInterpolotor src_coords, req_coords_diag = _higher_precision_time_stack(source, request, udims) - ckdtree_source = cKDTree(src_coords.T) + # For nD stacked coordinates we need to unravel the stacked dimension + ckdtree_source = cKDTree(src_coords.reshape(src_coords.shape[0], -1).T) if (len(indep_evals) + len(stacked)) <= 1: req_coords = req_coords_diag.T elif (len(stacked) == 0) | (len(indep_evals) == 0 and len(stacked) == len(udims)): diff --git a/podpac/core/interpolation/test/test_interpolators.py b/podpac/core/interpolation/test/test_interpolators.py index c6616845..b0e04093 100644 --- a/podpac/core/interpolation/test/test_interpolators.py +++ b/podpac/core/interpolation/test/test_interpolators.py @@ -31,6 +31,15 @@ def get_data(self, coordinates, coordinates_index): return self.create_output_array(coordinates, data=self.data[coordinates_index]) +class MockArrayDataSourceXR(InterpolationMixin, DataSource): + data = ArrayTrait().tag(attr=True) + coordinates = tl.Instance(Coordinates).tag(attr=True) + + def get_data(self, coordinates, coordinates_index): + dataxr = self.create_output_array(self.coordinates, data=self.data) + return self.create_output_array(coordinates, data=dataxr[coordinates_index].data) + + class TestNone(object): def test_none_select(self): reqcoords = Coordinates([[-0.5, 1.5, 3.5], [0.5, 2.5, 4.5]], dims=["lat", "lon"]) @@ -564,6 +573,128 @@ def test_respect_bounds(self): np.testing.assert_array_equal(output.data[1:], source[[0, 2]]) assert np.isnan(output.data[0]) + def test_2Dstacked(self): + # With Time + source = np.random.rand(5, 4, 2) + coords_src = Coordinates( + [ + [ + np.arange(5)[:, None] + 0.1 * np.ones((5, 4)), + np.arange(4)[None, :] + 0.1 * np.ones((5, 4)), + ], + [0.4, 0.7], + ], + ["lat_lon", "time"], + ) + coords_dst = Coordinates([np.arange(4) + 0.2, np.arange(1, 4) - 0.2, [0.5]], ["lat", "lon", "time"]) + node = MockArrayDataSource( + data=source, + coordinates=coords_src, + interpolation={ + "method": "nearest", + "interpolators": [NearestNeighbor], + }, + ) + output = node.eval(coords_dst) + np.testing.assert_array_equal(output, source[:4, 1:, :1]) + + # Using 'xarray' coordinates type + node = MockArrayDataSourceXR( + data=source, + coordinates=coords_src, + coordinate_index_type="xarray", + interpolation={ + "method": "nearest", + "interpolators": [NearestNeighbor], + }, + ) + output = node.eval(coords_dst) + np.testing.assert_array_equal(output, source[:4, 1:, :1]) + + # Using 'slice' coordinates type + node = MockArrayDataSource( + data=source, + coordinates=coords_src, + coordinate_index_type="slice", + interpolation={ + "method": "nearest", + "interpolators": [NearestNeighbor], + }, + ) + output = node.eval(coords_dst) + np.testing.assert_array_equal(output, source[:4, 1:, :1]) + + # Without Time + source = np.random.rand(5, 4) + node = MockArrayDataSource( + data=source, + coordinates=coords_src.drop("time"), + interpolation={ + "method": "nearest", + "interpolators": [NearestNeighbor], + }, + ) + output = node.eval(coords_dst) + np.testing.assert_array_equal(output, source[:4, 1:]) + + # def test_3Dstacked(self): + # # With Time + # source = np.random.rand(5, 4, 2) + # coords_src = Coordinates([[ + # np.arange(5)[:, None, None] + 0.1 * np.ones((5, 4, 2)), + # np.arange(4)[None, :, None] + 0.1 * np.ones((5, 4, 2)), + # np.arange(2)[None, None, :] + 0.1 * np.ones((5, 4, 2))]], ["lat_lon_time"]) + # coords_dst = Coordinates([np.arange(4)+0.2, np.arange(1, 4)-0.2, [0.5]], ["lat", "lon", "time"]) + # node = MockArrayDataSource( + # data=source, + # coordinates=coords_src, + # interpolation={ + # "method": "nearest", + # "interpolators": [NearestNeighbor], + # }, + # ) + # output = node.eval(coords_dst) + # np.testing.assert_array_equal(output, source[:4, 1:, :1]) + + # # Using 'xarray' coordinates type + # node = MockArrayDataSourceXR( + # data=source, + # coordinates=coords_src, + # coordinate_index_type='xarray', + # interpolation={ + # "method": "nearest", + # "interpolators": [NearestNeighbor], + # }, + # ) + # output = node.eval(coords_dst) + # np.testing.assert_array_equal(output, source[:4, 1:, :1]) + + # # Using 'slice' coordinates type + # node = MockArrayDataSource( + # data=source, + # coordinates=coords_src, + # coordinate_index_type='slice', + # interpolation={ + # "method": "nearest", + # "interpolators": [NearestNeighbor], + # }, + # ) + # output = node.eval(coords_dst) + # np.testing.assert_array_equal(output, source[:4, 1:, :1]) + + # # Without Time + # source = np.random.rand(5, 4) + # node = MockArrayDataSource( + # data=source, + # coordinates=coords_src.drop('time'), + # interpolation={ + # "method": "nearest", + # "interpolators": [NearestNeighbor], + # }, + # ) + # output = node.eval(coords_dst) + # np.testing.assert_array_equal(output, source[:4, 1:]) + class TestInterpolateRasterioInterpolator(object): """test interpolation functions""" diff --git a/podpac/datalib/egi.py b/podpac/datalib/egi.py index 97c6aedf..fe167ccd 100644 --- a/podpac/datalib/egi.py +++ b/podpac/datalib/egi.py @@ -105,11 +105,14 @@ class EGI(InterpolationMixin, DataSource): lat_key = tl.Unicode(allow_none=True).tag(attr=True) lon_key = tl.Unicode(allow_none=True).tag(attr=True) time_key = tl.Unicode(allow_none=True).tag(attr=True) + udims_overwrite = tl.List() min_bounds_span = tl.Dict(allow_none=True).tag(attr=True) @property def udims(self): + if self.udims_overwrite: + return self.udims_overwrite """ This needs to be implemented so this node will cache properly. See Datasource.eval.""" raise NotImplementedError