Skip to content

Commit

Permalink
bugfix #66: apply proper check for complex
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Nov 17, 2023
1 parent 9be8304 commit a14fe68
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 20 deletions.
23 changes: 14 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
language_version: python3
- repo: https://github.com/kynan/nbstripout
rev: 0.3.9
hooks:
- id: nbstripout
- repo: https://github.com/kynan/nbstripout
rev: 0.6.0
hooks:
- id: nbstripout
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1
hooks:
- id: black
language_version: python3.11
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1
hooks:
- id: black-jupyter
language_version: python3.11
15 changes: 14 additions & 1 deletion fdtd/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ class Backend:
def __repr__(self):
return self.__class__.__name__

def is_complex(x):
"""check if an object is a `ComplexFloat`"""
return (
isinstance(x, complex)
or (isinstance(x, np.ndarray) and x.dtype in (np.complex64, np.complex128))
or (
isinstance(x, torch.Tensor)
and x.dtype in (torch.complex64, torch.complex128)
)
)


def _replace_float(func):
"""replace the default dtype a function is called with"""
Expand All @@ -99,7 +110,7 @@ class NumpyBackend(Backend):

float = numpy.float64
""" floating type for array """

complex = numpy.complex128
""" complex type for array """

Expand Down Expand Up @@ -305,6 +316,8 @@ def numpy(self, arr):
else:
return numpy.asarray(arr)

is_complex = staticmethod(torch.is_complex)

# Torch Cuda Backend
if TORCH_CUDA_AVAILABLE:

Expand Down
18 changes: 8 additions & 10 deletions fdtd/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

## Object
class Object:
""" An object to place in the grid """
"""An object to place in the grid"""

def __init__(self, permittivity: Tensorlike, name: str = None):
"""
Expand All @@ -48,7 +48,7 @@ def _register_grid(
self.grid = grid
self.grid.objects.append(self)

if self.permittivity.dtype is bd.complex().dtype:
if bd.is_complex(self.permittivity):
self.grid.promote_dtypes_to_complex()

if self.name is not None:
Expand All @@ -70,7 +70,8 @@ def _register_grid(
if bd.is_array(self.permittivity) and len(self.permittivity.shape) == 3:
self.permittivity = self.permittivity[:, :, :, None]
self.inverse_permittivity = (
bd.ones((self.Nx, self.Ny, self.Nz, 3),dtype=self.permittivity.dtype) / self.permittivity
bd.ones((self.Nx, self.Ny, self.Nz, 3), dtype=self.permittivity.dtype)
/ self.permittivity
)

# set the permittivity values of the object at its border to be equal
Expand Down Expand Up @@ -123,9 +124,9 @@ def update_E(self, curl_H):
"""
loc = (self.x, self.y, self.z)

self.grid.E[loc] = self.grid.E[loc] +(
self.grid.E[loc] = self.grid.E[loc] + (
self.grid.courant_number * self.inverse_permittivity * curl_H[loc]
)
)

def update_H(self, curl_E):
"""custom update equations for inside the object
Expand All @@ -134,9 +135,6 @@ def update_H(self, curl_E):
curl_E: the curl of electric field in the grid.
"""
# def promote_dtypes_to_complex(self):
# self.E = self.E.astype(bd.complex)
# self.H = self.H.astype(bd.complex)

def __repr__(self):
return f"{self.__class__.__name__}(name={repr(self.name)})"
Expand All @@ -163,7 +161,7 @@ def _handle_slice(s):


class AbsorbingObject(Object):
""" An absorbing object takes conductivity into account """
"""An absorbing object takes conductivity into account"""

def __init__(
self, permittivity: Tensorlike, conductivity: Tensorlike, name: str = None
Expand Down Expand Up @@ -232,7 +230,7 @@ def update_H(self, curl_E):


class AnisotropicObject(Object):
""" An object with anisotropic permittivity tensor """
"""An object with anisotropic permittivity tensor"""

def _register_grid(
self, grid: Grid, x: slice = None, y: slice = None, z: slice = None
Expand Down

0 comments on commit a14fe68

Please sign in to comment.