Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask for coregistration #525

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c193b20
feat: preprocess_coreg for xarray input
ameliefroessl Apr 30, 2024
e54fa12
feat: subsample with dask array
ameliefroessl May 21, 2024
70d48e7
tests: test biascorr synthetic data for dask/ xarray
ameliefroessl May 22, 2024
85b54ef
fix: calculating mask on the correct input
ameliefroessl May 31, 2024
f5a75e5
refactor: clearer mask
ameliefroessl May 31, 2024
43bb825
feat: dask version of meshgrid
ameliefroessl May 31, 2024
f5778ed
feat: map_blocks for polyval2d
ameliefroessl May 31, 2024
74e0981
feat: apply() adapted to dask input
ameliefroessl Jun 4, 2024
6be8a15
fix: precommit
ameliefroessl Jun 4, 2024
044d6f0
refactor: rename postprocessing function
ameliefroessl Jun 5, 2024
d88f09e
fix: remove explicit cog paths and wrong default parameter
ameliefroessl Jun 5, 2024
9f5518c
tests: create raster mask on the fly from vector data in xarray test …
ameliefroessl Jun 6, 2024
110c6cb
fix: correct type for fit function parameters
ameliefroessl Jun 6, 2024
b46d20c
fix: error with weights
ameliefroessl Jun 6, 2024
0452bd8
fix: calculate valid_mask from bias_var
ameliefroessl Jun 7, 2024
1d74be6
refactor: generalized wrapper for fit functions to evaluate them chunked
ameliefroessl Jun 10, 2024
7ce44d1
feat: map logic for _postprocess_coreg_apply_xarray
ameliefroessl Jun 10, 2024
cb07a80
tests: remove unused mask_cog file reference
ameliefroessl Jun 10, 2024
1c94c72
docs: cleaning up a bit
ameliefroessl Jun 11, 2024
67a432a
docs: correct type hints
ameliefroessl Jun 11, 2024
624d69b
refactor: get_valid_data, mask_data and valid_data_darr for generaliz…
ameliefroessl Jun 11, 2024
3adb1b6
fix: fixing pytest
ameliefroessl Jun 14, 2024
1c57a38
refactor: moving dask check out of if for general biascorr dask
ameliefroessl Jun 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 80 additions & 8 deletions tests/test_coreg/test_biascorr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
"""Tests for the biascorr module (non-rigid coregistrations)."""

from __future__ import annotations

import re
import warnings

import dask.array as da
import geopandas as gpd
import geoutils as gu
import numpy as np
import pytest
import rasterio
import rioxarray
import scipy
from xarray.core.dataarray import DataArray

import xdem.terrain
from xdem import examples
Expand All @@ -28,6 +33,24 @@ def load_examples() -> tuple[gu.Raster, gu.Raster, gu.Vector]:
return reference_raster, to_be_aligned_raster, glacier_mask


def load_examples_xarray() -> tuple[DataArray, DataArray, DataArray]:
"""Load cog example files as xarrays to try delayed / dask coregistration methods with."""
chunk_size = 256 # the rasters are COGs with blocksizes 256
reference_raster = rioxarray.open_rasterio(
filename=examples.get_path("longyearbyen_ref_dem"), chunks={"x": chunk_size, "y": chunk_size}
).squeeze()
to_be_aligned_raster = rioxarray.open_rasterio(
filename=examples.get_path("longyearbyen_tba_dem"), chunks={"x": chunk_size, "y": chunk_size}
).squeeze()

# Create a raster mask on the fly from the vector data
glacier_mask_vector = gu.Vector(examples.get_path("longyearbyen_glacier_outlines"))
inlier_mask = glacier_mask_vector.create_mask(raster=gu.Raster(examples.get_path("longyearbyen_ref_dem")))
inlier_mask = DataArray(da.from_array(inlier_mask.data.data, chunks=reference_raster.chunks))

return reference_raster, to_be_aligned_raster, inlier_mask


class TestBiasCorr:
ref, tba, outlines = load_examples() # Load example reference, to-be-aligned and mask.
inlier_mask = ~outlines.create_mask(ref)
Expand All @@ -41,6 +64,15 @@ class TestBiasCorr:
verbose=True,
)

# Load Xarray - Xarray example data
ref_xarr, tba_xarr, mask_xarr = load_examples_xarray()
fit_args_xarr_xarr = dict(
reference_elev=ref_xarr,
to_be_aligned_elev=tba_xarr,
inlier_mask=mask_xarr,
verbose=True,
)

# Convert DEMs to points with a bit of subsampling for speed-up
tba_pts = tba.to_pointcloud(data_column_name="z", subsample=50000, random_state=42).ds

Expand All @@ -64,6 +96,12 @@ class TestBiasCorr:

all_fit_args = [fit_args_rst_rst, fit_args_rst_pts, fit_args_pts_rst]

# used to test the methods that have already been adapted to dask
# once all methods are adapted the fit_args_xarr_xarr can be added to all_fit_args
# without having to define them separately
all_fit_args_xaray = all_fit_args.copy()
all_fit_args_xaray.append(fit_args_xarr_xarr)

def test_biascorr(self) -> None:
"""Test the parent class BiasCorr instantiation."""

Expand Down Expand Up @@ -498,11 +536,14 @@ def test_deramp(self) -> None:
# Check that variable names are defined during instantiation
assert deramp.meta["bias_var_names"] == ["xx", "yy"]

@pytest.mark.parametrize("fit_args", all_fit_args) # type: ignore
@pytest.mark.parametrize("fit_args", all_fit_args_xaray) # type: ignore
@pytest.mark.parametrize("order", [1, 2, 3, 4]) # type: ignore
def test_deramp__synthetic(self, fit_args, order: int) -> None:
"""Run the deramp for varying polynomial orders using a synthetic elevation difference."""

# These warning will cause pytest to fail, even though there is no issue with the data
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

# Get coordinates
xx, yy = np.meshgrid(np.arange(0, self.ref.shape[1]), np.arange(0, self.ref.shape[0]))

Expand All @@ -515,28 +556,59 @@ def test_deramp__synthetic(self, fit_args, order: int) -> None:

# Create a synthetic bias and add to the DEM
synthetic_bias = polynomial_2d((xx, yy), *params)
bias_dem = self.ref - synthetic_bias

elev_fit_args = fit_args.copy()

if isinstance(elev_fit_args["reference_elev"], DataArray):
# Unfortunately subtracting two rioxarrays looses their geospatial properties. So we need to create
# a new output rioxarray DataArray
bias_dem = DataArray(
da.from_array(
elev_fit_args["reference_elev"].data.compute() - synthetic_bias,
chunks=elev_fit_args["reference_elev"].data.chunks,
)
)
# Reset properties. Order matters!!
bias_dem = bias_dem.rio.write_transform(elev_fit_args["reference_elev"].rio.transform())
bias_dem = bias_dem.rio.set_crs(elev_fit_args["reference_elev"].rio.crs)
bias_dem = bias_dem.rio.set_nodata(input_nodata=elev_fit_args["reference_elev"].rio.nodata)

else:
bias_dem = self.ref - synthetic_bias

# Fit
deramp = biascorr.Deramp(poly_order=order)
elev_fit_args = fit_args.copy()
if isinstance(elev_fit_args["to_be_aligned_elev"], gpd.GeoDataFrame):
bias_elev = bias_dem.to_pointcloud(data_column_name="z", subsample=30000, random_state=42).ds
else:
bias_elev = bias_dem
deramp.fit(elev_fit_args["reference_elev"], to_be_aligned_elev=bias_elev, subsample=20000, random_state=42)

deramp.fit(
elev_fit_args["reference_elev"],
to_be_aligned_elev=bias_elev,
inlier_mask=elev_fit_args["inlier_mask"],
subsample=20000,
random_state=42,
)

# Check high-order fit parameters are the same within 10%
fit_params = deramp.meta["fit_params"]
assert np.shape(fit_params) == np.shape(params)
assert np.allclose(
params.reshape(order + 1, order + 1)[-1:, -1:], fit_params.reshape(order + 1, order + 1)[-1:, -1:], rtol=0.1
params.reshape(order + 1, order + 1)[-1:, -1:],
fit_params.reshape(order + 1, order + 1)[-1:, -1:],
rtol=0.1,
)

# Run apply and check that 99% of the variance was corrected
corrected_dem = deramp.apply(bias_dem)
# Need to standardize by the synthetic bias spread to avoid huge/small values close to infinity
assert np.nanvar((corrected_dem - self.ref) / np.nanstd(synthetic_bias)) < 0.01
if isinstance(bias_dem, DataArray):
corrected_dem, _ = deramp.apply(bias_dem)
corrected_dem = corrected_dem.compute()
assert np.nanvar((corrected_dem - elev_fit_args["reference_elev"]) / np.nanstd(synthetic_bias)) < 0.01
else:
corrected_dem = deramp.apply(bias_dem)
# Need to standardize by the synthetic bias spread to avoid huge/small values close to infinity
assert np.nanvar((corrected_dem - self.ref) / np.nanstd(synthetic_bias)) < 0.01

def test_terrainbias(self) -> None:
"""Test the subclass TerrainBias."""
Expand Down
137 changes: 137 additions & 0 deletions tests/test_coreg/test_delayed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from unittest.mock import Mock

import dask.array as da
import numpy as np
import pytest
import rasterio as rio

from xdem._typing import NDArrayb, NDArrayf
from xdem.coreg.base import (
_select_transform_crs,
get_valid_data,
mask_data,
valid_data_darr,
)


@pytest.mark.filterwarnings("ignore::UserWarning") # type: ignore [misc]
@pytest.mark.parametrize( # type: ignore [misc]
"epsg_ref, epsg_other, epsg, expected",
[
(3246, 4326, 3005, 3246),
(None, 4326, 3005, 4326),
(None, None, 3005, 3005),
],
)
def test__select_transform_crs_selects_correct_crs(
epsg_ref: int | None, epsg_other: int | None, epsg: int, expected: int
) -> None:
"""Test _select_transform_crs selects the correct crs."""
mock_transform = Mock(rio.transform.Affine) # we dont care about the transform in this test

# for epsg_ref, epsg_other, epsg, expected in epsg_pairs:
_, crs = _select_transform_crs(
transform=mock_transform,
crs=rio.crs.CRS.from_epsg(epsg),
transform_reference=mock_transform,
transform_other=mock_transform,
crs_reference=rio.crs.CRS.from_epsg(epsg_ref) if epsg_ref is not None else epsg_ref,
crs_other=rio.crs.CRS.from_epsg(epsg_other) if epsg_other is not None else epsg_other,
)
assert crs.to_epsg() == expected


def test__select_transform_crs_selects_correct_transform() -> None:
"""Test _select_transform_crs selects the correct transform."""
# TODO
pass


@pytest.mark.parametrize( # type: ignore[misc]
"input,nodata,expected",
[
(
np.array([np.nan, 1, -100, 1]),
-100,
np.array([np.nan, 1, np.nan, 1]),
),
(
np.array([1, 1, -100, 1]),
-100,
np.array([1, 1, np.nan, 1]),
),
(
np.array([np.nan, 1, 1, 1]),
-100,
np.array([np.nan, 1, 1, 1]),
),
],
)
def test_mask_data(input: NDArrayf, nodata: int, expected: NDArrayf) -> None:
"""Test that mask_data masks the correct values."""
output = mask_data(data=input, nodata=nodata)
assert np.array_equal(output, expected, equal_nan=True)


@pytest.mark.parametrize( # type: ignore [misc]
"input_arrays,nodatas,expected",
[
(
(np.array([np.nan, 1, -100, 1]),),
(-100,),
np.array([False, True, False, True]),
),
(
(
np.array([np.nan, 1, -100, 1]),
np.array([1, -200, 1, 1]),
),
(-100, -200),
np.array([False, False, False, True]),
),
(
(
np.array([np.nan, 1, -100, 1]),
np.array([1, -200, 1, 1]),
np.array([1, 1, 1, -400]),
),
(-100, -200, -400),
np.array([False, False, False, False]),
),
],
)
def test_get_valid_data(input_arrays: tuple[NDArrayf], nodatas: tuple[int], expected: NDArrayb) -> None:
"""Test get_valid_data returns correct output."""
output = get_valid_data(*input_arrays, nodatas=nodatas)
assert np.array_equal(output, expected, equal_nan=True)


@pytest.mark.parametrize( # type: ignore [misc]
"input_arrays,mask,nodatas,expected",
[
(
(
da.from_array(np.array([1, 1, -100, 1]), chunks=2),
da.from_array(np.array([1, 1, -200, 1]), chunks=2),
),
None,
(-100, -200),
np.array([True, True, False, True]),
),
(
(
da.from_array(np.array([1, 1, -100, 1]), chunks=2),
da.from_array(np.array([1, 1, -200, 1]), chunks=2),
),
da.from_array([False, True, True, True]),
(-100, -200),
np.array([False, True, False, True]),
),
],
)
def test_valid_data_darr(
input_arrays: tuple[NDArrayf], mask: NDArrayb, nodatas: tuple[int], expected: NDArrayb
) -> None:
"""Test valid_data_darr returns correct output."""
output = valid_data_darr(*input_arrays, mask=mask, nodatas=nodatas).compute()
assert np.array_equal(output, expected, equal_nan=True)
Loading