-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2aeea3e
commit 02aeeef
Showing
9 changed files
with
236 additions
and
140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
from classy_blocks.mesh import Mesh | ||
from classy_blocks.optimize.grid import Grid | ||
from classy_blocks.optimize.junction import Junction | ||
|
||
|
||
class SmootherBase: | ||
def __init__(self, grid: Grid): | ||
self.grid = grid | ||
|
||
self.inner: List[Junction] = [] | ||
for junction in self.grid.junctions: | ||
if not junction.is_boundary: | ||
self.inner.append(junction) | ||
|
||
def smooth_iteration(self) -> None: | ||
for junction in self.inner: | ||
near_points = [j.point for j in junction.neighbours] | ||
junction.move_to(np.average(near_points, axis=0)) | ||
|
||
def smooth(self, iterations: int = 5) -> None: | ||
for _ in range(iterations): | ||
self.smooth_iteration() | ||
|
||
|
||
class MeshSmoother(SmootherBase): | ||
def __init__(self, mesh: Mesh): | ||
self.mesh = mesh | ||
|
||
super().__init__(Grid.from_mesh(self.mesh)) | ||
|
||
def smooth(self, iterations: int = 5) -> None: | ||
super().smooth(iterations) | ||
|
||
for i, point in enumerate(self.grid.points): | ||
self.mesh.vertices[i].move_to(point) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import unittest | ||
from typing import get_args | ||
|
||
from classy_blocks.construct.operations.box import Box | ||
from classy_blocks.mesh import Mesh | ||
from classy_blocks.modify.find.geometric import GeometricFinder | ||
from classy_blocks.optimize.grid import Grid | ||
from classy_blocks.types import AxisType | ||
|
||
|
||
class BoxTestsBase(unittest.TestCase): | ||
def setUp(self): | ||
self.mesh = Mesh() | ||
|
||
# generate a cube, consisting of 2x2x2 smaller cubes | ||
for x in (-1, 0): | ||
for y in (-1, 0): | ||
for z in (-1, 0): | ||
box = Box([x, y, z], [x + 1, y + 1, z + 1]) | ||
|
||
for axis in get_args(AxisType): | ||
box.chop(axis, count=10) | ||
|
||
self.mesh.add(box) | ||
|
||
self.mesh.assemble() | ||
|
||
self.finder = GeometricFinder(self.mesh) | ||
|
||
def get_vertex(self, position): | ||
return next(iter(self.finder.find_in_sphere(position))) | ||
|
||
def get_grid(self, mesh: Mesh) -> Grid: | ||
return Grid.from_mesh(mesh) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import unittest | ||
|
||
from classy_blocks.optimize.iteration import IterationDriver | ||
from classy_blocks.util.constants import VBIG | ||
|
||
|
||
class IterationDriverTests(unittest.TestCase): | ||
def setUp(self): | ||
self.max_iterations = 20 | ||
self.tolerance = 0.1 | ||
|
||
@property | ||
def driver(self) -> IterationDriver: | ||
return IterationDriver(self.max_iterations, self.tolerance) | ||
|
||
def test_initial_improvement_empty(self): | ||
self.assertEqual(self.driver.initial_improvement, VBIG) | ||
|
||
def test_initial_improvement(self): | ||
driver = self.driver | ||
|
||
driver.begin_iteration(1000) | ||
driver.end_iteration(900) | ||
|
||
self.assertEqual(driver.initial_improvement, 100) | ||
|
||
def test_end_last_improvement_single(self): | ||
driver = self.driver | ||
|
||
driver.begin_iteration(1000) | ||
driver.end_iteration(900) | ||
|
||
self.assertEqual(driver.last_improvement, 100) | ||
|
||
def test_initial_improvement_multi(self): | ||
driver = self.driver | ||
|
||
driver.begin_iteration(1000) | ||
driver.end_iteration(900) | ||
|
||
driver.begin_iteration(900) | ||
driver.end_iteration(899) | ||
|
||
self.assertEqual(driver.initial_improvement, 100) | ||
|
||
def test_last_improvement_multi(self): | ||
driver = self.driver | ||
|
||
driver.begin_iteration(1000) | ||
driver.end_iteration(900) | ||
|
||
driver.begin_iteration(900) | ||
driver.end_iteration(899) | ||
|
||
self.assertEqual(driver.last_improvement, 1) | ||
|
||
def test_converged_empty(self): | ||
self.assertFalse(self.driver.converged) | ||
|
||
def test_converged_iter_limit(self): | ||
driver = self.driver | ||
|
||
for i in range(1, driver.max_iterations + 2): | ||
driver.begin_iteration(1000 / i) | ||
driver.end_iteration(1100 / i) | ||
|
||
self.assertTrue(driver.converged) | ||
|
||
def test_converged_first(self): | ||
"""Cannot converge in the first iteration""" | ||
driver = self.driver | ||
|
||
driver.begin_iteration(1000) | ||
driver.end_iteration(999) | ||
|
||
self.assertFalse(driver.converged) | ||
|
||
def test_converged_inadequate(self): | ||
"""No improvement, no convergence""" | ||
driver = self.driver | ||
driver.tolerance = 0 | ||
|
||
for _ in range(1, driver.max_iterations - 1): | ||
driver.begin_iteration(1000) | ||
driver.end_iteration(1000) | ||
|
||
self.assertFalse(driver.converged) | ||
|
||
def test_converged_improvement(self): | ||
driver = self.driver | ||
|
||
driver.begin_iteration(1000) | ||
driver.end_iteration(900) | ||
|
||
driver.begin_iteration(900) | ||
driver.end_iteration(890) | ||
|
||
driver.begin_iteration(890) | ||
driver.end_iteration(889) | ||
|
||
self.assertTrue(driver.converged) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.