diff --git a/eodal/core/band.py b/eodal/core/band.py index e0c6250..8473a45 100644 --- a/eodal/core/band.py +++ b/eodal/core/band.py @@ -176,7 +176,7 @@ class GeoInfo(object): def __init__( self, - epsg: int, + epsg: int | CRS, ulx: Union[int, float], uly: Union[int, float], pixres_x: Union[int, float], @@ -205,10 +205,13 @@ def __init__( system. """ # make sure the EPSG code is valid - try: - CRS.from_epsg(epsg) - except Exception as e: - raise ValueError(e) + if isinstance(epsg, CRS): + epsg = epsg.to_epsg() + else: + try: + CRS.from_epsg(epsg) + except Exception as e: + raise ValueError(e) object.__setattr__(self, "epsg", epsg) object.__setattr__(self, "ulx", ulx) @@ -1918,9 +1921,7 @@ def resample( def reproject( self, target_crs: Union[int, CRS], - dst_transform: Optional[Affine] = None, interpolation_method: Optional[int] = Resampling.nearest, - num_threads: Optional[int] = 1, inplace: Optional[bool] = False, **kwargs, ): @@ -1936,13 +1937,12 @@ def reproject( :param interpolation_method: interpolation method to use for interpolating grid cells after reprojection. Default is neares neighbor interpolation. - :param num_threads: - number of threads to use for the operation. Uses a single thread by - default. :param inplace: if False (default) returns a copy of the ``Band`` instance with the changes applied. If True overwrites the values in the current instance. + :param kwargs: + optional keyword arguments to pass to ``rasterio.warp.reproject``. :returns: ``Band`` instance if `inplace` is False, None instead. """ @@ -1953,9 +1953,7 @@ def reproject( "src_transform": self.transform, "dst_crs": target_crs, "src_nodata": self.nodata, - "resampling": interpolation_method, - "num_threads": num_threads, - "dst_transform": dst_transform, + "resampling": interpolation_method } reprojection_options.update(kwargs) @@ -1971,13 +1969,11 @@ def reproject( try: # set destination array in case dst_transfrom is provided - if ( - "dst_transform" in reprojection_options.keys() - and reprojection_options.get("dst_transfrom") is not None - ): + if reprojection_options.get("dst_transfrom") is not None: if "destination" not in reprojection_options.keys(): - dst = np.zeros_like(band_data) - reprojection_options.update({"destination": dst}) + raise ValueError( + '"destination" must be provided ' + + 'alongside "dst_transform"') out_data, out_transform = reproject_raster_dataset( raster=band_data, **reprojection_options @@ -1986,14 +1982,20 @@ def reproject( raise ReprojectionError(f"Could not re-project band {self.band_name}: {e}") # cast array back to original dtype - out_data = out_data[0, :, :].astype(self.values.dtype) + if len(out_data.shape) == 2: + out_data = out_data.astype(self.values.dtype) + elif len(out_data.shape) == 3: + out_data = out_data[0, :, :].astype(self.values.dtype) # reproject the mask separately if self.is_masked_array: out_mask, _ = reproject_raster_dataset( raster=band_mask, **reprojection_options ) - out_mask = out_mask[0, :, :].astype(bool) + if len(out_mask.shape) == 2: + out_mask = out_mask.astype(bool) + elif len(out_mask.shape) == 3: + out_mask = out_mask[0, :, :].astype(bool) # mask also those pixels which were set to nodata after reprojection # due to the raster alignment nodata = reprojection_options.get("src_nodata", 0)