Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Refactor IR Printer: PrimExpr to Doc #197

Merged
merged 1 commit into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/matxscript/ir/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const StringRef& t
this->AddExitCallback([doc = d.get()]() { doc->dispatch_tokens.pop_back(); });
}

/*! \brief Creates the matx common prefix, which is by default `T` */
inline ExprDoc Dialect(const IRDocsifier& d, const StringRef& attr) {
d->ir_usage.insert("matx");
return IdDoc(d->cfg->dialect_prefix)->Attr(attr);
}

} // namespace printer
} // namespace ir
} // namespace matxscript
8 changes: 3 additions & 5 deletions include/matxscript/ir/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,8 @@ class PrinterConfigNode : public Object {
public:
/*! \brief A stack that tracks the names of the binding hierarchy */
Array<StringRef> binding_names = {};
/*! \brief The prefix of IR nodes */
StringRef ir_prefix = "I";
/*! \brief The prefix of TIR nodes */
StringRef tir_prefix = "T";
/*! \brief The prefix of module */
StringRef dialect_prefix = "matx";
/*! \brief Number of spaces used for indentation*/
int indent_spaces = 4;
/*! \brief Whether to print line numbers */
Expand All @@ -66,7 +64,7 @@ class PrinterConfigNode : public Object {

void VisitAttrs(AttrVisitor* v) {
v->Visit("binding_names", &binding_names);
v->Visit("ir_prefix", &ir_prefix);
v->Visit("dialect_prefix", &dialect_prefix);
v->Visit("indent_spaces", &indent_spaces);
v->Visit("print_line_numbers", &print_line_numbers);
v->Visit("num_context_lines", &num_context_lines);
Expand Down
136 changes: 134 additions & 2 deletions src/ir/prim_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<FloatImm>("", [](FloatImm s, ObjectPath p, IRDocsifier d) -> Doc {
return LiteralDoc::Float(s->value, p->Attr("value"));
});

// PrimCast
PrimCast::PrimCast(DataType t, PrimExpr value, Span span) {
MXCHECK(value.defined());
Expand Down Expand Up @@ -128,6 +133,13 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<PrimCast>("", [](PrimCast s, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc dtype = LiteralDoc::DataType(s->dtype, p->Attr("dtype"));
ExprDoc value = d->AsDoc<ExprDoc>(s->value, p->Attr("value"));
return Dialect(d, "PrimCast")->Call({dtype, value});
});

// HLOCastPrim
HLOCastPrim::HLOCastPrim(DataType t, BaseExpr value, Span span) {
MXCHECK(value.defined());
Expand All @@ -154,6 +166,13 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<HLOCastPrim>("", [](HLOCastPrim s, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc dtype = LiteralDoc::DataType(s->dtype, p->Attr("dtype"));
ExprDoc value = d->AsDoc<ExprDoc>(s->value, p->Attr("value"));
return Dialect(d, "HLOCastPrim")->Call({dtype, value});
});

#define MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(Name) \
Name::Name(PrimExpr a, PrimExpr b, Span span) { \
using T = Name::ContainerType; \
Expand Down Expand Up @@ -184,6 +203,28 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
data_ = std::move(node); \
}

#define MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR( \
NodeType, NodeObj, NodeFunc, OpString, OpKind) \
MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \
.set_dispatch<ir::NodeType>("", [](ir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \
ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \
ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \
PrimExpr ret = matxscript::ir::NodeFunc(node->a, node->b); \
if (!ret->IsInstance<ir::NodeObj>() && ret->IsInstance<ir::IntImmNode>() && \
ret->IsInstance<ir::FloatImmNode>()) { \
return Dialect(d, OpString)->Call({a, b}); \
} \
return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \
});

#define MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \
MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \
.set_dispatch<ir::NodeType>("", [](ir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \
ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \
ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \
return Dialect(d, OpString)->Call({a, b}); \
});

// PrimAdd
MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimAdd);

Expand All @@ -203,6 +244,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimAdd, PrimAddNode, add, "Add", kAdd);

// PrimSub
MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimSub);

Expand All @@ -222,6 +265,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimSub, PrimSubNode, sub, "Sub", kSub);

// PrimMul
MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimMul);

Expand All @@ -241,6 +286,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimMul, PrimMulNode, mul, "Mul", kMult);

// PrimDiv
MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimDiv);

Expand All @@ -260,6 +307,21 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<PrimDiv>("", [](PrimDiv node, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a"));
ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b"));
PrimExpr ret = matxscript::ir::div(node->a, node->b);
if (!ret->IsInstance<PrimDivNode>()) {
return Dialect(d, "PrimDiv")->Call({a, b});
}
if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) &&
(node->b->dtype.is_int() || node->b->dtype.is_uint())) {
return Dialect(d, "PrimDiv")->Call({a, b});
}
return OperationDoc(OperationDocNode::Kind::kDiv, {a, b});
});

// PrimMod
MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimMod);

Expand All @@ -279,6 +341,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY(PrimMod, "truncmod");

// PrimFloorDiv
MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimFloorDiv);

Expand All @@ -294,6 +358,9 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "floordiv(" << op->a << ", " << op->b << ")";
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(
PrimFloorDiv, PrimFloorDivNode, floordiv, "FloorDiv", kFloorDiv);

// PrimFloorMod
MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimFloorMod);

Expand All @@ -309,6 +376,9 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "floormod(" << op->a << ", " << op->b << ")";
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(
PrimFloorMod, PrimFloorModNode, floormod, "FloorMod", kMod);

// PrimMin
MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimMin);

Expand All @@ -328,6 +398,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY(PrimMin, "min");

// PrimMax
MATXSCRIPT_DEFINE_BINOP_CONSTRUCTOR(PrimMax);

Expand All @@ -347,6 +419,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY(PrimMax, "max");

// PrimEQ
MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimEQ);

Expand All @@ -366,6 +440,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimEQ, PrimEQNode, equal, "EQ", kEq);

// PrimNE
MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimNE);

Expand All @@ -385,6 +461,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimNE, PrimNENode, not_equal, "NE", kNotEq);

// PrimLT
MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimLT);

Expand All @@ -404,6 +482,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimLT, PrimLTNode, less_than, "LT", kLt);

// PrimLE
MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimLE);

Expand All @@ -423,6 +503,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimLE, PrimLENode, less_or_equal, "LE", kLtE);

// PrimGT
MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimGT);

Expand All @@ -442,6 +524,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimGT, PrimGTNode, greater_than, "GT", kGt);

// PrimGE
MATXSCRIPT_DEFINE_CMPOP_CONSTRUCTOR(PrimGE);

Expand All @@ -461,6 +545,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimGE, PrimGENode, greater_or_equal, "GE", kGtE);

// PrimAnd
PrimAnd::PrimAnd(PrimExpr a, PrimExpr b, Span span) {
MXCHECK(a.defined()) << "ValueError: a is undefined";
Expand Down Expand Up @@ -493,6 +579,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimAnd, PrimAndNode, logic_and, "And", kAnd);

// PrimOr
PrimOr::PrimOr(PrimExpr a, PrimExpr b, Span span) {
MXCHECK(a.defined()) << "ValueError: a is undefined";
Expand Down Expand Up @@ -525,6 +613,8 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ')';
});

MATXSCRIPT_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(PrimOr, PrimOrNode, logic_or, "Or", kOr);

// PrimNot
PrimNot::PrimNot(PrimExpr a, Span span) {
MXCHECK(a.defined()) << "ValueError: a is undefined";
Expand Down Expand Up @@ -590,6 +680,17 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::PrimSelect>(
"", [](ir::PrimSelect select, ObjectPath p, IRDocsifier d) -> Doc {
return Dialect(d, "PrimSelect")
->Call({
d->AsDoc<ExprDoc>(select->condition, p->Attr("condition")),
d->AsDoc<ExprDoc>(select->true_value, p->Attr("true_value")),
d->AsDoc<ExprDoc>(select->false_value, p->Attr("false_value")),
});
});

// Let
PrimLet::PrimLet(PrimVar var, PrimExpr value, PrimExpr body, Span span) {
MXCHECK(value.defined());
Expand All @@ -607,8 +708,8 @@ PrimLet::PrimLet(PrimVar var, PrimExpr value, PrimExpr body, Span span) {
data_ = std::move(node);
}

MATXSCRIPT_REGISTER_GLOBAL("ir.Let").set_body_typed(
[](PrimVar var, PrimExpr value, PrimExpr body, Span span) {
MATXSCRIPT_REGISTER_GLOBAL("ir.PrimLet")
.set_body_typed([](PrimVar var, PrimExpr value, PrimExpr body, Span span) {
return PrimLet(var, value, body, span);
});

Expand All @@ -624,6 +725,16 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::PrimLet>("", [](ir::PrimLet let, ObjectPath p, IRDocsifier d) -> Doc {
DictDoc where({d->AsDoc<ExprDoc>(let->var, p->Attr("var"))},
{d->AsDoc<ExprDoc>(let->value, p->Attr("value"))});
return Dialect(d, "PrimLet")
->Call({d->AsDoc<ExprDoc>(let->body, p->Attr("body"))}, //
{"where"},
{where});
});

// Call
PrimCall::PrimCall(DataType dtype, HLOExpr op, Array<PrimExpr> args, Span span) {
for (size_t i = 0; i < args.size(); ++i) {
Expand Down Expand Up @@ -671,6 +782,27 @@ MATXSCRIPT_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
// p->stream << ")";
});

MATXSCRIPT_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ir::PrimCall>("", [](ir::PrimCall call, ObjectPath call_p, IRDocsifier d) -> Doc {
ExprDoc prefix{nullptr};
if (const auto* op = call->op.as<OpNode>()) {
// TODO: fix prim op name
StringRef name = op->name;
prefix = Dialect(d, name);
} else if (const auto* gv = call->op.as<GlobalVarNode>()) {
prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op"));
} else {
MXLOG(FATAL) << "call: " << call;
}
Array<ExprDoc> args;
int n_args = call->args.size();
args.reserve(n_args + 1);
for (int i = 0; i < n_args; ++i) {
args.push_back(d->AsDoc<ExprDoc>(call->args[i], call_p->Attr("args")->ArrayIndex(i)));
}
return prefix->Call(args);
});

MATXSCRIPT_REGISTER_GLOBAL("runtime.GetIntImm").set_body_typed([](IntImm i) { return i->value; });

MATXSCRIPT_REGISTER_GLOBAL("runtime.GetFloatImm").set_body_typed([](FloatImm f) {
Expand Down
7 changes: 2 additions & 5 deletions src/ir/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,8 @@ PrinterConfig::PrinterConfig(Map<StringRef, ObjectRef> config_dict) {
if (auto v = config_dict.Get("name")) {
n->binding_names.push_back(Downcast<StringRef>(v));
}
if (auto v = config_dict.Get("ir_prefix")) {
n->ir_prefix = Downcast<StringRef>(v);
}
if (auto v = config_dict.Get("tir_prefix")) {
n->tir_prefix = Downcast<StringRef>(v);
if (auto v = config_dict.Get("dialect_prefix")) {
n->dialect_prefix = Downcast<StringRef>(v);
}
if (auto v = config_dict.Get("indent_spaces")) {
n->indent_spaces = Downcast<IntImm>(v)->value;
Expand Down
15 changes: 14 additions & 1 deletion test/cc/test_ir_new_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <gtest/gtest.h>

#include <matxscript/ir/prim_builtin.h>
#include <matxscript/ir/prim_expr.h>
#include <matxscript/ir/printer/text_printer.h>
#include <matxscript/ir/stmt.h>
Expand All @@ -30,8 +31,20 @@ namespace matxscript {
namespace ir {

TEST(IRTextPrinter, PrintAllocaVar) {
PrimExpr a(3);
PrimExpr b(4);

PrimAdd c(a, b);
PrimMul d(c, a);

Bool cond(true);
PrimCall if_expr(d.dtype(), builtin::if_then_else(), {cond, d, c});

PrimCast cast_expr(runtime::DataType::Int(32), if_expr);

runtime::DataType int_ty = runtime::DataType::Int(64);
AllocaVarStmt alloca_stmt("b", PrimType(int_ty), IntImm(int_ty, 0));
AllocaVarStmt alloca_stmt("b", PrimType(int_ty), cast_expr);

auto ir_text = printer::IRTextPrinter::Print(alloca_stmt, printer::PrinterConfig());
// b: "int" = 0
std::cout << ir_text << std::endl;
Expand Down