From c23e6bbd1b74eee7059b66dba02a314eb94cb0a3 Mon Sep 17 00:00:00 2001 From: HideBa Date: Tue, 5 Mar 2024 16:19:40 +0100 Subject: [PATCH] Refactor code for AHN CLI*** --- README.md | 2 - ahn_cli/fetcher/geotiles.py | 21 ++++- ahn_cli/fetcher/request.py | 62 +++++++++++++- .../{pipeline.py => ptc_handler.py} | 82 +++++++++++-------- ahn_cli/manipulator/rasterizer.py | 16 ++-- ahn_cli/process.py | 27 +++--- ahn_cli/validator.py | 3 +- tests/fetcher/test_geotiles.py | 29 ++++++- tests/test_pipeline.py | 61 +++++++------- 9 files changed, 206 insertions(+), 97 deletions(-) rename ahn_cli/manipulator/{pipeline.py => ptc_handler.py} (72%) diff --git a/README.md b/README.md index 6abcfed..0fa346a 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,6 @@ AHN CLI is a command-line interface tool designed for the effortless downloading ## Installation -> **NOTE:** AHN CLI requires PDAL to be installed on your system. Follow the installation instructions in the [PDAL documentation](https://pdal.io/download.html) before proceeding. - Install AHN CLI using pip: ``` diff --git a/ahn_cli/fetcher/geotiles.py b/ahn_cli/fetcher/geotiles.py index d93c591..0bd897c 100644 --- a/ahn_cli/fetcher/geotiles.py +++ b/ahn_cli/fetcher/geotiles.py @@ -1,6 +1,7 @@ import os import geopandas as gpd +from pyproj import Transformer from ahn_cli.fetcher.municipality import city_polygon @@ -14,13 +15,29 @@ def geotiles() -> gpd.GeoDataFrame: return ahn_tile_gdf -def ahn_subunit_indicies_of_city(city_name: str) -> list[int]: +def ahn_subunit_indicies_of_city(city_name: str) -> list[str]: """Return a list of AHN tile indicies that intersect with the city's boundary.""" # noqa city_poly = city_polygon(city_name) geotiles_tile_gdf = geotiles() # Filter the DataFrame based on lowercase column values filtered_df = geotiles_tile_gdf.overlay(city_poly) - tile_indices: list[int] = filtered_df["AHN_subuni"].tolist() # noqa + tile_indices: list[str] = filtered_df["AHN_subuni"].tolist() # noqa + + return tile_indices + + +def ahn_subunit_indicies_of_bbox(bbox: list[float]) -> list[str]: + """Return a list of AHN tile indicies that intersect with the bbox.""" # noqa + geotiles_tile_gdf = geotiles() + + transformer = Transformer.from_crs( + "EPSG:28992", "EPSG:4326", always_xy=True + ) + minx, miny = transformer.transform(bbox[0], bbox[1]) + maxx, maxy = transformer.transform(bbox[2], bbox[3]) + # Filter the DataFrame based on lowercase column values + filtered_df = geotiles_tile_gdf.cx[minx:maxx, miny:maxy] + tile_indices: list[str] = filtered_df["AHN_subuni"].tolist() # noqa return tile_indices diff --git a/ahn_cli/fetcher/request.py b/ahn_cli/fetcher/request.py index bf9c954..9e03102 100644 --- a/ahn_cli/fetcher/request.py +++ b/ahn_cli/fetcher/request.py @@ -9,18 +9,54 @@ import requests from tqdm import tqdm -from ahn_cli.fetcher.geotiles import ahn_subunit_indicies_of_city +from ahn_cli.fetcher.geotiles import (ahn_subunit_indicies_of_bbox, + ahn_subunit_indicies_of_city) class Fetcher: - def __init__(self, base_url: str, city_name: str): + """ + Fetcher class for fetching AHN data. + + Args: + base_url (str): The base URL for fetching AHN data. + city_name (str): The name of the city for which to fetch AHN data. + bbox (list[float] | None, optional): The bounding box coordinates [minx, miny, maxx, maxy] + for a specific area of interest. Defaults to None. + + Raises: + ValueError: If the base URL is invalid. + + Attributes: + base_url (str): The base URL for fetching AHN data. + city_name (str): The name of the city for which to fetch AHN data. + bbox (list[float] | None): The bounding box coordinates [minx, miny, maxx, maxy] + for a specific area of interest. + urls (list[str]): The constructed URLs for fetching AHN data. + + Methods: + fetch: Fetches AHN data. + _check_valid_url: Checks if the base URL is valid. + _construct_urls: Constructs the URLs for fetching AHN data. + """ + + def __init__( + self, base_url: str, city_name: str, bbox: list[float] | None = None + ): if not self._check_valid_url(base_url): raise ValueError("Invalid URL") self.base_url = base_url self.city_name = city_name + self.bbox = bbox self.urls = self._construct_urls() def fetch(self) -> dict: + """ + Fetches AHN data. + + Returns: + dict: A dictionary containing the fetched AHN data, where the keys are the URLs + and the values are the temporary file names where the data is stored. + """ logging.info("Start fetching AHN data") logging.info(f"Fetching {len(self.urls)} tiles") @@ -50,6 +86,15 @@ def req( return results def _check_valid_url(self, url: str) -> bool: + """ + Checks if the base URL is valid. + + Args: + url (str): The base URL to check. + + Returns: + bool: True if the URL is valid, False otherwise. + """ try: result = urlparse(url) return all([result.scheme, result.netloc, result.path]) @@ -57,9 +102,18 @@ def _check_valid_url(self, url: str) -> bool: return False def _construct_urls(self) -> list[str]: - tiles_indices = ahn_subunit_indicies_of_city(self.city_name) + """ + Constructs the URLs for fetching AHN data. + + Returns: + list[str]: A list of URLs for fetching AHN data. + """ + tiles_indices = ( + ahn_subunit_indicies_of_bbox(self.bbox) + if self.bbox + else ahn_subunit_indicies_of_city(self.city_name) + ) urls = [] for tile_index in tiles_indices: urls.append(os.path.join(self.base_url + f"{tile_index}.LAZ")) - return urls diff --git a/ahn_cli/manipulator/pipeline.py b/ahn_cli/manipulator/ptc_handler.py similarity index 72% rename from ahn_cli/manipulator/pipeline.py rename to ahn_cli/manipulator/ptc_handler.py index a203671..8a9e66e 100644 --- a/ahn_cli/manipulator/pipeline.py +++ b/ahn_cli/manipulator/ptc_handler.py @@ -10,20 +10,26 @@ from ahn_cli.manipulator.transformer import tranform_polygon -class PntCPipeline: +class PntCHandler: """ - A class representing a data processing pipeline. - - Args: - input_path (str): The path to the input data. - output_path (str): The path to save the output data. - city_filepath (str): The path to the city data file. - city_name (str): The name of the city. - epsg (int): The EPSG code for the coordinate reference system (CRS). + A class for handling point clouds. Attributes: - pipeline_setting (list): The configuration settings for the pipeline. - + las (laspy.LasData): The point cloud data. + city_df (gpd.GeoDataFrame): The city data. + city_name (str): The name of the city. + raster_res (float): The raster resolution. + epsg (str | None): The EPSG code. + + Methods: + __init__: Initializes the PntCHandler object. + decimate: Decimates the point cloud by selecting every `step`-th point. + include: Filters the point cloud to include only the specified classes. + exclude: Exclude points with specific classification values from the pipeline. + clip: Clip the point cloud by a polygon. + clip_by_arbitrary_polygon: Clip the point cloud by an arbitrary polygon. + clip_by_bbox: Clips the point cloud by a bounding box. + points: Execute the pipeline and return the processed point cloud. """ las = laspy.LasData @@ -46,27 +52,27 @@ def __init__( def decimate(self, step: int) -> Self: """ - Decimate the point cloud by a given step. + Decimates the point cloud by selecting every `step`-th point. Args: - step (int): The step to decimate by. + step (int): The decimation step size. Returns: - Pipeline: The updated pipeline object. - + Self: The modified pipeline object. """ - self.las.points = self.las.points[::step] + valid_point_masks = np.arange(0, len(self.las.points), step) + self.las.points = self.las.points[valid_point_masks] return self def include(self, include_classes: list[int]) -> Self: """ - Filters the point cloud to include only the specified classes. + Filters the point cloud by including only the specified classes. Args: - include_classes (list[int]): List of class labels to include. + include_classes (list[int]): A list of class IDs to include. Returns: - Self: The modified pipeline object. + Self: The updated instance of the pipeline. """ mask = np.isin(self.las.classification, include_classes) self.las.points = self.las.points[mask] @@ -74,13 +80,14 @@ def include(self, include_classes: list[int]) -> Self: def exclude(self, exclude_classes: list[int]) -> Self: """ - Exclude points with specific classification values from the pipeline. + Exclude points from the point cloud based on their classification. Args: - exclude_classes (list[int]): List of classification values to exclude. + exclude_classes (list[int]): List of classification codes to exclude. Returns: Self: The modified pipeline object. + """ mask = np.isin(self.las.classification, exclude_classes, invert=True) self.las.points = self.las.points[mask] @@ -88,7 +95,10 @@ def exclude(self, exclude_classes: list[int]) -> Self: def clip(self) -> Self: """ - Clip the point cloud by a polygon. + Clips the point cloud to the extent of the city polygon. + + Returns: + Self: The modified pipeline object. """ rasterized_polygon, transform = rasterizer.polygon_to_raster( self._city_polygon(), self.raster_res @@ -118,15 +128,15 @@ def clip(self) -> Self: def clip_by_arbitrary_polygon(self, clip_file: str) -> Self: """ - Clip the point cloud by a polygon. + Clips the point cloud by an arbitrary polygon defined in a clip file. Args: - clip_file (str): The path to the polygon file. + clip_file (str): The path to the clip file containing the polygon. Returns: - Pipeline: The updated pipeline object. - + Self: The modified instance of the pipeline. """ + polygon = self._arbitrary_polygon(clip_file) rasterized_polygon, transform = rasterizer.polygon_to_raster( polygon, self.raster_res @@ -156,28 +166,28 @@ def clip_by_arbitrary_polygon(self, clip_file: str) -> Self: def clip_by_bbox(self, bbox: list[float]) -> Self: """ - Clips the point cloud by a bounding box. + Clips the point cloud by a given bounding box. Args: - bbox (list[float]): The bounding box to clip the point cloud. [xmin, ymin, xmax, ymax] + bbox (list[float]): The bounding box coordinates in the format [xmin, ymin, xmax, ymax]. Returns: - Self: The updated pipeline object. - + Self: The modified instance of the pipeline. """ - xyz = self.las.xyz - x_valid = (xyz[:, 0] >= bbox[0]) & (xyz[:, 0] <= bbox[2]) - y_valid = (xyz[:, 1] >= bbox[1]) & (xyz[:, 1] <= bbox[3]) - valid_points_mask = x_valid & y_valid + + x_valid = (self.las.x >= bbox[0]) & (self.las.x <= bbox[2]) + y_valid = (self.las.y >= bbox[1]) & (self.las.y <= bbox[3]) + valid_points_mask = np.where(x_valid & y_valid)[0] self.las.points = self.las.points[valid_points_mask] + return self def points(self) -> laspy.LasData: """ - Execute the pipeline. + Returns the point cloud data. Returns: - laspy.LasData: The processed point cloud. + laspy.LasData: The point cloud data. """ return self.las diff --git a/ahn_cli/manipulator/rasterizer.py b/ahn_cli/manipulator/rasterizer.py index 95862aa..5f8858f 100644 --- a/ahn_cli/manipulator/rasterizer.py +++ b/ahn_cli/manipulator/rasterizer.py @@ -1,7 +1,7 @@ from typing import Tuple import numpy as np -from rasterio import features +from rasterio.features import rasterize from rasterio.transform import Affine, from_origin from shapely import Polygon @@ -11,23 +11,25 @@ def polygon_to_raster( resolution: float, ) -> Tuple[np.ndarray, Affine]: """ - Convert a polygon to a raster file. + Converts a polygon into a rasterized numpy array. Args: - polygon (Polygon): The polygon to convert. - resolution (float): The resolution of the raster. + polygon (Polygon): The input polygon to be rasterized. + resolution (float): The desired resolution of the rasterized array. Returns: - None - + Tuple[np.ndarray, Affine]: A tuple containing the rasterized numpy array and the affine transformation matrix. """ + # Implementation code here + pass + bbox = polygon.bounds height = int((bbox[3] - bbox[1]) / resolution) width = int((bbox[2] - bbox[0]) / resolution) transform = from_origin(bbox[0], bbox[3], resolution, resolution) shape = (height, width) - rasterized = features.rasterize( + rasterized = rasterize( shapes=[polygon], out_shape=shape, transform=transform, diff --git a/ahn_cli/process.py b/ahn_cli/process.py index 08a7ce5..cfb6f9a 100644 --- a/ahn_cli/process.py +++ b/ahn_cli/process.py @@ -4,7 +4,7 @@ import numpy as np from tqdm import tqdm from ahn_cli.fetcher.request import Fetcher -from ahn_cli.manipulator.pipeline import PntCPipeline +from ahn_cli.manipulator.ptc_handler import PntCHandler from ahn_cli.manipulator.preview import previewer import laspy from laspy.lasappender import LasAppender @@ -24,7 +24,7 @@ def process( bbox: list[float] | None = None, preview: bool | None = False, ) -> None: - ahn_fetcher = Fetcher(base_url, city_name) + ahn_fetcher = Fetcher(base_url, city_name, bbox) fetched_files = ahn_fetcher.fetch() files = list(fetched_files.values()) @@ -43,7 +43,7 @@ def process( global_header.maxs = maxs global_header.mins = mins - pipeline = PntCPipeline( + p_handler = PntCHandler( las.read(), city_polygon_path, city_name, @@ -51,25 +51,28 @@ def process( ) if bbox is not None: - pipeline.clip_by_bbox(bbox) - if decimate is not None: - pipeline.decimate(decimate) + p_handler.clip_by_bbox(bbox) if include_classes is not None and len(include_classes) > 0: - pipeline.include(include_classes) + p_handler.include(include_classes) if exclude_classes is not None and len(exclude_classes) > 0: - pipeline.exclude(exclude_classes) - if not no_clip_city: - pipeline.clip() + p_handler.exclude(exclude_classes) + if not no_clip_city and city_name is not None: + p_handler.clip() if clip_file is not None: - pipeline.clip_by_arbitrary_polygon(clip_file) + p_handler.clip_by_arbitrary_polygon(clip_file) + if decimate is not None: + p_handler.decimate(decimate) with laspy.open( output_path, mode="w" if i == 0 else "a", header=global_header ) as writer: - points = pipeline.points().points + points = p_handler.points().points + if len(points) == 0: + continue points.x = points.x - offset[0] points.y = points.y - offset[1] points.z = points.z - offset[2] + if isinstance(writer, laspy.LasWriter): writer.write_points(points) if isinstance(writer, LasAppender): diff --git a/ahn_cli/validator.py b/ahn_cli/validator.py index a981dc1..0ff3bd9 100644 --- a/ahn_cli/validator.py +++ b/ahn_cli/validator.py @@ -100,7 +100,8 @@ def validate_all( bbox: list[float] | None = None, ) -> bool: validate_output(output_path) - validate_city(city_name, cfg.city_polygon_file) + if not bbox: + validate_city(city_name, cfg.city_polygon_file) validate_include_classes(include_classes) validate_exclude_classes(exclude_classes) validate_include_exclude(include_classes, exclude_classes) diff --git a/tests/fetcher/test_geotiles.py b/tests/fetcher/test_geotiles.py index 2ea234b..3de3629 100644 --- a/tests/fetcher/test_geotiles.py +++ b/tests/fetcher/test_geotiles.py @@ -1,9 +1,12 @@ import unittest -from ahn_cli.fetcher.geotiles import ahn_subunit_indicies_of_city +from ahn_cli.fetcher.geotiles import ( + ahn_subunit_indicies_of_city, + ahn_subunit_indicies_of_bbox, +) class TestGeoTile(unittest.TestCase): - def test_ahn(self) -> None: + def test_ahn_subunit_indicies_of_city(self) -> None: tiles = ahn_subunit_indicies_of_city("Delft") expected = [ "37EZ1_03", @@ -43,6 +46,28 @@ def test_ahn(self) -> None: ] self.assertEqual(tiles, expected) + def test_ahn_subunit_indicies_of_bbox(self) -> None: + bbox = [ + 84592.705048133007949, + 444443.127025160647463, + 86312.074818017281359, + 446712.346010794688482, + ] + tiles = ahn_subunit_indicies_of_bbox(bbox) + expected = [ + "37EN1_15", + "37EN1_20", + "37EN1_25", + "37EN2_11", + "37EN2_12", + "37EN2_16", + "37EN2_17", + "37EN2_21", + "37EN2_22", + ] + + self.assertListEqual(tiles, expected) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1ec7f82..b083e9d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,10 +1,9 @@ -import time import unittest import laspy import numpy as np -from ahn_cli.manipulator.pipeline import PntCPipeline +from ahn_cli.manipulator.ptc_handler import PntCHandler TEST_DATA0 = "./tests/testdata/westervoort0_thinned.las" TEST_DATA1 = "./tests/testdata/westervoort1_thinned.las" @@ -13,23 +12,23 @@ WESTERVOORT28992_FILE_PATH = "./tests/testdata/westervoort28992.geojson" -class TestPipeline(unittest.TestCase): +class TestPntCHandler(unittest.TestCase): def test_decimate(self) -> None: with laspy.open(TEST_DATA0) as reader: las = reader.read() - pipeline = PntCPipeline(las, CITY_FILE_PATH, "Westervoort") - points_before = len(pipeline.las.points) - pipeline.decimate(10) - points_after = len(pipeline.las.points) + p_handler = PntCHandler(las, CITY_FILE_PATH, "Westervoort") + points_before = len(p_handler.las.points) + p_handler.decimate(10) + points_after = len(p_handler.las.points) self.assertTrue(points_after < points_before) def test_include(self) -> None: with laspy.open(TEST_DATA0) as reader: las = reader.read() - pipeline = PntCPipeline(las, CITY_FILE_PATH, "Westervoort") - points_before = len(pipeline.las.points) - pipeline.include([2, 6]) - points_after = len(pipeline.las.points) + p_handler = PntCHandler(las, CITY_FILE_PATH, "Westervoort") + points_before = len(p_handler.las.points) + p_handler.include([2, 6]) + points_after = len(p_handler.las.points) self.assertTrue(points_after < points_before) classes2 = len(las.points[las.classification == 2]) classes6 = len(las.points[las.classification == 6]) @@ -47,10 +46,10 @@ def test_include(self) -> None: def test_exclude(self) -> None: with laspy.open(TEST_DATA0) as reader: las = reader.read() - pipeline = PntCPipeline(las, CITY_FILE_PATH, "Westervoort") - points_before = len(pipeline.las.points) - pipeline.exclude([2, 6]) - points_after = len(pipeline.las.points) + p_handler = PntCHandler(las, CITY_FILE_PATH, "Westervoort") + points_before = len(p_handler.las.points) + p_handler.exclude([2, 6]) + points_after = len(p_handler.las.points) self.assertTrue(points_after < points_before) classes2 = len(las.points[las.classification == 2]) classes6 = len(las.points[las.classification == 6]) @@ -62,38 +61,38 @@ def test_clip(self) -> None: las = reader.read() extra_dim = laspy.ExtraBytesParams(name="raster", type=np.uint8) las.add_extra_dim(extra_dim) - pipeline = PntCPipeline(las, CITY_FILE_PATH, "Westervoort") - points_before = len(pipeline.las.points) - pipeline.clip() - points_after = len(pipeline.las.points) + p_handler = PntCHandler(las, CITY_FILE_PATH, "Westervoort") + points_before = len(p_handler.las.points) + p_handler.clip() + points_after = len(p_handler.las.points) self.assertTrue(points_after < points_before) def test_clip_by_arbitrary_polygon(self) -> None: with laspy.open(TEST_DATA1) as reader: las = reader.read() - pipeline = PntCPipeline(las, CITY_FILE_PATH, "Westervoort") - points_before = len(pipeline.las.points) - pipeline.clip_by_arbitrary_polygon(WESTERVOORT_FILE_PATH) - points_after = len(pipeline.las.points) + p_handler = PntCHandler(las, CITY_FILE_PATH, "Westervoort") + points_before = len(p_handler.las.points) + p_handler.clip_by_arbitrary_polygon(WESTERVOORT_FILE_PATH) + points_after = len(p_handler.las.points) self.assertTrue(points_after < points_before) with laspy.open(TEST_DATA1) as reader: las = reader.read() - pipeline = PntCPipeline(las, CITY_FILE_PATH, "Westervoort", 28992) - points_before = len(pipeline.las.points) - pipeline.clip_by_arbitrary_polygon(WESTERVOORT28992_FILE_PATH) - points_after = len(pipeline.las.points) + p_handler = PntCHandler(las, CITY_FILE_PATH, "Westervoort", 28992) + points_before = len(p_handler.las.points) + p_handler.clip_by_arbitrary_polygon(WESTERVOORT28992_FILE_PATH) + points_after = len(p_handler.las.points) self.assertTrue(points_after < points_before) def test_clip_by_bbox(self) -> None: with laspy.open(TEST_DATA0) as reader: las = reader.read() - pipeline = PntCPipeline(las, CITY_FILE_PATH, "Westervoort") - points_before = len(pipeline.las.points) - pipeline.clip_by_bbox( + p_handler = PntCHandler(las, CITY_FILE_PATH, "Westervoort") + points_before = len(p_handler.las.points) + p_handler.clip_by_bbox( [194198.302994, 443461.343994, 194594.109009, 443694.838989] ) - points_after = len(pipeline.las.points) + points_after = len(p_handler.las.points) self.assertTrue(points_after < points_before)