diff --git a/src/classy_blocks/construct/curves/curve.py b/src/classy_blocks/construct/curves/curve.py index a2e2c179..d0aa6102 100644 --- a/src/classy_blocks/construct/curves/curve.py +++ b/src/classy_blocks/construct/curves/curve.py @@ -7,9 +7,9 @@ from classy_blocks.base.element import ElementBase from classy_blocks.construct.point import Point -from classy_blocks.types import NPPointListType, NPPointType, ParamCurveFuncType, PointListType, PointType +from classy_blocks.types import NPPointListType, NPPointType, ParamCurveFuncType, PointListType, PointType, VectorType from classy_blocks.util import functions as f -from classy_blocks.util.constants import DTYPE +from classy_blocks.util.constants import DTYPE, TOL class CurveBase(ElementBase): @@ -79,6 +79,13 @@ def get_closest_param(self, point: PointType) -> float: i_distance = np.argmin(distances) return params[i_distance] + def get_tangent(self, param: float, delta: float = TOL) -> VectorType: + """Returns a normalized tangent to the curve at given parameter""" + prev_point = self.get_point(param - delta / 2) + next_point = self.get_point(param + delta / 2) + + return f.unit_vector(next_point - prev_point) + class PointCurveBase(CurveBase): """A base object for curves, defined by a list of points""" diff --git a/tests/test_construct/test_curves/test_analytic.py b/tests/test_construct/test_curves/test_analytic.py index 903300ee..26d139bf 100644 --- a/tests/test_construct/test_curves/test_analytic.py +++ b/tests/test_construct/test_curves/test_analytic.py @@ -55,6 +55,16 @@ def test_transform(self): with self.assertRaises(NotImplementedError): self.curve.translate([1, 1, 1]) + @parameterized.expand( + ( + (0, [0, 1, 0]), + (np.pi / 2, [-1, 0, 0]), + (np.pi, [0, -1, 0]), + ) + ) + def test_tangent(self, param, tangent): + np.testing.assert_almost_equal(self.curve.get_tangent(param), tangent) + class LineCurveTests(unittest.TestCase): def setUp(self):