Skip to content

Commit

Permalink
Replace obsolete routine compare_arrays with assertClose in MPB u…
Browse files Browse the repository at this point in the history
…nit tests (#2547)

* replace obsolete routine compare_arrays with assertClose in MPB unit tests

* replace numpy.testing_allclose with assertClose
  • Loading branch information
oskooi authored Jun 22, 2023
1 parent d23e342 commit 0aabdcc
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 36 deletions.
37 changes: 19 additions & 18 deletions python/tests/test_mpb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import glob
import math
import os
import re
Expand All @@ -7,16 +6,18 @@
import unittest

import h5py
import numpy as np
from scipy.optimize import minimize_scalar, ridder
from utils import compare_arrays

import meep as mp
import numpy as np
from meep import mpb
from scipy.optimize import minimize_scalar, ridder
from utils import ApproxComparisonTestCase


@unittest.skipIf(os.getenv("MEEP_SKIP_LARGE_TESTS", False), "skipping large tests")
class TestModeSolver(unittest.TestCase):
@unittest.skipIf(
os.getenv("MEEP_SKIP_LARGE_TESTS", False),
"skipping large tests",
)
class TestModeSolver(ApproxComparisonTestCase):

data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "data"))
examples_dir = os.path.abspath(
Expand Down Expand Up @@ -224,7 +225,7 @@ def check_fields_against_h5(self, ref_path, field, suffix=""):
ref_arr[1::3] = ref_y.ravel()
ref_arr[2::3] = ref_z.ravel()

compare_arrays(self, ref_arr, field)
self.assertClose(ref_arr, field, epsilon=1e-4)

def compare_h5_files(self, ref_path, res_path, tol=1e-3):
with h5py.File(ref_path) as ref:
Expand All @@ -233,7 +234,7 @@ def compare_h5_files(self, ref_path, res_path, tol=1e-3):
if k == "description":
self.assertEqual(ref[k][()], res[k][()])
else:
compare_arrays(self, ref[k][()], res[k][()], tol=tol)
self.assertClose(ref[k][()], res[k][()], epsilon=1e-3)

def test_update_band_range_data(self):
brd = []
Expand Down Expand Up @@ -526,7 +527,7 @@ def test_compute_field_energy(self):

expected_energy_in_dielectric = 0.6990769686037558

compare_arrays(self, np.array(expected_energy), np.array(energy))
self.assertClose(np.array(expected_energy), np.array(energy), epsilon=1e-4)
self.assertAlmostEqual(
expected_energy_in_dielectric, energy_in_dielectric, places=3
)
Expand Down Expand Up @@ -913,7 +914,7 @@ def get_dpwr(ms, band):
with h5py.File(ref_path, "r") as f:
expected = f["data-new"][()]

compare_arrays(self, expected, converted_dpwr[-1])
self.assertClose(expected, converted_dpwr[-1], epsilon=1e-3)

def test_hole_slab(self):
from mpb_hole_slab import ms
Expand Down Expand Up @@ -1389,7 +1390,7 @@ def test_tri_rods(self):
expected_re = f["z.r-new"][()]
expected_im = f["z.i-new"][()]
expected = np.vectorize(complex)(expected_re, expected_im)
compare_arrays(self, expected, new_efield)
self.assertClose(expected, new_efield, epsilon=1e-4)

ms.run_te()

Expand Down Expand Up @@ -1460,7 +1461,7 @@ def test_tri_rods(self):

with h5py.File(ref_path, "r") as f:
ref = f["data-new"][()]
compare_arrays(self, ref, new_eps, tol=1e-3)
self.assertClose(ref, new_eps, epsilon=1e-4)

def test_subpixel_averaging(self):
ms = self.init_solver()
Expand Down Expand Up @@ -1573,7 +1574,7 @@ def test_run_te_with_mu_material(self):
mu = ms.get_mu()
with h5py.File(data_path, "r") as f:
data = f["data"][()]
compare_arrays(self, data, mu)
self.assertClose(data, mu, epsilon=1e-4)

def test_output_tot_pwr(self):
ms = self.init_solver()
Expand All @@ -1591,7 +1592,7 @@ def test_output_tot_pwr(self):
with h5py.File(ref_path, "r") as f:
expected = f["data"][()]

compare_arrays(self, expected, arr)
self.assertClose(expected, arr, epsilon=1e-4)

def test_get_eigenvectors(self):
ms = self.init_solver()
Expand All @@ -1603,7 +1604,7 @@ def compare_eigenvectors(ref_fn, start, cols):
# Reshape the last dimension of 2 reals into one complex
expected = np.vectorize(complex)(expected[..., 0], expected[..., 1])
ev = ms.get_eigenvectors(start, cols)
np.testing.assert_allclose(expected, ev, rtol=1e-3)
self.assertClose(expected, ev, epsilon=1e-3)

# Get all columns
compare_eigenvectors("tutorial-te-eigenvectors.h5", 1, 8)
Expand Down Expand Up @@ -1738,7 +1739,7 @@ def test_epsilon_input_file(self):
]

self.check_band_range_data(expected_brd, ms.band_range_data)
compare_arrays(self, expected_freqs, ms.all_freqs[-1])
self.assertClose(expected_freqs, ms.all_freqs[-1], 1e-4)

self.check_gap_list(expected_gap_list, ms.gap_list)

Expand Down Expand Up @@ -2071,7 +2072,7 @@ def test_multiply_bloch_in_mpb_data(self):
md = mpb.MPBData(rectify=True, resolution=32, periods=3)
result2 = md.convert(efield_no_bloch)

compare_arrays(self, result1, result2, tol=1e-5)
self.assertClose(result1, result2, epsilon=1e-4)

def test_poynting(self):
ms = self.init_solver()
Expand Down
35 changes: 17 additions & 18 deletions python/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from typing import Union
import unittest

import numpy as np


def compare_arrays(test_instance, exp, res, tol=1e-3):
exp_1d = exp.ravel()
res_1d = res.ravel()

norm_exp = np.linalg.norm(exp_1d)
norm_res = np.linalg.norm(res_1d)

if norm_exp == 0:
test_instance.assertEqual(norm_res, 0)
else:
diff = np.linalg.norm(res_1d - exp_1d) / norm_exp
test_instance.assertLess(diff, tol)


class ApproxComparisonTestCase(unittest.TestCase):
"""A mixin for adding proper floating point value and vector comparison."""

def assertClose(self, x, y, epsilon=1e-2, msg=""):
"""Asserts that two values or vectors satisfy ‖x-y‖ ≤ ε * max(‖x‖, ‖y‖)."""
"""A mixin for adding correct scalar/vector comparison."""

def assertClose(
self,
x: Union[float, np.ndarray],
y: Union[float, np.ndarray],
epsilon: float = 1e-2,
msg: str = "",
):
"""Checks if two scalars or vectors satisfy ‖x-y‖ ≤ ε * max(‖x‖, ‖y‖).
Args:
x, y: two quantities to be compared (scalars or 1d arrays).
epsilon: threshold value (maximum) of the relative error.
msg: a string to display if the inequality is violated.
"""
x = np.atleast_1d(x).ravel()
y = np.atleast_1d(y).ravel()
x_norm = np.linalg.norm(x, ord=np.inf)
Expand Down

0 comments on commit 0aabdcc

Please sign in to comment.