Skip to content

Commit

Permalink
Merge pull request #488 from creare-com/feature/stack2dinterp
Browse files Browse the repository at this point in the history
ENH: Enabling NN interpolation for nD stacked coordinates.
  • Loading branch information
mpu-creare authored Oct 11, 2021
2 parents 75670f6 + ffd0d13 commit 3a2a5b3
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 14 deletions.
40 changes: 37 additions & 3 deletions podpac/core/interpolation/interpolation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import OrderedDict
from six import string_types
import numpy as np
import xarray as xr
import traitlets as tl

from podpac.core import settings
Expand Down Expand Up @@ -502,9 +503,42 @@ def select_coordinates(self, source_coordinates, eval_coordinates, index_type="n
validate_crs=False,
)
if index_type == "numpy":
selected_coords_idx2 = np.ix_(*[np.ravel(selected_coords_idx[k]) for k in source_coordinates.dims])
elif index_type in ["slice", "xarray"]:
selected_coords_idx2 = tuple([selected_coords_idx[d] for d in source_coordinates.dims])
npcoords = []
has_stacked = False
for k in source_coordinates.dims:
# Deal with nD stacked source coords (marked by coords being in tuple)
if isinstance(selected_coords_idx[k], tuple):
has_stacked = True
npcoords.extend([sci for sci in selected_coords_idx[k]])
else:
npcoords.append(selected_coords_idx[k])
if has_stacked:
# When stacked coordinates are nD we cannot use the catchall of the next branch
selected_coords_idx2 = npcoords
else:
# This would not be needed if everything went as planned in
# interpolator.select_coordinates, but this is a catchall that works
# for 90% of the cases
selected_coords_idx2 = np.ix_(*[np.ravel(npc) for npc in npcoords])
elif index_type == "xarray":
selected_coords_idx2 = []
for i in selected_coords.dims:
# Deal with nD stacked source coords (marked by coords being in tuple)
if isinstance(selected_coords_idx[i], tuple):
selected_coords_idx2.extend([xr.DataArray(sci, dims=[i]) for sci in selected_coords_idx[i]])
else:
selected_coords_idx2.append(selected_coords_idx[i])
selected_coords_idx2 = tuple(selected_coords_idx2)
elif index_type == "slice":
selected_coords_idx2 = []
for i in selected_coords.dims:
# Deal with nD stacked source coords (marked by coords being in tuple)
if isinstance(selected_coords_idx[i], tuple):
selected_coords_idx2.extend(selected_coords_idx[i])
else:
selected_coords_idx2.append(selected_coords_idx[i])

selected_coords_idx2 = tuple(selected_coords_idx2)
else:
raise ValueError("Unknown index_type '%s'" % index_type)
return selected_coords, tuple(selected_coords_idx2)
Expand Down
41 changes: 35 additions & 6 deletions podpac/core/interpolation/nearest_neighbor_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def is_stacked(d):
]

else:
bounds = {d: None for d in source_coordinates.udims}
bounds = None

if self.remove_nan:
# Eliminate nans from the source data. Note, this could turn a uniform griddted dataset into a stacked one
Expand All @@ -110,16 +110,38 @@ def is_stacked(d):
continue
source = source_coordinates[d]
if is_stacked(d):
bound = np.stack([bounds[dd] for dd in d.split("_")], axis=1)
if bounds is not None:
bound = np.stack([bounds[dd] for dd in d.split("_")], axis=1)
else:
bound = None
index = self._get_stacked_index(d, source, eval_coordinates, bound)

if len(source.shape) == 2: # Handle case of 2D-stacked coordinates
ncols = source.shape[1]
index1 = index // ncols
index1 = self._resize_stacked_index(index1, d, eval_coordinates)
# With nD stacked coordinates, there are 'n' indices in the tuple
# All of these need to get into the data_index, and in the right order
data_index.append(index1) # This is a hack
index = index % ncols # The second half can go through the usual machinery
elif len(source.shape) > 2: # Handle case of nD-stacked coordinates
raise NotImplementedError
index = self._resize_stacked_index(index, d, eval_coordinates)
elif source_coordinates[d].is_uniform:
request = eval_coordinates[d]
index = self._get_uniform_index(d, source, request, bounds[d])
if bounds is not None:
bound = bounds[d]
else:
bound = None
index = self._get_uniform_index(d, source, request, bound)
index = self._resize_unstacked_index(index, d, eval_coordinates)
else: # non-uniform coordinates... probably an optimization here
request = eval_coordinates[d]
index = self._get_nonuniform_index(d, source, request, bounds[d])
if bounds is not None:
bound = bounds[d]
else:
bound = None
index = self._get_nonuniform_index(d, source, request, bound)
index = self._resize_unstacked_index(index, d, eval_coordinates)

data_index.append(index)
Expand Down Expand Up @@ -219,7 +241,8 @@ def _get_stacked_index(self, dim, source, request, bounds=None):
scales = np.array([self._get_scale(d, time_source, time_request) for d in udims])[None, :]
tol = np.linalg.norm((tols * scales).squeeze())
src_coords, req_coords_diag = _higher_precision_time_stack(source, request, udims)
ckdtree_source = cKDTree(src_coords.T * scales)
# We need to unwravel the nD stacked coordinates
ckdtree_source = cKDTree(src_coords.reshape(src_coords.shape[0], -1).T * scales)

# if the udims are all stacked in the same stack as part of the request coordinates, then we're done.
# Otherwise we have to evaluate each unstacked set of dimensions independently
Expand Down Expand Up @@ -252,7 +275,13 @@ def _get_stacked_index(self, dim, source, request, bounds=None):

if self.respect_bounds:
if bounds is None:
bounds = [src_coords.min(0), src_coords.max(0)]
bounds = np.stack(
[
src_coords.reshape(src_coords.shape[0], -1).T.min(0),
src_coords.reshape(src_coords.shape[0], -1).T.max(0),
],
axis=1,
)
# Fix order of bounds
bounds = bounds[:, [source.udims.index(dim) for dim in udims]]
index[np.any((req_coords > bounds[1]), axis=1) | np.any((req_coords < bounds[0]), axis=1)] = -1
Expand Down
27 changes: 22 additions & 5 deletions podpac/core/interpolation/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,25 @@ def select(self, source_coords, request_coords, index_type="numpy"):
for coord1d in source_coords._coords.values():
ci = self._select1d(coord1d, request_coords, index_type)
ci = np.sort(np.unique(ci))
if index_type == "slice":
if len(coord1d.shape) == 2: # Handle case of 2D-stacked coordinates
ncols = coord1d.shape[1]
ci = (ci // ncols, ci % ncols)
if index_type == "slice":
ci = tuple([_index2slice(cii) for cii in ci])
elif index_type == "slice":
ci = _index2slice(ci)
if len(coord1d.shape) == 3: # Handle case of 3D-stacked coordinates
raise NotImplementedError
c = coord1d[ci]
coords.append(c)
coords_inds.append(ci)
coords = Coordinates(coords)
if index_type == "numpy":
coords_inds = self._merge_indices(coords_inds, source_coords.dims, request_coords.dims)
elif index_type == "xarray":
pass # unlike numpy, xarray assumes indexes are orthogonal by default, so the 1d coordinates are already correct
# unlike numpy, xarray assumes indexes are orthogonal by default, so the 1d coordinates are already correct
# unless there are tuple coordinates (nD stacked coords) but those are handled in interpolation_manager
pass
return coords, tuple(coords_inds)

def _select1d(self, source, request, index_type):
Expand All @@ -127,11 +136,18 @@ def _select1d(self, source, request, index_type):
def _merge_indices(self, indices, source_dims, request_dims):
# For numpy to broadcast correctly, we have to reshape each of the indices
reshape = np.ones(len(indices), int)
new_indices = []
for i in range(len(indices)):
reshape[:] = 1
reshape[i] = -1
indices[i] = indices[i].reshape(*reshape)
return tuple(indices)
if isinstance(indices[i], tuple):
# nD stacked coordinates
# This means the source has shape (N, M, ...)
# But the coordinates are stacked (i.e. lat_lon with shape N, M for the lon and lat parts)
new_indices.append(tuple([ind.reshape(*reshape) for ind in indices[i]]))
else:
new_indices.append(indices[i].reshape(*reshape))
return tuple(new_indices)

def _select_uniform(self, source, request, index_type):
crds = request[source.name]
Expand Down Expand Up @@ -186,7 +202,8 @@ def _select_stacked(self, source, request, index_type):
inds = np.array([])
# Parts of the below code is duplicated in NearestNeighborInterpolotor
src_coords, req_coords_diag = _higher_precision_time_stack(source, request, udims)
ckdtree_source = cKDTree(src_coords.T)
# For nD stacked coordinates we need to unravel the stacked dimension
ckdtree_source = cKDTree(src_coords.reshape(src_coords.shape[0], -1).T)
if (len(indep_evals) + len(stacked)) <= 1:
req_coords = req_coords_diag.T
elif (len(stacked) == 0) | (len(indep_evals) == 0 and len(stacked) == len(udims)):
Expand Down
131 changes: 131 additions & 0 deletions podpac/core/interpolation/test/test_interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ def get_data(self, coordinates, coordinates_index):
return self.create_output_array(coordinates, data=self.data[coordinates_index])


class MockArrayDataSourceXR(InterpolationMixin, DataSource):
data = ArrayTrait().tag(attr=True)
coordinates = tl.Instance(Coordinates).tag(attr=True)

def get_data(self, coordinates, coordinates_index):
dataxr = self.create_output_array(self.coordinates, data=self.data)
return self.create_output_array(coordinates, data=dataxr[coordinates_index].data)


class TestNone(object):
def test_none_select(self):
reqcoords = Coordinates([[-0.5, 1.5, 3.5], [0.5, 2.5, 4.5]], dims=["lat", "lon"])
Expand Down Expand Up @@ -564,6 +573,128 @@ def test_respect_bounds(self):
np.testing.assert_array_equal(output.data[1:], source[[0, 2]])
assert np.isnan(output.data[0])

def test_2Dstacked(self):
# With Time
source = np.random.rand(5, 4, 2)
coords_src = Coordinates(
[
[
np.arange(5)[:, None] + 0.1 * np.ones((5, 4)),
np.arange(4)[None, :] + 0.1 * np.ones((5, 4)),
],
[0.4, 0.7],
],
["lat_lon", "time"],
)
coords_dst = Coordinates([np.arange(4) + 0.2, np.arange(1, 4) - 0.2, [0.5]], ["lat", "lon", "time"])
node = MockArrayDataSource(
data=source,
coordinates=coords_src,
interpolation={
"method": "nearest",
"interpolators": [NearestNeighbor],
},
)
output = node.eval(coords_dst)
np.testing.assert_array_equal(output, source[:4, 1:, :1])

# Using 'xarray' coordinates type
node = MockArrayDataSourceXR(
data=source,
coordinates=coords_src,
coordinate_index_type="xarray",
interpolation={
"method": "nearest",
"interpolators": [NearestNeighbor],
},
)
output = node.eval(coords_dst)
np.testing.assert_array_equal(output, source[:4, 1:, :1])

# Using 'slice' coordinates type
node = MockArrayDataSource(
data=source,
coordinates=coords_src,
coordinate_index_type="slice",
interpolation={
"method": "nearest",
"interpolators": [NearestNeighbor],
},
)
output = node.eval(coords_dst)
np.testing.assert_array_equal(output, source[:4, 1:, :1])

# Without Time
source = np.random.rand(5, 4)
node = MockArrayDataSource(
data=source,
coordinates=coords_src.drop("time"),
interpolation={
"method": "nearest",
"interpolators": [NearestNeighbor],
},
)
output = node.eval(coords_dst)
np.testing.assert_array_equal(output, source[:4, 1:])

# def test_3Dstacked(self):
# # With Time
# source = np.random.rand(5, 4, 2)
# coords_src = Coordinates([[
# np.arange(5)[:, None, None] + 0.1 * np.ones((5, 4, 2)),
# np.arange(4)[None, :, None] + 0.1 * np.ones((5, 4, 2)),
# np.arange(2)[None, None, :] + 0.1 * np.ones((5, 4, 2))]], ["lat_lon_time"])
# coords_dst = Coordinates([np.arange(4)+0.2, np.arange(1, 4)-0.2, [0.5]], ["lat", "lon", "time"])
# node = MockArrayDataSource(
# data=source,
# coordinates=coords_src,
# interpolation={
# "method": "nearest",
# "interpolators": [NearestNeighbor],
# },
# )
# output = node.eval(coords_dst)
# np.testing.assert_array_equal(output, source[:4, 1:, :1])

# # Using 'xarray' coordinates type
# node = MockArrayDataSourceXR(
# data=source,
# coordinates=coords_src,
# coordinate_index_type='xarray',
# interpolation={
# "method": "nearest",
# "interpolators": [NearestNeighbor],
# },
# )
# output = node.eval(coords_dst)
# np.testing.assert_array_equal(output, source[:4, 1:, :1])

# # Using 'slice' coordinates type
# node = MockArrayDataSource(
# data=source,
# coordinates=coords_src,
# coordinate_index_type='slice',
# interpolation={
# "method": "nearest",
# "interpolators": [NearestNeighbor],
# },
# )
# output = node.eval(coords_dst)
# np.testing.assert_array_equal(output, source[:4, 1:, :1])

# # Without Time
# source = np.random.rand(5, 4)
# node = MockArrayDataSource(
# data=source,
# coordinates=coords_src.drop('time'),
# interpolation={
# "method": "nearest",
# "interpolators": [NearestNeighbor],
# },
# )
# output = node.eval(coords_dst)
# np.testing.assert_array_equal(output, source[:4, 1:])


class TestInterpolateRasterioInterpolator(object):
"""test interpolation functions"""
Expand Down
3 changes: 3 additions & 0 deletions podpac/datalib/egi.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,14 @@ class EGI(InterpolationMixin, DataSource):
lat_key = tl.Unicode(allow_none=True).tag(attr=True)
lon_key = tl.Unicode(allow_none=True).tag(attr=True)
time_key = tl.Unicode(allow_none=True).tag(attr=True)
udims_overwrite = tl.List()

min_bounds_span = tl.Dict(allow_none=True).tag(attr=True)

@property
def udims(self):
if self.udims_overwrite:
return self.udims_overwrite
""" This needs to be implemented so this node will cache properly. See Datasource.eval."""
raise NotImplementedError

Expand Down

0 comments on commit 3a2a5b3

Please sign in to comment.