diff --git a/teg/ir/instr/binary.py b/teg/ir/instr/binary.py index 168b555..637700f 100644 --- a/teg/ir/instr/binary.py +++ b/teg/ir/instr/binary.py @@ -36,6 +36,11 @@ def __init__(self, output, input1, input2): super(IR_CompareLT, self).__init__(output=output, input1=input1, input2=input2) -class IR_CompareGT(IR_Binary): +class IR_CompareLTE(IR_Binary): + def __init__(self, output, input1, input2): + super(IR_CompareLTE, self).__init__(output=output, input1=input1, input2=input2) + + +class IR_CompareGT(IR_Binary): # NOTE: This class is never instantiated? def __init__(self, output, input1, input2): super(IR_CompareGT, self).__init__(output=output, input1=input1, input2=input2) diff --git a/teg/ir/passes/to_c.py b/teg/ir/passes/to_c.py index b8d69a7..4f42a4e 100644 --- a/teg/ir/passes/to_c.py +++ b/teg/ir/passes/to_c.py @@ -9,6 +9,7 @@ IR_Binary, IR_CompareLT, + IR_CompareLTE, IR_CompareGT, IR_LAnd, IR_LOr, @@ -301,6 +302,12 @@ def __tegpass_c__(self, name_ctx, **kwargs): return self.__tegpass_c_binary__(name_ctx, op='<', **kwargs) +@overloads(IR_CompareLTE) +class CPass_CompareLTE: + def __tegpass_c__(self, name_ctx, **kwargs): + return self.__tegpass_c_binary__(name_ctx, op='<=', **kwargs) + + @overloads(IR_Add) class CPass_Add: def __tegpass_c__(self, name_ctx, **kwargs): diff --git a/teg/ir/passes/typing.py b/teg/ir/passes/typing.py index 75aade0..2f94806 100644 --- a/teg/ir/passes/typing.py +++ b/teg/ir/passes/typing.py @@ -6,6 +6,7 @@ IR_Binary, IR_IfElse, IR_CompareLT, + IR_CompareLTE, IR_CompareGT, IR_Variable, IR_Literal, @@ -183,6 +184,20 @@ def __tegpass_typing__(self): self.output.set_irtype(IR_Type(ctype=Types.BOOL, size=left_type.size)) +@overloads(IR_CompareLTE) +class TypingPass_CompareLTE: + def __tegpass_typing__(self): + left_symbol, right_symbol = self.inputs + left_type = left_symbol.irtype() + right_type = right_symbol.irtype() + + assert left_type.ctype == right_type.ctype, 'Binary operands have incompatible types' + assert ((left_type.size == right_type.size) or + (left_type.size == 1) or (right_type.size == 1)), 'Binary operands have incompatible sizes' + + self.output.set_irtype(IR_Type(ctype=Types.BOOL, size=left_type.size)) + + @overloads(IR_UnaryMath) class TypingPass_UnaryMath: def __tegpass_typing__(self): diff --git a/teg/lang/operator_overloads.py b/teg/lang/operator_overloads.py index e3e715f..8560f64 100644 --- a/teg/lang/operator_overloads.py +++ b/teg/lang/operator_overloads.py @@ -391,10 +391,10 @@ def __eq__(self, other): class BoolOverloads: def __str__(self): - return f'{self.left_expr} < {self.right_expr}' + return f'{self.left_expr} <{"=" if self.allow_eq else ""} {self.right_expr}' def __repr__(self): - return f'Bool({repr(self.left_expr)}, {repr(self.right_expr)})' + return f'Bool({repr(self.left_expr)}, {repr(self.right_expr)}, {repr(self.allow_eq)})' @overloads(And) diff --git a/teg/passes/compile.py b/teg/passes/compile.py index 5a7d7bd..9d1c7ef 100644 --- a/teg/passes/compile.py +++ b/teg/passes/compile.py @@ -34,6 +34,7 @@ IR_Integrate, IR_Pack, IR_CompareLT, + IR_CompareLTE, IR_LAnd, IR_LOr, IR_Assign @@ -226,7 +227,8 @@ def _to_ir(expr: ITeg, symbols: Dict[str, IR_Symbol]) -> (List[IR_Instruction], elif isinstance(expr, Or): code = [*ir_left, *ir_right, IR_LOr(out_var, left_var, right_var)] elif isinstance(expr, Bool): - code = [*ir_left, *ir_right, IR_CompareLT(out_var, left_var, right_var)] + ir_class = IR_CompareLTE if expr.allow_eq else IR_CompareLT + code = [*ir_left, *ir_right, ir_class(out_var, left_var, right_var)] return code, out_var, {**left_symbols, **right_symbols} diff --git a/teg/passes/simplify.py b/teg/passes/simplify.py index 542459d..626903b 100644 --- a/teg/passes/simplify.py +++ b/teg/passes/simplify.py @@ -263,8 +263,10 @@ def simplify(expr: ITeg) -> ITeg: elif isinstance(expr, Bool): left_expr, right_expr = simplify(expr.left_expr), simplify(expr.right_expr) if isinstance(left_expr, Const) and isinstance(right_expr, Const): - return false if evaluate(Bool(left_expr, right_expr)) == 0.0 else true + return false if evaluate(Bool(left_expr, right_expr, allow_eq=expr.allow_eq)) == 0.0 else true return Bool(left_expr, right_expr) + # NOTE: having the allow_eq flag be carried through appears to break evals of certain complicated expressions + # return Bool(left_expr, right_expr, allow_eq=expr.allow_eq) elif isinstance(expr, And): left_expr, right_expr = simplify(expr.left_expr), simplify(expr.right_expr) @@ -275,7 +277,7 @@ def simplify(expr: ITeg) -> ITeg: if left_expr == false or right_expr == false: return false if isinstance(left_expr, Const) and isinstance(right_expr, Const): - return Const(evaluate(And(simple1, simple2))) + return Const(evaluate(And(simple1, simple2))) # TODO: These variables are never defined; is this tested? return And(left_expr, right_expr) elif isinstance(expr, Or): @@ -287,7 +289,7 @@ def simplify(expr: ITeg) -> ITeg: if left_expr == true or right_expr == true: return true if isinstance(left_expr, Const) and isinstance(right_expr, Const): - return Const(evaluate(Or(simple1, simple2))) + return Const(evaluate(Or(simple1, simple2))) # TODO: These variables are never defined; is this tested? return Or(left_expr, right_expr) else: