diff --git a/mathy_core/rules/__init__.py b/mathy_core/rules/__init__.py index c583ad7..da949db 100644 --- a/mathy_core/rules/__init__.py +++ b/mathy_core/rules/__init__.py @@ -4,6 +4,7 @@ from .constants_simplify import ConstantsSimplifyRule # noqa from .distributive_factor_out import DistributiveFactorOutRule # noqa from .distributive_multiply_across import DistributiveMultiplyRule # noqa +from .multiplicative_inverse import MultiplicativeInverseRule # noqa from .restate_subtraction import RestateSubtractionRule # noqa from .variable_multiply import VariableMultiplyRule # noqa @@ -14,6 +15,7 @@ "ConstantsSimplifyRule", "DistributiveFactorOutRule", "DistributiveMultiplyRule", + "MultiplicativeInverseRule", "RestateSubtractionRule", "VariableMultiplyRule", ) diff --git a/mathy_core/rules/multiplicative_inverse.py b/mathy_core/rules/multiplicative_inverse.py new file mode 100644 index 0000000..2bd3186 --- /dev/null +++ b/mathy_core/rules/multiplicative_inverse.py @@ -0,0 +1,105 @@ +from typing import Optional, cast + +from ..expressions import ( + AddExpression, + ConstantExpression, + DivideExpression, + EqualExpression, + MathExpression, + MultiplyExpression, + NegateExpression, + PowerExpression, + SubtractExpression, + VariableExpression, +) +from ..rule import BaseRule, ExpressionChangeRule + +_OP_DIVISION_EXPRESSION = "division-expression" +_OP_DIVISION_VARIABLE = "division-variable" +_OP_DIVISION_COMPLEX_DENOMINATOR = "division-complex-denominator" +_OP_DIVISION_NEGATIVE_DENOMINATOR = "division-negative-denominator" + + +class MultiplicativeInverseRule(BaseRule): + """Convert division operations to multiplication by the reciprocal.""" + + @property + def name(self) -> str: + return "Multiplicative Inverse" + + @property + def code(self) -> str: + return "MI" + + def get_type(self, node: MathExpression) -> Optional[str]: + """Determine the configuration of the tree for this transformation. + + Support different types of tree configurations based on the division operation: + - DivisionExpression is a division to be restated as multiplication by reciprocal + - DivisionVariable is a division by a variable + - DivisionComplexDenominator is a division by a complex expression + - DivisionNegativeDenominator is a division by a negative term + """ + is_division = isinstance(node, DivideExpression) + if not is_division: + return None + + # Division by a variable (e.g., (2 + 3z) / z) + if isinstance(node.right, VariableExpression): + return _OP_DIVISION_VARIABLE + + # Division where the denominator is a complex expression (e.g., (x^2 + 4x + 4) / (2x - 2)) + if isinstance(node.right, AddExpression) or isinstance( + node.right, SubtractExpression + ): + return _OP_DIVISION_COMPLEX_DENOMINATOR + + # Division where the denominator is negative (e.g., (2 + 3z) / -z) + if isinstance(node.right, NegateExpression): + return _OP_DIVISION_NEGATIVE_DENOMINATOR + + # If none of the above, it's a general division expression + return _OP_DIVISION_EXPRESSION + + def can_apply_to(self, node: MathExpression) -> bool: + tree_type = self.get_type(node) + return tree_type is not None + + def apply_to(self, node: MathExpression) -> ExpressionChangeRule: + change = super().apply_to(node) + tree_type = self.get_type(node) + assert tree_type is not None, "call can_apply_to before applying a rule" + change.save_parent() # connect result to node.parent + + # Handle the division based on the tree type + if tree_type == _OP_DIVISION_EXPRESSION: + result = MultiplyExpression( + node.left.clone(), + DivideExpression(ConstantExpression(1), node.right.clone()), + ) + + elif tree_type == _OP_DIVISION_VARIABLE: + # For division by a single variable, treat it the same as a general expression + reciprocal = DivideExpression(node.right.clone(), ConstantExpression(-1)) + result = MultiplyExpression(node.left.clone(), reciprocal) + + elif tree_type == _OP_DIVISION_COMPLEX_DENOMINATOR: + result = MultiplyExpression( + node.left.clone(), + DivideExpression(ConstantExpression(1), node.right.clone()), + ) + + elif tree_type == _OP_DIVISION_NEGATIVE_DENOMINATOR: + # For division by a negative denominator, negate the numerator and use the positive reciprocal + result = MultiplyExpression( + node.left.clone(), + DivideExpression(ConstantExpression(-1), node.right.get_child().clone()), + ) + + else: + raise NotImplementedError( + "Unsupported tree configuration for MultiplicativeInverseRule" + ) + + result.set_changed() # mark this node as changed for visualization + return change.done(result) diff --git a/mathy_core/rules/multiplicative_inverse.test.json b/mathy_core/rules/multiplicative_inverse.test.json new file mode 100644 index 0000000..03fc719 --- /dev/null +++ b/mathy_core/rules/multiplicative_inverse.test.json @@ -0,0 +1,21 @@ +{ + "valid": [ + { + "input": "(21x^3 - 35x^2) / 7x", + "output": "(21x^3 - 35x^2) * 1 / 7x" + }, + { + "input": "(x^2 + 4x + 4) / (2x - 2)", + "output": "(x^2 + 4x + 4) * 1 / (2x - 2)" + }, + { + "input": "(2 + 3x) / 2x", + "output": "(2 + 3x) * 1 / 2x" + }, + { + "input": "((x + 1) / -(y + 2))", + "output": "(x + 1) * -1 / (y + 2)" + } + ], + "invalid": [] +} diff --git a/tests/test_rules.py b/tests/test_rules.py index 8b43d87..011bcb4 100644 --- a/tests/test_rules.py +++ b/tests/test_rules.py @@ -1,3 +1,4 @@ +from mathy_core import MathExpression from mathy_core.parser import ExpressionParser from mathy_core.rules import ( AssociativeSwapRule, @@ -8,6 +9,7 @@ DistributiveMultiplyRule, RestateSubtractionRule, VariableMultiplyRule, + MultiplicativeInverseRule, ) from mathy_core.testing import run_rule_tests @@ -54,6 +56,13 @@ def debug(ex): run_rule_tests("restate_subtraction", RestateSubtractionRule, debug) +def test_rules_multiplicative_inverse(): + def debug(ex): + pass + + run_rule_tests("multiplicative_inverse", MultiplicativeInverseRule, debug) + + def test_rules_variable_multiply(): def debug(ex): pass @@ -80,3 +89,36 @@ def test_rules_rule_can_apply_to(): ] for action in available_actions: assert type(action.can_apply_to(expression)) == bool + + +def debug_expressions(one: MathExpression, two: MathExpression): + one_inputs = [f"{e.__class__.__name__}" for e in one.to_list()] + two_inputs = [f"{e.__class__.__name__}" for e in two.to_list()] + print("one: ", one.raw, one_inputs) + print("two: ", two.raw, two_inputs) + + +def test_rules_rule_restate_subtraction_corner_case_1(): + parser = ExpressionParser() + expression = parser.parse("4x - 3y + 3x") + + restate = RestateSubtractionRule() + dfo = DistributiveFactorOutRule() + commute = CommutativeSwapRule(preferred=False) + + node = restate.find_node(expression) + assert node is not None, "should find node" + assert restate.can_apply_to(node), "should be able to apply" + change = restate.apply_to(node) + assert change.result is not None, "should get change" + assert change.result.get_root().raw == "4x + -3y + 3x" + + change = commute.apply_to(change.result.get_root()) + assert change.result is not None, "should get change" + node = dfo.find_node(change.result.get_root()) + assert node is not None, "should find node" + assert dfo.can_apply_to(node), "should be able to apply" + change = dfo.apply_to(node) + assert change.result is not None, "should get change" + node = change.result.get_root() + assert node.raw == "(4 + 3) * x + -3y"