Skip to content

Commit

Permalink
Merge pull request #95 from rcjackson/vorticity_fix
Browse files Browse the repository at this point in the history
FIX: Unit testing for Jax/TensorFlow cost functions.
  • Loading branch information
rcjackson authored Aug 30, 2023
2 parents fb9e9f9 + 594011d commit c42663a
Show file tree
Hide file tree
Showing 12 changed files with 835 additions and 249 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
os: [macOS, ubuntu]
inlcude:
- os: macos-latest
Expand Down
34 changes: 18 additions & 16 deletions examples/plot_fun_with_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,37 +34,36 @@
berr_grid = pydda.initialization.make_constant_wind_field(
berr_grid, (0.0, 0.0, 0.0))

# Let's make a plot on a map
fig = plt.figure(figsize=(7, 3))

pydda.vis.plot_xz_xsection_streamlines(
[cpol_grid, berr_grid], bg_grid_no=-1, level=50, w_vel_contours=[1, 3, 5, 8])
plt.show()

# Let's provide an initial state from the sounding
u_back = sounding[1].u_wind
v_back = sounding[1].v_wind
z_back = sounding[1].height
cpol_grid = pydda.initialization.make_wind_field_from_profile(cpol_grid, sounding[1])

new_grids, _ = pydda.retrieval.get_dd_wind_field([cpol_grid, berr_grid],

new_grids, _ = pydda.retrieval.get_dd_wind_field([cpol_grid, berr_grid],
u_back=u_back, v_back=v_back, z_back=z_back,
Co=10.0, Cm=4096.0, frz=5000.0, Cb=1e-6,
mask_outside_opt=False, wind_tol=0.2,
Co=1.0, Cm=64.0, frz=5000.0, Cb=1e-5,
Cx=1e2, Cy=1e2, Cz=1e2,
mask_outside_opt=False, wind_tol=0.1,
engine="tensorflow")
fig = plt.figure(figsize=(7, 7))

pydda.vis.plot_xz_xsection_streamlines(
new_grids, bg_grid_no=-1, level=50, w_vel_contours=[1, 3, 5, 8])
plt.show()

# Let's see what happens when we use a zero initialization
# This causes there to be convergence in the cone of silence
# This is an artifact that we want to avoid!
# Prescribing winds inside the background through either a constraint
# Or through the initial state will help mitigate this issue.
cpol_grid = pydda.initialization.make_constant_wind_field(
berr_grid, (0.0, 0.0, 0.0))
cpol_grid, (0.0, 0.0, 0.0))
new_grids, _ = pydda.retrieval.get_dd_wind_field([cpol_grid, berr_grid],
u_back=u_back, v_back=v_back, z_back=z_back,
Co=1.0, Cm=128.0, frz=5000.0, Cb=1e-6,
mask_outside_opt=False, wind_tol=0.2,
Co=1.0, Cm=64.0, frz=5000.0, Cb=1e-5,
Cx=1e2, Cy=1e2, Cz=1e2,
mask_outside_opt=False, wind_tol=0.5,
engine="tensorflow")

fig = plt.figure(figsize=(7, 7))
Expand All @@ -74,9 +73,12 @@
plt.show()

# Or, let's make the radar data more important!
cpol_grid = pydda.initialization.make_wind_field_from_profile(cpol_grid, sounding[1])
new_grids, _ = pydda.retrieval.get_dd_wind_field([cpol_grid, berr_grid],
Co=100.0, Cm=128.0, frz=5000.0,
mask_outside_opt=False, wind_tol=0.2,
Co=10.0, Cm=64.0, frz=5000.0,
u_back=u_back, v_back=v_back, z_back=z_back, Cb=1e-5,
Cx=1e2, Cy=1e2, Cz=1e2,
mask_outside_opt=False, wind_tol=0.1,
engine="tensorflow")
fig = plt.figure(figsize=(7, 7))

Expand Down
307 changes: 203 additions & 104 deletions notebooks/PyDDA example notebook.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pydda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from . import tests
from . import constraints

__version__ = '1.3.1'
__version__ = '1.4.0'

print("Welcome to PyDDA %s" % __version__)
print("If you are using PyDDA in your publications, please cite:")
print("Jackson et al. (2020) Journal of Open Research Science")
print("Detecting Jax...")
try:
import jax
Expand Down
4 changes: 3 additions & 1 deletion pydda/cost_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@
calculate_point_gradient
"""


from . import _cost_functions_tensorflow as tf
from . import _cost_functions_numpy as np
from . import _cost_functions_jax as jax
from ._cost_functions_numpy import calculate_radial_vel_cost_function
from ._cost_functions_numpy import calculate_fall_speed
from ._cost_functions_numpy import calculate_grad_radial_vel
Expand Down
68 changes: 39 additions & 29 deletions pydda/cost_functions/_cost_functions_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,17 @@ def calculate_smoothness_cost(u, v, w, dx, dy, dz, Cx=1e-5, Cy=1e-5, Cz=1e-5):
dwdz = jnp.gradient(w, dz, axis=0)

x_term = Cx * (
jnp.gradient(dudx, dx, axis=2) ** 2 +
jnp.gradient(dvdx, dx, axis=1) ** 2 +
jnp.gradient(dwdx, dx, axis=2) ** 2)
jnp.gradient(dudx, dx, axis=2) +
jnp.gradient(dvdx, dx, axis=1) +
jnp.gradient(dwdx, dx, axis=2)) ** 2
y_term = Cy * (
jnp.gradient(dudy, dy, axis=2) ** 2 +
jnp.gradient(dvdy, dy, axis=1) ** 2 +
jnp.gradient(dwdy, dy, axis=2) ** 2)
jnp.gradient(dudy, dy, axis=2) +
jnp.gradient(dvdy, dy, axis=1) +
jnp.gradient(dwdy, dy, axis=2)) ** 2
z_term = Cz * (
jnp.gradient(dudz, dz, axis=2) ** 2 +
jnp.gradient(dvdz, dz, axis=1) ** 2 +
jnp.gradient(dwdz, dz, axis=2) ** 2)
jnp.gradient(dudz, dz, axis=2) +
jnp.gradient(dvdz, dz, axis=1) +
jnp.gradient(dwdz, dz, axis=2)) ** 2
return np.asanyarray(jnp.sum(x_term + y_term + z_term))


Expand Down Expand Up @@ -229,15 +229,15 @@ def calculate_smoothness_gradient(u, v, w, dx, dy, dz, Cx=1e-5, Cy=1e-5, Cz=1e-5
grad_u = np.zeros(w.shape)
grad_v = np.zeros(w.shape)
grad_w = np.zeros(w.shape)
scipy.ndimage.filters.laplace(u, du, mode='wrap')
scipy.ndimage.filters.laplace(v, dv, mode='wrap')
scipy.ndimage.filters.laplace(w, dw, mode='wrap')
scipy.ndimage.laplace(u, du, mode='wrap')
scipy.ndimage.laplace(v, dv, mode='wrap')
scipy.ndimage.laplace(w, dw, mode='wrap')
du = du / dx
dv = dv / dy
dw = dw / dz
scipy.ndimage.filters.laplace(du, grad_u, mode='wrap')
scipy.ndimage.filters.laplace(dv, grad_v, mode='wrap')
scipy.ndimage.filters.laplace(dw, grad_w, mode='wrap')
scipy.ndimage.laplace(du, grad_u, mode='wrap')
scipy.ndimage.laplace(dv, grad_v, mode='wrap')
scipy.ndimage.laplace(dw, grad_w, mode='wrap')

grad_u = grad_u / du
grad_v = grad_v / dy
Expand Down Expand Up @@ -289,10 +289,9 @@ def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0):
"""
J = 0.0
for the_point in point_list:
the_box = jnp.where(np.logical_and.reduce(
(jnp.abs(x - the_point["x"]) < roi,
jnp.abs(y - the_point["y"]) < roi,
jnp.abs(z - the_point["z"]) < roi)))
the_box = jnp.logical_and(
jnp.logical_and(jnp.abs(x - the_point["x"]) < roi,
jnp.abs(y - the_point["y"]) < roi), jnp.abs(z - the_point["z"]) < roi)
J += jnp.sum(
((u[the_box] - the_point["u"]) ** 2 +
(v[the_box] - the_point["v"]) ** 2))
Expand Down Expand Up @@ -336,13 +335,19 @@ def calculate_point_gradient(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0):
"""

primals, fun_vjp = jax.vjp(
calculate_point_gradient, u, v, x, y, z, point_list, Cp, roi)
grad_u, grad_v, _, _, _, _, _, _, _ = fun_vjp(1.0)
gradJ_u = jnp.zeros_like(u)
gradJ_v = jnp.zeros_like(v)
gradJ_w = jnp.zeros_like(u)

for the_point in point_list:
the_box = jnp.where(jnp.logical_and(jnp.logical_and(
np.abs(x - the_point["x"]) < roi, np.abs(y - the_point["y"]) < roi),
np.abs(z - the_point["z"]) < roi), 1., 0.)
gradJ_u += 2 * (u - the_point["u"]) * the_box
gradJ_v += 2 * (v - the_point["v"]) * the_box

gradJ_w = jnp.zeros_like(grad_u)
gradJ = jnp.stack([grad_u, grad_v, gradJ_w], axis=0)
return np.copy(gradJ.flatten()) * Cp
gradJ = jnp.stack([gradJ_u, gradJ_v, gradJ_w], axis=0).flatten()
return gradJ * Cp


def calculate_mass_continuity(u, v, w, z, dx, dy, dz, coeff=1500.0, anel=1):
Expand Down Expand Up @@ -582,11 +587,12 @@ def calculate_vertical_vorticity_cost(u, v, w, dx, dy, dz, Ut, Vt,
jv_array = ((u - Ut) * dzeta_dx + (v - Vt) * dzeta_dy +
w * dzeta_dz + (dvdz * dwdx - dudz * dwdy) +
zeta * (dudx + dvdy))

return jnp.sum(coeff * jv_array ** 2)


def calculate_vertical_vorticity_gradient(u, v, w, dx, dy, dz, Ut, Vt,
coeff=1e-5):
coeff=1e-5, upper_bc=True):
"""
Calculates the gradient of the cost function due to deviance from vertical
vorticity equation. This is done by taking the functional derivative of
Expand Down Expand Up @@ -637,6 +643,10 @@ def calculate_vertical_vorticity_gradient(u, v, w, dx, dy, dz, Ut, Vt,
calculate_vertical_vorticity_cost, u, v, w, dx, dy,
dz, Ut, Vt, coeff)
u_grad, v_grad, w_grad, _, _, _, _, _, _ = fun_vjp(1.0)
# Impermeability condition
w_grad.at[0, :, :].set(0)
if(upper_bc is True):
w_grad.at[-1, :, :].set(0)
y = np.stack([u_grad, v_grad, w_grad], axis=0)
return y.flatten().copy()

Expand Down Expand Up @@ -715,10 +725,10 @@ def calculate_model_gradient(u, v, w, weights, u_model,
Returns
-------
y: float array
value of gradient of background cost function
value of gradient of model cost function
"""
primals, fun_vjp = jax.vjp(
calculate_model_cost, u, v, w, u_model, v_model, w_model, coeff)
u_grad, v_grad, w_grad, _, _, _, _, _, _ = fun_vjp(1.0)
calculate_model_cost, u, v, w, weights, u_model, v_model, w_model, coeff)
u_grad, v_grad, w_grad, _, _, _, _, _ = fun_vjp(1.0)
y = np.stack([u_grad, v_grad, w_grad], axis=0)
return y.flatten().copy()
20 changes: 12 additions & 8 deletions pydda/cost_functions/_cost_functions_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ def calculate_smoothness_gradient(u, v, w, dx, dy, dz, Cx=1e-5, Cy=1e-5, Cz=1e-5
grad_u = np.zeros(w.shape)
grad_v = np.zeros(w.shape)
grad_w = np.zeros(w.shape)
scipy.ndimage.filters.laplace(u, du, mode='wrap')
scipy.ndimage.filters.laplace(v, dv, mode='wrap')
scipy.ndimage.filters.laplace(w, dw, mode='wrap')
scipy.ndimage.laplace(u, du, mode='wrap')
scipy.ndimage.laplace(v, dv, mode='wrap')
scipy.ndimage.laplace(w, dw, mode='wrap')
du = du / dx
dv = dv / dy
dw = dw / dz
scipy.ndimage.filters.laplace(du, grad_u, mode='wrap')
scipy.ndimage.filters.laplace(dv, grad_v, mode='wrap')
scipy.ndimage.filters.laplace(dw, grad_w, mode='wrap')
scipy.ndimage.laplace(du, grad_u, mode='wrap')
scipy.ndimage.laplace(dv, grad_v, mode='wrap')
scipy.ndimage.laplace(dw, grad_w, mode='wrap')
grad_u = grad_u / dx
grad_v = grad_v / dy
grad_w = grad_w / dz
Expand Down Expand Up @@ -611,7 +611,7 @@ def calculate_vertical_vorticity_cost(u, v, w, dx, dy, dz, Ut, Vt,


def calculate_vertical_vorticity_gradient(u, v, w, dx, dy, dz, Ut, Vt,
coeff=1e-5):
coeff=1e-5, upper_bc=True):
"""
Calculates the gradient of the cost function due to deviance from vertical
vorticity equation. This is done by taking the functional derivative of
Expand Down Expand Up @@ -702,7 +702,11 @@ def calculate_vertical_vorticity_gradient(u, v, w, dx, dy, dz, Ut, Vt,
u_grad = u_grad * 2 * dzeta_dt * coeff
v_grad = v_grad * 2 * dzeta_dt * coeff
w_grad = w_grad * 2 * dzeta_dt * coeff


# Impermeability condition
w_grad[0, :, :] = 0
if(upper_bc is True):
w_grad[-1, :, :] = 0
y = np.stack([u_grad, v_grad, w_grad], axis=0)
return y.flatten()

Expand Down
Loading

0 comments on commit c42663a

Please sign in to comment.