Skip to content

Commit

Permalink
Merge pull request #344 from xylar/improve_interp
Browse files Browse the repository at this point in the history
Use an internal bilinear interpolation instead of scipy for meshDensity and bathymetry

The `scipy` interpolators are quite slow and inefficient.  For bilinear interpolation on a regular grid, it is much more efficient to use `numpy.interp` in 1D to set up coefficients for doing 2D, bilinear interpolation.  A new function `mpas_tools.mesh.interpolation.interp_bilin()` is added to do this.
  • Loading branch information
xylar authored Sep 4, 2020
2 parents 345692b + eae5e28 commit bfd3f9f
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 35 deletions.
70 changes: 70 additions & 0 deletions conda_package/mpas_tools/mesh/interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np


def interp_bilin(x, y, field, xCell, yCell):
"""
Perform bilinear interpolation of ``field`` on a tensor grid to cell centers
on an MPAS mesh. ``xCell`` and ``yCell`` must be bounded by ``x`` and ``y``,
respectively.
If x and y coordinates are longitude and latitude, respectively, it is
recommended that they be passed in degrees to avoid round-off problems at
the north and south poles and at the date line.
Parameters
----------
x : ndarray
x coordinate of the input field (length n)
y : ndarray
y coordinate fo the input field (length m)
field : ndarray
a field of size m x n
xCell : ndarray
x coordinate of MPAS cell centers
yCell : ndarray
y coordinate of MPAS cell centers
Returns
-------
mpasField : ndarray
``field`` interpoyed to MPAS cell centers
"""

assert np.all(xCell >= x[0])
assert np.all(xCell <= x[-1])
assert np.all(yCell >= y[0])
assert np.all(yCell <= y[-1])

# find float indices into the x and y arrays of cells on the MPAS mesh
xFrac = np.interp(xCell, x, np.arange(len(x)))
yFrac = np.interp(yCell, y, np.arange(len(y)))

# xIndices/yIndices are the integer indices of the lower bound for bilinear
# interpoyion; xFrac/yFrac are the fraction of the way ot the next index
xIndices = np.array(xFrac, dtype=int)
xFrac -= xIndices
yIndices = np.array(yFrac, dtype=int)
yFrac -= yIndices

# If points are exactly at the upper index, this is going to give us a bit
# of trouble so we'll move them down one index and adjust the fraction
# accordingly
mask = xIndices == len(x)
xIndices[mask] -= 1
xFrac[mask] += 1.

mask = yIndices == len(y)
yIndices[mask] -= 1
yFrac[mask] += 1.

mpasField = \
(1. - xFrac) * (1. - yFrac) * field[yIndices, xIndices] + \
xFrac * (1. - yFrac) * field[yIndices, xIndices + 1] + \
(1. - xFrac) * yFrac * field[yIndices + 1, xIndices] + \
xFrac * yFrac * field[yIndices + 1, xIndices + 1]

return mpasField
8 changes: 2 additions & 6 deletions conda_package/mpas_tools/ocean/inject_bathymetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
unicode_literals

from mpas_tools.mesh.creation.open_msh import readmsh
from mpas_tools.mesh.interpolation import interp_bilin
import numpy as np
from scipy import interpolate
import netCDF4 as nc4
Expand Down Expand Up @@ -84,13 +85,8 @@ def interpolate_SRTM(lon_pts, lat_pts):
idx = np.intersect1d(lon_idx, lat_idx)
xpts = lon_pts[idx]
ypts = lat_pts[idx]
xy_pts = np.vstack((xpts, ypts)).T

# Interpolate bathymetry onto points
bathy = interpolate.RegularGridInterpolator(
(xdata, ydata), zdata.T, bounds_error=False, fill_value=np.nan)
bathy_int = bathy(xy_pts)
bathymetry[idx] = bathy_int
bathymetry[idx] = interp_bilin(xdata, ydata, zdata, xpts, ypts)

end = timeit.default_timer()
print(end - start, " seconds")
Expand Down
43 changes: 14 additions & 29 deletions conda_package/mpas_tools/ocean/inject_meshDensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
unicode_literals

import numpy as np
from scipy import interpolate
import netCDF4 as nc4
import sys

from mpas_tools.mesh.interpolation import interp_bilin


def inject_meshDensity_from_file(cw_filename, mesh_filename, on_sphere=True):
"""
Expand Down Expand Up @@ -74,37 +75,26 @@ def inject_spherical_meshDensity(cellWidth, lon, lat, mesh_filename):
The mesh file to add ``meshDensity`` to
"""

# Add extra column in longitude to interpolate over the Date Line
cellWidth = np.concatenate(
(cellWidth, cellWidth[:, 0:1]), axis=1)
LonPos = np.deg2rad(np.concatenate(
(lon.T, lon.T[0:1] + 360)))
LatPos = np.deg2rad(lat.T)
# set max lat position to be exactly at North Pole to avoid interpolation
# errors
LatPos[np.argmax(LatPos)] = np.pi / 2.0
minCellWidth = cellWidth.min()
meshDensityVsXY = (minCellWidth / cellWidth)**4
print(' minimum cell width in grid definition: {0:.0f} km'.format(
minCellWidth))
print(' maximum cell width in grid definition: {0:.0f} km'.format(
cellWidth.max()))

X, Y = np.meshgrid(LonPos, LatPos)

print('Open unstructured MPAS mesh file...')
ds = nc4.Dataset(mesh_filename, 'r+')
meshDensity = ds.variables['meshDensity']
lonCell = ds.variables['lonCell'][:]
latCell = ds.variables['latCell'][:]

print('Preparing interpolation of meshDensity from native coordinates to mesh...')
meshDensityInterp = interpolate.LinearNDInterpolator(
np.vstack((X.ravel(), Y.ravel())).T, meshDensityVsXY.ravel())
lonCell = np.mod(np.rad2deg(lonCell) + 180., 360.) - 180.
latCell = np.rad2deg(latCell)

print('Interpolating and writing meshDensity...')
meshDensity[:] = meshDensityInterp(
np.vstack((np.mod(ds.variables['lonCell'][:] + np.pi,
2*np.pi) - np.pi,
ds.variables['latCell'][:])).T)
mpasMeshDensity = interp_bilin(lon, lat, meshDensityVsXY, lonCell, latCell)

meshDensity[:] = mpasMeshDensity

ds.close()

Expand All @@ -131,21 +121,16 @@ def inject_planar_meshDensity(cellWidth, x, y, mesh_filename):
print(' minimum cell width in grid definition: {0:.0f} km'.format(minCellWidth))
print(' maximum cell width in grid definition: {0:.0f} km'.format(cellWidth.max()))

X, Y = np.meshgrid(x, y)

print('Open unstructured MPAS mesh file...')
ds = nc4.Dataset(mesh_filename, 'r+')
meshDensity = ds.variables['meshDensity']

print('Preparing interpolation of meshDensity from native coordinates to mesh...')
meshDensityInterp = interpolate.LinearNDInterpolator(
np.vstack((X.ravel(), Y.ravel())).T, meshDensityVsXY.ravel())
xCell = ds.variables['xCell'][:]
yCell = ds.variables['xCell'][:]

print('Interpolating and writing meshDensity...')
meshDensity[:] = meshDensityInterp(
np.vstack(
(ds.variables['xCell'][:],
ds.variables['yCell'][:])).T)
mpasMeshDensity = interp_bilin(x, y, meshDensityVsXY, xCell, yCell)

meshDensity[:] = mpasMeshDensity

ds.close()

Expand Down

0 comments on commit bfd3f9f

Please sign in to comment.