From 05163b44ca7ab66729150991bf7a6995382e3c35 Mon Sep 17 00:00:00 2001 From: Marc Shapiro Date: Fri, 1 Nov 2019 15:02:24 -0400 Subject: [PATCH] FIX #320: update interpolation definition --- podpac/core/data/datasource.py | 38 +++++---- podpac/core/data/interpolation.py | 39 +++++---- podpac/core/data/test/test_interpolate.py | 98 +++++++++++++---------- podpac/core/utils.py | 3 + 4 files changed, 94 insertions(+), 84 deletions(-) diff --git a/podpac/core/data/datasource.py b/podpac/core/data/datasource.py index bb1bcd16f..cab0d03ec 100644 --- a/podpac/core/data/datasource.py +++ b/podpac/core/data/datasource.py @@ -99,26 +99,24 @@ :attr:`podpac.data.INTERPOLATION_SHORTCUTS`. The interpolation method associated with this string will be applied to all dimensions at the same time. - If input is a dict, the dict must contain one of two definitions: - - 1. A dictionary which contains the key ``'method'`` defining the interpolation method name. - If the interpolation method is not one of :attr:`podpac.data.INTERPOLATION_SHORTCUTS`, a - second key ``'interpolators'`` must be defined with a list of - :class:`podpac.interpolators.Interpolator` classes to use in order of uages. - The dictionary may contain an option ``'params'`` key which contains a dict of parameters to pass along to - the :class:`podpac.interpolators.Interpolator` classes associated with the interpolation method. - This interpolation definition will be applied to all dimensions. - 2. A dictionary containing an ordered set of keys defining dimensions and values - defining the interpolation method to use with the dimensions. - The key must be a string or tuple of dimension names (i.e. ``'time'`` or ``('lat', 'lon')`` ). - The value can either be a string matching one of the interpolation shortcuts defined in - :attr:`podpac.data.INTERPOLATION_SHORTCUTS` or a dictionary meeting the previous requirements - (1). If the dictionary does not contain a key for all unstacked dimensions of the source coordinates, the - :attr:`podpac.data.INTERPOLATION_DEFAULT` value will be used. - All dimension keys must be unstacked even if the underlying coordinate dimensions are stacked. - Any extra dimensions included but not found in the source coordinates will be ignored. - - If input is a :class:`podpac.data.Interpolation` class, this interpolation + If input is a dict or list of dict, the dict or dict elements must adhere to the following format: + + The key ``'method'`` defining the interpolation method name. + If the interpolation method is not one of :attr:`podpac.data.INTERPOLATION_SHORTCUTS`, a + second key ``'interpolators'`` must be defined with a list of + :class:`podpac.interpolators.Interpolator` classes to use in order of uages. + The dictionary may contain an option ``'params'`` key which contains a dict of parameters to pass along to + the :class:`podpac.interpolators.Interpolator` classes associated with the interpolation method. + + The dict may contain the key ``'dims'`` which specifies dimension names (i.e. ``'time'`` or ``('lat', 'lon')`` ). + If the dictionary does not contain a key for all unstacked dimensions of the source coordinates, the + :attr:`podpac.data.INTERPOLATION_DEFAULT` value will be used. + All dimension keys must be unstacked even if the underlying coordinate dimensions are stacked. + Any extra dimensions included but not found in the source coordinates will be ignored. + + The dict may contain a key ``'params'`` that can be used to configure the :class:`podpac.interpolators.Interpolator` classes associated with the interpolation method. + + If input is a :class:`podpac.data.Interpolation` class, this Interpolation class will be used without modification. """, } diff --git a/podpac/core/data/interpolation.py b/podpac/core/data/interpolation.py index f02b495b6..bd00895b1 100644 --- a/podpac/core/data/interpolation.py +++ b/podpac/core/data/interpolation.py @@ -116,7 +116,7 @@ class Interpolation(object): def __init__(self, definition=INTERPOLATION_DEFAULT): - self.definition = definition + self.definition = deepcopy(definition) self.config = OrderedDict() # if definition is None, set to default @@ -126,37 +126,34 @@ def __init__(self, definition=INTERPOLATION_DEFAULT): self.definition = INTERPOLATION_DEFAULT # set each dim to interpolator definition - if isinstance(definition, dict): + if isinstance(definition, (dict, list)): - # covert input to an ordered dict to preserve order of dimensions - definition = OrderedDict(definition) + # convert dict to list + if isinstance(definition, dict): + definition = [definition] - for key in iter(definition): + for interp_definition in definition: - # if dict is a default definition, skip the rest of the handling - if not isinstance(key, tuple): - if key in ["method", "params", "interpolators"]: - method = self._parse_interpolation_method(definition) - self._set_interpolation_method(("default",), method) - break + # get interpolation method dict + method = self._parse_interpolation_method(interp_definition) - # if key is not a tuple, convert it to one and set it to the udims key - if not isinstance(key, tuple): - udims = (key,) + # specify dims + if "dims" in interp_definition: + if isinstance(interp_definition["dims"], list): + interp_definition["dims"].sort() # make sure the dims are always in the same order + udims = tuple(interp_definition["dims"]) + else: + raise TypeError('The "dims" key of an interpolation definition must be a list') else: - udims = key + udims = ("default",) # make sure udims are not already specified in config for config_dims in iter(self.config): if set(config_dims) & set(udims): raise InterpolationException( 'Dimensions "{}" cannot be defined '.format(udims) - + "multiple times in interpolation definition {}".format(definition) + + "multiple times in interpolation definition {}".format(interp_definition) ) - - # get interpolation method - method = self._parse_interpolation_method(definition[key]) - # add all udims to definition self._set_interpolation_method(udims, method) @@ -173,7 +170,7 @@ def __init__(self, definition=INTERPOLATION_DEFAULT): else: raise TypeError( '"{}" is not a valid interpolation definition type. '.format(definition) - + "Interpolation definiton must be a string or dict" + + "Interpolation definiton must be a string or list of dicts" ) # make sure ('default',) is always the last entry in config dictionary diff --git a/podpac/core/data/test/test_interpolate.py b/podpac/core/data/test/test_interpolate.py index df17d7874..8536e94fa 100644 --- a/podpac/core/data/test/test_interpolate.py +++ b/podpac/core/data/test/test_interpolate.py @@ -64,20 +64,16 @@ def test_str_definition(self): def test_dict_definition(self): - # should handle a default definition without any dimensions specified as keys + # should handle a default definition without any dimensions interp = Interpolation({"method": "nearest", "params": {"spatial_tolerance": 1}}) assert isinstance(interp.config[("default",)], dict) assert interp.config[("default",)]["method"] == "nearest" assert isinstance(interp.config[("default",)]["interpolators"][0], Interpolator) assert interp.config[("default",)]["params"] == {"spatial_tolerance": 1} - # should throw an error on _parse_interpolation_method(definition) - # if definition is not in INTERPOLATION_METHODS - with pytest.raises(InterpolationException): - Interpolation({("lat", "lon"): "test"}) - # handle string methods - interp = Interpolation({("lat", "lon"): "nearest"}) + interp = Interpolation({"method": "nearest", "dims": ["lat", "lon"]}) + print (interp.config) assert isinstance(interp.config[("lat", "lon")], dict) assert interp.config[("lat", "lon")]["method"] == "nearest" assert isinstance(interp.config[("default",)]["interpolators"][0], Interpolator) @@ -87,34 +83,38 @@ def test_dict_definition(self): # should throw an error if method is not in dict with pytest.raises(InterpolationException): - Interpolation({("lat", "lon"): {"test": "test"}}) + Interpolation([{"test": "test", "dims": ["lat", "lon"]}]) # should throw an error if method is not a string with pytest.raises(InterpolationException): - Interpolation({("lat", "lon"): {"method": 5}}) + Interpolation([{"method": 5, "dims": ["lat", "lon"]}]) # should throw an error if method is not one of the INTERPOLATION_METHODS and no interpolators defined with pytest.raises(InterpolationException): - Interpolation({("lat", "lon"): {"method": "myinter"}}) + Interpolation([{"method": "myinter", "dims": ["lat", "lon"]}]) # should throw an error if params is not a dict with pytest.raises(TypeError): - Interpolation({("lat", "lon"): {"method": "nearest", "params": "test"}}) + Interpolation([{"method": "nearest", "dims": ["lat", "lon"], "params": "test"}]) # should throw an error if interpolators is not a list with pytest.raises(TypeError): - Interpolation({("lat", "lon"): {"method": "nearest", "interpolators": "test"}}) + Interpolation([{"method": "nearest", "interpolators": "test", "dims": ["lat", "lon"]}]) # should throw an error if interpolators are not Interpolator classes with pytest.raises(TypeError): - Interpolation({("lat", "lon"): {"method": "nearest", "interpolators": [NearestNeighbor, "test"]}}) + Interpolation([{"method": "nearest", "interpolators": [NearestNeighbor, "test"], "dims": ["lat", "lon"]}]) # should throw an error if dimension is defined twice with pytest.raises(InterpolationException): - Interpolation({("lat", "lon"): "nearest", "lat": "bilinear"}) + Interpolation([{"method": "nearest", "dims": ["lat", "lon"]}, {"method": "bilinear", "dims": ["lat"]}]) + + # should throw an error if dimension is not a list + with pytest.raises(TypeError): + Interpolation([{"method": "nearest", "dims": "lat"}]) # should handle standard INTEPROLATION_SHORTCUTS - interp = Interpolation({("lat", "lon"): {"method": "nearest"}}) + interp = Interpolation([{"method": "nearest", "dims": ["lat", "lon"]}]) assert isinstance(interp.config[("lat", "lon")], dict) assert interp.config[("lat", "lon")]["method"] == "nearest" assert isinstance(interp.config[("lat", "lon")]["interpolators"][0], Interpolator) @@ -123,35 +123,47 @@ def test_dict_definition(self): # should not allow custom methods if interpolators can't support with pytest.raises(InterpolatorException): interp = Interpolation( - {("lat", "lon"): {"method": "myinter", "interpolators": [NearestNeighbor, NearestPreview]}} + [{"method": "myinter", "interpolators": [NearestNeighbor, NearestPreview], "dims": ["lat", "lon"]}] ) # should allow custom methods if interpolators can support class MyInterp(Interpolator): methods_supported = ["myinter"] - interp = Interpolation({("lat", "lon"): {"method": "myinter", "interpolators": [MyInterp]}}) + interp = Interpolation([{"method": "myinter", "interpolators": [MyInterp], "dims": ["lat", "lon"]}]) assert interp.config[("lat", "lon")]["method"] == "myinter" assert isinstance(interp.config[("lat", "lon")]["interpolators"][0], MyInterp) # should allow params to be set interp = Interpolation( - {("lat", "lon"): {"method": "myinter", "interpolators": [MyInterp], "params": {"spatial_tolerance": 5}}} + [ + { + "method": "myinter", + "interpolators": [MyInterp], + "params": {"spatial_tolerance": 5}, + "dims": ["lat", "lon"], + } + ] ) + assert interp.config[("lat", "lon")]["params"] == {"spatial_tolerance": 5} # set default equal to empty tuple - interp = Interpolation({"lat": "bilinear"}) + interp = Interpolation([{"method": "bilinear", "dims": ["lat"]}]) assert interp.config[("default",)]["method"] == INTERPOLATION_DEFAULT # use default with override if not all dimensions are supplied - interp = Interpolation({"lat": "bilinear", "default": "nearest"}) + interp = Interpolation([{"method": "bilinear", "dims": ["lat"]}, "nearest"]) assert interp.config[("default",)]["method"] == "nearest" # make sure default is always the last key in the ordered config dict - interp = Interpolation({"default": "nearest", "lat": "bilinear"}) + interp = Interpolation(["nearest", {"method": "bilinear", "dims": ["lat"]}]) assert list(interp.config.keys())[-1] == ("default",) + # should sort the dims keys + interp = Interpolation(["nearest", {"method": "bilinear", "dims": ["lon", "lat"]}]) + assert interp.config[("lat", "lon")]["method"] == "bilinear" + def test_init_interpolators(self): # should set method @@ -159,16 +171,16 @@ def test_init_interpolators(self): assert interp.config[("default",)]["interpolators"][0].method == "nearest" # Interpolation init should init all interpolators in the list - interp = Interpolation({"default": {"method": "nearest", "params": {"spatial_tolerance": 1}}}) + interp = Interpolation([{"method": "nearest", "params": {"spatial_tolerance": 1}}]) assert interp.config[("default",)]["interpolators"][0].spatial_tolerance == 1 # should throw TraitErrors defined by Interpolator with pytest.raises(tl.TraitError): - Interpolation({"default": {"method": "nearest", "params": {"spatial_tolerance": "tol"}}}) + Interpolation([{"method": "nearest", "params": {"spatial_tolerance": "tol"}}]) # should not allow undefined params with pytest.warns(DeprecationWarning): # eventually, Traitlets will raise an exception here - interp = Interpolation({"default": {"method": "nearest", "params": {"myarg": 1}}}) + interp = Interpolation([{"method": "nearest", "params": {"myarg": 1}}]) with pytest.raises(AttributeError): assert interp.config[("default",)]["interpolators"][0].myarg == "tol" @@ -212,10 +224,10 @@ def can_interpolate(self, udims, source_coordinates, eval_coordinates): # set up a strange interpolation definition # we want to interpolate (lat, lon) first, then after (time, alt) interp = Interpolation( - { - ("lat", "lon"): {"method": "myinterp", "interpolators": [LatLon, TimeLat]}, - ("time", "alt"): {"method": "myinterp", "interpolators": [TimeLat, Lon]}, - } + [ + {"method": "myinterp", "interpolators": [LatLon, TimeLat], "dims": ["lat", "lon"]}, + {"method": "myinterp", "interpolators": [TimeLat, Lon], "dims": ["time", "alt"]}, + ] ) # default = 'nearest', which will return NearestPreview for can_select @@ -271,10 +283,10 @@ def select_coordinates(self, udims, srccoords, srccoords_idx, reqcoords): # set up a strange interpolation definition # we want to interpolate (lat, lon) first, then after (time, alt) interp = Interpolation( - { - ("lat", "lon"): {"method": "myinterp", "interpolators": [LatLon, TimeLat]}, - ("time", "alt"): {"method": "myinterp", "interpolators": [TimeLat, Lon]}, - } + [ + {"method": "myinterp", "interpolators": [LatLon, TimeLat], "dims": ["lat", "lon"]}, + {"method": "myinterp", "interpolators": [TimeLat, Lon], "dims": ["time", "alt"]}, + ] ) coords, cidx = interp.select_coordinates(srccoords, [], reqcoords) @@ -301,7 +313,7 @@ def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, np.zeros(srcdata.shape), coords=[reqcoords[c].coordinates for c in reqcoords], dims=reqcoords.dims ) - interp = Interpolation({("lat", "lon"): {"method": "myinterp", "interpolators": [TestInterp]}}) + interp = Interpolation({"method": "myinterp", "interpolators": [TestInterp], "dims": ["lat", "lon"]}) outdata = interp.interpolate(srccoords, srcdata, reqcoords, outdata) assert np.all(outdata == srcdata) @@ -322,7 +334,7 @@ def interpolate(self, udims, source_coordinates, source_data, eval_coordinates, np.zeros(srcdata.shape), coords=[reqcoords[c].coordinates for c in reqcoords], dims=reqcoords.dims ) - interp = Interpolation({("lat", "lon"): {"method": "myinterp", "interpolators": [TestFakeInterp]}}) + interp = Interpolation({"method": "myinterp", "interpolators": [TestFakeInterp], "dims": ["lat", "lon"]}) outdata = interp.interpolate(srccoords, srcdata, reqcoords, outdata) assert np.all(outdata == srcdata) @@ -387,7 +399,9 @@ def test_nearest_preview_select(self): reqcoords = Coordinates([[-0.5, 1.5, 3.5], [0.5, 2.5, 4.5]], dims=["lat", "lon"]) srccoords = Coordinates([[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]], dims=["lat", "lon"]) - interp = Interpolation({"lat": "nearest_preview", "lon": "nearest_preview"}) + interp = Interpolation( + [{"method": "nearest_preview", "dims": ["lat"]}, {"method": "nearest_preview", "dims": ["lon"]}] + ) srccoords, srccoords_index = srccoords.intersect(reqcoords, outer=True, return_indices=True) coords, cidx = interp.select_coordinates(srccoords, srccoords_index, reqcoords) @@ -481,14 +495,14 @@ def test_spatial_tolerance(self): node = MockArrayDataSource( source=source, native_coordinates=coords_src, - interpolation={"default": {"method": "nearest", "params": {"spatial_tolerance": 1.1}}}, + interpolation={"method": "nearest", "params": {"spatial_tolerance": 1.1}}, ) coords_dst = Coordinates([[1, 1.2, 1.5, 5, 9]], dims=["lat"]) output = node.eval(coords_dst) - print(output) - print(source) + print (output) + print (source) assert isinstance(output, UnitsDataArray) assert np.all(output.lat.values == coords_dst.coords["lat"]) assert output.values[0] == source[0] and np.isnan(output.values[1]) and output.values[2] == source[1] @@ -504,10 +518,8 @@ def test_time_tolerance(self): source=source, native_coordinates=coords_src, interpolation={ - "default": { - "method": "nearest", - "params": {"spatial_tolerance": 1.1, "time_tolerance": np.timedelta64(1, "D")}, - } + "method": "nearest", + "params": {"spatial_tolerance": 1.1, "time_tolerance": np.timedelta64(1, "D")}, }, ) @@ -601,7 +613,7 @@ def test_interpolate_scipy_grid(self): assert isinstance(output, UnitsDataArray) assert np.all(output.lat.values == coords_dst.coords["lat"]) - print(output) + print (output) assert output.data[0, 0] == 0.0 assert output.data[0, 3] == 3.0 assert output.data[1, 3] == 8.0 diff --git a/podpac/core/utils.py b/podpac/core/utils.py index 6811de09d..bb0929ed3 100644 --- a/podpac/core/utils.py +++ b/podpac/core/utils.py @@ -233,6 +233,9 @@ def default(self, obj): elif isinstance(obj, podpac.core.style.Style): return obj.definition + elif isinstance(obj, podpac.data.Interpolation): + return obj.definition + # pint Units elif isinstance(obj, podpac.core.units.ureg.Unit): return str(obj)