Skip to content

Commit

Permalink
padding regions (#23)
Browse files Browse the repository at this point in the history
* initial draft of the numpy padding algorithm

* skeleton for the padding machinery

* use the right import for the neighbours search function

* don't show cell ids, pre-computed indices and the grid info in the repr

* implement the constant-mode padding

* add `ring` to the members of the linear ramp padding object

* comments on the padding algorithms

* depend on `xdggs`

* add `xdggs` to the environment

* replace the manual set difference with `numpy.setdiff1d`

* pin `scipy<1.14`
  • Loading branch information
keewis authored Jul 4, 2024
1 parent 5eae2bf commit 6a05e20
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 0 deletions.
4 changes: 4 additions & 0 deletions ci/requirements/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- numpy
- dask
- healpy
- scipy<1.14 # healpy uses the removed `scipy.integrate.trapz`
- geopandas
- pandas
- sparse
Expand All @@ -19,3 +20,6 @@ dependencies:
- shapely
- opt_einsum
- matplotlib
- pip
- pip:
- git+https://github.com/xarray-contrib/xdggs
2 changes: 2 additions & 0 deletions healpix_convolution/neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def neighbours(cell_ids, *, resolution, indexing_scheme, ring=1):
plus their immediate neighbours (a total of 24 cells), and so on.
"""
nside = 2**resolution
if ring < 0:
raise ValueError(f"ring must be a positive integer or 0, got {ring}")
if ring > nside:
raise ValueError(
"rings containing more than the neighbouring base pixels are not supported"
Expand Down
166 changes: 166 additions & 0 deletions healpix_convolution/padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from dataclasses import dataclass, field
from functools import partial

import numpy as np
from xarray.namedarray._typing import _arrayfunction_or_api as _ArrayLike
from xarray.namedarray._typing import _ScalarType
from xdggs.grid import DGGSInfo

from healpix_convolution.neighbours import neighbours as search_neighbours


@dataclass
class Padding:
cell_ids: _ArrayLike = field(repr=False)
insert_indices: _ArrayLike = field(repr=False)
grid_info: DGGSInfo = field(repr=False)

def apply(self, data):
raise NotImplementedError()


@dataclass
class ConstantPadding(Padding):
constant_value: _ScalarType

def apply(self, data):
common_dtype = np.result_type(data, self.constant_value)

return np.insert(
data.astype(common_dtype), self.insert_indices, self.constant_value, axis=-1
)


@dataclass
class LinearRampPadding(Padding):
end_value: _ScalarType
ring: int

border_indices: _ArrayLike
distance: _ArrayLike

def apply(self, data):
offsets = data[..., self.border_indices]
ramp = (self.end_value - offsets) / self.ring
pad_values = offsets + ramp * self.distance

return np.insert(data, self.insert_indices, pad_values, axis=-1)


@dataclass
class DataPadding(Padding):
data_indices: _ArrayLike

def apply(self, data):
pad_values = data[..., self.data_indices]

return np.insert(data, self.insert_indices, pad_values, axis=-1)


def constant_mode(cell_ids, neighbours, grid_info, constant_value):
all_cell_ids = np.unique(neighbours)
new_cell_ids = np.setdiff1d(
all_cell_ids, np.concatenate((np.array([-1]), cell_ids))
)

insert_indices = np.searchsorted(cell_ids, new_cell_ids)

return ConstantPadding(
cell_ids=all_cell_ids,
insert_indices=insert_indices,
grid_info=grid_info,
constant_value=constant_value,
)


def linear_ramp_mode(cell_ids, neighbours, grid_info, end_value):
# algorithm: for each padded cell find the closest edge cell and the distance
pass


def edge_mode(cell_ids, neighbours, grid_info):
# algorithm: for each padded cell find the closest edge cell
pass


def reflect_mode(cell_ids, neighbours, grid_info):
# algorithm: for each padded cell, find the closest edge cell and the distance, then take the index of the cell that in the same distance and direction as the edge cell from the padded cell
pass


def pad(
cell_ids,
*,
grid_info,
ring,
mode="constant",
constant_value=0,
end_value=0,
reflect_type="even",
):
"""pad an array
Parameters
----------
cell_ids : array-like
The cell ids.
grid_info : xdggs.DGGSInfo
The grid parameters.
ring : int
The pad width in rings around the input domain. Must be 0 or positive.
mode : str, default: "constant"
The padding mode. Can be one of:
- "constant": fill the padded cells with ``constant_value``.
- "linear_ramp": linearly interpolate the padded cells from the edge of the array
to ``end_value``. For ring 1, this is the same as ``mode="constant"``
- "edge": fill the padded cells with the values at the edge of the array.
- "reflect": pad with the reflected values.
constant_value : scalar, default: 0
The constant value used in constant mode.
end_value : scalar, default: 0
The othermost value to interpolate to. Only used in linear ramp mode.
reflect_type : {"even", "odd"}, default: "even"
The reflect type. Only used in reflect mode.
Returns
-------
padding_object : Padding
The padding object. Can be used to apply the same padding operation for different
arrays with the same geometry.
"""
# TODO: figure out how to allow reusing indices. How this works depends on the mode:
# - in constant mode, we have:
# * an array of new cell ids
# * an array of indices that indicate where to insert them
# * and the constant value
# - in the case of linear ramp, we have:
# * the new cell ids
# * an array of indices that indicate where to insert them
# * the value of where the ramp should end
# * the distance of each cell from the edge of the array
# - in all other cases, we have:
# * an array of new cell ids
# * an array of indices that indicate where to insert them
# * an array of indices that map existing values to the new cell ids
# To be able to reuse this, we need a set of dataclasses that can encapsulate that,
# plus a method to apply the padding to data.
neighbours = search_neighbours(
cell_ids,
resolution=grid_info.resolution,
indexing_scheme=grid_info.indexing_scheme,
ring=ring,
)

modes = {
"constant": partial(constant_mode, constant_value=constant_value),
"linear_ramp": partial(linear_ramp_mode, end_value=end_value, ring=ring),
"edge": edge_mode,
"reflect_mode": partial(reflect_mode, reflect_type=reflect_type),
}

mode_func = modes.get(mode)
if mode_func is None:
raise ValueError(f"unknown mode: {mode}")

return mode_func(cell_ids, neighbours, grid_info)
145 changes: 145 additions & 0 deletions healpix_convolution/tests/test_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import numpy as np
import pytest
import xdggs

from healpix_convolution import padding

try:
import dask.array as da
except ImportError:
da = None
dask_array_type = ()

dask_available = False
else:
dask_array_type = da.Array
dask_available = True

requires_dask = pytest.mark.skipif(not dask_available, reason="requires dask")


class TestArray:
@pytest.mark.parametrize("dask", (False, pytest.param(True, marks=requires_dask)))
@pytest.mark.parametrize(
["ring", "mode", "kwargs", "expected_cell_ids", "expected_data"],
(
pytest.param(
1,
"constant",
{"constant_value": np.nan},
np.array([163, 166, 167, 169, 171, 172, 173, 174, 175, 178, 184, 186]),
np.array(
[
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
1,
1,
np.nan,
np.nan,
np.nan,
np.nan,
np.nan,
]
),
id="constant-ring1-nan",
),
pytest.param(
2,
"constant",
{"constant_value": 0},
np.array(
[
160,
161,
162,
163,
164,
165,
166,
167,
168,
169,
170,
171,
172,
173,
174,
175,
176,
177,
178,
179,
184,
185,
186,
187,
853,
855,
861,
863,
885,
887,
]
),
np.array(
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
1,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
]
),
id="constant-ring2-0",
),
),
)
def test_pad(self, dask, ring, mode, kwargs, expected_cell_ids, expected_data):
grid_info = xdggs.healpix.HealpixInfo(resolution=4, indexing_scheme="nested")
cell_ids = np.array([172, 173])

if not dask:
data = np.full_like(cell_ids, fill_value=1)
else:
import dask.array as da

data = da.full_like(cell_ids, fill_value=1, chunks=(1,))

padder = padding.pad(
cell_ids, grid_info=grid_info, ring=ring, mode=mode, **kwargs
)
actual = padder.apply(data)

if dask:
assert isinstance(actual, da.Array)

np.testing.assert_equal(padder.cell_ids, expected_cell_ids)
np.testing.assert_equal(actual, expected_data)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ license = {text = "Apache-2.0"}
dependencies = [
"numpy",
"healpy",
"xdggs",
"sparse",
"numba",
"opt_einsum",
Expand Down

0 comments on commit 6a05e20

Please sign in to comment.