diff --git a/python/Makefile.am b/python/Makefile.am index 95aaa66ba..58b125832 100644 --- a/python/Makefile.am +++ b/python/Makefile.am @@ -220,9 +220,8 @@ adjoint_PYTHON = $(srcdir)/adjoint/__init__.py \ $(srcdir)/adjoint/optimization_problem.py \ $(srcdir)/adjoint/filters.py \ $(srcdir)/adjoint/filter_source.py \ - $(srcdir)/adjoint/jax/__init__.py \ - $(srcdir)/adjoint/jax/wrapper.py \ - $(srcdir)/adjoint/jax/utils.py + $(srcdir)/adjoint/wrapper.py \ + $(srcdir)/adjoint/utils.py ###################################################################### # finally, specification of what gets installed in the meep python diff --git a/python/adjoint/__init__.py b/python/adjoint/__init__.py index 4aad1400d..18b019323 100644 --- a/python/adjoint/__init__.py +++ b/python/adjoint/__init__.py @@ -1,10 +1,7 @@ """ Adjoint-based sensitivity-analysis module for pymeep. -Authors: Homer Reid , Alec Hammond +Authors: Homer Reid , Alec Hammond , Ian Williamson """ -import sys - -import meep as mp from .objective import * @@ -16,7 +13,9 @@ from .filters import * +from . import utils + try: - from . import jax -except ImportError as error: + from .wrapper import MeepJaxWrapper +except ModuleNotFoundError as _: pass \ No newline at end of file diff --git a/python/adjoint/jax/__init__.py b/python/adjoint/jax/__init__.py deleted file mode 100644 index f502f3395..000000000 --- a/python/adjoint/jax/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Interface for composing Meep with JAX. - -""" - -from .wrapper import MeepJaxWrapper - -from . import utils diff --git a/python/adjoint/jax/utils.py b/python/adjoint/utils.py similarity index 93% rename from python/adjoint/jax/utils.py rename to python/adjoint/utils.py index 13919e813..1f1ef0ada 100644 --- a/python/adjoint/jax/utils.py +++ b/python/adjoint/utils.py @@ -1,9 +1,10 @@ from typing import List, Iterable, Tuple import meep as mp -import meep.adjoint as mpa import numpy as onp +from . import ObjectiveQuantitiy, DesignRegion + # Meep field components used to compute adjoint sensitivities _ADJOINT_FIELD_COMPONENTS = [mp.Ex, mp.Ey, mp.Ez] @@ -18,7 +19,7 @@ def _make_at_least_nd(x: onp.ndarray, dims: int = 3) -> onp.ndarray: def calculate_vjps( simulation: mp.Simulation, - design_regions: List[mpa.DesignRegion], + design_regions: List[DesignRegion], frequencies: List[float], fwd_fields: List[List[onp.ndarray]], adj_fields: List[List[onp.ndarray]], @@ -42,7 +43,7 @@ def calculate_vjps( def register_monitors( - monitors: List[mpa.ObjectiveQuantitiy], + monitors: List[ObjectiveQuantitiy], frequencies: List[float], ) -> None: """Registers a list of monitors.""" @@ -52,7 +53,7 @@ def register_monitors( def install_design_region_monitors( simulation: mp.Simulation, - design_regions: List[mpa.DesignRegion], + design_regions: List[DesignRegion], frequencies: List[float], ) -> List[mp.DftFields]: """Installs DFT field monitors at the design regions of the simulation.""" @@ -67,7 +68,7 @@ def install_design_region_monitors( return design_region_monitors -def gather_monitor_values(monitors: List[mpa.ObjectiveQuantitiy]) -> onp.ndarray: +def gather_monitor_values(monitors: List[ObjectiveQuantitiy]) -> onp.ndarray: """Gathers the mode monitor overlap values as a rank 2 ndarray. Args: @@ -122,7 +123,7 @@ def gather_design_region_fields( return fwd_fields -def validate_and_update_design(design_regions: List[mpa.DesignRegion], design_variables: Iterable[onp.ndarray]) -> None: +def validate_and_update_design(design_regions: List[DesignRegion], design_variables: Iterable[onp.ndarray]) -> None: """Validate the design regions and variables. In particular the design variable should be 1,2,3-D and the design region @@ -154,7 +155,7 @@ def validate_and_update_design(design_regions: List[mpa.DesignRegion], design_va design_region.update_design_parameters(design_variable.flatten()) -def create_adjoint_sources(monitors: mpa.ObjectiveQuantitiy, monitor_values_grad: onp.ndarray) -> List[mp.Source]: +def create_adjoint_sources(monitors: ObjectiveQuantitiy, monitor_values_grad: onp.ndarray) -> List[mp.Source]: monitor_values_grad = onp.asarray(monitor_values_grad, dtype=onp.complex128) if not onp.any(monitor_values_grad): raise RuntimeError('The gradient of all monitor values is zero, which ' diff --git a/python/adjoint/jax/wrapper.py b/python/adjoint/wrapper.py similarity index 98% rename from python/adjoint/jax/wrapper.py rename to python/adjoint/wrapper.py index 523c3e49b..282fbd657 100644 --- a/python/adjoint/jax/wrapper.py +++ b/python/adjoint/wrapper.py @@ -52,10 +52,10 @@ def loss(x): import jax import jax.numpy as jnp import meep as mp -import meep.adjoint as mpa import numpy as onp from . import utils +from . import DesignRegion, EigenmodeCoefficient _norm_fn = onp.linalg.norm _reduce_fn = onp.max @@ -95,8 +95,8 @@ class MeepJaxWrapper: def __init__(self, simulation: mp.Simulation, sources: List[mp.Source], - monitors: List[mpa.EigenmodeCoefficient], - design_regions: List[mpa.DesignRegion], + monitors: List[EigenmodeCoefficient], + design_regions: List[DesignRegion], frequencies: List[float], measurement_interval: float = 50.0, dft_field_components: Tuple[int, ...] = (mp.Ez,), diff --git a/python/tests/adjoint_jax.py b/python/tests/adjoint_jax.py index 188b27bf6..2e1b7a92f 100644 --- a/python/tests/adjoint_jax.py +++ b/python/tests/adjoint_jax.py @@ -1,7 +1,7 @@ import unittest import parameterized -from . import utils +from utils import VectorComparisonMixin import jax import jax.numpy as jnp @@ -134,21 +134,21 @@ def setUp(self): ) = build_straight_wg_simulation() def test_mode_monitor_helpers(self): - mpa.jax.utils.register_monitors(self.monitors, self.frequencies) + mpa.utils.register_monitors(self.monitors, self.frequencies) self.simulation.run(until=100) - monitor_values = mpa.jax.utils.gather_monitor_values(self.monitors) + monitor_values = mpa.utils.gather_monitor_values(self.monitors) self.assertEqual(monitor_values.dtype, onp.complex128) self.assertEqual(monitor_values.shape, (len(self.monitors), len(self.frequencies))) def test_design_region_monitor_helpers(self): - design_region_monitors = mpa.jax.utils.install_design_region_monitors( + design_region_monitors = mpa.utils.install_design_region_monitors( self.simulation, self.design_regions, self.frequencies, ) self.simulation.run(until=100) - design_region_fields = mpa.jax.utils.gather_design_region_fields( + design_region_fields = mpa.utils.gather_design_region_fields( self.simulation, design_region_monitors, self.frequencies, @@ -158,7 +158,7 @@ def test_design_region_monitor_helpers(self): self.assertEqual(len(design_region_fields), len(self.design_regions)) self.assertIsInstance(design_region_fields[0], list) - self.assertEqual(len(design_region_fields[0]), len(mpa.jax.utils._ADJOINT_FIELD_COMPONENTS)) + self.assertEqual(len(design_region_fields[0]), len(mpa.utils._ADJOINT_FIELD_COMPONENTS)) for value in design_region_fields[0]: self.assertIsInstance(value, onp.ndarray) @@ -166,7 +166,7 @@ def test_design_region_monitor_helpers(self): self.assertEqual(value.dtype, onp.complex128) -class WrapperTest(utils.VectorComparisonMixin, unittest.TestCase): +class WrapperTest(VectorComparisonMixin, unittest.TestCase): @parameterized.parameterized.expand([ ('1500_1550bw_01relative_gaussian', onp.linspace(1 / 1.50, 1 / 1.55, 3).tolist(), 0.1, 1.0), @@ -183,7 +183,7 @@ def test_wrapper_gradients(self, _, frequencies, gaussian_rel_width, design_vari frequencies, ) = build_straight_wg_simulation(frequencies=frequencies, gaussian_rel_width=gaussian_rel_width) - wrapped_meep = mpa.jax.MeepJaxWrapper( + wrapped_meep = mpa.MeepJaxWrapper( simulation, sources, monitors,