Skip to content

Commit

Permalink
Get rid of Grid.points property
Browse files Browse the repository at this point in the history
np.array() that was called in Cell.points property was slowing
optimization down by 10 times. Putting all positions of all mesh
vertices into a copy of numpy array would improve speed but introduce
confusion and errors and scatter the update() code all over the place.

Instead, a Cell.point_array is being held in memory and get updated from
vertices' positions, thus saving time by not creating new arrays and
lessening confusion by not having a duplicate list of points.
  • Loading branch information
FranzBangar committed Jun 3, 2024
1 parent f0dc174 commit 71e7796
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 65 deletions.
25 changes: 10 additions & 15 deletions src/classy_blocks/modify/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from classy_blocks.items.vertex import Vertex
from classy_blocks.types import NPPointListType, NPPointType, OrientType
from classy_blocks.util import functions as f
from classy_blocks.util.constants import EDGE_PAIRS, FACE_MAP, VSMALL
from classy_blocks.util.constants import DTYPE, EDGE_PAIRS, FACE_MAP, VSMALL


class NoCommonSidesError(Exception):
Expand All @@ -19,9 +19,12 @@ class Cell:
its quality metrics can then be transcribed directly
from checkMesh."""

def __init__(self, block: Block, mesh_points: NPPointListType):
def __init__(self, block: Block):
self.block = block
self.mesh_points = mesh_points

# The slowest thing in numpy is calling np.array(); keep this array
# in memory and update positions from vertices dynamically when needed
self.point_array = np.zeros((8, 3), dtype=DTYPE)

self.neighbours: Dict[OrientType, Optional[Cell]] = {
"bottom": None,
Expand All @@ -47,12 +50,6 @@ def __init__(self, block: Block, mesh_points: NPPointListType):
self.side_indexes = [item[0] for item in q_map.items()]
self.face_indexes = [item[1] for item in q_map.items()]

self._quality: Optional[float] = None

def invalidate(self) -> None:
"""Returns True if a cached quality/center can be returned"""
self._quality = None

def get_common_vertices(self, candidate: "Cell") -> Set[int]:
"""Returns indexes of common vertices between this and provided cell"""
this_indexes = set(self.vertex_indexes)
Expand Down Expand Up @@ -103,7 +100,10 @@ def vertices(self) -> List[Vertex]:
@property
def points(self) -> NPPointListType:
"""A list of points defining this cell, as a numpy array"""
return np.take(self.mesh_points, self.vertex_indexes, axis=0)
for i, vertex in enumerate(self.block.vertices):
self.point_array[i] = vertex.position

return self.point_array

@property
def center(self) -> NPPointType:
Expand All @@ -121,9 +121,6 @@ def face_centers(self) -> NPPointListType:

@property
def quality(self) -> float:
if self._quality is not None:
return self._quality

quality = 0

center = self.center
Expand Down Expand Up @@ -187,8 +184,6 @@ def q_scale(base, exponent, factor, value):

quality += np.sum(q_scale(3, 2.5, 3, aspect_factor))

self._quality = quality

return quality

@property
Expand Down
29 changes: 2 additions & 27 deletions src/classy_blocks/modify/grid.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import List

import numpy as np

from classy_blocks.mesh import Mesh
from classy_blocks.modify.cell import Cell
from classy_blocks.modify.clamps.clamp import ClampBase
from classy_blocks.modify.junction import Junction
from classy_blocks.util.constants import DTYPE


class NoJunctionError(Exception):
Expand All @@ -19,13 +16,8 @@ class Grid:
def __init__(self, mesh: Mesh):
self.mesh = mesh

# store all mesh points in a numpy array for faster
# calculations; when a vertex position is modified, update the
# array using update()
self.points = np.array([vertex.position for vertex in self.mesh.vertices], dtype=DTYPE)

self.cells = [Cell(block, self.points) for block in self.mesh.blocks]
self.junctions = [Junction(vertex, self.points) for vertex in self.mesh.vertices]
self.cells = [Cell(block) for block in self.mesh.blocks]
self.junctions = [Junction(vertex) for vertex in self.mesh.vertices]

self._bind_junctions()
self._bind_cell_neighbours()
Expand Down Expand Up @@ -56,23 +48,6 @@ def get_junction_from_clamp(self, clamp: ClampBase) -> Junction:

raise NoJunctionError

def update(self, junction: Junction) -> None:
self.points[junction.vertex.index] = junction.vertex.position
for cell in junction.cells:
cell.invalidate()

# also update linked stuff
if junction.clamp is not None:
if junction.clamp.is_linked:
linked_junction = self.get_junction_from_clamp(junction.clamp)
self.points[linked_junction.vertex.index] = linked_junction.vertex.position
for cell in linked_junction.cells:
cell.invalidate()

def clear_cache(self):
for cell in self.cells:
cell._quality = None

def add_clamp(self, clamp: ClampBase) -> None:
for junction in self.junctions:
if junction.vertex == clamp.vertex:
Expand Down
12 changes: 8 additions & 4 deletions src/classy_blocks/modify/junction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
from classy_blocks.items.vertex import Vertex
from classy_blocks.modify.cell import Cell
from classy_blocks.modify.clamps.clamp import ClampBase
from classy_blocks.types import NPPointListType


class NoClampError(Exception):
"""Raised when this junction has no clamp defined but slope calculation is requested"""


class ClampExistsError(Exception):
"""Raised when a clamp is added to a junction that already has one defined"""


class Junction:
"""A class that collects Cells/Blocks that
share the same Vertex"""

def __init__(self, vertex: Vertex, points: NPPointListType):
def __init__(self, vertex: Vertex):
self.vertex = vertex
self.points = points
self.index = self.vertex.index

self.cells: Set[Cell] = set()
self.neighbours: Set[Junction] = set()
Expand Down Expand Up @@ -47,6 +48,9 @@ def add_neighbour(self, junction: "Junction") -> bool:
return False

def add_clamp(self, clamp: ClampBase) -> None:
if self.clamp is not None:
raise NoClampError(f"A clamp is already defined on junction {self.vertex.index}")

self.clamp = clamp

@property
Expand Down
25 changes: 15 additions & 10 deletions src/classy_blocks/modify/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from classy_blocks.modify.clamps.clamp import ClampBase
from classy_blocks.modify.grid import Grid
from classy_blocks.modify.iteration import ClampOptimizationData, IterationDriver
from classy_blocks.modify.junction import Junction
from classy_blocks.util.constants import TOL

MinimizationMethodType = Literal["SLSQP", "L-BFGS-B", "Nelder-Mead", "Powell"]
Expand All @@ -27,19 +28,21 @@ def release_vertex(self, clamp: ClampBase) -> None:
"""Adds a clamp to optimization. Raises an exception if it already exists"""
self.grid.add_clamp(clamp)

def optimize_clamp(self, clamp: ClampBase, method: MinimizationMethodType) -> None:
def optimize_junction(self, junction: Junction, method: MinimizationMethodType) -> None:
"""Move clamp.vertex so that quality at junction is improved;
rollback changes if grid quality decreased after optimization"""
if junction.clamp is None:
raise ValueError(f"No clamp at this junction {junction.vertex.index}")

clamp = junction.clamp
initial_params = copy.copy(clamp.params)
junction = self.grid.get_junction_from_clamp(clamp)

reporter = ClampOptimizationData(clamp.vertex.index, self.grid.quality, junction.quality)
reporter.report_start()

def fquality(params):
# move all vertices according to X
clamp.update_params(params)
self.grid.update(junction)

if clamp.is_linked:
return self.grid.quality
Expand All @@ -53,30 +56,32 @@ def fquality(params):

if reporter.rollback:
clamp.update_params(initial_params)
self.grid.update(junction)

def _get_sensitivity(self, clamp):
"""Returns maximum partial derivative at current params"""
junction = self.grid.get_junction_from_clamp(clamp)

def fquality(clamp, junction, params):
clamp.update_params(params)
self.grid.update(junction)
return junction.quality

sensitivities = np.asarray(
scipy.optimize.approx_fprime(clamp.params, lambda p: fquality(clamp, junction, p), epsilon=10 * TOL)
)
return np.linalg.norm(sensitivities)
# return np.max(np.abs(sensitivities.flatten()))

def optimize_iteration(self, method: MinimizationMethodType) -> None:
self.grid.clear_cache()
junctions = []

for junction in self.grid.junctions:
if junction.clamp is None:
continue
junctions.append(junction)

clamps = sorted(self.grid.clamps, key=lambda c: self._get_sensitivity(c), reverse=True)
junctions.sort(key=lambda j: self._get_sensitivity(j.clamp), reverse=True)

for clamp in clamps:
self.optimize_clamp(clamp, method)
for junction in junctions:
self.optimize_junction(junction, method)

def optimize(
self, max_iterations: int = 20, tolerance: float = 0.1, method: MinimizationMethodType = "SLSQP"
Expand Down
18 changes: 9 additions & 9 deletions tests/test_modify/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,40 @@ def test_common_vertices(self, index_1, index_2, count):
block_1 = self.mesh.blocks[index_1]
block_2 = self.mesh.blocks[index_2]

cell_1 = Cell(block_1, self.mesh_points)
cell_2 = Cell(block_2, self.mesh_points)
cell_1 = Cell(block_1)
cell_2 = Cell(block_2)

self.assertEqual(len(cell_1.get_common_vertices(cell_2)), count)

@parameterized.expand(((0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 8, 1)))
def test_get_corner(self, block, vertex, corner):
cell = Cell(self.mesh.blocks[block], self.mesh_points)
cell = Cell(self.mesh.blocks[block])

self.assertEqual(cell.get_corner(vertex), corner)

@parameterized.expand(((0, 1, "right"), (1, 0, "left"), (1, 2, "back")))
def test_get_common_side(self, index_1, index_2, orient):
cell_1 = Cell(self.mesh.blocks[index_1], self.mesh_points)
cell_2 = Cell(self.mesh.blocks[index_2], self.mesh_points)
cell_1 = Cell(self.mesh.blocks[index_1])
cell_2 = Cell(self.mesh.blocks[index_2])

self.assertEqual(cell_1.get_common_side(cell_2), orient)

def test_no_common_sides(self):
with self.assertRaises(NoCommonSidesError):
cell_1 = Cell(self.mesh.blocks[0], self.mesh_points)
cell_2 = Cell(self.mesh.blocks[2], self.mesh_points)
cell_1 = Cell(self.mesh.blocks[0])
cell_2 = Cell(self.mesh.blocks[2])

cell_1.get_common_side(cell_2)

def test_quality_good(self):
cell = Cell(self.mesh.blocks[0], self.mesh_points)
cell = Cell(self.mesh.blocks[0])

self.assertLess(cell.quality, 1)

def test_quality_bad(self):
block = self.mesh.blocks[0]
block.vertices[0].move_to([-10, -10, -10])

cell = Cell(block, self.mesh_points)
cell = Cell(block)

self.assertGreater(cell.quality, 100)

0 comments on commit 71e7796

Please sign in to comment.