Skip to content

Commit

Permalink
Original fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ianwilliamson committed Sep 28, 2021
1 parent a5627e0 commit ee7e823
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
30 changes: 19 additions & 11 deletions python/adjoint/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def _adj_src_scale(self, include_resolution=True):
else:
# multi frequency simulations
scale = dV * iomega / adj_src_phase
# compensate for the fact that real fields take the real part of the current,
# which halves the Fourier amplitude at the positive frequency (Re[J] = (J + J*)/2)
if self.sim.using_real_fields():
scale *= 2
return scale

def _create_time_profile(self, fwidth_frac=0.1):
Expand All @@ -110,10 +114,6 @@ def _create_time_profile(self, fwidth_frac=0.1):
)


_KPOINT_POS_DIR = mp.Vector3(0.0, 0.0, 0.0)
_KPOINT_NEG_DIR = mp.Vector3(-0.0, -0.0, -0.0)


class EigenmodeCoefficient(ObjectiveQuantity):
"""A frequency-dependent eigenmode coefficient.
Attributes:
Expand All @@ -134,6 +134,7 @@ def __init__(self,
forward=True,
kpoint_func=None,
kpoint_func_overlap_idx=0,
decimation_factor=0,
**kwargs):
super().__init__(sim)
if kpoint_func_overlap_idx not in [0, 1]:
Expand All @@ -148,13 +149,15 @@ def __init__(self,
self.eigenmode_kwargs = kwargs
self._monitor = None
self._cscale = None
self.decimation_factor = decimation_factor

def register_monitors(self, frequencies):
self._frequencies = np.asarray(frequencies)
self._monitor = self.sim.add_mode_monitor(
frequencies,
mp.ModeRegion(center=self.volume.center, size=self.volume.size),
yee_grid=True,
decimation_factor=self.decimation_factor,
)
return self._monitor

Expand All @@ -169,7 +172,8 @@ def place_adjoint_source(self, dJ):
if self.kpoint_func:
eig_kpoint = -1 * self.kpoint_func(time_src.frequency, self.mode)
else:
eig_kpoint = _KPOINT_NEG_DIR if self.forward else _KPOINT_POS_DIR
direction = mp.Vector3(*[float(v == 0) for v in self.volume.size])
eig_kpoint = -1 * direction if self.forward else direction

if self._frequencies.size == 1:
amp = da_dE * dJ * scale
Expand Down Expand Up @@ -201,7 +205,8 @@ def __call__(self):
kpoint_func = self.kpoint_func
overlap_idx = self.kpoint_func_overlap_idx
else:
kpoint_func = lambda *not_used: _KPOINT_POS_DIR if self.forward else _KPOINT_NEG_DIR
direction = mp.Vector3(*[float(v == 0) for v in self.volume.size])
kpoint_func = lambda *not_used: direction if self.forward else -1 * direction
overlap_idx = 0
ob = self.sim.get_eigenmode_coefficients(
self._monitor,
Expand All @@ -218,10 +223,11 @@ def __call__(self):


class FourierFields(ObjectiveQuantity):
def __init__(self, sim, volume, component):
def __init__(self, sim, volume, component, decimation_factor=0):
super().__init__(sim)
self.volume = volume
self.component = component
self.decimation_factor = decimation_factor

def register_monitors(self, frequencies):
self._frequencies = np.asarray(frequencies)
Expand All @@ -230,6 +236,7 @@ def register_monitors(self, frequencies):
self._frequencies,
where=self.volume,
yee_grid=False,
decimation_factor=self.decimation_factor,
)
return self._monitor

Expand Down Expand Up @@ -281,7 +288,7 @@ def place_adjoint_source(self, dJ):
for yi in range(y_dim):
for xi in range(x_dim):
'''We only need to add a current source if the
jacobian is nonzero for all frequencies at
jacobian is nonzero for all frequencies at
that particular point. Otherwise, the fitting
algorithm is going to fail.
'''
Expand Down Expand Up @@ -316,18 +323,19 @@ def __call__(self):


class Near2FarFields(ObjectiveQuantity):
def __init__(self, sim, Near2FarRegions, far_pts):
def __init__(self, sim, Near2FarRegions, far_pts, decimation_factor=0):
super().__init__(sim)
self.Near2FarRegions = Near2FarRegions
self.far_pts = far_pts #list of far pts
self._nfar_pts = len(far_pts)
self.decimation_factor = decimation_factor

def register_monitors(self, frequencies):
self._frequencies = np.asarray(frequencies)
self._monitor = self.sim.add_near2far(
self._frequencies,
*self.Near2FarRegions,
yee_grid=True,
decimation_factor=self.decimation_factor,
)
return self._monitor

Expand Down Expand Up @@ -359,7 +367,7 @@ def place_adjoint_source(self, dJ):
time_src.frequency,
self._frequencies,
scale,
dt,
self.sim.fields.dt,
)
(num_basis, num_pts) = src.nodes.shape
for basis_i in range(num_basis):
Expand Down
19 changes: 10 additions & 9 deletions python/tests/test_adjoint_jax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import parameterized

from utils import VectorComparisonMixin
from utils import ApproxComparisonTestCase

import jax
import jax.numpy as jnp
Expand All @@ -16,7 +16,7 @@
_FD_STEP = 1e-4

# The tolerance for the adjoint and finite difference gradient comparison
_TOL = 2e-2
_TOL = 0.1 if mp.is_single_precision() else 0.02

mp.verbosity(0)

Expand Down Expand Up @@ -125,7 +125,8 @@ def build_straight_wg_simulation(
simulation,
mp.Volume(center=center, size=monitor_size),
mode=1,
forward=forward) for center in monitor_centers for forward in [True, False]
forward=True,
decimation_factor=1) for center in monitor_centers for forward in [True, False]
]
return simulation, sources, monitors, design_regions, frequencies

Expand Down Expand Up @@ -174,15 +175,15 @@ def test_design_region_monitor_helpers(self):
self.assertEqual(value.dtype, onp.complex128)


class WrapperTest(VectorComparisonMixin, unittest.TestCase):
class WrapperTest(ApproxComparisonTestCase):

@parameterized.parameterized.expand([
('1500_1550bw_01relative_gaussian_port1', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 0),
# ('1550_1600bw_02relative_gaussian_port1', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 0),
# ('1500_1600bw_03relative_gaussian_port1', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 0),
('1550_1600bw_02relative_gaussian_port1', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 0),
('1500_1600bw_03relative_gaussian_port1', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 0),
('1500_1550bw_01relative_gaussian_port2', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0, 1),
# ('1550_1600bw_02relative_gaussian_port2', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 1),
# ('1500_1600bw_03relative_gaussian_port2', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 1),
('1550_1600bw_02relative_gaussian_port2', onp.linspace(1 / 1.55, 1 / 1.60, 3).tolist(), 0.2, 1.0, 1),
('1500_1600bw_03relative_gaussian_port2', onp.linspace(1 / 1.50, 1 / 1.60, 4).tolist(), 0.3, 1.0, 1),
])
def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_variable_fill_value, excite_port_idx):
"""Tests gradient from the JAX-Meep wrapper against finite differences."""
Expand Down Expand Up @@ -246,7 +247,7 @@ def loss_fn(x, excite_port_idx=0):
fd_projection = onp.stack(fd_projection)

# Check that dp . ∇T ~ T(p + dp) - T(p)
self.assertVectorsClose(
self.assertClose(
projection,
fd_projection,
epsilon=_TOL,
Expand Down
5 changes: 1 addition & 4 deletions src/mpb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ void *fields::get_eigenmode(double frequency, direction d, const volume where, c
// which we automatically pick if kmatch == 0.
if (match_frequency && kmatch == 0) {
vec cen = eig_vol.center();
kmatch = copysign(frequency * sqrt(real(get_eps(cen, frequency)) * real(get_mu(cen, frequency))), _kpoint.in_direction(d));
kmatch = frequency * sqrt(real(get_eps(cen, frequency)) * real(get_mu(cen, frequency)));
if (d == NO_DIRECTION) {
for (int i = 0; i < 3; ++i)
k[i] = dot_product(R[i], kdir) * kmatch; // kdir*kmatch in reciprocal basis
Expand Down Expand Up @@ -831,9 +831,6 @@ void fields::add_eigenmode_source(component c0, const src_time &src, direction d
// electric current K = nHat \times H */
// magnetic current N = -nHat \times E */
/*--------------------------------------------------------------*/
if (global_eigenmode_data->group_velocity < 0)
master_printf("vg is negative!\n");
amp *= -1; // equivalent to flipping the direction of nhat.
if (is_D(c0)) c0 = direction_component(Ex, component_direction(c0));
if (is_B(c0)) c0 = direction_component(Hx, component_direction(c0));
component cE[3] = {Ex, Ey, Ez}, cH[3] = {Hx, Hy, Hz};
Expand Down

0 comments on commit ee7e823

Please sign in to comment.