Skip to content

Commit

Permalink
cleaning up and fixing code when destination is passed (EOA-team#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasValentin committed Jun 2, 2023
1 parent 96224a9 commit e6c6e50
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions eodal/core/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class GeoInfo(object):

def __init__(
self,
epsg: int,
epsg: int | CRS,
ulx: Union[int, float],
uly: Union[int, float],
pixres_x: Union[int, float],
Expand Down Expand Up @@ -205,10 +205,13 @@ def __init__(
system.
"""
# make sure the EPSG code is valid
try:
CRS.from_epsg(epsg)
except Exception as e:
raise ValueError(e)
if isinstance(epsg, CRS):
epsg = epsg.to_epsg()
else:
try:
CRS.from_epsg(epsg)
except Exception as e:
raise ValueError(e)

object.__setattr__(self, "epsg", epsg)
object.__setattr__(self, "ulx", ulx)
Expand Down Expand Up @@ -1918,9 +1921,7 @@ def resample(
def reproject(
self,
target_crs: Union[int, CRS],
dst_transform: Optional[Affine] = None,
interpolation_method: Optional[int] = Resampling.nearest,
num_threads: Optional[int] = 1,
inplace: Optional[bool] = False,
**kwargs,
):
Expand All @@ -1936,13 +1937,12 @@ def reproject(
:param interpolation_method:
interpolation method to use for interpolating grid cells after
reprojection. Default is neares neighbor interpolation.
:param num_threads:
number of threads to use for the operation. Uses a single thread by
default.
:param inplace:
if False (default) returns a copy of the ``Band`` instance
with the changes applied. If True overwrites the values
in the current instance.
:param kwargs:
optional keyword arguments to pass to ``rasterio.warp.reproject``.
:returns:
``Band`` instance if `inplace` is False, None instead.
"""
Expand All @@ -1953,9 +1953,7 @@ def reproject(
"src_transform": self.transform,
"dst_crs": target_crs,
"src_nodata": self.nodata,
"resampling": interpolation_method,
"num_threads": num_threads,
"dst_transform": dst_transform,
"resampling": interpolation_method
}
reprojection_options.update(kwargs)

Expand All @@ -1971,13 +1969,11 @@ def reproject(

try:
# set destination array in case dst_transfrom is provided
if (
"dst_transform" in reprojection_options.keys()
and reprojection_options.get("dst_transfrom") is not None
):
if reprojection_options.get("dst_transfrom") is not None:
if "destination" not in reprojection_options.keys():
dst = np.zeros_like(band_data)
reprojection_options.update({"destination": dst})
raise ValueError(
'"destination" must be provided ' +
'alongside "dst_transform"')

out_data, out_transform = reproject_raster_dataset(
raster=band_data, **reprojection_options
Expand All @@ -1986,14 +1982,20 @@ def reproject(
raise ReprojectionError(f"Could not re-project band {self.band_name}: {e}")

# cast array back to original dtype
out_data = out_data[0, :, :].astype(self.values.dtype)
if len(out_data.shape) == 2:
out_data = out_data.astype(self.values.dtype)
elif len(out_data.shape) == 3:
out_data = out_data[0, :, :].astype(self.values.dtype)

# reproject the mask separately
if self.is_masked_array:
out_mask, _ = reproject_raster_dataset(
raster=band_mask, **reprojection_options
)
out_mask = out_mask[0, :, :].astype(bool)
if len(out_mask.shape) == 2:
out_mask = out_mask.astype(bool)
elif len(out_mask.shape) == 3:
out_mask = out_mask[0, :, :].astype(bool)
# mask also those pixels which were set to nodata after reprojection
# due to the raster alignment
nodata = reprojection_options.get("src_nodata", 0)
Expand Down

0 comments on commit e6c6e50

Please sign in to comment.