Skip to content

Commit

Permalink
Fix imports and move mpa.jax.* to mpa.* namespace, and fix Makefile.am (
Browse files Browse the repository at this point in the history
  • Loading branch information
ianwilliamson authored Jun 3, 2021
1 parent 4978324 commit cde677a
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 34 deletions.
5 changes: 2 additions & 3 deletions python/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,8 @@ adjoint_PYTHON = $(srcdir)/adjoint/__init__.py \
$(srcdir)/adjoint/optimization_problem.py \
$(srcdir)/adjoint/filters.py \
$(srcdir)/adjoint/filter_source.py \
$(srcdir)/adjoint/jax/__init__.py \
$(srcdir)/adjoint/jax/wrapper.py \
$(srcdir)/adjoint/jax/utils.py
$(srcdir)/adjoint/wrapper.py \
$(srcdir)/adjoint/utils.py

######################################################################
# finally, specification of what gets installed in the meep python
Expand Down
11 changes: 5 additions & 6 deletions python/adjoint/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""
Adjoint-based sensitivity-analysis module for pymeep.
Authors: Homer Reid <homer@homerreid.com>, Alec Hammond <alec.hammond@gatech.edu>
Authors: Homer Reid <homer@homerreid.com>, Alec Hammond <alec.hammond@gatech.edu>, Ian Williamson <iwill@google.com>
"""
import sys

import meep as mp

from .objective import *

Expand All @@ -16,7 +13,9 @@

from .filters import *

from . import utils

try:
from . import jax
except ImportError as error:
from .wrapper import MeepJaxWrapper
except ModuleNotFoundError as _:
pass
7 changes: 0 additions & 7 deletions python/adjoint/jax/__init__.py

This file was deleted.

15 changes: 8 additions & 7 deletions python/adjoint/jax/utils.py → python/adjoint/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Iterable, Tuple

import meep as mp
import meep.adjoint as mpa
import numpy as onp

from . import ObjectiveQuantitiy, DesignRegion

# Meep field components used to compute adjoint sensitivities
_ADJOINT_FIELD_COMPONENTS = [mp.Ex, mp.Ey, mp.Ez]

Expand All @@ -18,7 +19,7 @@ def _make_at_least_nd(x: onp.ndarray, dims: int = 3) -> onp.ndarray:

def calculate_vjps(
simulation: mp.Simulation,
design_regions: List[mpa.DesignRegion],
design_regions: List[DesignRegion],
frequencies: List[float],
fwd_fields: List[List[onp.ndarray]],
adj_fields: List[List[onp.ndarray]],
Expand All @@ -42,7 +43,7 @@ def calculate_vjps(


def register_monitors(
monitors: List[mpa.ObjectiveQuantitiy],
monitors: List[ObjectiveQuantitiy],
frequencies: List[float],
) -> None:
"""Registers a list of monitors."""
Expand All @@ -52,7 +53,7 @@ def register_monitors(

def install_design_region_monitors(
simulation: mp.Simulation,
design_regions: List[mpa.DesignRegion],
design_regions: List[DesignRegion],
frequencies: List[float],
) -> List[mp.DftFields]:
"""Installs DFT field monitors at the design regions of the simulation."""
Expand All @@ -67,7 +68,7 @@ def install_design_region_monitors(
return design_region_monitors


def gather_monitor_values(monitors: List[mpa.ObjectiveQuantitiy]) -> onp.ndarray:
def gather_monitor_values(monitors: List[ObjectiveQuantitiy]) -> onp.ndarray:
"""Gathers the mode monitor overlap values as a rank 2 ndarray.
Args:
Expand Down Expand Up @@ -122,7 +123,7 @@ def gather_design_region_fields(
return fwd_fields


def validate_and_update_design(design_regions: List[mpa.DesignRegion], design_variables: Iterable[onp.ndarray]) -> None:
def validate_and_update_design(design_regions: List[DesignRegion], design_variables: Iterable[onp.ndarray]) -> None:
"""Validate the design regions and variables.
In particular the design variable should be 1,2,3-D and the design region
Expand Down Expand Up @@ -154,7 +155,7 @@ def validate_and_update_design(design_regions: List[mpa.DesignRegion], design_va
design_region.update_design_parameters(design_variable.flatten())


def create_adjoint_sources(monitors: mpa.ObjectiveQuantitiy, monitor_values_grad: onp.ndarray) -> List[mp.Source]:
def create_adjoint_sources(monitors: ObjectiveQuantitiy, monitor_values_grad: onp.ndarray) -> List[mp.Source]:
monitor_values_grad = onp.asarray(monitor_values_grad, dtype=onp.complex128)
if not onp.any(monitor_values_grad):
raise RuntimeError('The gradient of all monitor values is zero, which '
Expand Down
6 changes: 3 additions & 3 deletions python/adjoint/jax/wrapper.py → python/adjoint/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def loss(x):
import jax
import jax.numpy as jnp
import meep as mp
import meep.adjoint as mpa
import numpy as onp

from . import utils
from . import DesignRegion, EigenmodeCoefficient

_norm_fn = onp.linalg.norm
_reduce_fn = onp.max
Expand Down Expand Up @@ -95,8 +95,8 @@ class MeepJaxWrapper:
def __init__(self,
simulation: mp.Simulation,
sources: List[mp.Source],
monitors: List[mpa.EigenmodeCoefficient],
design_regions: List[mpa.DesignRegion],
monitors: List[EigenmodeCoefficient],
design_regions: List[DesignRegion],
frequencies: List[float],
measurement_interval: float = 50.0,
dft_field_components: Tuple[int, ...] = (mp.Ez,),
Expand Down
16 changes: 8 additions & 8 deletions python/tests/adjoint_jax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import parameterized

from . import utils
from utils import VectorComparisonMixin

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -134,21 +134,21 @@ def setUp(self):
) = build_straight_wg_simulation()

def test_mode_monitor_helpers(self):
mpa.jax.utils.register_monitors(self.monitors, self.frequencies)
mpa.utils.register_monitors(self.monitors, self.frequencies)
self.simulation.run(until=100)
monitor_values = mpa.jax.utils.gather_monitor_values(self.monitors)
monitor_values = mpa.utils.gather_monitor_values(self.monitors)
self.assertEqual(monitor_values.dtype, onp.complex128)
self.assertEqual(monitor_values.shape,
(len(self.monitors), len(self.frequencies)))

def test_design_region_monitor_helpers(self):
design_region_monitors = mpa.jax.utils.install_design_region_monitors(
design_region_monitors = mpa.utils.install_design_region_monitors(
self.simulation,
self.design_regions,
self.frequencies,
)
self.simulation.run(until=100)
design_region_fields = mpa.jax.utils.gather_design_region_fields(
design_region_fields = mpa.utils.gather_design_region_fields(
self.simulation,
design_region_monitors,
self.frequencies,
Expand All @@ -158,15 +158,15 @@ def test_design_region_monitor_helpers(self):
self.assertEqual(len(design_region_fields), len(self.design_regions))

self.assertIsInstance(design_region_fields[0], list)
self.assertEqual(len(design_region_fields[0]), len(mpa.jax.utils._ADJOINT_FIELD_COMPONENTS))
self.assertEqual(len(design_region_fields[0]), len(mpa.utils._ADJOINT_FIELD_COMPONENTS))

for value in design_region_fields[0]:
self.assertIsInstance(value, onp.ndarray)
self.assertEqual(value.ndim, 4) # dims: freq, x, y, pad
self.assertEqual(value.dtype, onp.complex128)


class WrapperTest(utils.VectorComparisonMixin, unittest.TestCase):
class WrapperTest(VectorComparisonMixin, unittest.TestCase):

@parameterized.parameterized.expand([
('1500_1550bw_01relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0),
Expand All @@ -183,7 +183,7 @@ def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_vari
frequencies,
) = build_straight_wg_simulation(frequencies=frequencies, gaussian_rel_width=gaussian_rel_width)

wrapped_meep = mpa.jax.MeepJaxWrapper(
wrapped_meep = mpa.MeepJaxWrapper(
simulation,
sources,
monitors,
Expand Down

0 comments on commit cde677a

Please sign in to comment.