Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix circular dependency in imports and change namespace of JAX adjoint components #1586

Merged
merged 1 commit into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,8 @@ adjoint_PYTHON = $(srcdir)/adjoint/__init__.py \
$(srcdir)/adjoint/optimization_problem.py \
$(srcdir)/adjoint/filters.py \
$(srcdir)/adjoint/filter_source.py \
$(srcdir)/adjoint/jax/__init__.py \
$(srcdir)/adjoint/jax/wrapper.py \
$(srcdir)/adjoint/jax/utils.py
$(srcdir)/adjoint/wrapper.py \
$(srcdir)/adjoint/utils.py

######################################################################
# finally, specification of what gets installed in the meep python
Expand Down
11 changes: 5 additions & 6 deletions python/adjoint/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""
Adjoint-based sensitivity-analysis module for pymeep.
Authors: Homer Reid <homer@homerreid.com>, Alec Hammond <alec.hammond@gatech.edu>
Authors: Homer Reid <homer@homerreid.com>, Alec Hammond <alec.hammond@gatech.edu>, Ian Williamson <iwill@google.com>
"""
import sys

import meep as mp

from .objective import *

Expand All @@ -16,7 +13,9 @@

from .filters import *

from . import utils

try:
from . import jax
except ImportError as error:
from .wrapper import MeepJaxWrapper
except ModuleNotFoundError as _:
pass
7 changes: 0 additions & 7 deletions python/adjoint/jax/__init__.py

This file was deleted.

15 changes: 8 additions & 7 deletions python/adjoint/jax/utils.py → python/adjoint/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Iterable, Tuple

import meep as mp
import meep.adjoint as mpa
import numpy as onp

from . import ObjectiveQuantitiy, DesignRegion

# Meep field components used to compute adjoint sensitivities
_ADJOINT_FIELD_COMPONENTS = [mp.Ex, mp.Ey, mp.Ez]

Expand All @@ -18,7 +19,7 @@ def _make_at_least_nd(x: onp.ndarray, dims: int = 3) -> onp.ndarray:

def calculate_vjps(
simulation: mp.Simulation,
design_regions: List[mpa.DesignRegion],
design_regions: List[DesignRegion],
frequencies: List[float],
fwd_fields: List[List[onp.ndarray]],
adj_fields: List[List[onp.ndarray]],
Expand All @@ -42,7 +43,7 @@ def calculate_vjps(


def register_monitors(
monitors: List[mpa.ObjectiveQuantitiy],
monitors: List[ObjectiveQuantitiy],
frequencies: List[float],
) -> None:
"""Registers a list of monitors."""
Expand All @@ -52,7 +53,7 @@ def register_monitors(

def install_design_region_monitors(
simulation: mp.Simulation,
design_regions: List[mpa.DesignRegion],
design_regions: List[DesignRegion],
frequencies: List[float],
) -> List[mp.DftFields]:
"""Installs DFT field monitors at the design regions of the simulation."""
Expand All @@ -67,7 +68,7 @@ def install_design_region_monitors(
return design_region_monitors


def gather_monitor_values(monitors: List[mpa.ObjectiveQuantitiy]) -> onp.ndarray:
def gather_monitor_values(monitors: List[ObjectiveQuantitiy]) -> onp.ndarray:
"""Gathers the mode monitor overlap values as a rank 2 ndarray.
Args:
Expand Down Expand Up @@ -122,7 +123,7 @@ def gather_design_region_fields(
return fwd_fields


def validate_and_update_design(design_regions: List[mpa.DesignRegion], design_variables: Iterable[onp.ndarray]) -> None:
def validate_and_update_design(design_regions: List[DesignRegion], design_variables: Iterable[onp.ndarray]) -> None:
"""Validate the design regions and variables.
In particular the design variable should be 1,2,3-D and the design region
Expand Down Expand Up @@ -154,7 +155,7 @@ def validate_and_update_design(design_regions: List[mpa.DesignRegion], design_va
design_region.update_design_parameters(design_variable.flatten())


def create_adjoint_sources(monitors: mpa.ObjectiveQuantitiy, monitor_values_grad: onp.ndarray) -> List[mp.Source]:
def create_adjoint_sources(monitors: ObjectiveQuantitiy, monitor_values_grad: onp.ndarray) -> List[mp.Source]:
monitor_values_grad = onp.asarray(monitor_values_grad, dtype=onp.complex128)
if not onp.any(monitor_values_grad):
raise RuntimeError('The gradient of all monitor values is zero, which '
Expand Down
6 changes: 3 additions & 3 deletions python/adjoint/jax/wrapper.py → python/adjoint/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def loss(x):
import jax
import jax.numpy as jnp
import meep as mp
import meep.adjoint as mpa
import numpy as onp

from . import utils
from . import DesignRegion, EigenmodeCoefficient

_norm_fn = onp.linalg.norm
_reduce_fn = onp.max
Expand Down Expand Up @@ -95,8 +95,8 @@ class MeepJaxWrapper:
def __init__(self,
simulation: mp.Simulation,
sources: List[mp.Source],
monitors: List[mpa.EigenmodeCoefficient],
design_regions: List[mpa.DesignRegion],
monitors: List[EigenmodeCoefficient],
design_regions: List[DesignRegion],
frequencies: List[float],
measurement_interval: float = 50.0,
dft_field_components: Tuple[int, ...] = (mp.Ez,),
Expand Down
16 changes: 8 additions & 8 deletions python/tests/adjoint_jax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import parameterized

from . import utils
from utils import VectorComparisonMixin

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -134,21 +134,21 @@ def setUp(self):
) = build_straight_wg_simulation()

def test_mode_monitor_helpers(self):
mpa.jax.utils.register_monitors(self.monitors, self.frequencies)
mpa.utils.register_monitors(self.monitors, self.frequencies)
self.simulation.run(until=100)
monitor_values = mpa.jax.utils.gather_monitor_values(self.monitors)
monitor_values = mpa.utils.gather_monitor_values(self.monitors)
self.assertEqual(monitor_values.dtype, onp.complex128)
self.assertEqual(monitor_values.shape,
(len(self.monitors), len(self.frequencies)))

def test_design_region_monitor_helpers(self):
design_region_monitors = mpa.jax.utils.install_design_region_monitors(
design_region_monitors = mpa.utils.install_design_region_monitors(
self.simulation,
self.design_regions,
self.frequencies,
)
self.simulation.run(until=100)
design_region_fields = mpa.jax.utils.gather_design_region_fields(
design_region_fields = mpa.utils.gather_design_region_fields(
self.simulation,
design_region_monitors,
self.frequencies,
Expand All @@ -158,15 +158,15 @@ def test_design_region_monitor_helpers(self):
self.assertEqual(len(design_region_fields), len(self.design_regions))

self.assertIsInstance(design_region_fields[0], list)
self.assertEqual(len(design_region_fields[0]), len(mpa.jax.utils._ADJOINT_FIELD_COMPONENTS))
self.assertEqual(len(design_region_fields[0]), len(mpa.utils._ADJOINT_FIELD_COMPONENTS))

for value in design_region_fields[0]:
self.assertIsInstance(value, onp.ndarray)
self.assertEqual(value.ndim, 4) # dims: freq, x, y, pad
self.assertEqual(value.dtype, onp.complex128)


class WrapperTest(utils.VectorComparisonMixin, unittest.TestCase):
class WrapperTest(VectorComparisonMixin, unittest.TestCase):

@parameterized.parameterized.expand([
('1500_1550bw_01relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0),
Expand All @@ -183,7 +183,7 @@ def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_vari
frequencies,
) = build_straight_wg_simulation(frequencies=frequencies, gaussian_rel_width=gaussian_rel_width)

wrapped_meep = mpa.jax.MeepJaxWrapper(
wrapped_meep = mpa.MeepJaxWrapper(
simulation,
sources,
monitors,
Expand Down