Skip to content

Commit

Permalink
Merge pull request #383 from ortk95/382-mapping-interpolation
Browse files Browse the repository at this point in the history
Improve mapping for images containing NaN values
  • Loading branch information
ortk95 authored Jul 18, 2024
2 parents de115d6 + b2544e0 commit c36b8e9
Show file tree
Hide file tree
Showing 16 changed files with 386 additions and 257 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
pylint:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
steps:
Expand All @@ -61,6 +62,7 @@ jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
os: [ubuntu-latest, windows-latest, macos-latest]
Expand Down
201 changes: 134 additions & 67 deletions planetmapper/body_xy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import pyproj
import scipy.interpolate
import scipy.ndimage
from matplotlib.axes import Axes
from matplotlib.collections import QuadMesh
from matplotlib.figure import Figure
Expand Down Expand Up @@ -1202,49 +1203,50 @@ def map_img(
Project an observed image to a map. See :func:`generate_map_coordinates` for
details about customising the projection used.
If `interpolation` is `'linear'`, `'quadratic'` or `'cubic'`, the map projection
is performed using `scipy.interpolate.RectBivariateSpline` using the specified
degree of interpolation.
If `interpolation` is `'linear'` (the default), `'quadratic'` or `'cubic'`, the
map projection is performed using `scipy.interpolate.RectBivariateSpline` using
the specified degree of spline interpolation. This spline interpolation does not
accept NaN values in the input data, so any NaN pixels are automatically
replaced by the average value of the surrounding pixels (3x3 footprint) before
spline interpolation is performed, preventing or significantly reducing the
magnitude of any artefacts on the edge of NaN regions.
If `interpolation` is `'nearest'`, no interpolation is performed, and the mapped
image takes the value of the nearest pixel in the image to that location. This
can be useful to easily visualise the pixel scale for low spatial resolution
observations.
To map a cube, this function can be called repeatedly on each image in the cube:
::
mapped_cube = np.array([body.map_img(img) for img in cube])
See also :func:`Observation.get_mapped_data`.
Args:
img: Observed image where pixel coordinates correspond to the `xy` pixel
coordinates (e.g. those used in :func:`get_x0`).
degree_interval: Interval in degrees between the longitude/latitude points
in the mapped output. Passed to :func:`get_x_map` and :func:`get_y_map`
when generating the coordinates used for the projection.
interpolation: Interpolation used when mapping. This can be any of
`'nearest'`, `'linear'`, `'quadratic'` or `'cubic'`; the default is
`'linear'`. `'linear'`, `'quadratic'` and `'cubic'` are aliases for
spline interpolations of degree 1, 2 and 3 respectively. Alternatively,
the degree of spline interpolation can be specified manually by passing
an integer or tuple of integers. If an integer is passed, the same
coordinates (e.g. those used in :func:`get_x0`). If `img` is a data cube
(i.e. has 3 dimensions), then each image in the cube is mapped
and a mapped cube is returned, equivalent to
`np.array([body.map_img(img) for img in cube])`.
interpolation: Interpolation method used when mapping. This can be any of
`'nearest'`, `'linear'` (the default), `'quadratic'` or `'cubic'`.
`'linear'`, `'quadratic'` and `'cubic'` are aliases for spline
interpolations of degree 1, 2 and 3 respectively. Alternatively, the
degree of spline interpolation can be specified manually by passing an
integer or tuple of integers. If an integer is passed, the same
interpolation is used in both the x and y directions (i.e.
`RectBivariateSpline` with `kx = ky = interpolation`). If a tuple of
integers is passed, the first integer is used for the x direction and
the second integer is used for the y direction (i.e.
`RectBivariateSpline` with `kx, ky = interpolation`).
spline_smoothing: Smoothing factor passed to
`RectBivariateSpline(..., s=spline_smoothing)` when spline interpolation
is used. This parameter is ignored when `interpolation='nearest'`.
propagate_nan: If using spline interpolation, propagate NaN values from the
image to the mapped data. If `propagate_nan` is `True` (the default),
the interpolation is performed as normal (i.e. with NaN values in the
image set to 0), then any mapped locations where the nearest
corresponding image pixel is NaN are set to NaN. Note that there may
still be very small errors on the boundaries of NaN regions caused by
the interpolation.
is used. This can be useful to smooth over noisy data, though may hide
subtle structure if a large smoothing value is used. This parameter is
ignored if spline interpolation is not used.
propagate_nan: By default (`propagate_nan=True`) when performing spline
interpolation, areas of the map corresponding to NaN values in the image
are set to NaN, along with areas outside the convex hull of the image's
pixel centres. If `propagate_nan` is `False`, then areas corresponding
to NaN pixels will be filled with interpolated/extrapolated values where
possible, which can be useful to fill in regions of missing data.
This parameter has no effect if `interpolation='nearest'`.
warn_nan: Print warning if any values in `img` are NaN when any of the
spline interpolations are used.
**map_kwargs: Additional arguments are passed to
Expand All @@ -1255,21 +1257,49 @@ def map_img(
Array containing map of the values in `img` at each location on the surface
of the target body. Locations which are not visible or outside the
projection domain have a value of NaN.
"""
Raises:
ValueError: if the input `img` shape is inconsistent with the body's image
size.
.. versionchanged:: ?
Added more sophisticated replacement of NaN values in `img` before
performing spline interpolation, preventing or significantly reducing any
artefacts on the edge of NaN regions. NaN values are now replaced by the
average value of the surrounding non-NaN pixels, whereas previously NaN
values were always replaced with 0.
"""
img = np.asarray(img)
if img.ndim == 3:
return np.array(
[
self.map_img(
img_slice,
interpolation=interpolation,
spline_smoothing=spline_smoothing,
propagate_nan=propagate_nan,
warn_nan=warn_nan,
**map_kwargs,
)
for img_slice in img
]
)
if img.shape != (self._ny, self._nx):
raise ValueError(
f'The input `img` shape {img.shape!r} is inconsistent with '
f'the body\'s image size (ny={self._ny}, nx={self._nx})'
)

x_map = self.get_x_map(**map_kwargs)
y_map = self.get_y_map(**map_kwargs)
projected = self._make_empty_map(**map_kwargs)

spline_k = {
'linear': 1,
'quadratic': 2,
'cubic': 3,
}
spline_k = {'linear': 1, 'quadratic': 2, 'cubic': 3}
if interpolation in spline_k: # pylint: disable=consider-using-get
interpolation = spline_k[interpolation]

if interpolation == 'nearest':
nan_sentinel = -999
nan_sentinel = -999 # x_map and y_map are always >= 0
x_map = np.asarray(
np.nan_to_num(np.round(x_map), nan=nan_sentinel), dtype=int
)
Expand All @@ -1288,48 +1318,49 @@ def map_img(
else:
kx, ky = interpolation
nans = np.isnan(img)
if np.any(np.isnan(img)):
if warn_nan:
print('Warning, image contains NaN values which will be set to 0')
img = np.nan_to_num(img)
interpolator = scipy.interpolate.RectBivariateSpline(
np.arange(img.shape[0]),
np.arange(img.shape[1]),
img,
kx=kx,
ky=ky,
s=spline_smoothing, # type: ignore (docs say s is a float)
)
if not np.all(nans):
img = self._replace_nans_with_interpolated_values(img, warn_nan)
interpolator = scipy.interpolate.RectBivariateSpline(
np.arange(img.shape[0]),
np.arange(img.shape[1]),
img,
kx=kx,
ky=ky,
s=spline_smoothing, # type: ignore (docs say s is a float)
)

# Collect any coordinates to interpolate in these lists, then perform the
# interpolation at the end with a single call to interpolator.ev. This is
# directly equivalent to doing the interpolation inside the for loop with
# `projected[a, b] = interpolator(y, x).item()`, but can be much faster for
# large images.
a_vals: list[int] = []
b_vals: list[int] = []
x_vals: list[float] = []
y_vals: list[float] = []
for a, b in self._iterate_image(projected.shape):
x = x_map[a, b]
if math.isnan(x):
continue
y = y_map[a, b] # y should never be nan when x is not nan
if propagate_nan and self._should_propagate_nan_to_map(x, y, nans):
continue
a_vals.append(a)
b_vals.append(b)
x_vals.append(x)
y_vals.append(y)
projected[a_vals, b_vals] = interpolator.ev(y_vals, x_vals)
# Collect any coordinates to interpolate in these lists, then perform
# the interpolation at the end with a single call to interpolator.ev.
# This is directly equivalent to doing the interpolation inside the for
# loop with `projected[a, b] = interpolator(y, x).item()`, but can be
# much faster for large images.
a_vals: list[int] = []
b_vals: list[int] = []
x_vals: list[float] = []
y_vals: list[float] = []
for a, b in self._iterate_image(projected.shape):
x = x_map[a, b]
if math.isnan(x):
continue
y = y_map[a, b] # y should never be nan when x is not nan
if propagate_nan and self._should_propagate_nan_to_map(x, y, nans):
continue
a_vals.append(a)
b_vals.append(b)
x_vals.append(x)
y_vals.append(y)
projected[a_vals, b_vals] = interpolator.ev(y_vals, x_vals)
else:
raise ValueError(f'Unknown interpolation method {interpolation!r}')
return projected

def _should_propagate_nan_to_map(
self, x: float, y: float, nans: np.ndarray
) -> bool:
# Test if any of the four surrounding integer pixels in the image are NaN
# Test if any of the four surrounding integer pixels in the image are NaN or if
# outside the convex hull of pixel centres
if x < 0.0 or y < 0.0 or x > self._nx - 1 or y > self._ny - 1:
return True
x0 = max(math.floor(x), 0)
x1 = min(math.ceil(x), self._nx - 1)
y0 = max(math.floor(y), 0)
Expand All @@ -1339,6 +1370,42 @@ def _should_propagate_nan_to_map(
def _xy_in_image_frame(self, x: float, y: float) -> bool:
return (-0.5 < x < self._nx - 0.5) and (-0.5 < y < self._ny - 0.5)

def _replace_nans_with_interpolated_values(
self, img: np.ndarray, warn_nan: bool
) -> np.ndarray:
"""
Return a copy of the input image where NaNs are replaced with the mean of
surrounding non-NaN pixels (3x3 footprint). All other NaNs are replaced with the
median of the original data.
This is mainly useful for preparing an input image before passing it through
e.g. RectBivariateSpline (which doesn't accept NaNs, and may produce artefacts
if we just use nan_to_num).
"""
bad = ~np.isfinite(img)
if warn_nan and np.any(bad):
print('Warning, image contains NaN values which will be corrected')
cleaned = img.astype(float, copy=True)
if np.any(np.isinf(img)):
# Treat inf as nan in averaging calculations
img = np.nan_to_num(
img, nan=np.nan, posinf=np.nan, neginf=np.nan, copy=True
)
if np.all(bad):
median = 0.0
else:
median = np.nanmedian(img)
cleaned[bad] = median
# Fix bad pixels that have neighbouring good pixels by replacing them with the
# mean of the surrounding 3x3 good pixels.
# pylint: disable-next=invalid-unary-operand-type
to_fix = bad & ~scipy.ndimage.uniform_filter(bad, size=3) #  type: ignore
for i, j in np.argwhere(to_fix):
cleaned[i, j] = np.nanmean(
img[max(i - 1, 0) : i + 2, max(j - 1, 0) : j + 2]
)
return cleaned

# Plotting
def plot_wireframe_xy(
self,
Expand Down
Loading

0 comments on commit c36b8e9

Please sign in to comment.