Skip to content

Commit

Permalink
suffel and save tile function implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
iamtekson committed Sep 12, 2023
1 parent 572764c commit c47103c
Showing 1 changed file with 135 additions and 8 deletions.
143 changes: 135 additions & 8 deletions geotile/GeoTile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rasterio.warp import calculate_default_transform, reproject
from rasterio.enums import Resampling
from rasterio.features import rasterize
from rasterio.transform import Affine

# geopandas library
import geopandas as gpd
Expand Down Expand Up @@ -104,9 +105,55 @@ def _calculate_offset(self, stride_x: Optional[int] = None, stride_y: Optional[i
X = [x for x in range(0, self.width, stride_x)]
Y = [y for y in range(0, self.height, stride_y)]
offsets = list(itertools.product(X, Y))
return offsets

self.offsets = offsets

def _windows_transform_to_affine(self, window_transform: Optional[tuple]):
"""Convert the window transform to affine transform
Parameters
----------
window_transform: tuple
tuple of window transform
Returns
-------
tuple: tuple of affine transform
"""
a, b, c, d, e, f, _, _, _ = window_transform
return Affine(a, b, c, d, e, f)


def suffel_tiles(self, random_state: Optional[int] = None):
"""Shuffle the tiles
Parameters
----------
random_state: int
Random state for shuffling the tiles
Returns
-------
None: Shuffle the tiles. The offsets will be shuffled in place
Examples
--------
>>> from geotile import GeoTile
>>> tiler = GeoTile('/path/to/raster/file.tif')
>>> tiler.shuffle_tiles()
"""
# check if random_state is not None
if random_state is not None:
self.random_state = random_state
np.random.seed(self.random_state)

assert len(self.offsets) == len(self.window_data) == len(self.window_transform), "The number of offsets and window data should be same"

# shuffle the offsets and window data
p = np.random.permutation(len(self.offsets))
self.offsets = np.array(self.offsets)[p]
self.window_data = np.array(self.window_data)[p]
self.window_transform = np.array(self.window_transform)[p]

def tile_info(self):
"""Get the information of the tiles
Expand All @@ -126,13 +173,14 @@ def tile_info(self):
def generate_tiles(
self,
output_folder: str,
save_tiles: Optional[bool] = True,
out_bands: Optional[list] = None,
image_format: Optional[str] = None,
dtype: Optional[str] = None,
tile_x: Optional[int] = 256,
tile_y: Optional[int] = 256,
stride_x: Optional[int] = 128,
stride_y: Optional[int] = 128
stride_y: Optional[int] = 128,
):
"""
Save the tiles to the output folder
Expand All @@ -141,6 +189,8 @@ def generate_tiles(
----------
output_folder : str
Path to the output folder
save_tiles : bool
If True, the tiles will be saved to the output folder else the tiles will be stored in the class
out_bands : list
The bands to save (eg. [3, 2, 1]), if None, the output bands will be same as the input raster bands
image_format : str
Expand Down Expand Up @@ -184,13 +234,25 @@ def generate_tiles(
os.makedirs(output_folder)

# offset calculation
offsets = self._calculate_offset(self.stride_x, self.stride_y)
self._calculate_offset(self.stride_x, self.stride_y)

#store all the windows data as a list, windows shape: (band, tile_y, tile_x)
self.window_data = []

# store all the transform data as a list
self.window_transform = []

# iterate through the offsets
for col_off, row_off in offsets:
# iterate through the offsets and save the tiles
for col_off, row_off in self.offsets:
window = windows.Window(
col_off=col_off, row_off=row_off, width=self.tile_x, height=self.tile_y)
transform = windows.transform(window, self.ds.transform)

# convert the window transform to affine transform and append to the list
transform = self._windows_transform_to_affine(transform)
self.window_transform.append(transform)

# copy the meta data
meta = self.ds.meta.copy()
nodata = meta['nodata']

Expand All @@ -206,6 +268,10 @@ def generate_tiles(
else:
meta.update({"count": len(out_bands)})

# read the window data and append to the list
single_window_data = self.ds.read(out_bands, window=window, fill_value=nodata, boundless=True)
self.window_data.append(single_window_data)

# if data_type, update the meta
if dtype:
meta.update({"dtype": dtype})
Expand All @@ -218,10 +284,71 @@ def generate_tiles(
str(row_off) + '.' + image_format
tile_path = os.path.join(output_folder, tile_name)

if save_tiles:
# save the tiles with new metadata
with rio.open(tile_path, 'w', **meta) as outds:
outds.write(self.ds.read(
out_bands, window=window, fill_value=nodata, boundless=True).astype(dtype))

def save_tiles(self, output_folder: str, image_format: Optional[str] = None, dtype: Optional[str] = None):
"""Save the tiles to the output folder
Parameters
----------
output_folder : str
Path to the output folder
image_format : str
The image format (eg. tif), if None, the image format will be the same as the input raster format (eg. tif)
dtype : str, np.dtype
The output dtype (eg. uint8, float32), if None, the dtype will be the same as the input raster
Returns
-------
None: save the tiles to the output folder
Examples
--------
>>> from geotile import GeoTile
>>> tiler = GeoTile('/path/to/raster/file.tif')
>>> tiler.save_tiles('/path/to/output/folder')
"""
# create the output folder if it doesn't exist
if not os.path.exists(output_folder):
os.makedirs(output_folder)

# meta data
meta = self.meta.copy()
# nodata = meta['nodata']
meta.update({
"width": self.tile_x,
"height": self.tile_y,
})

# check if image_format is None
if image_format is None:
image_format = pathlib.Path(self.path).suffix[1:]

# if data_type, update the meta
if dtype:
meta.update({"dtype": dtype})
dtype=dtype

# iterate through the offsets and windows_data and save the tiles
for i, ((col_off, row_off), wd, wt) in enumerate(zip(self.offsets, self.window_data, self.window_transform)):

# update meta data with transform
meta.update({"transform": tuple(wt)})

# tile name and path
tile_name = 'tile_' + str(col_off) + '_' + \
str(row_off) + '.' + image_format
tile_path = os.path.join(output_folder, tile_name)

# save the tiles with new metadata
with rio.open(tile_path, 'w', **meta) as outds:
outds.write(self.ds.read(
out_bands, window=window, fill_value=nodata, boundless=True).astype(dtype))
outds.write(wd.astype(dtype))


def mask(self, input_vector: str, out_path: str, crop=False, invert=False, **kwargs):
"""Generate a mask raster from a vector
Expand Down

0 comments on commit c47103c

Please sign in to comment.