Skip to content

Commit

Permalink
ENH: Merge pull request #485 from creare-com/feature/get_source_data
Browse files Browse the repository at this point in the history
Feature: general get_source_data for datasources and tile compositors.
  • Loading branch information
jmilloy authored Aug 3, 2021
2 parents dce22c5 + 91be103 commit 180af3d
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 0 deletions.
10 changes: 10 additions & 0 deletions doc/source/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ node = podpac.datalib.TerrainTiles(tile_format='geotiff', zoom=8)
# ... and more each release
```

Retrieve the raw source data array at full/native resolution. **Note**: Some data source are too large to fit in RAM, and calling this function can crash Python.

```python
# retrieve full source data
node.get_source_data()

# retrieve bounded source data
node.get_source_data(bounds={'lat': (40, 45), 'lon': (-70, -75)})
```

## Coordinates

Define geospatial and temporal dataset coordinates.
Expand Down
24 changes: 24 additions & 0 deletions podpac/core/compositor/test/test_tiled_compositor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,27 @@ def test_composition_stacked_multiindex_names(self):
np.testing.assert_array_equal(output["lat"], [3, 4, 5, 6])
np.testing.assert_array_equal(output["lon"], [3, 4, 5, 6])
np.testing.assert_array_equal(output, [103, 104, 200, 201])

def test_get_source_data(self):
a = ArrayRaw(source=np.arange(5) + 100, coordinates=podpac.Coordinates([[0, 1, 2, 3, 4]], dims=["lat"]))
b = ArrayRaw(source=np.arange(5) + 200, coordinates=podpac.Coordinates([[5, 6, 7, 8, 9]], dims=["lat"]))
c = ArrayRaw(source=np.arange(5) + 300, coordinates=podpac.Coordinates([[10, 11, 12, 13, 14]], dims=["lat"]))

node = TileCompositorRaw(sources=[a, b, c])

data = node.get_source_data()
np.testing.assert_array_equal(data["lat"], np.arange(15))
np.testing.assert_array_equal(data, np.hstack([source.source for source in node.sources]))

# with bounds
data = node.get_source_data({"lat": (2.5, 6.5)})
np.testing.assert_array_equal(data["lat"], [3, 4, 5, 6])
np.testing.assert_array_equal(data, [103, 104, 200, 201])

# error
with podpac.settings:
podpac.settings.set_unsafe_eval(True)
d = podpac.algorithm.Arithmetic(eqn="a+2", a=a)
node = TileCompositorRaw(sources=[a, b, c, d])
with pytest.raises(ValueError, match="Cannot get composited source data"):
node.get_source_data()
25 changes: 25 additions & 0 deletions podpac/core/compositor/tile_compositor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ def composite(self, coordinates, data_arrays, result=None):
return result
return res

def get_source_data(self, bounds={}):
"""
Get composited source data, without interpolation.
Arguments
---------
bounds : dict
Dictionary of bounds by dimension, optional.
Keys must be dimension names, and values are (min, max) tuples, e.g. ``{'lat': (10, 20)}``.
Returns
-------
data : UnitsDataArray
Source data
"""

if any(not hasattr(source, "get_source_data") for source in self.sources):
raise ValueError(
"Cannot get composited source data; all sources must have `get_source_data` implemented (such as nodes derived from a DataSource or TileCompositor node)."
)

coords = None # n/a
source_data_arrays = (source.get_source_data(bounds) for source in self.sources) # generator
return self.composite(coords, source_data_arrays)


class TileCompositor(InterpolationMixin, TileCompositorRaw):
pass
19 changes: 19 additions & 0 deletions podpac/core/data/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,25 @@ def _get_data(self, rc, rci):
# Methods
# ------------------------------------------------------------------------------------------------------------------

def get_source_data(self, bounds={}):
"""
Get source data, without interpolation.
Arguments
---------
bounds : dict
Dictionary of bounds by dimension, optional.
Keys must be dimension names, and values are (min, max) tuples, e.g. ``{'lat': (10, 20)}``.
Returns
-------
data : UnitsDataArray
Source data
"""

coords, I = self.coordinates.select(bounds, return_index=True)
return self._get_data(coords, I)

def eval(self, coordinates, **kwargs):
"""
Wraps the super Node.eval method in order to cache with the correct coordinates.
Expand Down
20 changes: 20 additions & 0 deletions podpac/core/data/ogr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ def extents(self):
layer = self.datasource.GetLayerByName(self.layer)
return layer.GetExtent()

def get_source_data(self, bounds={}):
"""
Raise a user-friendly exception when calling get_source_data for this node.
Arguments
---------
bounds : dict
Dictionary of bounds by dimension, optional.
Keys must be dimension names, and values are (min, max) tuples, e.g. ``{'lat': (10, 20)}``.
raises
------
AttributeError : Cannot get source data for OGR datasources
"""

raise AttributeError(
"Cannot get source data for OGR datasources. "
"The source data is a vector-based shapefile without a native resolution."
)

@common_doc(COMMON_NODE_DOC)
def _eval(self, coordinates, output=None, _selector=None):
if "lat" not in coordinates.udims or "lon" not in coordinates.udims:
Expand Down
18 changes: 18 additions & 0 deletions podpac/core/data/test/test_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,24 @@ def test_eval_get_cache_transform_crs(self):
node.eval(node.coordinates.transform("EPSG:4326"))
assert node._from_cache

def test_get_source_data(self):
node = podpac.data.Array(
source=np.ones((3, 4)),
coordinates=podpac.Coordinates([range(3), range(4)], ["lat", "lon"]),
)

data = node.get_source_data()
np.testing.assert_array_equal(data, node.source)

def test_get_source_data_with_bounds(self):
node = podpac.data.Array(
source=np.ones((3, 4)),
coordinates=podpac.Coordinates([range(3), range(4)], ["lat", "lon"]),
)

data = node.get_source_data({"lon": (1.5, 4.5)})
np.testing.assert_array_equal(data, node.source[:, 2:])


class TestDataSourceWithMultipleOutputs(object):
def test_evaluate_no_overlap_with_output_extract_output(self):
Expand Down

0 comments on commit 180af3d

Please sign in to comment.