Skip to content

Commit

Permalink
Improve NaN replacement when mapping
Browse files Browse the repository at this point in the history
Non-finite values are replaced with the average of neighbouring pixels before mapping, rather than 0. This should reduce artefacts when using spline interpolation.

#382
  • Loading branch information
ortk95 committed Jul 16, 2024
1 parent de115d6 commit f7d5529
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions planetmapper/body_xy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import pyproj
import scipy.interpolate
import scipy.ndimage
from matplotlib.axes import Axes
from matplotlib.collections import QuadMesh
from matplotlib.figure import Figure
Expand Down Expand Up @@ -1269,7 +1270,7 @@ def map_img(
interpolation = spline_k[interpolation]

if interpolation == 'nearest':
nan_sentinel = -999
nan_sentinel = -999 # x_map and y_map are always >= 0
x_map = np.asarray(
np.nan_to_num(np.round(x_map), nan=nan_sentinel), dtype=int
)
Expand All @@ -1288,10 +1289,7 @@ def map_img(
else:
kx, ky = interpolation
nans = np.isnan(img)
if np.any(np.isnan(img)):
if warn_nan:
print('Warning, image contains NaN values which will be set to 0')
img = np.nan_to_num(img)
img = self._replace_nans_with_interpolated_values(img, warn_nan)
interpolator = scipy.interpolate.RectBivariateSpline(
np.arange(img.shape[0]),
np.arange(img.shape[1]),
Expand Down Expand Up @@ -1339,6 +1337,32 @@ def _should_propagate_nan_to_map(
def _xy_in_image_frame(self, x: float, y: float) -> bool:
return (-0.5 < x < self._nx - 0.5) and (-0.5 < y < self._ny - 0.5)

def _replace_nans_with_interpolated_values(self, img: np.ndarray, warn_nan:bool) -> np.ndarray:
"""
Return a copy of the input image where NaNs are replaced with the mean of
surrounding non-NaN pixels (3x3 footprint). All other NaNs are replaced with the
median of the original data.
This is mainly useful for preparing an input image before passing it through
e.g. RectBivariateSpline (which doesn't accept NaNs, and may produce artefacts
if we just use nan_to_num).
"""
# XXX test
nans = ~np.isfinite(img)
if warn_nan and np.any(nans):
print('Warning, image contains NaN values which will be corrected')
median = np.nanmedian(img)
if not np.isfinite(median):
median = 0
cleaned = img.copy()
cleaned[nans] = median
to_fix = nans & ~scipy.ndimage.uniform_filter(nans, size=3) # type: ignore
for i, j in np.argwhere(to_fix):
cleaned[i, j] = np.nanmean(
img[max(i - 1, 0) : i + 2, max(j - 1, 0) : j + 2]
)
return cleaned

# Plotting
def plot_wireframe_xy(
self,
Expand Down

0 comments on commit f7d5529

Please sign in to comment.