Skip to content

Commit

Permalink
FIX #320: update interpolation definition
Browse files Browse the repository at this point in the history
  • Loading branch information
mlshapiro committed Nov 1, 2019
1 parent e11b29a commit 05163b4
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 84 deletions.
38 changes: 18 additions & 20 deletions podpac/core/data/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,26 +99,24 @@
:attr:`podpac.data.INTERPOLATION_SHORTCUTS`. The interpolation method associated
with this string will be applied to all dimensions at the same time.
If input is a dict, the dict must contain one of two definitions:
1. A dictionary which contains the key ``'method'`` defining the interpolation method name.
If the interpolation method is not one of :attr:`podpac.data.INTERPOLATION_SHORTCUTS`, a
second key ``'interpolators'`` must be defined with a list of
:class:`podpac.interpolators.Interpolator` classes to use in order of uages.
The dictionary may contain an option ``'params'`` key which contains a dict of parameters to pass along to
the :class:`podpac.interpolators.Interpolator` classes associated with the interpolation method.
This interpolation definition will be applied to all dimensions.
2. A dictionary containing an ordered set of keys defining dimensions and values
defining the interpolation method to use with the dimensions.
The key must be a string or tuple of dimension names (i.e. ``'time'`` or ``('lat', 'lon')`` ).
The value can either be a string matching one of the interpolation shortcuts defined in
:attr:`podpac.data.INTERPOLATION_SHORTCUTS` or a dictionary meeting the previous requirements
(1). If the dictionary does not contain a key for all unstacked dimensions of the source coordinates, the
:attr:`podpac.data.INTERPOLATION_DEFAULT` value will be used.
All dimension keys must be unstacked even if the underlying coordinate dimensions are stacked.
Any extra dimensions included but not found in the source coordinates will be ignored.
If input is a :class:`podpac.data.Interpolation` class, this interpolation
If input is a dict or list of dict, the dict or dict elements must adhere to the following format:
The key ``'method'`` defining the interpolation method name.
If the interpolation method is not one of :attr:`podpac.data.INTERPOLATION_SHORTCUTS`, a
second key ``'interpolators'`` must be defined with a list of
:class:`podpac.interpolators.Interpolator` classes to use in order of uages.
The dictionary may contain an option ``'params'`` key which contains a dict of parameters to pass along to
the :class:`podpac.interpolators.Interpolator` classes associated with the interpolation method.
The dict may contain the key ``'dims'`` which specifies dimension names (i.e. ``'time'`` or ``('lat', 'lon')`` ).
If the dictionary does not contain a key for all unstacked dimensions of the source coordinates, the
:attr:`podpac.data.INTERPOLATION_DEFAULT` value will be used.
All dimension keys must be unstacked even if the underlying coordinate dimensions are stacked.
Any extra dimensions included but not found in the source coordinates will be ignored.
The dict may contain a key ``'params'`` that can be used to configure the :class:`podpac.interpolators.Interpolator` classes associated with the interpolation method.
If input is a :class:`podpac.data.Interpolation` class, this Interpolation
class will be used without modification.
""",
}
Expand Down
39 changes: 18 additions & 21 deletions podpac/core/data/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class Interpolation(object):

def __init__(self, definition=INTERPOLATION_DEFAULT):

self.definition = definition
self.definition = deepcopy(definition)
self.config = OrderedDict()

# if definition is None, set to default
Expand All @@ -126,37 +126,34 @@ def __init__(self, definition=INTERPOLATION_DEFAULT):
self.definition = INTERPOLATION_DEFAULT

# set each dim to interpolator definition
if isinstance(definition, dict):
if isinstance(definition, (dict, list)):

# covert input to an ordered dict to preserve order of dimensions
definition = OrderedDict(definition)
# convert dict to list
if isinstance(definition, dict):
definition = [definition]

for key in iter(definition):
for interp_definition in definition:

# if dict is a default definition, skip the rest of the handling
if not isinstance(key, tuple):
if key in ["method", "params", "interpolators"]:
method = self._parse_interpolation_method(definition)
self._set_interpolation_method(("default",), method)
break
# get interpolation method dict
method = self._parse_interpolation_method(interp_definition)

# if key is not a tuple, convert it to one and set it to the udims key
if not isinstance(key, tuple):
udims = (key,)
# specify dims
if "dims" in interp_definition:
if isinstance(interp_definition["dims"], list):
interp_definition["dims"].sort() # make sure the dims are always in the same order
udims = tuple(interp_definition["dims"])
else:
raise TypeError('The "dims" key of an interpolation definition must be a list')
else:
udims = key
udims = ("default",)

# make sure udims are not already specified in config
for config_dims in iter(self.config):
if set(config_dims) & set(udims):
raise InterpolationException(
'Dimensions "{}" cannot be defined '.format(udims)
+ "multiple times in interpolation definition {}".format(definition)
+ "multiple times in interpolation definition {}".format(interp_definition)
)

# get interpolation method
method = self._parse_interpolation_method(definition[key])

# add all udims to definition
self._set_interpolation_method(udims, method)

Expand All @@ -173,7 +170,7 @@ def __init__(self, definition=INTERPOLATION_DEFAULT):
else:
raise TypeError(
'"{}" is not a valid interpolation definition type. '.format(definition)
+ "Interpolation definiton must be a string or dict"
+ "Interpolation definiton must be a string or list of dicts"
)

# make sure ('default',) is always the last entry in config dictionary
Expand Down
98 changes: 55 additions & 43 deletions podpac/core/data/test/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,16 @@ def test_str_definition(self):

def test_dict_definition(self):

# should handle a default definition without any dimensions specified as keys
# should handle a default definition without any dimensions
interp = Interpolation({"method": "nearest", "params": {"spatial_tolerance": 1}})
assert isinstance(interp.config[("default",)], dict)
assert interp.config[("default",)]["method"] == "nearest"
assert isinstance(interp.config[("default",)]["interpolators"][0], Interpolator)
assert interp.config[("default",)]["params"] == {"spatial_tolerance": 1}

# should throw an error on _parse_interpolation_method(definition)
# if definition is not in INTERPOLATION_METHODS
with pytest.raises(InterpolationException):
Interpolation({("lat", "lon"): "test"})

# handle string methods
interp = Interpolation({("lat", "lon"): "nearest"})
interp = Interpolation({"method": "nearest", "dims": ["lat", "lon"]})
print (interp.config)
assert isinstance(interp.config[("lat", "lon")], dict)
assert interp.config[("lat", "lon")]["method"] == "nearest"
assert isinstance(interp.config[("default",)]["interpolators"][0], Interpolator)
Expand All @@ -87,34 +83,38 @@ def test_dict_definition(self):

# should throw an error if method is not in dict
with pytest.raises(InterpolationException):
Interpolation({("lat", "lon"): {"test": "test"}})
Interpolation([{"test": "test", "dims": ["lat", "lon"]}])

# should throw an error if method is not a string
with pytest.raises(InterpolationException):
Interpolation({("lat", "lon"): {"method": 5}})
Interpolation([{"method": 5, "dims": ["lat", "lon"]}])

# should throw an error if method is not one of the INTERPOLATION_METHODS and no interpolators defined
with pytest.raises(InterpolationException):
Interpolation({("lat", "lon"): {"method": "myinter"}})
Interpolation([{"method": "myinter", "dims": ["lat", "lon"]}])

# should throw an error if params is not a dict
with pytest.raises(TypeError):
Interpolation({("lat", "lon"): {"method": "nearest", "params": "test"}})
Interpolation([{"method": "nearest", "dims": ["lat", "lon"], "params": "test"}])

# should throw an error if interpolators is not a list
with pytest.raises(TypeError):
Interpolation({("lat", "lon"): {"method": "nearest", "interpolators": "test"}})
Interpolation([{"method": "nearest", "interpolators": "test", "dims": ["lat", "lon"]}])

# should throw an error if interpolators are not Interpolator classes
with pytest.raises(TypeError):
Interpolation({("lat", "lon"): {"method": "nearest", "interpolators": [NearestNeighbor, "test"]}})
Interpolation([{"method": "nearest", "interpolators": [NearestNeighbor, "test"], "dims": ["lat", "lon"]}])

# should throw an error if dimension is defined twice
with pytest.raises(InterpolationException):
Interpolation({("lat", "lon"): "nearest", "lat": "bilinear"})
Interpolation([{"method": "nearest", "dims": ["lat", "lon"]}, {"method": "bilinear", "dims": ["lat"]}])

# should throw an error if dimension is not a list
with pytest.raises(TypeError):
Interpolation([{"method": "nearest", "dims": "lat"}])

# should handle standard INTEPROLATION_SHORTCUTS
interp = Interpolation({("lat", "lon"): {"method": "nearest"}})
interp = Interpolation([{"method": "nearest", "dims": ["lat", "lon"]}])
assert isinstance(interp.config[("lat", "lon")], dict)
assert interp.config[("lat", "lon")]["method"] == "nearest"
assert isinstance(interp.config[("lat", "lon")]["interpolators"][0], Interpolator)
Expand All @@ -123,52 +123,64 @@ def test_dict_definition(self):
# should not allow custom methods if interpolators can't support
with pytest.raises(InterpolatorException):
interp = Interpolation(
{("lat", "lon"): {"method": "myinter", "interpolators": [NearestNeighbor, NearestPreview]}}
[{"method": "myinter", "interpolators": [NearestNeighbor, NearestPreview], "dims": ["lat", "lon"]}]
)

# should allow custom methods if interpolators can support
class MyInterp(Interpolator):
methods_supported = ["myinter"]

interp = Interpolation({("lat", "lon"): {"method": "myinter", "interpolators": [MyInterp]}})
interp = Interpolation([{"method": "myinter", "interpolators": [MyInterp], "dims": ["lat", "lon"]}])
assert interp.config[("lat", "lon")]["method"] == "myinter"
assert isinstance(interp.config[("lat", "lon")]["interpolators"][0], MyInterp)

# should allow params to be set
interp = Interpolation(
{("lat", "lon"): {"method": "myinter", "interpolators": [MyInterp], "params": {"spatial_tolerance": 5}}}
[
{
"method": "myinter",
"interpolators": [MyInterp],
"params": {"spatial_tolerance": 5},
"dims": ["lat", "lon"],
}
]
)

assert interp.config[("lat", "lon")]["params"] == {"spatial_tolerance": 5}

# set default equal to empty tuple
interp = Interpolation({"lat": "bilinear"})
interp = Interpolation([{"method": "bilinear", "dims": ["lat"]}])
assert interp.config[("default",)]["method"] == INTERPOLATION_DEFAULT

# use default with override if not all dimensions are supplied
interp = Interpolation({"lat": "bilinear", "default": "nearest"})
interp = Interpolation([{"method": "bilinear", "dims": ["lat"]}, "nearest"])
assert interp.config[("default",)]["method"] == "nearest"

# make sure default is always the last key in the ordered config dict
interp = Interpolation({"default": "nearest", "lat": "bilinear"})
interp = Interpolation(["nearest", {"method": "bilinear", "dims": ["lat"]}])
assert list(interp.config.keys())[-1] == ("default",)

# should sort the dims keys
interp = Interpolation(["nearest", {"method": "bilinear", "dims": ["lon", "lat"]}])
assert interp.config[("lat", "lon")]["method"] == "bilinear"

def test_init_interpolators(self):

# should set method
interp = Interpolation("nearest")
assert interp.config[("default",)]["interpolators"][0].method == "nearest"

# Interpolation init should init all interpolators in the list
interp = Interpolation({"default": {"method": "nearest", "params": {"spatial_tolerance": 1}}})
interp = Interpolation([{"method": "nearest", "params": {"spatial_tolerance": 1}}])
assert interp.config[("default",)]["interpolators"][0].spatial_tolerance == 1

# should throw TraitErrors defined by Interpolator
with pytest.raises(tl.TraitError):
Interpolation({"default": {"method": "nearest", "params": {"spatial_tolerance": "tol"}}})
Interpolation([{"method": "nearest", "params": {"spatial_tolerance": "tol"}}])

# should not allow undefined params
with pytest.warns(DeprecationWarning): # eventually, Traitlets will raise an exception here
interp = Interpolation({"default": {"method": "nearest", "params": {"myarg": 1}}})
interp = Interpolation([{"method": "nearest", "params": {"myarg": 1}}])
with pytest.raises(AttributeError):
assert interp.config[("default",)]["interpolators"][0].myarg == "tol"

Expand Down Expand Up @@ -212,10 +224,10 @@ def can_interpolate(self, udims, source_coordinates, eval_coordinates):
# set up a strange interpolation definition
# we want to interpolate (lat, lon) first, then after (time, alt)
interp = Interpolation(
{
("lat", "lon"): {"method": "myinterp", "interpolators": [LatLon, TimeLat]},
("time", "alt"): {"method": "myinterp", "interpolators": [TimeLat, Lon]},
}
[
{"method": "myinterp", "interpolators": [LatLon, TimeLat], "dims": ["lat", "lon"]},
{"method": "myinterp", "interpolators": [TimeLat, Lon], "dims": ["time", "alt"]},
]
)

# default = 'nearest', which will return NearestPreview for can_select
Expand Down Expand Up @@ -271,10 +283,10 @@ def select_coordinates(self, udims, srccoords, srccoords_idx, reqcoords):
# set up a strange interpolation definition
# we want to interpolate (lat, lon) first, then after (time, alt)
interp = Interpolation(
{
("lat", "lon"): {"method": "myinterp", "interpolators": [LatLon, TimeLat]},
("time", "alt"): {"method": "myinterp", "interpolators": [TimeLat, Lon]},
}
[
{"method": "myinterp", "interpolators": [LatLon, TimeLat], "dims": ["lat", "lon"]},
{"method": "myinterp", "interpolators": [TimeLat, Lon], "dims": ["time", "alt"]},
]
)

coords, cidx = interp.select_coordinates(srccoords, [], reqcoords)
Expand All @@ -301,7 +313,7 @@ def interpolate(self, udims, source_coordinates, source_data, eval_coordinates,
np.zeros(srcdata.shape), coords=[reqcoords[c].coordinates for c in reqcoords], dims=reqcoords.dims
)

interp = Interpolation({("lat", "lon"): {"method": "myinterp", "interpolators": [TestInterp]}})
interp = Interpolation({"method": "myinterp", "interpolators": [TestInterp], "dims": ["lat", "lon"]})
outdata = interp.interpolate(srccoords, srcdata, reqcoords, outdata)

assert np.all(outdata == srcdata)
Expand All @@ -322,7 +334,7 @@ def interpolate(self, udims, source_coordinates, source_data, eval_coordinates,
np.zeros(srcdata.shape), coords=[reqcoords[c].coordinates for c in reqcoords], dims=reqcoords.dims
)

interp = Interpolation({("lat", "lon"): {"method": "myinterp", "interpolators": [TestFakeInterp]}})
interp = Interpolation({"method": "myinterp", "interpolators": [TestFakeInterp], "dims": ["lat", "lon"]})
outdata = interp.interpolate(srccoords, srcdata, reqcoords, outdata)

assert np.all(outdata == srcdata)
Expand Down Expand Up @@ -387,7 +399,9 @@ def test_nearest_preview_select(self):
reqcoords = Coordinates([[-0.5, 1.5, 3.5], [0.5, 2.5, 4.5]], dims=["lat", "lon"])
srccoords = Coordinates([[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]], dims=["lat", "lon"])

interp = Interpolation({"lat": "nearest_preview", "lon": "nearest_preview"})
interp = Interpolation(
[{"method": "nearest_preview", "dims": ["lat"]}, {"method": "nearest_preview", "dims": ["lon"]}]
)

srccoords, srccoords_index = srccoords.intersect(reqcoords, outer=True, return_indices=True)
coords, cidx = interp.select_coordinates(srccoords, srccoords_index, reqcoords)
Expand Down Expand Up @@ -481,14 +495,14 @@ def test_spatial_tolerance(self):
node = MockArrayDataSource(
source=source,
native_coordinates=coords_src,
interpolation={"default": {"method": "nearest", "params": {"spatial_tolerance": 1.1}}},
interpolation={"method": "nearest", "params": {"spatial_tolerance": 1.1}},
)

coords_dst = Coordinates([[1, 1.2, 1.5, 5, 9]], dims=["lat"])
output = node.eval(coords_dst)

print(output)
print(source)
print (output)
print (source)
assert isinstance(output, UnitsDataArray)
assert np.all(output.lat.values == coords_dst.coords["lat"])
assert output.values[0] == source[0] and np.isnan(output.values[1]) and output.values[2] == source[1]
Expand All @@ -504,10 +518,8 @@ def test_time_tolerance(self):
source=source,
native_coordinates=coords_src,
interpolation={
"default": {
"method": "nearest",
"params": {"spatial_tolerance": 1.1, "time_tolerance": np.timedelta64(1, "D")},
}
"method": "nearest",
"params": {"spatial_tolerance": 1.1, "time_tolerance": np.timedelta64(1, "D")},
},
)

Expand Down Expand Up @@ -601,7 +613,7 @@ def test_interpolate_scipy_grid(self):

assert isinstance(output, UnitsDataArray)
assert np.all(output.lat.values == coords_dst.coords["lat"])
print(output)
print (output)
assert output.data[0, 0] == 0.0
assert output.data[0, 3] == 3.0
assert output.data[1, 3] == 8.0
Expand Down
3 changes: 3 additions & 0 deletions podpac/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def default(self, obj):
elif isinstance(obj, podpac.core.style.Style):
return obj.definition

elif isinstance(obj, podpac.data.Interpolation):
return obj.definition

# pint Units
elif isinstance(obj, podpac.core.units.ureg.Unit):
return str(obj)
Expand Down

0 comments on commit 05163b4

Please sign in to comment.