-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial draft of the numpy padding algorithm * skeleton for the padding machinery * use the right import for the neighbours search function * don't show cell ids, pre-computed indices and the grid info in the repr * implement the constant-mode padding * add `ring` to the members of the linear ramp padding object * comments on the padding algorithms * depend on `xdggs` * add `xdggs` to the environment * replace the manual set difference with `numpy.setdiff1d` * pin `scipy<1.14`
- Loading branch information
Showing
5 changed files
with
318 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from dataclasses import dataclass, field | ||
from functools import partial | ||
|
||
import numpy as np | ||
from xarray.namedarray._typing import _arrayfunction_or_api as _ArrayLike | ||
from xarray.namedarray._typing import _ScalarType | ||
from xdggs.grid import DGGSInfo | ||
|
||
from healpix_convolution.neighbours import neighbours as search_neighbours | ||
|
||
|
||
@dataclass | ||
class Padding: | ||
cell_ids: _ArrayLike = field(repr=False) | ||
insert_indices: _ArrayLike = field(repr=False) | ||
grid_info: DGGSInfo = field(repr=False) | ||
|
||
def apply(self, data): | ||
raise NotImplementedError() | ||
|
||
|
||
@dataclass | ||
class ConstantPadding(Padding): | ||
constant_value: _ScalarType | ||
|
||
def apply(self, data): | ||
common_dtype = np.result_type(data, self.constant_value) | ||
|
||
return np.insert( | ||
data.astype(common_dtype), self.insert_indices, self.constant_value, axis=-1 | ||
) | ||
|
||
|
||
@dataclass | ||
class LinearRampPadding(Padding): | ||
end_value: _ScalarType | ||
ring: int | ||
|
||
border_indices: _ArrayLike | ||
distance: _ArrayLike | ||
|
||
def apply(self, data): | ||
offsets = data[..., self.border_indices] | ||
ramp = (self.end_value - offsets) / self.ring | ||
pad_values = offsets + ramp * self.distance | ||
|
||
return np.insert(data, self.insert_indices, pad_values, axis=-1) | ||
|
||
|
||
@dataclass | ||
class DataPadding(Padding): | ||
data_indices: _ArrayLike | ||
|
||
def apply(self, data): | ||
pad_values = data[..., self.data_indices] | ||
|
||
return np.insert(data, self.insert_indices, pad_values, axis=-1) | ||
|
||
|
||
def constant_mode(cell_ids, neighbours, grid_info, constant_value): | ||
all_cell_ids = np.unique(neighbours) | ||
new_cell_ids = np.setdiff1d( | ||
all_cell_ids, np.concatenate((np.array([-1]), cell_ids)) | ||
) | ||
|
||
insert_indices = np.searchsorted(cell_ids, new_cell_ids) | ||
|
||
return ConstantPadding( | ||
cell_ids=all_cell_ids, | ||
insert_indices=insert_indices, | ||
grid_info=grid_info, | ||
constant_value=constant_value, | ||
) | ||
|
||
|
||
def linear_ramp_mode(cell_ids, neighbours, grid_info, end_value): | ||
# algorithm: for each padded cell find the closest edge cell and the distance | ||
pass | ||
|
||
|
||
def edge_mode(cell_ids, neighbours, grid_info): | ||
# algorithm: for each padded cell find the closest edge cell | ||
pass | ||
|
||
|
||
def reflect_mode(cell_ids, neighbours, grid_info): | ||
# algorithm: for each padded cell, find the closest edge cell and the distance, then take the index of the cell that in the same distance and direction as the edge cell from the padded cell | ||
pass | ||
|
||
|
||
def pad( | ||
cell_ids, | ||
*, | ||
grid_info, | ||
ring, | ||
mode="constant", | ||
constant_value=0, | ||
end_value=0, | ||
reflect_type="even", | ||
): | ||
"""pad an array | ||
Parameters | ||
---------- | ||
cell_ids : array-like | ||
The cell ids. | ||
grid_info : xdggs.DGGSInfo | ||
The grid parameters. | ||
ring : int | ||
The pad width in rings around the input domain. Must be 0 or positive. | ||
mode : str, default: "constant" | ||
The padding mode. Can be one of: | ||
- "constant": fill the padded cells with ``constant_value``. | ||
- "linear_ramp": linearly interpolate the padded cells from the edge of the array | ||
to ``end_value``. For ring 1, this is the same as ``mode="constant"`` | ||
- "edge": fill the padded cells with the values at the edge of the array. | ||
- "reflect": pad with the reflected values. | ||
constant_value : scalar, default: 0 | ||
The constant value used in constant mode. | ||
end_value : scalar, default: 0 | ||
The othermost value to interpolate to. Only used in linear ramp mode. | ||
reflect_type : {"even", "odd"}, default: "even" | ||
The reflect type. Only used in reflect mode. | ||
Returns | ||
------- | ||
padding_object : Padding | ||
The padding object. Can be used to apply the same padding operation for different | ||
arrays with the same geometry. | ||
""" | ||
# TODO: figure out how to allow reusing indices. How this works depends on the mode: | ||
# - in constant mode, we have: | ||
# * an array of new cell ids | ||
# * an array of indices that indicate where to insert them | ||
# * and the constant value | ||
# - in the case of linear ramp, we have: | ||
# * the new cell ids | ||
# * an array of indices that indicate where to insert them | ||
# * the value of where the ramp should end | ||
# * the distance of each cell from the edge of the array | ||
# - in all other cases, we have: | ||
# * an array of new cell ids | ||
# * an array of indices that indicate where to insert them | ||
# * an array of indices that map existing values to the new cell ids | ||
# To be able to reuse this, we need a set of dataclasses that can encapsulate that, | ||
# plus a method to apply the padding to data. | ||
neighbours = search_neighbours( | ||
cell_ids, | ||
resolution=grid_info.resolution, | ||
indexing_scheme=grid_info.indexing_scheme, | ||
ring=ring, | ||
) | ||
|
||
modes = { | ||
"constant": partial(constant_mode, constant_value=constant_value), | ||
"linear_ramp": partial(linear_ramp_mode, end_value=end_value, ring=ring), | ||
"edge": edge_mode, | ||
"reflect_mode": partial(reflect_mode, reflect_type=reflect_type), | ||
} | ||
|
||
mode_func = modes.get(mode) | ||
if mode_func is None: | ||
raise ValueError(f"unknown mode: {mode}") | ||
|
||
return mode_func(cell_ids, neighbours, grid_info) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import numpy as np | ||
import pytest | ||
import xdggs | ||
|
||
from healpix_convolution import padding | ||
|
||
try: | ||
import dask.array as da | ||
except ImportError: | ||
da = None | ||
dask_array_type = () | ||
|
||
dask_available = False | ||
else: | ||
dask_array_type = da.Array | ||
dask_available = True | ||
|
||
requires_dask = pytest.mark.skipif(not dask_available, reason="requires dask") | ||
|
||
|
||
class TestArray: | ||
@pytest.mark.parametrize("dask", (False, pytest.param(True, marks=requires_dask))) | ||
@pytest.mark.parametrize( | ||
["ring", "mode", "kwargs", "expected_cell_ids", "expected_data"], | ||
( | ||
pytest.param( | ||
1, | ||
"constant", | ||
{"constant_value": np.nan}, | ||
np.array([163, 166, 167, 169, 171, 172, 173, 174, 175, 178, 184, 186]), | ||
np.array( | ||
[ | ||
np.nan, | ||
np.nan, | ||
np.nan, | ||
np.nan, | ||
np.nan, | ||
1, | ||
1, | ||
np.nan, | ||
np.nan, | ||
np.nan, | ||
np.nan, | ||
np.nan, | ||
] | ||
), | ||
id="constant-ring1-nan", | ||
), | ||
pytest.param( | ||
2, | ||
"constant", | ||
{"constant_value": 0}, | ||
np.array( | ||
[ | ||
160, | ||
161, | ||
162, | ||
163, | ||
164, | ||
165, | ||
166, | ||
167, | ||
168, | ||
169, | ||
170, | ||
171, | ||
172, | ||
173, | ||
174, | ||
175, | ||
176, | ||
177, | ||
178, | ||
179, | ||
184, | ||
185, | ||
186, | ||
187, | ||
853, | ||
855, | ||
861, | ||
863, | ||
885, | ||
887, | ||
] | ||
), | ||
np.array( | ||
[ | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
1, | ||
1, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
] | ||
), | ||
id="constant-ring2-0", | ||
), | ||
), | ||
) | ||
def test_pad(self, dask, ring, mode, kwargs, expected_cell_ids, expected_data): | ||
grid_info = xdggs.healpix.HealpixInfo(resolution=4, indexing_scheme="nested") | ||
cell_ids = np.array([172, 173]) | ||
|
||
if not dask: | ||
data = np.full_like(cell_ids, fill_value=1) | ||
else: | ||
import dask.array as da | ||
|
||
data = da.full_like(cell_ids, fill_value=1, chunks=(1,)) | ||
|
||
padder = padding.pad( | ||
cell_ids, grid_info=grid_info, ring=ring, mode=mode, **kwargs | ||
) | ||
actual = padder.apply(data) | ||
|
||
if dask: | ||
assert isinstance(actual, da.Array) | ||
|
||
np.testing.assert_equal(padder.cell_ids, expected_cell_ids) | ||
np.testing.assert_equal(actual, expected_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters