Skip to content

Commit

Permalink
Fix backward gradients in EigenmodeCoefficient and update JAX adjoint…
Browse files Browse the repository at this point in the history
… test to check gradients in both forward and backward directions
  • Loading branch information
ianwilliamson committed Sep 28, 2021
1 parent 553e6e7 commit 27ec197
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 70 deletions.
87 changes: 48 additions & 39 deletions python/adjoint/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,6 @@ def _adj_src_scale(self, include_resolution=True):
else:
# multi frequency simulations
scale = dV * iomega / adj_src_phase
# compensate for the fact that real fields take the real part of the current,
# which halves the Fourier amplitude at the positive frequency (Re[J] = (J + J*)/2)
if self.sim.using_real_fields():
scale *= 2
return scale

def _create_time_profile(self, fwidth_frac=0.1):
Expand All @@ -114,56 +110,67 @@ def _create_time_profile(self, fwidth_frac=0.1):
)


_KPOINT_POS_DIR = mp.Vector3(0.0, 0.0, 0.0)
_KPOINT_NEG_DIR = mp.Vector3(-0.0, -0.0, -0.0)


class EigenmodeCoefficient(ObjectiveQuantity):
"""A frequency-dependent eigenmode coefficient.
Attributes:
volume: the volume over which the eigenmode coefficient is calculated.
mode: the eigenmode number.
forward: whether the forward or backward mode coefficient is returned as
the result of the evaluation.
kpoint_func: an optional k-point function to use when evaluating the eigenmode
coefficient. When specified, this overrides the effect of `forward`.
kpoint_func_overlap_idx: the index of the mode coefficient to return when
specifying `kpoint_func`. When specified, this overrides the effect of
`forward` and should have a value of either 0 or 1.
"""
def __init__(self,
sim,
volume,
mode,
forward=True,
kpoint_func=None,
decimation_factor=0,
kpoint_func_overlap_idx=0,
**kwargs):
super().__init__(sim)
if kpoint_func_overlap_idx not in [0, 1]:
raise ValueError(
'`kpoint_func_overlap_idx` should be either 0 or 1, but got %d'
% (kpoint_func_overlap_idx, ))
self.volume = volume
self.mode = mode
self.forward = forward
self.kpoint_func = kpoint_func
self.kpoint_func_overlap_idx = kpoint_func_overlap_idx
self.eigenmode_kwargs = kwargs
self._monitor = None
self._normal_direction = None
self._cscale = None
self.decimation_factor = decimation_factor

def register_monitors(self, frequencies):
self._frequencies = np.asarray(frequencies)
self._monitor = self.sim.add_mode_monitor(
frequencies,
mp.ModeRegion(center=self.volume.center, size=self.volume.size),
yee_grid=True,
decimation_factor=self.decimation_factor,
)
self._normal_direction = self._monitor.normal_direction
return self._monitor

def place_adjoint_source(self, dJ):
dJ = np.atleast_1d(dJ)
direction_scalar = -1 if self.forward else 1
time_src = self._create_time_profile()
if self.kpoint_func is None:
if self._normal_direction == 0:
k0 = direction_scalar * mp.Vector3(x=1)
elif self._normal_direction == 1:
k0 = direction_scalar * mp.Vector3(y=1)
elif self._normal_direction == 2:
k0 = direction_scalar * mp.Vector3(z=1)
else:
k0 = direction_scalar * self.kpoint_func(time_src.frequency, 1)
if dJ.ndim == 2:
dJ = np.sum(dJ, axis=1)
da_dE = 0.5 * self._cscale # scalar popping out of derivative

time_src = self._create_time_profile()
da_dE = 0.5 * self._cscale
scale = self._adj_src_scale()

if self.kpoint_func:
eig_kpoint = -1 * self.kpoint_func(time_src.frequency, self.mode)
else:
eig_kpoint = _KPOINT_NEG_DIR if self.forward else _KPOINT_POS_DIR

if self._frequencies.size == 1:
amp = da_dE * dJ * scale
src = time_src
Expand All @@ -176,12 +183,11 @@ def place_adjoint_source(self, dJ):
self.sim.fields.dt,
)
amp = 1

source = mp.EigenModeSource(
src,
eig_band=self.mode,
direction=mp.NO_DIRECTION,
eig_kpoint=k0,
direction=mp.AUTOMATIC,
eig_kpoint=eig_kpoint,
amplitude=amp,
eig_match_freq=True,
size=self.volume.size,
Expand All @@ -191,26 +197,31 @@ def place_adjoint_source(self, dJ):
return [source]

def __call__(self):
direction = mp.NO_DIRECTION if self.kpoint_func else mp.AUTOMATIC
if self.kpoint_func:
kpoint_func = self.kpoint_func
overlap_idx = self.kpoint_func_overlap_idx
else:
kpoint_func = lambda *not_used: _KPOINT_POS_DIR if self.forward else _KPOINT_NEG_DIR
overlap_idx = 0
ob = self.sim.get_eigenmode_coefficients(
self._monitor,
[self.mode],
direction=direction,
kpoint_func=self.kpoint_func,
direction=mp.AUTOMATIC,
kpoint_func=kpoint_func,
**self.eigenmode_kwargs,
)
# record eigenmode coefficients for scaling
self._eval = np.squeeze(ob.alpha[:, :, int(not self.forward)])
self._cscale = ob.cscale # pull scaling factor
overlaps = ob.alpha.squeeze(axis=0)
assert overlaps.ndim == 2
self._eval = overlaps[:, overlap_idx]
self._cscale = ob.cscale
return self._eval


class FourierFields(ObjectiveQuantity):
def __init__(self, sim, volume, component, decimation_factor=0):
def __init__(self, sim, volume, component):
super().__init__(sim)
self.volume = volume
self.component = component
self.decimation_factor = decimation_factor

def register_monitors(self, frequencies):
self._frequencies = np.asarray(frequencies)
Expand All @@ -219,7 +230,6 @@ def register_monitors(self, frequencies):
self._frequencies,
where=self.volume,
yee_grid=False,
decimation_factor=self.decimation_factor,
)
return self._monitor

Expand Down Expand Up @@ -271,7 +281,7 @@ def place_adjoint_source(self, dJ):
for yi in range(y_dim):
for xi in range(x_dim):
'''We only need to add a current source if the
jacobian is nonzero for all frequencies at
jacobian is nonzero for all frequencies at
that particular point. Otherwise, the fitting
algorithm is going to fail.
'''
Expand Down Expand Up @@ -306,19 +316,18 @@ def __call__(self):


class Near2FarFields(ObjectiveQuantity):
def __init__(self, sim, Near2FarRegions, far_pts, decimation_factor=0):
def __init__(self, sim, Near2FarRegions, far_pts):
super().__init__(sim)
self.Near2FarRegions = Near2FarRegions
self.far_pts = far_pts #list of far pts
self._nfar_pts = len(far_pts)
self.decimation_factor = decimation_factor

def register_monitors(self, frequencies):
self._frequencies = np.asarray(frequencies)
self._monitor = self.sim.add_near2far(
self._frequencies,
*self.Near2FarRegions,
decimation_factor=self.decimation_factor,
yee_grid=True,
)
return self._monitor

Expand Down Expand Up @@ -350,7 +359,7 @@ def place_adjoint_source(self, dJ):
time_src.frequency,
self._frequencies,
scale,
self.sim.fields.dt,
dt,
)
(num_basis, num_pts) = src.nodes.shape
for basis_i in range(num_basis):
Expand Down
68 changes: 38 additions & 30 deletions python/tests/test_adjoint_jax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import parameterized

from utils import ApproxComparisonTestCase
from utils import VectorComparisonMixin

import jax
import jax.numpy as jnp
Expand All @@ -16,7 +16,7 @@
_FD_STEP = 1e-4

# The tolerance for the adjoint and finite difference gradient comparison
_TOL = 0.1 if mp.is_single_precision() else 0.02
_TOL = 2e-2

mp.verbosity(0)

Expand Down Expand Up @@ -57,7 +57,15 @@ def build_straight_wg_simulation(
eig_kpoint=mp.Vector3(1, 0, 0),
size=mp.Vector3(0, wg_width + 2 * wg_padding, 0),
center=[-sx / 2 + pml_width + source_to_pml, 0, 0],
)
),
mp.EigenModeSource(
mp.GaussianSource(frequency=fmean, fwidth=fmean * gaussian_rel_width),
eig_band=1,
direction=mp.NO_DIRECTION,
eig_kpoint=mp.Vector3(-1, 0, 0),
size=mp.Vector3(0, wg_width + 2 * wg_padding, 0),
center=[sx / 2 - pml_width - source_to_pml, 0, 0],
),
]

nx = int(design_region_resolution * design_region_shape[0])
Expand Down Expand Up @@ -117,8 +125,7 @@ def build_straight_wg_simulation(
simulation,
mp.Volume(center=center, size=monitor_size),
mode=1,
forward=True,
decimation_factor=1) for center in monitor_centers
forward=forward) for center in monitor_centers for forward in [True, False]
]
return simulation, sources, monitors, design_regions, frequencies

Expand Down Expand Up @@ -167,14 +174,17 @@ def test_design_region_monitor_helpers(self):
self.assertEqual(value.dtype, onp.complex128)


class WrapperTest(ApproxComparisonTestCase):
class WrapperTest(VectorComparisonMixin, unittest.TestCase):

@parameterized.parameterized.expand([
('1500_1550bw_01relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0),
('1550_1600bw_02relative_gaussian', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0),
('1500_1600bw_03relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0),
('1500_1550bw_01relative_gaussian_port1', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 0),
# ('1550_1600bw_02relative_gaussian_port1', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 0),
# ('1500_1600bw_03relative_gaussian_port1', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 0),
('1500_1550bw_01relative_gaussian_port2', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 1),
# ('1550_1600bw_02relative_gaussian_port2', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 1),
# ('1500_1600bw_03relative_gaussian_port2', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 1),
])
def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_variable_fill_value):
def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_variable_fill_value, excite_port_idx):
"""Tests gradient from the JAX-Meep wrapper against finite differences."""
(
simulation,
Expand All @@ -184,31 +194,29 @@ def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_vari
frequencies,
) = build_straight_wg_simulation(frequencies=frequencies, gaussian_rel_width=gaussian_rel_width)

wrapped_meep = mpa.MeepJaxWrapper(
simulation,
sources,
monitors,
design_regions,
frequencies,
measurement_interval=50.0,
dft_field_components=(mp.Ez,),
dft_threshold=1e-6,
minimum_run_time=0,
maximum_run_time=onp.inf,
until_after_sources=True
)

design_shape = tuple(int(i) for i in design_regions[0].design_parameters.grid_size)[:2]
x = onp.ones(design_shape) * design_variable_fill_value

# Define a loss function
def loss_fn(x):
def loss_fn(x, excite_port_idx=0):
wrapped_meep = mpa.MeepJaxWrapper(
simulation,
[sources[excite_port_idx]],
monitors,
design_regions,
frequencies,
)
monitor_values = wrapped_meep([x])
t = monitor_values[1, :] / monitor_values[0, :]
s1p, s1m, s2m, s2p = monitor_values
if excite_port_idx == 0:
t = s2m / s1p
else:
t = s1m / s2p
# Mean transmission vs wavelength
return jnp.mean(jnp.square(jnp.abs(t)))
t_mean = jnp.mean(jnp.square(jnp.abs(t)))
return t_mean

value, adjoint_grad = jax.value_and_grad(loss_fn)(x)
value, adjoint_grad = jax.value_and_grad(loss_fn)(x, excite_port_idx=excite_port_idx)

projection = []
fd_projection = []
Expand All @@ -225,7 +233,7 @@ def loss_fn(x):
x_perturbed = x + random_perturbation_vector

# Calculate T(p + dp)
value_perturbed = loss_fn(x_perturbed)
value_perturbed = loss_fn(x_perturbed, excite_port_idx=excite_port_idx)

projection.append(
onp.dot(
Expand All @@ -238,7 +246,7 @@ def loss_fn(x):
fd_projection = onp.stack(fd_projection)

# Check that dp . ∇T ~ T(p + dp) - T(p)
self.assertClose(
self.assertVectorsClose(
projection,
fd_projection,
epsilon=_TOL,
Expand Down
4 changes: 3 additions & 1 deletion src/mpb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ void *fields::get_eigenmode(double frequency, direction d, const volume where, c
// which we automatically pick if kmatch == 0.
if (match_frequency && kmatch == 0) {
vec cen = eig_vol.center();
kmatch = frequency * sqrt(real(get_eps(cen, frequency)) * real(get_mu(cen, frequency)));
kmatch = copysign(frequency * sqrt(real(get_eps(cen, frequency)) * real(get_mu(cen, frequency))), _kpoint.in_direction(d));
if (d == NO_DIRECTION) {
for (int i = 0; i < 3; ++i)
k[i] = dot_product(R[i], kdir) * kmatch; // kdir*kmatch in reciprocal basis
Expand Down Expand Up @@ -831,6 +831,8 @@ void fields::add_eigenmode_source(component c0, const src_time &src, direction d
// electric current K = nHat \times H */
// magnetic current N = -nHat \times E */
/*--------------------------------------------------------------*/
if (global_eigenmode_data->group_velocity < 0)
amp *= -1; // equivalent to flipping the direction of nhat.
if (is_D(c0)) c0 = direction_component(Ex, component_direction(c0));
if (is_B(c0)) c0 = direction_component(Hx, component_direction(c0));
component cE[3] = {Ex, Ey, Ez}, cH[3] = {Hx, Hy, Hz};
Expand Down

0 comments on commit 27ec197

Please sign in to comment.