Skip to content

Commit

Permalink
Rework Arithmetic Expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyannn committed Jan 12, 2024
1 parent 2e68ffd commit 3eab363
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 101 deletions.
125 changes: 24 additions & 101 deletions arithmetic_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from math import factorial


## METHOD 1


class Reader:
def __init__(self, text):
self.text = text
Expand Down Expand Up @@ -103,8 +106,7 @@ def consume_or():
return int(consume_brackets())


def ascending(seq) -> bool:
return all(x <= y for x, y in pairwise(seq))
## METHOD 2


class PostfixConverter:
Expand All @@ -113,23 +115,21 @@ def __init__(
operator_order: list = None,
brackets: str = "()[]{}",
whitespace: str = " ",
unary: str = "",
):
self.whitespace = whitespace
self.open_brackets = {
brackets[i]: brackets[i + 1] for i in range(0, len(brackets), 2)
}
self.close_brackets = {
brackets[i + 1]: brackets[i] for i in range(0, len(brackets), 2)
}

self.open = {brackets[i]: brackets[i + 1] for i in range(0, len(brackets), 2)}
self.close = {brackets[i + 1]: brackets[i] for i in range(0, len(brackets), 2)}

operator_order = operator_order or []
precedence = {}
for p, ops in enumerate(reversed(operator_order)):
for op in ops:
precedence[op] = p
for op in self.open_brackets:
for op in self.open:
precedence[op] = -1
self.precedence = precedence
self.precedence_unary = {u: len(operator_order) for u in unary}

def convert(self, expression) -> list:
stack = []
Expand All @@ -138,26 +138,29 @@ def convert(self, expression) -> list:
if ch in self.whitespace:
continue

if ch in self.open_brackets:
if ch in self.open:
stack.append(ch)
prefix_context = True
continue

if ch in self.close_brackets:
if ch in self.close:
if stack and stack[-1] is None:
stack.pop()
while stack[-1] not in self.open_brackets:
while stack[-1] not in self.open:
yield stack.pop()
if self.open_brackets[stack.pop()] != ch:
if self.open[stack.pop()] != ch:
raise ValueError("Mismatched bracket")
prefix_context = False
continue

if prefix_context:
yield None

p = self.precedence.get(ch)
if p is None:
if prefix_context and ch in self.precedence_unary:
p = self.precedence_unary[ch]
elif ch in self.precedence:
p = self.precedence[ch]
else:
yield ch
prefix_context = False
continue
Expand All @@ -182,7 +185,7 @@ def eval_postfix(seq, ops: dict):
for ch in seq:
op = ops[ch]
arity = sum(
p.default == inspect._empty
p.default == inspect.Parameter.empty
for p in inspect.signature(op).parameters.values()
)
if arity:
Expand All @@ -192,111 +195,31 @@ def eval_postfix(seq, ops: dict):

return stack.pop()
finally:
pass


# if stack:
# raise ValueError("Unexpected values left in stack")
if stack:
raise ValueError("Unexpected values left in stack")


def check_conversions(converter, examples: dict, ops: dict = None):
for example, expected in examples.items():
result = "".join(ch or "_" for ch in converter.convert(example))
assert result == expected, f"Conversion failed for example: {example}"
if ops:
value = eval_postfix(converter.convert(example), ops)
print(example, "=", value)


empty = PostfixConverter()
plus = PostfixConverter(["+"])
boolean = PostfixConverter(["!", "&", "^|"])
arith = PostfixConverter(["!", "*/", "+-"])

DIGITS = {str(n): lambda x, y=n: 10 * x + y for n in range(10)}
DIGITS[None] = lambda: 0

if True:
check_conversions(
empty,
{
# "": "",
"2": "_2",
" 2 ": "_2",
" 235 ": "_235",
},
ops=DIGITS,
)

check_conversions(
plus,
{
"+5": "__5+",
"3+5": "_3_5+",
"2 + 34 + 5": "_2_34+_5+",
},
ops={
**DIGITS,
"+": operator.add,
},
)

ALL_OPS = {
**DIGITS,
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.truediv,
"%": operator.mod,
"**": operator.pow,
"<<": operator.lshift,
">>": operator.rshift,
"&": operator.and_,
"|": operator.or_,
"^": operator.xor,
"~": operator.invert,
"==": operator.eq,
"!=": operator.ne,
"<": operator.lt,
"<=": operator.le,
">": operator.gt,
">=": operator.ge,
"!": lambda x, y: factorial(x) if x else (1 - y),
}

check_conversions(
boolean,
{
# "1|0": "_1_0|",
# "!0": "__0!",
# "1|!0": "_1__0!|",
"1|0|1": "_1_0|_1|",
"1|!0|!1": "_1__0!|__1!|",
"1|!!0": "_1___0!!|",
" 1 & 0 | 1 & 1 ": "_1_0&_1_1&|",
" !1 & 0 | !1 & !!0 ": "__1!_0&__1!___0!!&|",
" 1 & (0 | 1) & 1 ": "_1_0_1|&_1&",
},
ALL_OPS,
)

check_conversions(
arith,
{
"+5": "__5+",
"3+5": "_3_5+",
"3!": "_3_!",
"!0": "__0!",
"3! + !0": "_3_!__0!+",
"2 + 34 + -5": "_2_34+__5-+",
"2 - 3 + 4 - 5 + 6 * 7": "_2_3-_4+_5-_6_7*+",
},
ALL_OPS,
)


def analyze2(expression):
return eval_postfix(arith.convert(expression), ALL_OPS)
arithmetics = PostfixConverter(["!", "*/&|%", "+-"])
return eval_postfix(arithmetics.convert(expression), ALL_OPS)


if __name__ == "__main__":
Expand Down
154 changes: 154 additions & 0 deletions test_arithmetic_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import unittest
from arithmetic_expression import *


class ReaderTests(unittest.TestCase):
def setUp(self):
self.reader = Reader("example_input")

def test_is_end(self):
for _ in range(13):
self.reader.read()
self.assertTrue(self.reader.is_end())

def test_read_and_peek(self):
for letter in ["e", "x", "a", "m", "p", "l", "e", "_", "i", "n", "p", "u", "t"]:
self.assertEqual(self.reader.peek(), letter)
self.assertEqual(self.reader.read(), letter)

def tearDown(self):
self.reader = None


class LoggerTests(unittest.TestCase):
def setUp(self):
reader = Reader("example_input")
self.logger = Logger(reader=reader)

def test_call(self):
with self.logger("test"):
self.assertIn((0, "test"), self.logger.contexts)

def tearDown(self):
self.logger = None


class Analyze1Tests(unittest.TestCase):
def test_analyze1(self):
self.assertEqual(analyze1("1&!0|!!!1"), 1)
self.assertEqual(analyze1("1&0|!1&1"), 0)


class PostfixConverterTests(unittest.TestCase):
def check_conversions(
self, converter: PostfixConverter, examples: dict, ops: dict = None
):
for example, expected in examples.items():
result = "".join(ch or "_" for ch in converter.convert(example))
self.assertEqual(
result, expected, f"Conversion failed for example: {example}"
)
if ops:
value = eval_postfix(converter.convert(example), ops)
try:
true_value = eval(example)
self.assertEqual(value, true_value)
except:
pass

def test_empty(self):
empty = PostfixConverter()
self.check_conversions(
empty,
{
# "": "", -- not a valid expression
"2": "_2",
" 2 ": "_2",
" 235 ": "_235",
},
ops=DIGITS,
)

def test_plus(self):
plus = PostfixConverter(["+"])
self.check_conversions(
plus,
{
"+5": "__5+",
"3+5": "_3_5+",
"2 + 34 + 5": "_2_34+_5+",
},
ops={
**DIGITS,
"+": operator.add,
},
)

def test_boolean(self):
boolean = PostfixConverter(["!", "&", "^|"])
self.check_conversions(
boolean,
{
"1|0": "_1_0|",
"!0": "__0!",
"1|!0": "_1__0!|",
"1|0|1": "_1_0|_1|",
"1|!0|!1": "_1__0!|__1!|",
"1|!!0": "_1___0!!|",
" 1 & 0 | 1 & 1 ": "_1_0&_1_1&|",
" !1 & 0 | !1 & !!0 ": "__1!_0&__1!___0!!&|",
" 1 & (0 | 1) & 1 ": "_1_0_1|&_1&",
},
ALL_OPS,
)

def test_arithmetics(self):
arith = PostfixConverter(["!", "*/", "+-"])
self.check_conversions(
arith,
{
"+5": "__5+",
"3+5": "_3_5+",
"3!": "_3_!",
"!0": "__0!",
"3! + !0": "_3_!__0!+",
"2 + 34 + -5": "_2_34+__5-+",
"2 - 3 + 4 - 5 + 6 * 7": "_2_3-_4+_5-_6_7*+",
},
ALL_OPS,
)

def test_unary(self):
arith = PostfixConverter(["!", "*/", "+-"], unary="!+-")
self.check_conversions(
arith,
{
"4/-2": "_4__2-/",
"--2": "___2--",
"2 - -1": "_2__1--",
},
ALL_OPS,
)

def test_brackets(self):
arith = PostfixConverter(["!", "*/", "+-"])
self.check_conversions(
arith,
{
"1 * 2 - 3 * 4 / 2 + 5": "_1_2*_3_4*_2/-_5+",
"1 * (2 - 3) * 4 / 2 + 5": "_1_2_3-*_4*_2/_5+",
"({1 * 2} - [(3 * 4) / 2] + 5)": "_1_2*_3_4*_2/-_5+",
"((1 * 2) - 3 * (4 / 2 + 5))": "_1_2*_3_4_2/_5+*-",
},
ALL_OPS,
)


class Analyze2Tests(unittest.TestCase):
def test_analyze2(self):
self.assertEqual(analyze2("2+34+-5"), 31)
self.assertEqual(analyze2("2-3+4-5+6*7"), 40)


if __name__ == "__main__":
unittest.main()

0 comments on commit 3eab363

Please sign in to comment.