diff --git a/luisa_lang/parse.py b/luisa_lang/parse.py index 658cc32..2a6fcdb 100644 --- a/luisa_lang/parse.py +++ b/luisa_lang/parse.py @@ -534,7 +534,24 @@ def collect_args() -> List[hir.Value | hir.Ref]: raise hir.ParsingError(expr, ret.message) return ret - def parse_binop(self, expr: ast.BinOp) -> hir.Value: + # def parse_compare(self, expr: ast.Compare) -> hir.Value | ComptimeValue: + # cmpop_to_str: Dict[type, str] = { + # ast.Eq: "==", + # ast.NotEq: "!=", + # ast.Lt: "<", + # ast.LtE: "<=", + # ast.Gt: ">", + # ast.GtE: ">=" + # } + # if len(expr.ops) != 1: + # raise hir.ParsingError(expr, "only one comparison operator is allowed") + # op = expr.ops[0] + # if type(op) not in cmpop_to_str: + # raise hir.ParsingError(expr, f"unsupported comparison operator {type(op)}") + # op_str = cmpop_to_str[type(op)] + # method_name = BINOP_TO_METHOD_NAMES[type(op)] + + def parse_binop(self, expr: ast.BinOp | ast.Compare) -> hir.Value: binop_to_op_str: Dict[type, str] = { ast.Add: "+", ast.Sub: "-", @@ -556,20 +573,32 @@ def parse_binop(self, expr: ast.BinOp) -> hir.Value: ast.GtE: ">=", } - op_str = binop_to_op_str[type(expr.op)] - lhs = self.parse_expr(expr.left) + op: ast.AST + if isinstance(expr, ast.Compare): + if len(expr.ops) != 1: + raise hir.ParsingError( + expr, "only one comparison operator is allowed") + op = expr.ops[0] + left = expr.left + right = expr.comparators[0] + else: + op = expr.op + left = expr.left + right = expr.right + op_str = binop_to_op_str[type(op)] + lhs = self.parse_expr(left) if isinstance(lhs, ComptimeValue): lhs = self.try_convert_comptime_value(lhs, hir.Span.from_ast(expr)) if not lhs.type: raise hir.ParsingError( - expr.left, f"unable to infer type of left operand of binary operation {op_str}") - rhs = self.parse_expr(expr.right) + left, f"unable to infer type of left operand of binary operation {op_str}") + rhs = self.parse_expr(right) if isinstance(rhs, ComptimeValue): rhs = self.try_convert_comptime_value(rhs, hir.Span.from_ast(expr)) if not rhs.type: raise hir.ParsingError( - expr.right, f"unable to infer type of right operand of binary operation {op_str}") - ops = BINOP_TO_METHOD_NAMES[type(expr.op)] + right, f"unable to infer type of right operand of binary operation {op_str}") + ops = BINOP_TO_METHOD_NAMES[type(op)] def infer_binop(name: str, rname: str) -> hir.Value: assert lhs.type and rhs.type @@ -712,6 +741,30 @@ def check(i: int, val_type: hir.Type) -> None: raise hir.ParsingError( targets[0], f"unsupported type for unpacking: {values.type}") + def parse_unary(self, expr: ast.UnaryOp) -> hir.Value: + op = expr.op + if type(op) not in UNARY_OP_TO_METHOD_NAMES: + raise hir.ParsingError( + expr, f"unsupported unary operator {type(op)}") + op_str = UNARY_OP_TO_METHOD_NAMES[type(op)] + operand = self.parse_expr(expr.operand) + if isinstance(operand, ComptimeValue): + operand = self.try_convert_comptime_value( + operand, hir.Span.from_ast(expr)) + if not operand.type: + raise hir.ParsingError( + expr.operand, f"unable to infer type of operand of unary operation {op_str}") + method_name = UNARY_OP_TO_METHOD_NAMES[type(op)] + if (method := operand.type.method(method_name)) and method: + ret = self.parse_call_impl( + hir.Span.from_ast(expr), method, [operand]) + if isinstance(ret, hir.TemplateMatchingError): + raise hir.ParsingError(expr, ret.message) + return ret + else: + raise hir.ParsingError( + expr, f"operator {type(op)} not defined for type {operand.type}") + def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue: match expr: case ast.Constant(): @@ -723,8 +776,10 @@ def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue: return ret case ast.Subscript() | ast.Attribute(): return self.parse_access(expr) - case ast.BinOp(): + case ast.BinOp() | ast.Compare(): return self.parse_binop(expr) + case ast.UnaryOp(): + return self.parse_unary(expr) case ast.Call(): return self.parse_call(expr) case ast.Tuple(): @@ -970,40 +1025,6 @@ def parse_anno_ty() -> hir.Type: if stmt.value: self.parse_multi_assignment( [stmt.target], [parse_anno_ty], self.parse_expr(stmt.value)) - # value = self.parse_expr(stmt.value) - # if isinstance(value, ComptimeValue): - # var = self.parse_ref( - # stmt.target, new_var_hint='comptime') - # else: - # var = self.parse_ref(stmt.target, new_var_hint='dsl') - # if isinstance(var, ComptimeValue): - # if isinstance(value, ComptimeValue): - # try: - # var.update(value.value) - # except Exception as e: - # raise hir.ParsingError( - # stmt, f"error updating comptime value: {e}") from e - # return - # else: - # raise hir.ParsingError( - # stmt, f"comptime value cannot be assigned with DSL value") - # else: - # if isinstance(value, ComptimeValue): - # value = self.try_convert_comptime_value( - # value, span) - # assert value.type - # anno_ty = parse_anno_ty() - # if not var.type: - # var.type = value.type - # if not var.type.is_concrete(): - # raise hir.ParsingError( - # stmt, "only concrete type can be assigned, please annotate the variable with concrete types") - # if not hir.is_type_compatible_to(value.type, anno_ty): - # raise hir.ParsingError( - # stmt, f"expected {anno_ty}, got {value.type}") - # if not value.type.is_concrete(): - # value.type = var.type - # self.cur_bb().append(hir.Assign(var, value, span)) else: var = self.parse_ref(stmt.target, new_var_hint='dsl') anno_ty = parse_anno_ty()