Skip to content

Commit

Permalink
feat(rules): add MultiplicativeInverse rule
Browse files Browse the repository at this point in the history
 - for handling / in expressions
  • Loading branch information
justindujardin committed Feb 5, 2024
1 parent ce56d37 commit 06f3341
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mathy_core/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .constants_simplify import ConstantsSimplifyRule # noqa
from .distributive_factor_out import DistributiveFactorOutRule # noqa
from .distributive_multiply_across import DistributiveMultiplyRule # noqa
from .multiplicative_inverse import MultiplicativeInverseRule # noqa
from .restate_subtraction import RestateSubtractionRule # noqa
from .variable_multiply import VariableMultiplyRule # noqa

Expand All @@ -14,6 +15,7 @@
"ConstantsSimplifyRule",
"DistributiveFactorOutRule",
"DistributiveMultiplyRule",
"MultiplicativeInverseRule",
"RestateSubtractionRule",
"VariableMultiplyRule",
)
105 changes: 105 additions & 0 deletions mathy_core/rules/multiplicative_inverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Optional, cast

from ..expressions import (
AddExpression,
ConstantExpression,
DivideExpression,
EqualExpression,
MathExpression,
MultiplyExpression,
NegateExpression,
PowerExpression,
SubtractExpression,
VariableExpression,
)
from ..rule import BaseRule, ExpressionChangeRule

_OP_DIVISION_EXPRESSION = "division-expression"
_OP_DIVISION_VARIABLE = "division-variable"
_OP_DIVISION_COMPLEX_DENOMINATOR = "division-complex-denominator"
_OP_DIVISION_NEGATIVE_DENOMINATOR = "division-negative-denominator"


class MultiplicativeInverseRule(BaseRule):
"""Convert division operations to multiplication by the reciprocal."""

@property
def name(self) -> str:
return "Multiplicative Inverse"

@property
def code(self) -> str:
return "MI"

def get_type(self, node: MathExpression) -> Optional[str]:
"""Determine the configuration of the tree for this transformation.
Support different types of tree configurations based on the division operation:
- DivisionExpression is a division to be restated as multiplication by reciprocal
- DivisionVariable is a division by a variable
- DivisionComplexDenominator is a division by a complex expression
- DivisionNegativeDenominator is a division by a negative term
"""
is_division = isinstance(node, DivideExpression)
if not is_division:
return None

# Division by a variable (e.g., (2 + 3z) / z)
if isinstance(node.right, VariableExpression):
return _OP_DIVISION_VARIABLE

# Division where the denominator is a complex expression (e.g., (x^2 + 4x + 4) / (2x - 2))
if isinstance(node.right, AddExpression) or isinstance(
node.right, SubtractExpression
):
return _OP_DIVISION_COMPLEX_DENOMINATOR

# Division where the denominator is negative (e.g., (2 + 3z) / -z)
if isinstance(node.right, NegateExpression):
return _OP_DIVISION_NEGATIVE_DENOMINATOR

# If none of the above, it's a general division expression
return _OP_DIVISION_EXPRESSION

def can_apply_to(self, node: MathExpression) -> bool:
tree_type = self.get_type(node)
return tree_type is not None

def apply_to(self, node: MathExpression) -> ExpressionChangeRule:
change = super().apply_to(node)
tree_type = self.get_type(node)
assert tree_type is not None, "call can_apply_to before applying a rule"
change.save_parent() # connect result to node.parent

# Handle the division based on the tree type
if tree_type == _OP_DIVISION_EXPRESSION:
result = MultiplyExpression(
node.left.clone(),
DivideExpression(ConstantExpression(1), node.right.clone()),
)

elif tree_type == _OP_DIVISION_VARIABLE:
# For division by a single variable, treat it the same as a general expression
reciprocal = DivideExpression(node.right.clone(), ConstantExpression(-1))
result = MultiplyExpression(node.left.clone(), reciprocal)

elif tree_type == _OP_DIVISION_COMPLEX_DENOMINATOR:
result = MultiplyExpression(
node.left.clone(),
DivideExpression(ConstantExpression(1), node.right.clone()),
)

elif tree_type == _OP_DIVISION_NEGATIVE_DENOMINATOR:
# For division by a negative denominator, negate the numerator and use the positive reciprocal
result = MultiplyExpression(
node.left.clone(),
DivideExpression(ConstantExpression(-1), node.right.get_child().clone()),
)

else:
raise NotImplementedError(
"Unsupported tree configuration for MultiplicativeInverseRule"
)

result.set_changed() # mark this node as changed for visualization
return change.done(result)
21 changes: 21 additions & 0 deletions mathy_core/rules/multiplicative_inverse.test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"valid": [
{
"input": "(21x^3 - 35x^2) / 7x",
"output": "(21x^3 - 35x^2) * 1 / 7x"
},
{
"input": "(x^2 + 4x + 4) / (2x - 2)",
"output": "(x^2 + 4x + 4) * 1 / (2x - 2)"
},
{
"input": "(2 + 3x) / 2x",
"output": "(2 + 3x) * 1 / 2x"
},
{
"input": "((x + 1) / -(y + 2))",
"output": "(x + 1) * -1 / (y + 2)"
}
],
"invalid": []
}
42 changes: 42 additions & 0 deletions tests/test_rules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mathy_core import MathExpression
from mathy_core.parser import ExpressionParser
from mathy_core.rules import (
AssociativeSwapRule,
Expand All @@ -8,6 +9,7 @@
DistributiveMultiplyRule,
RestateSubtractionRule,
VariableMultiplyRule,
MultiplicativeInverseRule,
)
from mathy_core.testing import run_rule_tests

Expand Down Expand Up @@ -54,6 +56,13 @@ def debug(ex):
run_rule_tests("restate_subtraction", RestateSubtractionRule, debug)


def test_rules_multiplicative_inverse():
def debug(ex):
pass

run_rule_tests("multiplicative_inverse", MultiplicativeInverseRule, debug)


def test_rules_variable_multiply():
def debug(ex):
pass
Expand All @@ -80,3 +89,36 @@ def test_rules_rule_can_apply_to():
]
for action in available_actions:
assert type(action.can_apply_to(expression)) == bool


def debug_expressions(one: MathExpression, two: MathExpression):
one_inputs = [f"{e.__class__.__name__}" for e in one.to_list()]
two_inputs = [f"{e.__class__.__name__}" for e in two.to_list()]
print("one: ", one.raw, one_inputs)
print("two: ", two.raw, two_inputs)


def test_rules_rule_restate_subtraction_corner_case_1():
parser = ExpressionParser()
expression = parser.parse("4x - 3y + 3x")

restate = RestateSubtractionRule()
dfo = DistributiveFactorOutRule()
commute = CommutativeSwapRule(preferred=False)

node = restate.find_node(expression)
assert node is not None, "should find node"
assert restate.can_apply_to(node), "should be able to apply"
change = restate.apply_to(node)
assert change.result is not None, "should get change"
assert change.result.get_root().raw == "4x + -3y + 3x"

change = commute.apply_to(change.result.get_root())
assert change.result is not None, "should get change"
node = dfo.find_node(change.result.get_root())
assert node is not None, "should find node"
assert dfo.can_apply_to(node), "should be able to apply"
change = dfo.apply_to(node)
assert change.result is not None, "should get change"
node = change.result.get_root()
assert node.raw == "(4 + 3) * x + -3y"

0 comments on commit 06f3341

Please sign in to comment.