Skip to content

Commit

Permalink
Allow MeshSource to take a 1D array of sources (openmc-dev#2980)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulromano authored and church89 committed Jul 18, 2024
1 parent c01d012 commit bdc4a46
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 32 deletions.
51 changes: 25 additions & 26 deletions openmc/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,20 +399,19 @@ class MeshSource(SourceBase):
----------
mesh : openmc.MeshBase
The mesh over which source sites will be generated.
sources : iterable of openmc.SourceBase
Sources for each element in the mesh. If spatial distributions are set
on any of the source objects, they will be ignored during source site
sampling.
sources : sequence of openmc.SourceBase
Sources for each element in the mesh. Sources must be specified as
either a 1-D array in the order of the mesh indices or a
multidimensional array whose shape matches the mesh shape. If spatial
distributions are set on any of the source objects, they will be ignored
during source site sampling.
Attributes
----------
mesh : openmc.MeshBase
The mesh over which source sites will be generated.
sources : numpy.ndarray or iterable of openmc.SourceBase
The set of sources to apply to each element. The shape of this array
must match the shape of the mesh with and exception in the case of
unstructured mesh, which allows for application of 1-D array or
iterable.
sources : numpy.ndarray of openmc.SourceBase
Sources to apply to each element
strength : float
Strength of the source
type : str
Expand All @@ -433,7 +432,7 @@ def mesh(self) -> MeshBase:

@property
def strength(self) -> float:
return sum(s.strength for s in self.sources.flat)
return sum(s.strength for s in self.sources)

@property
def sources(self) -> np.ndarray:
Expand All @@ -450,16 +449,23 @@ def sources(self, s):

s = np.asarray(s)

if isinstance(self.mesh, StructuredMesh) and s.shape != self.mesh.dimension:
raise ValueError('The shape of the source array'
f'({s.shape}) does not match the '
f'dimensions of the structured mesh ({self.mesh.dimension})')
if isinstance(self.mesh, StructuredMesh):
if s.size != self.mesh.num_mesh_cells:
raise ValueError(
f'The length of the source array ({s.size}) does not match '
f'the number of mesh elements ({self.mesh.num_mesh_cells}).')

# If user gave a multidimensional array, flatten in the order
# of the mesh indices
if s.ndim > 1:
s = s.ravel(order='F')

elif isinstance(self.mesh, UnstructuredMesh):
if len(s.shape) > 1:
if s.ndim > 1:
raise ValueError('Sources must be a 1-D array for unstructured mesh')

self._sources = s
for src in self._sources.flat:
for src in self._sources:
if isinstance(src, IndependentSource) and src.space is not None:
warnings.warn('Some sources on the mesh have spatial '
'distributions that will be ignored at runtime.')
Expand All @@ -481,7 +487,7 @@ def set_total_strength(self, strength: float):
"""
current_strength = self.strength if self.strength != 0.0 else 1.0

for s in self.sources.flat:
for s in self.sources:
s.strength *= strength / current_strength

def normalize_source_strengths(self):
Expand All @@ -500,13 +506,8 @@ def populate_xml_element(self, elem: ET.Element):
elem.set("mesh", str(self.mesh.id))

# write in the order of mesh indices
if isinstance(self.mesh, openmc.UnstructuredMesh):
for s in self.sources:
elem.append(s.to_xml_element())
else:
for idx in self.mesh.indices:
idx = tuple(i - 1 for i in idx)
elem.append(self.sources[idx].to_xml_element())
for s in self.sources:
elem.append(s.to_xml_element())

@classmethod
def from_xml_element(cls, elem: ET.Element, meshes) -> openmc.MeshSource:
Expand All @@ -527,11 +528,9 @@ def from_xml_element(cls, elem: ET.Element, meshes) -> openmc.MeshSource:
MeshSource generated from the XML element
"""
mesh_id = int(get_text(elem, 'mesh'))

mesh = meshes[mesh_id]

sources = [SourceBase.from_xml_element(e) for e in elem.iterchildren('source')]
sources = np.asarray(sources).reshape(mesh.dimension, order='F')
return cls(mesh, sources)


Expand Down
9 changes: 3 additions & 6 deletions tests/unit_tests/test_source_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,12 @@ def test_mesh_source_independent(run_in_tmpdir, void_model, mesh_type):
# for each element, set a single-non zero source with particles
# traveling out of the mesh (and geometry) w/o crossing any other
# mesh elements
for i, j, k in mesh.indices:
for flat_index, (i, j, k) in enumerate(mesh.indices):
ijk = (i-1, j-1, k-1)
# zero-out all source strengths and set the strength
# on the element of interest
mesh_source.strength = 0.0
mesh_source.sources[ijk].strength = 1.0
mesh_source.sources[flat_index].strength = 1.0

sp_file = model.run()

Expand Down Expand Up @@ -375,10 +375,7 @@ def test_mesh_source_file(run_in_tmpdir):
mesh.upper_right = (2, 3, 4)
mesh.dimension = (1, 1, 1)

mesh_source_arr = np.asarray([file_source]).reshape(mesh.dimension)
source = openmc.MeshSource(mesh, mesh_source_arr)

model.settings.source = source
model.settings.source = openmc.MeshSource(mesh, [file_source])

model.export_to_model_xml()

Expand Down

0 comments on commit bdc4a46

Please sign in to comment.