Skip to content

Commit

Permalink
[FFTConvolve] Reduce pad_width for "constant" boundary conditions
Browse files Browse the repository at this point in the history
Also adds comments on padding and todo's.
  • Loading branch information
joanrue authored and SepandKashani committed Aug 17, 2024
1 parent 11d4b4b commit 7f1120b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
6 changes: 6 additions & 0 deletions dev/todo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ NumPy 2.0 Migration
doc/
api/ has/is-being updated to ND API interactively.

pyxu.operator.linop.fft.filter
FFTConvolve/FFTCorrelate overpad for boundary conditions != "constant".
Currently there is two padding actions done in the input array:
1. op._pad() pads the N-sized input array to N_pad with the desired boundary conditions (reflect, periodic, etc.),
2. fft's in the _stencil_chain used apply/adjoint further zero-pad to shape (N_pad+K-1).

pyxu.abc.solver
Allow not writing data to disk at end.
writeback_rate = None -> don't write (ever)
Expand Down
34 changes: 22 additions & 12 deletions src/pyxu/operator/linop/fft/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class FFTCorrelate(Stencil):
.. rubric:: Implementation Notes
* :py:class:`~pyxu.operator.FFTCorrelate` can scale to much larger kernels than :py:class:`~pyxu.operator.Stencil`.
* This implementation is most efficient with "constant" boundary conditions (default).
* Kernels must be small enough to fit in memory, i.e. unbounded kernels are not allowed.
* Kernels should be supplied an NUMPY/CUPY arrays. DASK arrays will be evaluated if provided.
* :py:class:`~pyxu.operator.FFTCorrelate` instances are **not arraymodule-agnostic**: they will only work with
Expand Down Expand Up @@ -117,15 +118,19 @@ def _compute_pad_width(_kernel, _center, _mode) -> Pad.WidthSpec:
N = _kernel[0].ndim
pad_width = [None] * N
for i in range(N):
if len(_kernel) == 1: # non-seperable filter
n = _kernel[0].shape[i]
else: # seperable filter(s)
n = _kernel[i].size

# 1. Pad/Trim operators are shared amongst [apply,adjoint]():
# lhs/rhs are thus padded equally.
# 2. Pad width must match kernel dimensions to retain border effects.
pad_width[i] = (n - 1, n - 1)
if _mode[i] == "constant":
# FFT already implements padding with zeros to size N+K-1.
pad_width[i] = (0, 0)
else:
if len(_kernel) == 1: # non-seperable filter
n = _kernel[0].shape[i]
else: # seperable filter(s)
n = _kernel[i].size

# 1. Pad/Trim operators are shared amongst [apply,adjoint]():
# lhs/rhs are thus padded equally.
# 2. Pad width must match kernel dimensions to retain border effects.
pad_width[i] = (n - 1, n - 1)
return tuple(pad_width)

@staticmethod
Expand Down Expand Up @@ -205,9 +210,14 @@ def _chain(x, stencils, fft_kwargs):
# Compute (depth,boundary) values for [overlap,trim_internal]()
N_stack = x.ndim - self.dim_rank
depth = {ax: 0 for ax in range(x.ndim)}
for ax, (p_lhs, p_rhs) in enumerate(self._pad._pad_width):
assert p_lhs == p_rhs, "top-level Pad() should be symmetric."
depth[N_stack + ax] = p_lhs
for ax in range(self.dim_rank):
if len(stencils) == 1: # non-seperable filter
n = stencils[0]._kernel.shape[ax]
else: # seperable filter(s)
n = stencils[ax]._kernel.size
c = stencils[ax]._center[ax]
max_dist = max(c, n - c)
depth[N_stack + ax] = max_dist
boundary = 0

xp = ndi.module()
Expand Down

0 comments on commit 7f1120b

Please sign in to comment.