Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Sep 4, 2024
1 parent 803a60d commit 6d43c59
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
16 changes: 8 additions & 8 deletions demos/with_automatic_differentiation/run_forward_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,20 @@
}


def make_vp_circle(vp_guess=False, plot_vp=False):
def make_c_camembert(c_guess=False, plot_c=False):
"""Acoustic velocity model"""
x, z = fire.SpatialCoordinate(mesh)
if vp_guess:
vp = fire.Function(V).interpolate(1.5 + 0.0 * x)
if c_guess:
c = fire.Function(V).interpolate(1.5 + 0.0 * x)
else:
vp = fire.Function(V).interpolate(
c = fire.Function(V).interpolate(
2.5
+ 1 * fire.tanh(100 * (0.125 - fire.sqrt((x - 0.5) ** 2 + (z - 0.5) ** 2)))
)
if plot_vp:
if plot_c:
outfile = fire.VTKFile("acoustic_cp.pvd")
outfile.write(vp)
return vp
outfile.write(c)
return c


# Use emsemble parallelism.
Expand All @@ -95,7 +95,7 @@ def make_vp_circle(vp_guess=False, plot_vp=False):

forward_solver = spyro.solvers.forward_ad.ForwardSolver(model, mesh, V)

c_true = make_vp_circle()
c_true = make_c_camembert()
# Ricker wavelet
wavelet = spyro.full_ricker_wavelet(
model["timeaxis"]["dt"], model["timeaxis"]["tf"],
Expand Down
18 changes: 9 additions & 9 deletions test/test_gradient_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,20 @@
}


def make_vp_circle(V, mesh, vp_guess=False, plot_vp=False):
def make_c_camembert(V, mesh, c_guess=False, plot_c=False):
"""Acoustic velocity model"""
x, z = fire.SpatialCoordinate(mesh)
if vp_guess:
vp = fire.Function(V).interpolate(1.5 + 0.0 * x)
if c_guess:
c = fire.Function(V).interpolate(1.5 + 0.0 * x)
else:
vp = fire.Function(V).interpolate(
c = fire.Function(V).interpolate(
2.5
+ 1 * fire.tanh(100 * (0.125 - fire.sqrt((x - 0.5) ** 2 + (z - 0.5) ** 2)))
)
if plot_vp:
if plot_c:
outfile = fire.VTKFile("acoustic_cp.pvd")
outfile.write(vp)
return vp
outfile.write(c)
return c


def forward(
Expand Down Expand Up @@ -112,11 +112,11 @@ def test_taylor():
model["timeaxis"]["dt"], model["timeaxis"]["tf"],
model["acquisition"]["frequency"],
)
c_true = make_vp_circle(V, mesh)
c_true = make_c_camembert(V, mesh)
true_rec, _ = forward(c_true, fwd_solver, wavelet, my_ensemble)

# --- Gradient with AD --- #
c_guess = make_vp_circle(V, mesh, vp_guess=True)
c_guess = make_c_camembert(V, mesh, c_guess=True)
_, J = forward(
c_guess, fwd_solver, wavelet, my_ensemble,
compute_functional=True,
Expand Down

0 comments on commit 6d43c59

Please sign in to comment.