diff --git a/include/matxscript/ir/printer/ir_docsifier.h b/include/matxscript/ir/printer/ir_docsifier.h index cdfac31d..515dc67e 100644 --- a/include/matxscript/ir/printer/ir_docsifier.h +++ b/include/matxscript/ir/printer/ir_docsifier.h @@ -333,6 +333,12 @@ inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const StringRef& t this->AddExitCallback([doc = d.get()]() { doc->dispatch_tokens.pop_back(); }); } +/*! \brief Creates the matx common prefix, which is by default `T` */ +inline ExprDoc Dialect(const IRDocsifier& d, const StringRef& attr) { + d->ir_usage.insert("matx"); + return IdDoc(d->cfg->dialect_prefix)->Attr(attr); +} + } // namespace printer } // namespace ir } // namespace matxscript diff --git a/include/matxscript/ir/printer/text_printer.h b/include/matxscript/ir/printer/text_printer.h index 5f3b544b..3a024cbe 100644 --- a/include/matxscript/ir/printer/text_printer.h +++ b/include/matxscript/ir/printer/text_printer.h @@ -45,10 +45,8 @@ class PrinterConfigNode : public Object { public: /*! \brief A stack that tracks the names of the binding hierarchy */ Array binding_names = {}; - /*! \brief The prefix of IR nodes */ - StringRef ir_prefix = "I"; - /*! \brief The prefix of TIR nodes */ - StringRef tir_prefix = "T"; + /*! \brief The prefix of module */ + StringRef dialect_prefix = "matx"; /*! \brief Number of spaces used for indentation*/ int indent_spaces = 4; /*! \brief Whether to print line numbers */ @@ -66,7 +64,7 @@ class PrinterConfigNode : public Object { void VisitAttrs(AttrVisitor* v) { v->Visit("binding_names", &binding_names); - v->Visit("ir_prefix", &ir_prefix); + v->Visit("dialect_prefix", &dialect_prefix); v->Visit("indent_spaces", &indent_spaces); v->Visit("print_line_numbers", &print_line_numbers); v->Visit("num_context_lines", &num_context_lines); diff --git a/src/ir/prim_expr.cc b/src/ir/prim_expr.cc index d1b93873..a319f903 100644 --- a/src/ir/prim_expr.cc +++ b/src/ir/prim_expr.cc @@ -101,6 +101,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](FloatImm s, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Float(s->value, p->Attr("value")); + }); + // PrimCast PrimCast::PrimCast(DataType t, PrimExpr value, Span span) { MXCHECK(value.defined()); @@ -128,6 +133,13 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](PrimCast s, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc dtype = LiteralDoc::DataType(s->dtype, p->Attr("dtype")); + ExprDoc value = d->AsDoc(s->value, p->Attr("value")); + return Dialect(d, "PrimCast")->Call({dtype, value}); + }); + // HLOCastPrim HLOCastPrim::HLOCastPrim(DataType t, BaseExpr value, Span span) { MXCHECK(value.defined()); @@ -154,6 +166,13 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](HLOCastPrim s, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc dtype = LiteralDoc::DataType(s->dtype, p->Attr("dtype")); + ExprDoc value = d->AsDoc(s->value, p->Attr("value")); + return Dialect(d, "HLOCastPrim")->Call({dtype, value}); + }); + #define MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(Name) \ Name::Name(PrimExpr a, PrimExpr b, Span span) { \ using T = Name::ContainerType; \ @@ -184,6 +203,28 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) data_ = std::move(node); \ } +#define MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR( \ + NodeType, NodeObj, NodeFunc, OpString, OpKind) \ + MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ + .set_dispatch("", [](ir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ + PrimExpr ret = matxscript::ir::NodeFunc(node->a, node->b); \ + if (!ret->IsInstance() && ret->IsInstance() && \ + ret->IsInstance()) { \ + return Dialect(d, OpString)->Call({a, b}); \ + } \ + return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ + }); + +#define MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \ + MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ + .set_dispatch("", [](ir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ + return Dialect(d, OpString)->Call({a, b}); \ + }); + // PrimAdd MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimAdd); @@ -203,6 +244,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimAdd, PrimAddNode, add, "Add", kAdd); + // PrimSub MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimSub); @@ -222,6 +265,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimSub, PrimSubNode, sub, "Sub", kSub); + // PrimMul MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimMul); @@ -241,6 +286,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimMul, PrimMulNode, mul, "Mul", kMult); + // PrimDiv MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimDiv); @@ -260,6 +307,21 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](PrimDiv node, ObjectPath p, IRDocsifier d) -> Doc { + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); + PrimExpr ret = matxscript::ir::div(node->a, node->b); + if (!ret->IsInstance()) { + return Dialect(d, "PrimDiv")->Call({a, b}); + } + if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) && + (node->b->dtype.is_int() || node->b->dtype.is_uint())) { + return Dialect(d, "PrimDiv")->Call({a, b}); + } + return OperationDoc(OperationDocNode::Kind::kDiv, {a, b}); + }); + // PrimMod MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimMod); @@ -279,6 +341,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY(PrimMod, "truncmod"); + // PrimFloorDiv MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimFloorDiv); @@ -294,6 +358,9 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "floordiv(" << op->a << ", " << op->b << ")"; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR( + PrimFloorDiv, PrimFloorDivNode, floordiv, "FloorDiv", kFloorDiv); + // PrimFloorMod MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimFloorMod); @@ -309,6 +376,9 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "floormod(" << op->a << ", " << op->b << ")"; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR( + PrimFloorMod, PrimFloorModNode, floormod, "FloorMod", kMod); + // PrimMin MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimMin); @@ -328,6 +398,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY(PrimMin, "min"); + // PrimMax MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimMax); @@ -347,6 +419,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY(PrimMax, "max"); + // PrimEQ MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimEQ); @@ -366,6 +440,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimEQ, PrimEQNode, equal, "EQ", kEq); + // PrimNE MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimNE); @@ -385,6 +461,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimNE, PrimNENode, not_equal, "NE", kNotEq); + // PrimLT MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimLT); @@ -404,6 +482,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimLT, PrimLTNode, less_than, "LT", kLt); + // PrimLE MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimLE); @@ -423,6 +503,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimLE, PrimLENode, less_or_equal, "LE", kLtE); + // PrimGT MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimGT); @@ -442,6 +524,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimGT, PrimGTNode, greater_than, "GT", kGt); + // PrimGE MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimGE); @@ -461,6 +545,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimGE, PrimGENode, greater_or_equal, "GE", kGtE); + // PrimAnd PrimAnd::PrimAnd(PrimExpr a, PrimExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined"; @@ -493,6 +579,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimAnd, PrimAndNode, logic_and, "And", kAnd); + // PrimOr PrimOr::PrimOr(PrimExpr a, PrimExpr b, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined"; @@ -525,6 +613,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ')'; }); +MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimOr, PrimOrNode, logic_or, "Or", kOr); + // PrimNot PrimNot::PrimNot(PrimExpr a, Span span) { MXCHECK(a.defined()) << "ValueError: a is undefined"; @@ -590,6 +680,17 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch( + "", [](ir::PrimSelect select, ObjectPath p, IRDocsifier d) -> Doc { + return Dialect(d, "PrimSelect") + ->Call({ + d->AsDoc(select->condition, p->Attr("condition")), + d->AsDoc(select->true_value, p->Attr("true_value")), + d->AsDoc(select->false_value, p->Attr("false_value")), + }); + }); + // Let PrimLet::PrimLet(PrimVar var, PrimExpr value, PrimExpr body, Span span) { MXCHECK(value.defined()); @@ -607,8 +708,8 @@ PrimLet::PrimLet(PrimVar var, PrimExpr value, PrimExpr body, Span span) { data_ = std::move(node); } -MATXSCRIPT_REGISTER_GLOBAL("ir.Let").set_body_typed( - [](PrimVar var, PrimExpr value, PrimExpr body, Span span) { +MATXSCRIPT_REGISTER_GLOBAL("ir.PrimLet") + .set_body_typed([](PrimVar var, PrimExpr value, PrimExpr body, Span span) { return PrimLet(var, value, body, span); }); @@ -624,6 +725,16 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::PrimLet let, ObjectPath p, IRDocsifier d) -> Doc { + DictDoc where({d->AsDoc(let->var, p->Attr("var"))}, + {d->AsDoc(let->value, p->Attr("value"))}); + return Dialect(d, "PrimLet") + ->Call({d->AsDoc(let->body, p->Attr("body"))}, // + {"where"}, + {where}); + }); + // Call PrimCall::PrimCall(DataType dtype, HLOExpr op, Array args, Span span) { for (size_t i = 0; i < args.size(); ++i) { @@ -671,6 +782,27 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // p->stream << ")"; }); +MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](ir::PrimCall call, ObjectPath call_p, IRDocsifier d) -> Doc { + ExprDoc prefix{nullptr}; + if (const auto* op = call->op.as()) { + // TODO: fix prim op name + StringRef name = op->name; + prefix = Dialect(d, name); + } else if (const auto* gv = call->op.as()) { + prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op")); + } else { + MXLOG(FATAL) << "call: " << call; + } + Array args; + int n_args = call->args.size(); + args.reserve(n_args + 1); + for (int i = 0; i < n_args; ++i) { + args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayIndex(i))); + } + return prefix->Call(args); + }); + MATXSCRIPT_REGISTER_GLOBAL("runtime.GetIntImm").set_body_typed([](IntImm i) { return i->value; }); MATXSCRIPT_REGISTER_GLOBAL("runtime.GetFloatImm").set_body_typed([](FloatImm f) { diff --git a/src/ir/printer/text_printer.cc b/src/ir/printer/text_printer.cc index 68d76394..0cd8ad01 100644 --- a/src/ir/printer/text_printer.cc +++ b/src/ir/printer/text_printer.cc @@ -99,11 +99,8 @@ PrinterConfig::PrinterConfig(Map config_dict) { if (auto v = config_dict.Get("name")) { n->binding_names.push_back(Downcast(v)); } - if (auto v = config_dict.Get("ir_prefix")) { - n->ir_prefix = Downcast(v); - } - if (auto v = config_dict.Get("tir_prefix")) { - n->tir_prefix = Downcast(v); + if (auto v = config_dict.Get("dialect_prefix")) { + n->dialect_prefix = Downcast(v); } if (auto v = config_dict.Get("indent_spaces")) { n->indent_spaces = Downcast(v)->value; diff --git a/test/cc/test_ir_new_printer.cc b/test/cc/test_ir_new_printer.cc index d8e64076..ba23b01f 100644 --- a/test/cc/test_ir_new_printer.cc +++ b/test/cc/test_ir_new_printer.cc @@ -21,6 +21,7 @@ #include +#include #include #include #include @@ -30,8 +31,20 @@ namespace matxscript { namespace ir { TEST(IRTextPrinter, PrintAllocaVar) { + PrimExpr a(3); + PrimExpr b(4); + + PrimAdd c(a, b); + PrimMul d(c, a); + + Bool cond(true); + PrimCall if_expr(d.dtype(), builtin::if_then_else(), {cond, d, c}); + + PrimCast cast_expr(runtime::DataType::Int(32), if_expr); + runtime::DataType int_ty = runtime::DataType::Int(64); - AllocaVarStmt alloca_stmt("b", PrimType(int_ty), IntImm(int_ty, 0)); + AllocaVarStmt alloca_stmt("b", PrimType(int_ty), cast_expr); + auto ir_text = printer::IRTextPrinter::Print(alloca_stmt, printer::PrinterConfig()); // b: "int" = 0 std::cout << ir_text << std::endl;