Skip to content

Commit

Permalink
Fix backward gradients in EigenmodeCoefficient and update JAX adjoint…
Browse files Browse the repository at this point in the history
… test to check gradients in both forward and backward directions
  • Loading branch information
ianwilliamson committed Jun 25, 2021
1 parent bcb90f3 commit 08d3e7f
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 50 deletions.
67 changes: 43 additions & 24 deletions python/adjoint/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,43 @@ def _create_time_profile(self, fwidth_frac=0.1):
)


_KPOINT_POS_DIR = mp.Vector3(0.0, 0.0, 0.0)
_KPOINT_NEG_DIR = mp.Vector3(-0.0, -0.0, -0.0)


class EigenmodeCoefficient(ObjectiveQuantity):
"""A frequency-dependent eigenmode coefficient.
Attributes:
volume: the volume over which the eigenmode coefficient is calculated.
mode: the eigenmode number.
forward: whether the forward or backward mode coefficient is returned as
the result of the evaluation.
kpoint_func: an optional k-point function to use when evaluating the eigenmode
coefficient. When specified, this overrides the effect of `forward`.
kpoint_func_overlap_idx: the index of the mode coefficient to return when
specifying `kpoint_func`. When specified, this overrides the effect of
`forward` and should have a value of either 0 or 1.
"""
def __init__(self,
sim,
volume,
mode,
forward=True,
kpoint_func=None,
kpoint_func_overlap_idx=0,
**kwargs):
super().__init__(sim)
if kpoint_func_overlap_idx not in [0, 1]:
raise ValueError(
'`kpoint_func_overlap_idx` should be either 0 or 1, but got %d'
% (kpoint_func_overlap_idx, ))
self.volume = volume
self.mode = mode
self.forward = forward
self.kpoint_func = kpoint_func
self.kpoint_func_overlap_idx = kpoint_func_overlap_idx
self.eigenmode_kwargs = kwargs
self._monitor = None
self._normal_direction = None
self._cscale = None

def register_monitors(self, frequencies):
Expand All @@ -135,28 +156,21 @@ def register_monitors(self, frequencies):
mp.ModeRegion(center=self.volume.center, size=self.volume.size),
yee_grid=True,
)
self._normal_direction = self._monitor.normal_direction
return self._monitor

def place_adjoint_source(self, dJ):
dJ = np.atleast_1d(dJ)
direction_scalar = -1 if self.forward else 1
time_src = self._create_time_profile()
if self.kpoint_func is None:
if self._normal_direction == 0:
k0 = direction_scalar * mp.Vector3(x=1)
elif self._normal_direction == 1:
k0 = direction_scalar * mp.Vector3(y=1)
elif self._normal_direction == 2:
k0 = direction_scalar * mp.Vector3(z=1)
else:
k0 = direction_scalar * self.kpoint_func(time_src.frequency, 1)
if dJ.ndim == 2:
dJ = np.sum(dJ, axis=1)
da_dE = 0.5 * self._cscale # scalar popping out of derivative

time_src = self._create_time_profile()
da_dE = 0.5 * self._cscale
scale = self._adj_src_scale()

if self.kpoint_func:
eig_kpoint = -1 * self.kpoint_func(time_src.frequency, self.mode)
else:
eig_kpoint = _KPOINT_NEG_DIR if self.forward else _KPOINT_POS_DIR

if self._frequencies.size == 1:
amp = da_dE * dJ * scale
src = time_src
Expand All @@ -169,12 +183,11 @@ def place_adjoint_source(self, dJ):
self.sim.fields.dt,
)
amp = 1

source = mp.EigenModeSource(
src,
eig_band=self.mode,
direction=mp.NO_DIRECTION,
eig_kpoint=k0,
direction=mp.AUTOMATIC,
eig_kpoint=eig_kpoint,
amplitude=amp,
eig_match_freq=True,
size=self.volume.size,
Expand All @@ -184,17 +197,23 @@ def place_adjoint_source(self, dJ):
return [source]

def __call__(self):
direction = mp.NO_DIRECTION if self.kpoint_func else mp.AUTOMATIC
if self.kpoint_func:
kpoint_func = self.kpoint_func
overlap_idx = self.kpoint_func_overlap_idx
else:
kpoint_func = lambda *not_used: _KPOINT_POS_DIR if self.forward else _KPOINT_NEG_DIR
overlap_idx = 0
ob = self.sim.get_eigenmode_coefficients(
self._monitor,
[self.mode],
direction=direction,
kpoint_func=self.kpoint_func,
direction=mp.AUTOMATIC,
kpoint_func=kpoint_func,
**self.eigenmode_kwargs,
)
# record eigenmode coefficients for scaling
self._eval = np.squeeze(ob.alpha[:, :, int(not self.forward)])
self._cscale = ob.cscale # pull scaling factor
overlaps = ob.alpha.squeeze(axis=0)
assert overlaps.ndim == 2
self._eval = overlaps[:, overlap_idx]
self._cscale = ob.cscale
return self._eval


Expand Down
59 changes: 34 additions & 25 deletions python/tests/test_adjoint_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ def build_straight_wg_simulation(
eig_kpoint=mp.Vector3(1, 0, 0),
size=mp.Vector3(0, wg_width + 2 * wg_padding, 0),
center=[-sx / 2 + pml_width + source_to_pml, 0, 0],
)
),
mp.EigenModeSource(
mp.GaussianSource(frequency=fmean, fwidth=fmean * gaussian_rel_width),
eig_band=1,
direction=mp.NO_DIRECTION,
eig_kpoint=mp.Vector3(-1, 0, 0),
size=mp.Vector3(0, wg_width + 2 * wg_padding, 0),
center=[sx / 2 - pml_width - source_to_pml, 0, 0],
),
]

nx = int(design_region_resolution * design_region_shape[0])
Expand Down Expand Up @@ -117,7 +125,7 @@ def build_straight_wg_simulation(
simulation,
mp.Volume(center=center, size=monitor_size),
mode=1,
forward=True) for center in monitor_centers
forward=forward) for center in monitor_centers for forward in [True, False]
]
return simulation, sources, monitors, design_regions, frequencies

Expand Down Expand Up @@ -169,11 +177,14 @@ def test_design_region_monitor_helpers(self):
class WrapperTest(VectorComparisonMixin, unittest.TestCase):

@parameterized.parameterized.expand([
('1500_1550bw_01relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0),
('1550_1600bw_02relative_gaussian', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0),
('1500_1600bw_03relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0),
('1500_1550bw_01relative_gaussian_port1', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 0),
# ('1550_1600bw_02relative_gaussian_port1', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 0),
# ('1500_1600bw_03relative_gaussian_port1', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 0),
('1500_1550bw_01relative_gaussian_port2', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 1),
# ('1550_1600bw_02relative_gaussian_port2', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 1),
# ('1500_1600bw_03relative_gaussian_port2', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 1),
])
def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_variable_fill_value):
def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_variable_fill_value, excite_port_idx):
"""Tests gradient from the JAX-Meep wrapper against finite differences."""
(
simulation,
Expand All @@ -183,31 +194,29 @@ def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_vari
frequencies,
) = build_straight_wg_simulation(frequencies=frequencies, gaussian_rel_width=gaussian_rel_width)

wrapped_meep = mpa.MeepJaxWrapper(
simulation,
sources,
monitors,
design_regions,
frequencies,
measurement_interval=50.0,
dft_field_components=(mp.Ez,),
dft_threshold=1e-6,
minimum_run_time=0,
maximum_run_time=onp.inf,
until_after_sources=True
)

design_shape = tuple(int(i) for i in design_regions[0].design_parameters.grid_size)[:2]
x = onp.ones(design_shape) * design_variable_fill_value

# Define a loss function
def loss_fn(x):
def loss_fn(x, excite_port_idx=0):
wrapped_meep = mpa.MeepJaxWrapper(
simulation,
[sources[excite_port_idx]],
monitors,
design_regions,
frequencies,
)
monitor_values = wrapped_meep([x])
t = monitor_values[1, :] / monitor_values[0, :]
s1p, s1m, s2m, s2p = monitor_values
if excite_port_idx == 0:
t = s2m / s1p
else:
t = s1m / s2p
# Mean transmission vs wavelength
return jnp.mean(jnp.square(jnp.abs(t)))
t_mean = jnp.mean(jnp.square(jnp.abs(t)))
return t_mean

value, adjoint_grad = jax.value_and_grad(loss_fn)(x)
value, adjoint_grad = jax.value_and_grad(loss_fn)(x, excite_port_idx=excite_port_idx)

projection = []
fd_projection = []
Expand All @@ -224,7 +233,7 @@ def loss_fn(x):
x_perturbed = x + random_perturbation_vector

# Calculate T(p + dp)
value_perturbed = loss_fn(x_perturbed)
value_perturbed = loss_fn(x_perturbed, excite_port_idx=excite_port_idx)

projection.append(
onp.dot(
Expand Down
4 changes: 3 additions & 1 deletion src/mpb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ void *fields::get_eigenmode(double frequency, direction d, const volume where, c
// which we automatically pick if kmatch == 0.
if (match_frequency && kmatch == 0) {
vec cen = eig_vol.center();
kmatch = frequency * sqrt(real(get_eps(cen, frequency)) * real(get_mu(cen, frequency)));
kmatch = copysign(frequency * sqrt(real(get_eps(cen, frequency)) * real(get_mu(cen, frequency))), _kpoint.in_direction(d));
if (d == NO_DIRECTION) {
for (int i = 0; i < 3; ++i)
k[i] = dot_product(R[i], kdir) * kmatch; // kdir*kmatch in reciprocal basis
Expand Down Expand Up @@ -829,6 +829,8 @@ void fields::add_eigenmode_source(component c0, const src_time &src, direction d
// electric current K = nHat \times H */
// magnetic current N = -nHat \times E */
/*--------------------------------------------------------------*/
if (global_eigenmode_data->group_velocity < 0)
amp *= -1; // equivalent to flipping the direction of nhat.
if (is_D(c0)) c0 = direction_component(Ex, component_direction(c0));
if (is_B(c0)) c0 = direction_component(Hx, component_direction(c0));
component cE[3] = {Ex, Ey, Ez}, cH[3] = {Hx, Hy, Hz};
Expand Down

0 comments on commit 08d3e7f

Please sign in to comment.