Skip to content

Commit

Permalink
Add mesh smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
FranzBangar committed Jul 10, 2024
1 parent 2aeea3e commit 02aeeef
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 140 deletions.
16 changes: 13 additions & 3 deletions src/classy_blocks/optimize/grid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import List, Type

import numpy as np

from classy_blocks.mesh import Mesh
from classy_blocks.optimize.cell import CellBase, HexCell, QuadCell
from classy_blocks.optimize.clamps.clamp import ClampBase
from classy_blocks.optimize.junction import Junction
Expand Down Expand Up @@ -32,7 +35,7 @@ def __init__(self, points: NPPointListType, addressing: List[IndexType]):

self._bind_cell_neighbours()
self._bind_junction_cells()
self._bind_junction_connections()
self._bind_junction_neighbours()

def _bind_cell_neighbours(self) -> None:
"""Adds neighbours to cells"""
Expand All @@ -46,11 +49,11 @@ def _bind_junction_cells(self) -> None:
for junction in self.junctions:
junction.add_cell(cell)

def _bind_junction_connections(self) -> None:
def _bind_junction_neighbours(self) -> None:
"""Adds connections to junctions"""
for junction_1 in self.junctions:
for junction_2 in self.junctions:
junction_1.add_connection(junction_2)
junction_1.add_neighbour(junction_2)

def get_junction_from_clamp(self, clamp: ClampBase) -> Junction:
for junction in self.junctions:
Expand Down Expand Up @@ -109,5 +112,12 @@ class QuadGrid(GridBase):
class HexGrid(GridBase):
cell_class = HexCell

@classmethod
def from_mesh(cls, mesh: Mesh) -> "HexGrid":
points = np.array([vertex.position for vertex in mesh.vertices])
addresses = [block.indexes for block in mesh.blocks]

return cls(points, addresses)


Grid = HexGrid
7 changes: 5 additions & 2 deletions src/classy_blocks/optimize/junction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from classy_blocks.optimize.cell import CellBase
from classy_blocks.optimize.clamps.clamp import ClampBase
from classy_blocks.optimize.links import LinkBase
from classy_blocks.types import NPPointListType, NPPointType
from classy_blocks.types import NPPointListType, NPPointType, PointType


class ClampExistsError(Exception):
Expand Down Expand Up @@ -45,7 +45,7 @@ def add_cell(self, cell: CellBase) -> None:
self.cells.add(cell)
return

def add_connection(self, to: "Junction") -> bool:
def add_neighbour(self, to: "Junction") -> bool:
"""Returns True if this Junction is connected to passed one"""
if to == self:
return False
Expand Down Expand Up @@ -82,6 +82,9 @@ def is_boundary(self) -> bool:

return False

def move_to(self, position: PointType) -> None:
self.points[self.index] = position

@property
def quality(self) -> float:
"""Returns average quality of all cells at this junction;
Expand Down
39 changes: 39 additions & 0 deletions src/classy_blocks/optimize/smoother.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List

import numpy as np

from classy_blocks.mesh import Mesh
from classy_blocks.optimize.grid import Grid
from classy_blocks.optimize.junction import Junction


class SmootherBase:
def __init__(self, grid: Grid):
self.grid = grid

self.inner: List[Junction] = []
for junction in self.grid.junctions:
if not junction.is_boundary:
self.inner.append(junction)

def smooth_iteration(self) -> None:
for junction in self.inner:
near_points = [j.point for j in junction.neighbours]
junction.move_to(np.average(near_points, axis=0))

def smooth(self, iterations: int = 5) -> None:
for _ in range(iterations):
self.smooth_iteration()


class MeshSmoother(SmootherBase):
def __init__(self, mesh: Mesh):
self.mesh = mesh

super().__init__(Grid.from_mesh(self.mesh))

def smooth(self, iterations: int = 5) -> None:
super().smooth(iterations)

for i, point in enumerate(self.grid.points):
self.mesh.vertices[i].move_to(point)
34 changes: 34 additions & 0 deletions tests/test_optimize/optimize_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest
from typing import get_args

from classy_blocks.construct.operations.box import Box
from classy_blocks.mesh import Mesh
from classy_blocks.modify.find.geometric import GeometricFinder
from classy_blocks.optimize.grid import Grid
from classy_blocks.types import AxisType


class BoxTestsBase(unittest.TestCase):
def setUp(self):
self.mesh = Mesh()

# generate a cube, consisting of 2x2x2 smaller cubes
for x in (-1, 0):
for y in (-1, 0):
for z in (-1, 0):
box = Box([x, y, z], [x + 1, y + 1, z + 1])

for axis in get_args(AxisType):
box.chop(axis, count=10)

self.mesh.add(box)

self.mesh.assemble()

self.finder = GeometricFinder(self.mesh)

def get_vertex(self, position):
return next(iter(self.finder.find_in_sphere(position)))

def get_grid(self, mesh: Mesh) -> Grid:
return Grid.from_mesh(mesh)
File renamed without changes.
101 changes: 101 additions & 0 deletions tests/test_optimize/test_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import unittest

from classy_blocks.optimize.iteration import IterationDriver
from classy_blocks.util.constants import VBIG


class IterationDriverTests(unittest.TestCase):
def setUp(self):
self.max_iterations = 20
self.tolerance = 0.1

@property
def driver(self) -> IterationDriver:
return IterationDriver(self.max_iterations, self.tolerance)

def test_initial_improvement_empty(self):
self.assertEqual(self.driver.initial_improvement, VBIG)

def test_initial_improvement(self):
driver = self.driver

driver.begin_iteration(1000)
driver.end_iteration(900)

self.assertEqual(driver.initial_improvement, 100)

def test_end_last_improvement_single(self):
driver = self.driver

driver.begin_iteration(1000)
driver.end_iteration(900)

self.assertEqual(driver.last_improvement, 100)

def test_initial_improvement_multi(self):
driver = self.driver

driver.begin_iteration(1000)
driver.end_iteration(900)

driver.begin_iteration(900)
driver.end_iteration(899)

self.assertEqual(driver.initial_improvement, 100)

def test_last_improvement_multi(self):
driver = self.driver

driver.begin_iteration(1000)
driver.end_iteration(900)

driver.begin_iteration(900)
driver.end_iteration(899)

self.assertEqual(driver.last_improvement, 1)

def test_converged_empty(self):
self.assertFalse(self.driver.converged)

def test_converged_iter_limit(self):
driver = self.driver

for i in range(1, driver.max_iterations + 2):
driver.begin_iteration(1000 / i)
driver.end_iteration(1100 / i)

self.assertTrue(driver.converged)

def test_converged_first(self):
"""Cannot converge in the first iteration"""
driver = self.driver

driver.begin_iteration(1000)
driver.end_iteration(999)

self.assertFalse(driver.converged)

def test_converged_inadequate(self):
"""No improvement, no convergence"""
driver = self.driver
driver.tolerance = 0

for _ in range(1, driver.max_iterations - 1):
driver.begin_iteration(1000)
driver.end_iteration(1000)

self.assertFalse(driver.converged)

def test_converged_improvement(self):
driver = self.driver

driver.begin_iteration(1000)
driver.end_iteration(900)

driver.begin_iteration(900)
driver.end_iteration(890)

driver.begin_iteration(890)
driver.end_iteration(889)

self.assertTrue(driver.converged)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from parameterized import parameterized

from classy_blocks.base.exceptions import UndefinedGradingsError
from classy_blocks.construct.flat.sketches.grid import Grid as GridSketch
from classy_blocks.construct.stack import ExtrudedStack
from classy_blocks.mesh import Mesh
Expand Down Expand Up @@ -40,7 +41,7 @@ def test_junction_internal(self):

try:
mesh.assemble() # will fail because there are no chops
except:
except UndefinedGradingsError:
pass

grid = self.get_grid(mesh)
Expand All @@ -62,5 +63,5 @@ def test_cell_neighbours(self, parent, orient, neighbour):
self.assertEqual(self.grid.cells[parent].neighbours[orient], self.grid.cells[neighbour])

@parameterized.expand(((0, 3), (1, 4), (2, 5), (3, 3)))
def test_connections(self, junction, count):
self.assertEqual(len(self.grid.junctions[junction].connections), count)
def test_neighbours(self, junction, count):
self.assertEqual(len(self.grid.junctions[junction].neighbours), count)
Loading

0 comments on commit 02aeeef

Please sign in to comment.