diff --git a/podpac/core/coordinates/coordinates.py b/podpac/core/coordinates/coordinates.py index fe1da0a86..292247eec 100644 --- a/podpac/core/coordinates/coordinates.py +++ b/podpac/core/coordinates/coordinates.py @@ -20,6 +20,7 @@ import xarray.core.coordinates from six import string_types import pyproj +import logging import podpac from podpac.core.settings import settings @@ -32,6 +33,14 @@ from podpac.core.coordinates.dependent_coordinates import DependentCoordinates from podpac.core.coordinates.rotated_coordinates import RotatedCoordinates +# Optional dependencies +from lazy_import import lazy_module, lazy_class + +rasterio = lazy_module("rasterio") + +# Set up logging +_logger = logging.getLogger(__name__) + class Coordinates(tl.HasTraits): """ @@ -442,6 +451,42 @@ def from_url(cls, url): return cls.from_definition(coords) + @classmethod + def from_geotransform(cls, geotransform, shape, crs=None): + """ Creates Coordinates from GDAL Geotransform. + + """ + # Handle the case of rotated coordinates + try: + rcoords = RotatedCoordinates.from_geotransform(geotransform) + coords = Coordinates([rcoords], dims=["lat,lon"], crs=crs) + except: + rcoords = None + _logger.debug("Rasterio source dataset does not have Rotated Coordinates") + + if rcoords is not None and rcoords.theta != 0: + # These are Rotated coordinates and we can return + return coords + + # Handle the case of uniform coordinates (not rotated, but N-S E-W aligned) + affine = rasterio.Affine.from_gdal(*geotransform) + if affine.e == affine.a == 0: + order = -1 + step = np.array([affine.d, affine.b]) + else: + order = 1 + step = np.array([affine.e, affine.a]) + + origin = affine.f + step[0] / 2, affine.c + step[1] / 2 + end = origin[0] + step[0] * (shape[::order][0] - 1), origin[1] + step[1] * (shape[::order][1] - 1) + coords = Coordinates( + [ + podpac.clinspace(origin[0], end[0], shape[::order][0], "lat"), + podpac.clinspace(origin[1], end[1], shape[::order][1], "lon"), + ][::order] + ) + return coords + @classmethod def from_definition(cls, d): """ @@ -692,6 +737,10 @@ def shape(self): return tuple(size for c in self._coords.values() for size in c.shape) + @property + def ushape(self): + return tuple(self[dim].size for dim in self.udims) + @property def ndim(self): """:int: Number of dimensions. """ @@ -801,6 +850,53 @@ def hash(self): json_d = json.dumps(self.full_definition, separators=(",", ":"), cls=podpac.core.utils.JSONEncoder) return hash_alg(json_d.encode("utf-8")).hexdigest() + @property + def geotransform(self): + """ :tuple: GDAL geotransform. """ + # Make sure we only have 1 time and alt dimension + if "time" in self.udims and self["time"].size > 1: + raise TypeError( + 'Only 2-D coordinates have a GDAL transform. This array has a "time" dimension of {} > 1'.format( + self["time"].size + ) + ) + if "alt" in self.udims and self["alt"].size > 1: + raise TypeError( + 'Only 2-D coordinates have a GDAL transform. This array has a "alt" dimension of {} > 1'.format( + self["alt"].size + ) + ) + + # Do the uniform coordinates case + if ( + "lat" in self.dims + and isinstance(self._coords["lat"], UniformCoordinates1d) + and "lon" in self.dims + and isinstance(self._coords["lon"], UniformCoordinates1d) + ): + if self.dims.index("lon") < self.dims.index("lat"): + first, second = "lat", "lon" + else: + first, second = "lon", "lat" # This case will have the exact correct geotransform + transform = rasterio.transform.Affine.translation( + self[first].start - self[first].step / 2, self[second].start - self[second].step / 2 + ) * rasterio.transform.Affine.scale(self[first].step, self[second].step) + transform = transform.to_gdal() + elif "lat,lon" in self.dims and isinstance(self._coords["lat,lon"], RotatedCoordinates): + transform = self._coords["lat,lon"].geotransform + elif "lon,lat" in self.dims and isinstance(self._coords["lon,lat"], RotatedCoordinates): + transform = self._coords["lon,lat"].geotransform + else: + raise TypeError( + "Only 2-D coordinates that are uniform or rotated have a GDAL transform. These coordinates " + "{} do not.".format(self) + ) + if self.udims.index("lon") < self.udims.index("lat"): + # transform = (transform[3], transform[5], transform[4], transform[0], transform[2], transform[1]) + transform = transform[3:] + transform[:3] + + return transform + # ------------------------------------------------------------------------------------------------------------------ # Methods # ------------------------------------------------------------------------------------------------------------------ diff --git a/podpac/core/coordinates/rotated_coordinates.py b/podpac/core/coordinates/rotated_coordinates.py index 4b1dd5524..a911f400c 100644 --- a/podpac/core/coordinates/rotated_coordinates.py +++ b/podpac/core/coordinates/rotated_coordinates.py @@ -116,10 +116,11 @@ def _validate_step(self, d): @classmethod def from_geotransform(cls, geotransform, shape, dims=None, ctypes=None, segment_lengths=None): affine = rasterio.Affine.from_gdal(*geotransform) - origin = affine.c, affine.f + origin = affine.f, affine.c deg = affine.rotation_angle scale = ~affine.rotation(deg) * ~affine.translation(*origin) * affine - step = np.array([scale.a, scale.e]) + step = np.array([scale.e, scale.a]) + origin = affine.f + step[0] / 2, affine.c + step[1] / 2 return cls(shape, np.deg2rad(deg), origin, step, dims=dims, ctypes=ctypes, segment_lengths=segment_lengths) @classmethod @@ -237,8 +238,15 @@ def corner(self): @property def geotransform(self): - """ :tuple: GDAL geotransform. """ - return self.affine.to_gdal() + """ :tuple: GDAL geotransform. + Note: This property may not provide the correct order of lat/lon in the geotransform as this class does not + always have knowledge of the dimension order of the specified dataset. As such it always supplies + geotransforms assuming that dims = ['lat', 'lon'] + """ + t = rasterio.Affine.translation(self.origin[1] - self.step[1] / 2, self.origin[0] - self.step[0] / 2) + r = rasterio.Affine.rotation(self.deg) + s = rasterio.Affine.scale(*self.step[::-1]) + return (t * r * s).to_gdal() @property def coordinates(self): diff --git a/podpac/core/coordinates/test/test_coordinates.py b/podpac/core/coordinates/test/test_coordinates.py index 449e1c32c..673f91059 100644 --- a/podpac/core/coordinates/test/test_coordinates.py +++ b/podpac/core/coordinates/test/test_coordinates.py @@ -1618,3 +1618,115 @@ def test_concat_crs(self): with pytest.raises(ValueError, match="Cannot concat Coordinates"): concat([c1, c2]) + + +class TestCoordinatesGeoTransform(object): + def uniform_working(self): + # order: -lat, lon + c = Coordinates([clinspace(1.5, 0.5, 5, "lat"), clinspace(1, 2, 9, "lon")]) + tf = np.array(c.geotransform).reshape(2, 3) + np.testing.assert_almost_equal( + tf, np.array([[c["lon"].area_bounds[0], c["lon"].step, 0], [c["lat"].area_bounds[1], 0, c["lat"].step]]) + ) + # order: lon, lat + c = Coordinates([clinspace(0.5, 1.5, 5, "lon"), clinspace(1, 2, 9, "lat")]) + tf = np.array(c.geotransform).reshape(2, 3) + np.testing.assert_almost_equal( + tf, np.array([[c["lon"].area_bounds[0], 0, c["lon"].step], [c["lat"].area_bounds[0], c["lat"].step, 0]]) + ) + + # order: lon, -lat, time + c = Coordinates([clinspace(0.5, 1.5, 5, "lon"), clinspace(2, 1, 9, "lat"), crange(10, 11, 2, "time")]) + tf = np.array(c.geotransform).reshape(2, 3) + np.testing.assert_almost_equal( + tf, np.array([[c["lon"].area_bounds[0], 0, c["lon"].step], [c["lat"].area_bounds[1], c["lat"].step, 0]]) + ) + # order: -lon, -lat, time, alt + c = Coordinates( + [ + clinspace(1.5, 0.5, 5, "lon"), + clinspace(2, 1, 9, "lat"), + crange(10, 11, 2, "time"), + crange(10, 11, 2, "alt"), + ] + ) + tf = np.array(c.geotransform).reshape(2, 3) + np.testing.assert_almost_equal( + tf, np.array([[c["lon"].area_bounds[1], 0, c["lon"].step], [c["lat"].area_bounds[1], c["lat"].step, 0]]) + ) + + def error_time_alt_too_big(self): + # time + c = Coordinates( + [ + clinspace(1.5, 0.5, 5, "lon"), + clinspace(2, 1, 9, "lat"), + crange(1, 11, 2, "time"), + crange(1, 11, 2, "alt"), + ] + ) + with pytest.raises( + TypeError, match='Only 2-D coordinates have a GDAL transform. This array has a "time" dimension of' + ): + c.geotransform + # alt + c = Coordinates([clinspace(1.5, 0.5, 5, "lon"), clinspace(2, 1, 9, "lat"), crange(1, 11, 2, "alt")]) + with pytest.raises( + TypeError, match='Only 2-D coordinates have a GDAL transform. This array has a "alt" dimension of' + ): + c.geotransform + + def rot_coords_working(self): + # order -lat, lon + rc = RotatedCoordinates(shape=(4, 3), theta=np.pi / 8, origin=[10, 20], step=[-2.0, 1.0], dims=["lat", "lon"]) + c = Coordinates([rc], dims=["lat,lon"]) + tf = np.array(c.geotransform).reshape(2, 3) + np.testing.assert_almost_equal( + tf, + np.array( + [ + [rc.origin[1] - rc.step[1] / 2, rc.step[1] * np.cos(rc.theta), -rc.step[0] * np.sin(rc.theta)], + [rc.origin[0] - rc.step[0] / 2, rc.step[1] * np.sin(rc.theta), rc.step[0] * np.cos(rc.theta)], + ] + ), + ) + # order lon, lat + rc = RotatedCoordinates(shape=(4, 3), theta=np.pi / 8, origin=[10, 20], step=[2.0, 1.0], dims=["lon", "lat"]) + c = Coordinates([rc], dims=["lon,lat"]) + tf = np.array(c.geotransform).reshape(2, 3) + np.testing.assert_almost_equal( + tf, + np.array( + [ + [rc.origin[0] - rc.step[0] / 2, rc.step[1] * np.sin(rc.theta), rc.step[0] * np.cos(rc.theta)], + [rc.origin[1] - rc.step[1] / 2, rc.step[1] * np.cos(rc.theta), -rc.step[0] * np.sin(rc.theta)], + ] + ), + ) + + # order -lon, lat + rc = RotatedCoordinates(shape=(4, 3), theta=np.pi / 8, origin=[10, 20], step=[-2.0, 1.0], dims=["lon", "lat"]) + c = Coordinates([rc], dims=["lon,lat"]) + tf = np.array(c.geotransform).reshape(2, 3) + np.testing.assert_almost_equal( + tf, + np.array( + [ + [rc.origin[0] - rc.step[0] / 2, rc.step[1] * np.sin(rc.theta), rc.step[0] * np.cos(rc.theta)], + [rc.origin[1] - rc.step[1] / 2, rc.step[1] * np.cos(rc.theta), -rc.step[0] * np.sin(rc.theta)], + ] + ), + ) + # order -lat, -lon + rc = RotatedCoordinates(shape=(4, 3), theta=np.pi / 8, origin=[10, 20], step=[-2.0, -1.0], dims=["lat", "lon"]) + c = Coordinates([rc], dims=["lat,lon"]) + tf = np.array(c.geotransform).reshape(2, 3) + np.testing.assert_almost_equal( + tf, + np.array( + [ + [rc.origin[1] - rc.step[1] / 2, rc.step[1] * np.cos(rc.theta), -rc.step[0] * np.sin(rc.theta)], + [rc.origin[0] - rc.step[0] / 2, rc.step[1] * np.sin(rc.theta), rc.step[0] * np.cos(rc.theta)], + ] + ), + ) diff --git a/podpac/core/coordinates/test/test_rotated_coordinates.py b/podpac/core/coordinates/test/test_rotated_coordinates.py index 11b1c3036..e933f8478 100644 --- a/podpac/core/coordinates/test/test_rotated_coordinates.py +++ b/podpac/core/coordinates/test/test_rotated_coordinates.py @@ -129,11 +129,17 @@ def test_copy(self): class TestRotatedCoordinatesGeotransform(object): def test_geotransform(self): c = RotatedCoordinates(shape=(3, 4), theta=np.pi / 4, origin=[10, 20], step=[1.0, 2.0], dims=["lat", "lon"]) - assert_allclose(c.geotransform, (10.0, 0.7071068, -1.4142136, 20.0, 0.7071068, 1.4142136)) + assert_allclose(c.geotransform, (19.0, 1.4142136, -0.7071068, 9.5, 1.4142136, 0.7071068)) c2 = RotatedCoordinates.from_geotransform(c.geotransform, c.shape, dims=["lat", "lon"]) assert c == c2 + c = RotatedCoordinates(shape=(3, 4), theta=np.pi / 4, origin=[10, 20], step=[1.0, 2.0], dims=["lon", "lat"]) + assert_allclose(c.geotransform, (19.0, 1.4142136, -0.7071068, 9.5, 1.4142136, 0.7071068)) + + c2 = RotatedCoordinates.from_geotransform(c.geotransform, c.shape, dims=["lon", "lat"]) + assert c == c2 + class TestRotatedCoordinatesStandardMethods(object): def test_eq_type(self): diff --git a/podpac/core/data/file.py b/podpac/core/data/file.py index ca7563d58..c5830d3cf 100644 --- a/podpac/core/data/file.py +++ b/podpac/core/data/file.py @@ -12,11 +12,14 @@ import traitlets as tl import pandas as pd import xarray as xr +import pyproj +import logging from podpac.core.settings import settings from podpac.core.utils import common_doc, trait_is_defined from podpac.core.data.datasource import COMMON_DATA_DOC, DataSource from podpac.core.coordinates import Coordinates, UniformCoordinates1d, ArrayCoordinates1d, StackedCoordinates +from podpac.core.coordinates import RotatedCoordinates from podpac.core.coordinates.utils import Dimension, VALID_DIMENSION_NAMES # Optional dependencies @@ -30,6 +33,9 @@ zarrGroup = lazy_class("zarr.Group") s3fs = lazy_module("s3fs") +# Set up logging +_logger = logging.getLogger(__name__) + @common_doc(COMMON_DATA_DOC) class DatasetSource(DataSource): @@ -676,8 +682,22 @@ class Rasterio(DataSource): source = tl.Union([tl.Unicode(), tl.Instance(BytesIO)]).tag(readonly=True) dataset = tl.Any().tag(readonly=True) + @property + def nan_vals(self): + return list(self.dataset.nodatavals) + # node attrs - band = tl.CInt(1).tag(attr=True) + band = tl.CInt(allow_none=True).tag(attr=True) + + @tl.default("band") + def _band_default(self): + if (self.outputs is not None) and (self.output is not None): + band = self.outputs.index(self.output) + elif self.outputs is None: + band = 1 + else: + band = None # All bands + return band @tl.default("dataset") def _open_dataset(self): @@ -719,21 +739,18 @@ def get_native_coordinates(self): # check to see if the coordinates are rotated used affine affine = self.dataset.transform - if affine[1] != 0.0 or affine[3] != 0.0: - raise NotImplementedError("Rotated coordinates are not yet supported") - try: + if isinstance(self.dataset.crs, rasterio.crs.CRS): + crs = self.dataset.crs.wkt + elif isinstance(self.dataset.crs, dict) and "init" in self.dataset.crs: crs = self.dataset.crs["init"].upper() - except: - crs = None - - # get bounds - left, bottom, right, top = self.dataset.bounds + else: + try: + crs = pyproj.CRS(self.dataset.crs).to_wkt() + except: + raise RuntimeError("Unexpected rasterio crs '%s'" % self.dataset.crs) - # rasterio reads data upside-down from coordinate conventions, so lat goes from top to bottom - lat = UniformCoordinates1d(top, bottom, size=self.dataset.height, name="lat") - lon = UniformCoordinates1d(left, right, size=self.dataset.width, name="lon") - return Coordinates([lat, lon], dims=["lat", "lon"], crs=crs) + return Coordinates.from_geotransform(affine.to_gdal(), self.dataset.shape, crs) @common_doc(COMMON_DATA_DOC) def get_data(self, coordinates, coordinates_index): @@ -744,7 +761,12 @@ def get_data(self, coordinates, coordinates_index): # read data within coordinates_index window window = ((slc[0].start, slc[0].stop), (slc[1].start, slc[1].stop)) - raster_data = self.dataset.read(self.band, out_shape=tuple(coordinates.shape), window=window) + + if self.outputs is not None: # read all the bands + raster_data = self.dataset.read(out_shape=(len(self.outputs),) + tuple(coordinates.shape), window=window) + raster_data = np.moveaxis(raster_data, 0, 2) + else: # read the requested band + raster_data = self.dataset.read(self.band, out_shape=tuple(coordinates.shape), window=window) # set raster data to output array data.data.ravel()[:] = raster_data.ravel() diff --git a/podpac/core/node.py b/podpac/core/node.py index d7d69d0c8..28ab48818 100644 --- a/podpac/core/node.py +++ b/podpac/core/node.py @@ -275,6 +275,10 @@ def create_output_array(self, coords, data=np.nan, **kwargs): attrs["crs"] = coords.crs if self.units is not None: attrs["units"] = ureg.Unit(self.units) + try: + attrs["geotransform"] = coords.geotransform + except (TypeError, AttributeError): + pass return UnitsDataArray.create(coords, data=data, outputs=self.outputs, dtype=self.dtype, attrs=attrs, **kwargs) diff --git a/podpac/core/test/test_units.py b/podpac/core/test/test_units.py index 96c888655..a73f5e72e 100644 --- a/podpac/core/test/test_units.py +++ b/podpac/core/test/test_units.py @@ -7,7 +7,7 @@ import xarray as xr from pint.errors import DimensionalityError -from podpac.core.coordinates import Coordinates +from podpac.core.coordinates import Coordinates, clinspace, RotatedCoordinates from podpac.core.style import Style from podpac.core.units import ureg @@ -15,7 +15,7 @@ from podpac.core.units import to_image from podpac.core.units import create_dataarray # DEPRECATED -from podpac.data import Array +from podpac.data import Array, Rasterio class TestUnitDataArray(object): @@ -509,3 +509,135 @@ def test_to_image(self): def test_to_image_vmin_vmax(self): data = np.ones((10, 10)) assert isinstance(to_image(data, vmin=0, vmax=2, return_base64=True), bytes) + + +class TestToGeoTiff(object): + def make_square_array(self, order=1, bands=1): + # order = -1 + # bands = 3 + node = Array( + source=np.arange(8 * bands).reshape(3 - order, 3 + order, bands), + native_coordinates=Coordinates([clinspace(4, 0, 2, "lat"), clinspace(1, 4, 4, "lon")][::order]), + outputs=[str(s) for s in list(range(bands))], + ) + return node + + def make_rot_array(self, order=1, bands=1): + # order = -1 + # bands = 3 + rc = RotatedCoordinates( + shape=(2, 4), theta=np.pi / 8, origin=[10, 20], step=[-2.0, 1.0], dims=["lat", "lon"][::order] + ) + c = Coordinates([rc]) + node = Array( + source=np.arange(8 * bands).reshape(3 - order, 3 + order, bands), + native_coordinates=c, + outputs=[str(s) for s in list(range(bands))], + ) + return node + + def test_to_geotiff_rountrip_1band(self): + # lat/lon order, usual + node = self.make_square_array() + out = node.eval(node.native_coordinates) + fp = io.BytesIO() + out.to_geotiff(fp) + fp.write(b"a") # for some reason needed to get good comparison + fp.seek(0) + rnode = Rasterio(source=fp, outputs=node.outputs, mode="r") + + assert node.native_coordinates == rnode.native_coordinates + + rout = rnode.eval(rnode.native_coordinates) + np.testing.assert_almost_equal(out.data, rout.data) + + # lon/lat order, unsual + node = self.make_square_array(order=-1) + out = node.eval(node.native_coordinates) + fp = io.BytesIO() + out.to_geotiff(fp) + fp.write(b"a") # for some reason needed to get good comparison + fp.seek(0) + rnode = Rasterio(source=fp, outputs=node.outputs) + + assert node.native_coordinates == rnode.native_coordinates + + rout = rnode.eval(rnode.native_coordinates) + np.testing.assert_almost_equal(out.data, rout.data) + + def test_to_geotiff_rountrip_2band(self): + # lat/lon order, usual + node = self.make_square_array(bands=2) + out = node.eval(node.native_coordinates) + fp = io.BytesIO() + out.to_geotiff(fp) + fp.write(b"a") # for some reason needed to get good comparison + fp.seek(0) + rnode = Rasterio(source=fp, outputs=node.outputs, mode="r") + + assert node.native_coordinates == rnode.native_coordinates + + rout = rnode.eval(rnode.native_coordinates) + np.testing.assert_almost_equal(out.data, rout.data) + + # lon/lat order, unsual + node = self.make_square_array(order=-1, bands=2) + out = node.eval(node.native_coordinates) + fp = io.BytesIO() + out.to_geotiff(fp) + fp.write(b"a") # for some reason needed to get good comparison + fp.seek(0) + rnode = Rasterio(source=fp, outputs=node.outputs) + + assert node.native_coordinates == rnode.native_coordinates + + rout = rnode.eval(rnode.native_coordinates) + np.testing.assert_almost_equal(out.data, rout.data) + + # Check single output + fp.seek(0) + rnode = Rasterio(source=fp, outputs=node.outputs, output=node.outputs[1]) + rout = rnode.eval(rnode.native_coordinates) + np.testing.assert_almost_equal(out.data[..., 1], rout.data) + + # Check single band 1 + fp.seek(0) + rnode = Rasterio(source=fp, band=1) + rout = rnode.eval(rnode.native_coordinates) + np.testing.assert_almost_equal(out.data[..., 0], rout.data) + + # Check single band 2 + fp.seek(0) + rnode = Rasterio(source=fp, band=2) + rout = rnode.eval(rnode.native_coordinates) + np.testing.assert_almost_equal(out.data[..., 1], rout.data) + + @pytest.mark.skip("TODO: We can remove this skipped test after solving #363") + def test_to_geotiff_rountrip_rotcoords(self): + # lat/lon order, usual + node = self.make_rot_array() + out = node.eval(node.native_coordinates) + fp = io.BytesIO() + out.to_geotiff(fp) + fp.write(b"a") # for some reason needed to get good comparison + fp.seek(0) + rnode = Rasterio(source=fp, outputs=node.outputs, mode="r") + + assert node.native_coordinates == rnode.native_coordinates + + rout = rnode.eval(rnode.native_coordinates) + np.testing.assert_almost_equal(out.data, rout.data) + + # lon/lat order, unsual + node = self.make_square_array(order=-1) + out = node.eval(node.native_coordinates) + fp = io.BytesIO() + out.to_geotiff(fp) + fp.write(b"a") # for some reason needed to get good comparison + fp.seek(0) + rnode = Rasterio(source=fp, outputs=node.outputs) + + assert node.native_coordinates == rnode.native_coordinates + + rout = rnode.eval(rnode.native_coordinates) + np.testing.assert_almost_equal(out.data, rout.data) diff --git a/podpac/core/units.py b/podpac/core/units.py index 147aea670..7f5ff3712 100644 --- a/podpac/core/units.py +++ b/podpac/core/units.py @@ -15,6 +15,7 @@ from io import BytesIO import base64 +import logging try: import cPickle # Python 2.7 @@ -35,6 +36,14 @@ from podpac.core.utils import JSONEncoder from podpac.core.style import Style +# Optional dependencies +from lazy_import import lazy_module, lazy_class + +rasterio = lazy_module("rasterio") + +# Set up logging +_logger = logging.getLogger(__name__) + class UnitsDataArray(xr.DataArray): """Like xarray.DataArray, but transfers units @@ -172,56 +181,9 @@ def to_format(self, format, *args, **kwargs): elif format in ["png", "jpg", "jpeg"]: r = self.to_image(format, *args, **kwargs) elif format.upper() in ["TIFF", "TIF", "GEOTIFF"]: - # This only works for data that essentially has lat/lon only - dims = self.coords.dims - if "lat" not in dims or "lon" not in dims: - raise NotImplementedError("Cannot export GeoTIFF for dataset with lat/lon coordinates.") - if "time" in dims and len(self.coords["time"] > 1): - raise NotImplemented("Cannot export GeoTIFF for dataset with multiple times,") - if "alt" in dims and len(self.coords["alt"] > 1): - raise NotImplemented("Cannot export GeoTIFF for dataset with multiple altitudes.") - - # Get the crs and geotransform that describes the coordinates - crs = self.attrs.get("crs", "EPSG:4326") - coords = podpac.Coordinates.from_xarray(self, crs=crs) - - # TODO: add proper checks, etc. to make sure we handle edge cases and throw errors when we cannot support - # i.e. do work to remove this warning. - _logger.warning("GeoTIFF export assumes data is in a uniform, non-rotated coordinate system.") - - # Build the transform from a translation and scaling - transform = rasterio.transform.Affine.translate( - min(self.coords.area_bounds["lon"]), max(self.coords.area_bounds["lat"]) - ) * rasterio.transform.Affine.scale( - (max(self.coords.bounds["lon"]) - min(self.coords.bounds["lon"])) / coords["lon"].size, - (max(self.coords.bounds["lat"]) - min(self.coords.bounds["lat"])) / coords["lat"].size, - ) - - # Update the kwargs that rasterio will use. Anything added by the user will take priority. - kwargs2 = dict( - drive="GTiff", - height=self.coords["lat"].size, - width=self.coords["lon"].size, - count=1, - dtype=data.dtype, - crs=crs, - transform=transform, - ) - kwargs2.update(kwargs) - - # Get the data - dtype = kwargs.get("dtype", np.float32) - data = self.data.astype(dtype).squeeze() - if dims.index("lat") > dims.index("lon"): - data = data.T - - # Write the file - with rasterio.open(*args, **kwargs2) as dst: - r = dst.write(data, 1) - + r = self.to_geotiff(*args, **kwargs) elif format in ["pickle", "pkl"]: r = cPickle.dumps(self) - elif format == "zarr_part": if part in kwargs: part = [slice(*sss) for sss in kwargs.pop("part")] @@ -261,6 +223,12 @@ def to_image(self, format="png", vmin=None, vmax=None, return_base64=False): """ return to_image(self, format, vmin, vmax, return_base64) + def to_geotiff(self, fp, geotransform=None, crs=None, **kwargs): + """ + For documentation, see `core.units.to_geotiff` + """ + to_geotiff(fp, self, geotransform=geotransform, crs=crs, **kwargs) + def serialize(self): if self.attrs.get("units"): self.attrs["units"] = str(self.attrs["units"]) @@ -602,3 +570,99 @@ def to_image(data, format="png", vmin=None, vmax=None, return_base64=False): return base64.b64encode(im_data.getvalue()) else: return im_data + + +def to_geotiff(fp, data, geotransform=None, crs=None, **kwargs): + """ Export a UnitsDataArray to a Geotiff + + Params + ------- + fp: str, file object or pathlib.Path object + A filename or URL, a file object opened in binary ('rb') mode, or a Path object. + data: UnitsDataArray, xr.DataArray, np.ndarray + The data to be saved. If there is more than 1 band, this should be the last dimension of the array. + If given a np.ndarray, ensure that the 'lat' dimension is aligned with the rows of the data, with an appropriate + geotransform. + geotransform: tuple, optional + The geotransform that describes the input data. If not given, will look for data.attrs['geotransform'] + crs: str, optional + The coordinate reference system for the data + kwargs: **dict + Additional key-word arguments that overwrite defaults used in the `rasterio.open` function. This function + populates the following defaults: + drive="GTiff" + height=data.shape[0] + width=data.shape[1] + count=data.shape[2] + dtype=data.dtype + mode="w" + + """ + + # This only works for data that essentially has lat/lon only + dims = data.coords.dims + if "lat" not in dims or "lon" not in dims: + raise NotImplementedError("Cannot export GeoTIFF for dataset with lat/lon coordinates.") + if "time" in dims and len(data.coords["time"] > 1): + raise NotImplemented("Cannot export GeoTIFF for dataset with multiple times,") + if "alt" in dims and len(data.coords["alt"] > 1): + raise NotImplemented("Cannot export GeoTIFF for dataset with multiple altitudes.") + + # TODO: add proper checks, etc. to make sure we handle edge cases and throw errors when we cannot support + # i.e. do work to remove this warning. + _logger.warning("GeoTIFF export assumes data is in a uniform, non-rotated coordinate system.") + + # Get the crs and geotransform that describes the coordinates + if crs is None: + crs = data.attrs.get("crs") + if crs is None: + raise ValueError( + "The `crs` of the data needs to be provided to save as GeoTIFF. If supplying a UnitsDataArray, created " + " through a PODPAC Node, the crs should be automatically populated. If not, please file an issue." + ) + if geotransform is None: + geotransform = data.attrs.get("geotransform") + # Geotransform should ALWAYS be defined as (lon_origin, lon_dj, lon_di, lat_origin, lat_dj, lat_di) + # if isinstance(data, xr.DataArray) and data.dims.index('lat') > data.dims.index('lon'): + # geotransform = geotransform[3:] + geotransform[:3] + if geotransform is None: + raise ValueError( + "The `geotransform` of the data needs to be provided to save as GeoTIFF. If the geotransform attribute " + "wasn't automatically populated as part of the dataset, it means that the data is in a non-uniform " + "coordinate system. This can sometimes happen when the data is transformed to a different CRS than the " + "native CRS, which can cause the coordinates to seems non-uniform due to floating point precision. " + ) + + # Make all types into a numpy array + if isinstance(data, xr.DataArray): + data = data.data + + # Get the data + dtype = kwargs.get("dtype", np.float32) + data = data.astype(dtype).squeeze() + + if len(data.shape) == 2: + data = data[:, :, None] + + geotransform = rasterio.Affine.from_gdal(*geotransform) + + # Update the kwargs that rasterio will use. Anything added by the user will take priority. + kwargs2 = dict( + driver="GTiff", + height=data.shape[0], + width=data.shape[1], + count=data.shape[2], + dtype=data.dtype, + crs=crs, + transform=geotransform, + mode="w", + ) + kwargs2.update(kwargs) + + # Write the file + r = [] + with rasterio.open(fp, **kwargs2) as dst: + for i in range(data.shape[2]): + r.append(dst.write(data[..., i], i + 1)) + + return r