From fe5f6163e60406f63144b2a13ca902eb8ee528d5 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Fri, 5 Jan 2024 10:05:24 +0800 Subject: [PATCH] [Unity][MSC][Legalize] legalize codes and mute logging (#16325) --- .../contrib/msc/core/gym/agent/base_agent.py | 8 +- .../msc/core/gym/control/controller.py | 4 +- .../contrib/msc/core/gym/control/service.py | 8 +- .../msc/core/gym/environment/base_env.py | 8 +- python/tvm/contrib/msc/core/utils/log.py | 4 + python/tvm/contrib/msc/pipeline/manager.py | 32 +- src/contrib/msc/core/codegen/code_stack.cc | 198 +++++-- src/contrib/msc/core/codegen/code_stack.h | 487 +++++++++++------- src/contrib/msc/core/codegen/py_codegen.h | 50 +- src/contrib/msc/core/printer/cpp_printer.cc | 139 ++++- src/contrib/msc/core/printer/cpp_printer.h | 47 +- .../msc/core/printer/msc_base_printer.cc | 8 + .../msc/core/printer/msc_base_printer.h | 14 + src/contrib/msc/core/printer/msc_doc.cc | 34 ++ src/contrib/msc/core/printer/msc_doc.h | 186 ++++++- src/contrib/msc/core/printer/print_utils.cc | 43 +- src/contrib/msc/core/printer/print_utils.h | 134 +++-- .../msc/core/printer/python_printer.cc | 68 ++- src/contrib/msc/core/printer/python_printer.h | 11 +- .../msc/core/transform/layout_utils.cc | 26 + src/contrib/msc/core/transform/layout_utils.h | 21 + .../msc/core/transform/set_expr_layout.cc | 78 +-- src/contrib/msc/core/utils.cc | 15 + src/contrib/msc/core/utils.h | 38 ++ .../msc/framework/tensorflow/codegen.cc | 29 +- .../msc/framework/tensorflow/tf_v1_opcode.cc | 49 +- src/contrib/msc/framework/tensorrt/codegen.cc | 127 ++--- .../msc/framework/tensorrt/tensorrt_opcode.cc | 35 +- src/contrib/msc/framework/torch/codegen.cc | 12 +- .../msc/framework/torch/torch_opcode.cc | 35 +- src/contrib/msc/framework/tvm/codegen.cc | 54 +- src/contrib/msc/framework/tvm/relax_opcode.cc | 34 +- tests/python/contrib/test_msc/test_manager.py | 48 +- tests/python/contrib/test_msc/test_runner.py | 33 +- tests/python/contrib/test_msc/test_tools.py | 108 ++-- .../python/contrib/test_msc/test_transform.py | 156 ++++++ .../test_transform_set_expr_layout.py | 73 --- .../test_msc/test_transform_set_expr_name.py | 108 ---- 38 files changed, 1706 insertions(+), 856 deletions(-) create mode 100644 tests/python/contrib/test_msc/test_transform.py delete mode 100644 tests/python/contrib/test_msc/test_transform_set_expr_layout.py delete mode 100644 tests/python/contrib/test_msc/test_transform_set_expr_name.py diff --git a/python/tvm/contrib/msc/core/gym/agent/base_agent.py b/python/tvm/contrib/msc/core/gym/agent/base_agent.py index 07c0b53597a7..801f3f82b430 100644 --- a/python/tvm/contrib/msc/core/gym/agent/base_agent.py +++ b/python/tvm/contrib/msc/core/gym/agent/base_agent.py @@ -37,8 +37,8 @@ class BaseAgent(object): The extra options for the agent. debug_level: int The debug level. - verbose_task: int - The verbose interval task. + verbose: str + The verbose level. logger: logging.Logger The logger """ @@ -50,6 +50,7 @@ def __init__( executors: dict, options: dict = None, debug_level: int = 0, + verbose: str = None, logger: logging.Logger = None, ): self._name = name @@ -60,7 +61,8 @@ def __init__( if logger: self._logger = logger else: - verbose = "debug" if debug_level > 0 else "info" + if not verbose: + verbose = "debug" if debug_level > 0 else "info" self._logger = msc_utils.create_file_logger(verbose, workspace.relpath("AGENT_LOG")) self._logger.info( msc_utils.msg_block("AGENT.SETUP({})".format(self.agent_type()), self.setup()) diff --git a/python/tvm/contrib/msc/core/gym/control/controller.py b/python/tvm/contrib/msc/core/gym/control/controller.py index 5716169e7678..17ca5edb1c0a 100644 --- a/python/tvm/contrib/msc/core/gym/control/controller.py +++ b/python/tvm/contrib/msc/core/gym/control/controller.py @@ -31,8 +31,8 @@ class BaseController(object): The worksapce. config: dict The config for service. - is_master: bool - Whether the node is master node + is_main: bool + Whether the node is main node """ def __init__( diff --git a/python/tvm/contrib/msc/core/gym/control/service.py b/python/tvm/contrib/msc/core/gym/control/service.py index 6bb97def2545..f8fbdd31ddf6 100644 --- a/python/tvm/contrib/msc/core/gym/control/service.py +++ b/python/tvm/contrib/msc/core/gym/control/service.py @@ -151,6 +151,8 @@ class BaseService(object): The record step. debug_level: int The debug level + verbose: str + The verbose level. """ def __init__( @@ -164,15 +166,19 @@ def __init__( max_iter: int = 1, record_step: int = 5, debug_level: int = 0, + verbose: str = None, ): self._workspace = workspace tasks = tasks or [GYMObject.ENV + ":0", GYMObject.AGENT + ":0"] - verbose = "debug" if debug_level > 0 else "info" + if not verbose: + verbose = "debug" if debug_level > 0 else "info" self._logger = msc_utils.create_file_logger(verbose, self._workspace.relpath("SERVICE_LOG")) def _create_workers(config: dict, obj_type: str) -> List[BaseWorker]: if "debug_level" not in config: config["debug_level"] = debug_level + if "verbose" not in config: + config["verbose"] = verbose if "logger" not in config: config["logger"] = self._logger return [ diff --git a/python/tvm/contrib/msc/core/gym/environment/base_env.py b/python/tvm/contrib/msc/core/gym/environment/base_env.py index edaea5120556..86f1bff7be89 100644 --- a/python/tvm/contrib/msc/core/gym/environment/base_env.py +++ b/python/tvm/contrib/msc/core/gym/environment/base_env.py @@ -43,8 +43,8 @@ class BaseEnv(object): The extra options for the environment. debug_level: int The debug level. - verbose_step: int - The verbose interval step. + verbose: str + The verbose level. logger: logging.Logger The logger """ @@ -60,6 +60,7 @@ def __init__( options: dict = None, max_tasks: int = -1, debug_level: int = 0, + verbose: str = None, logger: logging.Logger = None, ): self._name = name @@ -74,7 +75,8 @@ def __init__( if logger: self._logger = logger else: - verbose = "debug" if debug_level > 0 else "info" + if not verbose: + verbose = "debug" if debug_level > 0 else "info" self._logger = msc_utils.create_file_logger(verbose, workspace.relpath("ENV_LOG")) self._logger.info( msc_utils.msg_block("ENV.SETUP({})".format(self.env_type()), self.setup()) diff --git a/python/tvm/contrib/msc/core/utils/log.py b/python/tvm/contrib/msc/core/utils/log.py index a406db806a4e..f208ecde95db 100644 --- a/python/tvm/contrib/msc/core/utils/log.py +++ b/python/tvm/contrib/msc/core/utils/log.py @@ -75,6 +75,10 @@ def create_file_logger(level: Union[str, int] = logging.INFO, path: str = None) level = logging.INFO elif level == "warn": level = logging.WARN + elif level == "error": + level = logging.ERROR + elif level == "critical": + level = logging.CRITICAL else: raise Exception("Unexcept verbose {}, should be debug| info| warn") diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py index 9a45ee84b394..337503de5ba5 100644 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ b/python/tvm/contrib/msc/pipeline/manager.py @@ -53,10 +53,10 @@ def __init__(self, model, config): self._workspace = msc_utils.set_workspace(config.get("workspace")) log_path = config.get("log_path") or self._workspace.relpath("MSC_LOG", keep_history=False) if config.get("debug_level", 0) > 0 and "verbose" not in config: - verbose = "debug" + self._verbose = "debug" else: - verbose = config.get("verbose", "info") - self._logger = msc_utils.set_global_logger(verbose, log_path) + self._verbose = config.get("verbose", "info") + self._logger = msc_utils.set_global_logger(self._verbose, log_path) msc_utils.time_stamp(MSCStage.SETUP) self._logger.info(msc_utils.msg_block("SETUP", self.setup(config))) @@ -78,6 +78,9 @@ def setup(self, config: dict) -> dict: self._tools_config = {} self._relax_mod, self._runner = None, None self._data_loader, self._sample_inputs = None, None + self._model_type = self._config["model_type"] + self._optimize_type = self._config.get("optimize", {}).get("run_type", self._model_type) + self._compile_type = self._config.get("compile", {}).get("run_type", self._model_type) self._report = { "success": False, "info": { @@ -172,9 +175,9 @@ def run_pipe(self) -> dict: self._runner = self.compile(self._config["compile"], use_cache) except Exception as exc: # pylint: disable=broad-exception-caught err_msg = "Pipeline failed:{}\nTrace: {}".format(exc, traceback.format_exc()) - report = self.summary(err_msg) - self._logger.info(msc_utils.msg_block("SUMMARY", report, 0)) - return report + self.summary(err_msg) + self._logger.info(msc_utils.msg_block("SUMMARY", self._report, 0)) + return self._report def prepare(self, stage_config: dict, use_cache: bool = False) -> Dict[str, np.ndarray]: """Prepare datas for the pipeline. @@ -573,6 +576,7 @@ def _apply_tool(self, tool_type: str, stage_config: dict, add_tool: bool = True) "knowledge": knowledge, }, "debug_level": runner.debug_level, + "verbose": self._verbose, } controller = create_controller(runner.stage, config, extra_config) knowledge = controller.run() @@ -868,6 +872,22 @@ def _get_runner_cls(self, run_type: str) -> BaseRunner: def runner(self): return self._runner + @property + def report(self): + return self._report + + @property + def model_type(self): + return self._model_type + + @property + def optimize_type(self): + return self._optimize_type + + @property + def compile_type(self): + return self._compile_type + class MSCManager(BaseManager): """Normal manager in MSC""" diff --git a/src/contrib/msc/core/codegen/code_stack.cc b/src/contrib/msc/core/codegen/code_stack.cc index 6d8ced83dee6..4b1c1850e76e 100644 --- a/src/contrib/msc/core/codegen/code_stack.cc +++ b/src/contrib/msc/core/codegen/code_stack.cc @@ -36,19 +36,20 @@ void BaseStack::Line(const Doc& doc) { PushDoc(doc); } void BaseStack::Line(const String& line) { Line(IdDoc(line)); } -void BaseStack::Comment(const String& comment) { PushDoc(CommentDoc(comment)); } - -void BaseStack::AssignBase(const String& lhs, const ExprDoc& rhs, const String& annotation) { - if (annotation.size() == 0) { - PushDoc(AssignDoc(IdDoc(lhs), rhs, NullOpt)); +void BaseStack::Comment(const String& comment, bool attach) { + if (attach) { + const auto& doc = TopDoc(); + ICHECK(doc->IsInstance()) << "Only stmt doc support attach comments"; + const auto& stmt = Downcast(doc); + stmt->comment = comment; } else { - PushDoc(AssignDoc(IdDoc(lhs), rhs, IdDoc(annotation))); + PushDoc(CommentDoc(comment)); } } void BaseStack::Declare(const String& type, const String& variable, size_t len, bool use_constructor) { - PushDoc(DocUtils::ToDeclareDoc(type, variable, len, use_constructor)); + PushDoc(DocUtils::ToDeclare(type, variable, len, use_constructor)); } void BaseStack::DeclareArgBase(const ExprDoc& value) { @@ -70,20 +71,8 @@ void BaseStack::FuncDef(const String& func_name, const String& ret_type) { void BaseStack::FuncArg(const String& arg, const String& annotation, const String& value) { const auto& func = PopCheckedDoc(); - Optional value_doc; - if (value.size() > 0) { - value_doc = IdDoc(value); - } else { - value_doc = NullOpt; - } - Optional annotation_doc; - if (annotation.size() > 0) { - annotation_doc = IdDoc(annotation); - } else { - annotation_doc = NullOpt; - } Array args = func->args; - args.push_back(AssignDoc(IdDoc(arg), value_doc, annotation_doc)); + args.push_back(DocUtils::ToAssign(arg, value, annotation)); PushDoc(FunctionDoc(func->name, args, func->decorators, func->return_type, func->body)); } @@ -99,10 +88,7 @@ void BaseStack::FuncStart() { BlockStart(); } -void BaseStack::FuncEnd(const String& ret_val) { - if (ret_val.size() > 0) { - PushDoc(ReturnDoc(IdDoc(ret_val))); - } +void BaseStack::FuncEnd() { const auto& block = PopBlock(); const auto& func = PopCheckedDoc(); const auto& body = DocUtils::ToStmts(block); @@ -132,7 +118,80 @@ void BaseStack::ClassEnd() { PushDoc(ClassDoc(class_doc->name, class_doc->decorators, body)); } -void BaseStack::FuncCall(const String& callee, Optional assign_to, +void BaseStack::StructStart(const String& struct_name) { + PushDoc(StructDoc(IdDoc(struct_name), Array(), Array())); + BlockStart(); +} + +void BaseStack::StructEnd() { + const auto& block = PopBlock(); + const auto& struct_doc = PopCheckedDoc(); + const auto& body = DocUtils::ToStmts(block); + PushDoc(StructDoc(struct_doc->name, struct_doc->decorators, body)); +} + +void BaseStack::ConstructorDef(const String& constructor_name) { + PushDoc(ConstructorDoc(IdDoc(constructor_name), Array(), Array())); +} + +void BaseStack::ConstructorArg(const String& arg, const String& annotation, const String& value) { + const auto& func = PopCheckedDoc(); + Array args = func->args; + args.push_back(DocUtils::ToAssign(arg, value, annotation)); + PushDoc(ConstructorDoc(func->name, args, func->body)); +} + +void BaseStack::ConstructorStart() { + ICHECK(TopDoc()->IsInstance()) << "ConstructorDoc is not saved"; + BlockStart(); +} + +void BaseStack::ConstructorEnd() { + const auto& block = PopBlock(); + const auto& func = PopCheckedDoc(); + const auto& body = DocUtils::ToStmts(block); + PushDoc(ConstructorDoc(func->name, func->args, body)); +} + +void BaseStack::LambdaDef(const String& lambda_name) { + PushDoc(LambdaDoc(IdDoc(lambda_name), Array(), Array(), Array())); +} + +void BaseStack::LambdaArg(const String& arg, const String& annotation, const String& value) { + const auto& lambda = PopCheckedDoc(); + Array args = lambda->args; + args.push_back(DocUtils::ToAssign(arg, value, annotation)); + PushDoc(LambdaDoc(lambda->name, args, lambda->refs, lambda->body)); +} + +void BaseStack::LambdaRef(const String& ref) { + const auto& lambda = PopCheckedDoc(); + Array refs = lambda->refs; + refs.push_back(IdDoc(ref)); + PushDoc(LambdaDoc(lambda->name, lambda->args, refs, lambda->body)); +} + +void BaseStack::LambdaStart() { + ICHECK(TopDoc()->IsInstance()) << "LambdaDoc is not saved"; + BlockStart(); +} + +void BaseStack::LambdaEnd(const String& ret_val) { + if (ret_val.size() > 0) { + PushDoc(ReturnDoc(IdDoc(ret_val))); + } + const auto& block = PopBlock(); + const auto& lambda = PopCheckedDoc(); + const auto& body = DocUtils::ToStmts(block); + PushDoc(LambdaDoc(lambda->name, lambda->args, lambda->refs, body)); +} + +void BaseStack::LambdaEnd(const ExprDoc& ret_val) { + PushDoc(ReturnDoc(ret_val)); + LambdaEnd(""); +} + +void BaseStack::FuncCall(const String& callee, Optional assign_to, Optional caller) { if (!caller.defined()) { PushDoc(CallDoc(IdDoc(callee), Array(), Array(), Array())); @@ -142,17 +201,22 @@ void BaseStack::FuncCall(const String& callee, Optional assign_to, } if (assign_to.defined()) { const auto& last_call = PopCheckedDoc(); - const auto& declare = Downcast(assign_to.value()); - PushDoc(AssignDoc(declare->variable, last_call, declare->type)); + if (assign_to.value()->IsInstance()) { + const auto& declare = Downcast(assign_to.value()); + PushDoc(AssignDoc(declare->variable, last_call, declare->type)); + } else { + const auto& declare = DocUtils::ToDeclare("", assign_to.value()); + PushDoc(AssignDoc(declare->variable, last_call, declare->type)); + } } } void BaseStack::FuncCall(const String& callee, const String& assign_to, const String& caller) { - Optional assign_doc; + Optional assign_doc; if (assign_to.size() == 0) { assign_doc = NullOpt; } else { - assign_doc = DocUtils::ToDeclareDoc("", assign_to); + assign_doc = IdDoc(assign_to); } Optional caller_doc; if (caller.size() == 0) { @@ -163,10 +227,11 @@ void BaseStack::FuncCall(const String& callee, const String& assign_to, const St FuncCall(callee, assign_doc, caller_doc); } -void BaseStack::MethodCall(const String& callee) { +void BaseStack::MethodCall(const String& callee, bool new_line) { const auto& host = PopDoc(); if (host->IsInstance()) { - FuncCall(callee, NullOpt, Downcast(host)); + const auto& v_callee = callee + (new_line ? DocSymbol::NextLine() : ""); + FuncCall(v_callee, NullOpt, Downcast(host)); } else if (const auto* a_node = host.as()) { ICHECK(a_node->rhs.defined()) << "Can not find rhs for inplace host"; FuncCall(callee, DeclareDoc(a_node->annotation, a_node->lhs, Array(), true), @@ -176,6 +241,31 @@ void BaseStack::MethodCall(const String& callee) { } } +void BaseStack::InplaceStart(const String& callee, Optional assign_to, + Optional caller) { + FuncCall(callee, assign_to, caller); +} + +void BaseStack::InplaceStart(const String& callee, const String& assign_to, const String& caller) { + FuncCall(callee, assign_to, caller); +} + +void BaseStack::InplaceEnd() { + const auto& last = PopDoc(); + // get args and kwargs + if (last->IsInstance()) { + CallArgBase(Downcast(last)); + } else if (const auto* assign = last.as()) { + const auto& call = Downcast(assign->rhs); + ICHECK(assign->lhs->IsInstance()) + << "assign lhs should be IdDoc, get " << assign->lhs->GetTypeKey(); + const auto& key = Downcast(assign->lhs)->name; + CallArgBase(call, key); + } else { + LOG(FATAL) << "Unexpected last type for call arg " << last->GetTypeKey(); + } +} + void BaseStack::PopNest(const String& key) { const auto& last = PopDoc(); if (last->IsInstance()) { @@ -247,17 +337,6 @@ void BaseStack::ConditionEnd() { } } -void BaseStack::ForStart(const String& lhs, const String& rhs) { - PushDoc(ForDoc(IdDoc(lhs), IdDoc(rhs), Array())); - BlockStart(); -} - -void BaseStack::ForStart(const String& lhs, size_t start, size_t end) { - Array range{DocUtils::ToDoc(start), DocUtils::ToDoc(end)}; - PushDoc(ForDoc(IdDoc(lhs), TupleDoc(range), Array())); - BlockStart(); -} - void BaseStack::ForEnd() { const auto& block = PopBlock(); const auto& for_doc = PopCheckedDoc(); @@ -277,6 +356,41 @@ void BaseStack::WhileEnd() { PushDoc(WhileDoc(while_doc->predicate, body)); } +void BaseStack::SwitchStart(const String& predicate) { + Array predicates; + predicates.push_back(IdDoc(predicate)); + PushDoc(SwitchDoc(predicates, Array>(), Array())); + BlockStart(); +} + +void BaseStack::SwitchCase(const String& predicate) { + const auto& block = PopBlock(); + const auto& switch_doc = PopCheckedDoc(); + auto branchs = switch_doc->branchs; + branchs.push_back(DocUtils::ToStmts(block)); + if (predicate.size() == 0) { + Array default_branch{ExprStmtDoc(IdDoc("pass"))}; + PushDoc(SwitchDoc(switch_doc->predicates, branchs, default_branch)); + } else { + auto predicates = switch_doc->predicates; + predicates.push_back(IdDoc(predicate)); + PushDoc(SwitchDoc(predicates, branchs, switch_doc->default_branch)); + } + BlockStart(); +} + +void BaseStack::SwitchEnd() { + const auto& block = PopBlock(); + const auto& switch_doc = PopCheckedDoc(); + if (switch_doc->default_branch.size() > 0) { + PushDoc(SwitchDoc(switch_doc->predicates, switch_doc->branchs, DocUtils::ToStmts(block))); + } else { + auto branchs = switch_doc->branchs; + branchs.push_back(DocUtils::ToStmts(block)); + PushDoc(SwitchDoc(switch_doc->predicates, branchs, switch_doc->default_branch)); + } +} + void BaseStack::BlockStart() { Array block; blocks_.push(block); diff --git a/src/contrib/msc/core/codegen/code_stack.h b/src/contrib/msc/core/codegen/code_stack.h index 1ddc21aff5f0..e348bcdab1bf 100644 --- a/src/contrib/msc/core/codegen/code_stack.h +++ b/src/contrib/msc/core/codegen/code_stack.h @@ -67,32 +67,25 @@ class BaseStack { void Line(const String& line = ""); /*! \brief Push Comment Doc*/ - void Comment(const String& comment = ""); + void Comment(const String& comment, bool attach = false); - /*! \brief Push typed assign Doc*/ - void AssignBase(const String& lhs, const ExprDoc& rhs, const String& annotation = ""); - - template - inline void Assign(const String& lhs, const T& rhs, const String& annotation = "") { - const auto& doc_rhs = DocUtils::ToDoc(rhs); - if (doc_rhs.defined()) { - AssignBase(lhs, doc_rhs, annotation); - } + /*! \brief Push Assign Doc*/ + template + inline void Assign(const LT& lhs, const RT& rhs, const String& annotation = "") { + PushDoc(DocUtils::ToAssign(lhs, rhs, annotation)); } - /*! \brief Push declare for variable Doc*/ + /*! \brief Push declare Doc*/ void Declare(const String& type, const String& variable, size_t len = 0, bool use_constructor = true); - /*! \brief Cache declare typed argument*/ + /*! \brief Cache declare argument*/ void DeclareArgBase(const ExprDoc& value); + /*! \brief Cache declare typed argument*/ template inline void DeclareArg(const T& value) { - const auto& doc_value = DocUtils::ToDoc(value); - if (doc_value.defined()) { - DeclareArgBase(doc_value); - } + DeclareArgBase(DocUtils::ToDoc(value)); } /*! \brief Cache class Doc*/ @@ -107,30 +100,76 @@ class BaseStack { /*! \brief End class body block*/ void ClassEnd(); + /*! \brief Start struct body block*/ + void StructStart(const String& struct_name); + + /*! \brief End struct body block*/ + void StructEnd(); + /*! \brief Cache function Doc*/ void FuncDef(const String& func_name, const String& ret_type = ""); - /*! \brief Cache func argument*/ + /*! \brief Cache function argument*/ void FuncArg(const String& arg, const String& annotation = "", const String& value = ""); - /*! \brief Cache func decorator*/ + /*! \brief Cache function decorator*/ void FuncDecorator(const String& decorator); /*! \brief Start function body block*/ void FuncStart(); /*! \brief End function body block*/ - void FuncEnd(const String& ret_val = ""); + void FuncEnd(); - /*! \brief Push call and maybe assign Doc*/ - void FuncCall(const String& callee, Optional assign_to, - Optional caller = NullOpt); + template + void FuncEnd(const T& ret_val) { + PushDoc(ReturnDoc(DocUtils::ToDoc(ret_val))); + FuncEnd(); + } + + /*! \brief Cache constructor Doc*/ + void ConstructorDef(const String& constructor_name); + + /*! \brief Cache constructor argument*/ + void ConstructorArg(const String& arg, const String& annotation = "", const String& value = ""); + + /*! \brief Start constructor body block*/ + void ConstructorStart(); + + /*! \brief End constructor body block*/ + void ConstructorEnd(); + + /*! \brief Cache lambda Doc*/ + void LambdaDef(const String& lambda_name); + + /*! \brief Cache lambda argument*/ + void LambdaArg(const String& arg, const String& annotation = "", const String& value = ""); + + /*! \brief Cache lambda reference*/ + void LambdaRef(const String& ref); + + /*! \brief Start lambda body block*/ + void LambdaStart(); + + /*! \brief End lambda body block*/ + void LambdaEnd(const String& ret_val = ""); + void LambdaEnd(const ExprDoc& ret_val); /*! \brief Push call and maybe assign Doc*/ + void FuncCall(const String& callee, Optional assign_to, + Optional caller = NullOpt); void FuncCall(const String& callee, const String& assign_to = "", const String& caller = ""); /*! \brief Push method call Doc*/ - void MethodCall(const String& callee); + void MethodCall(const String& callee, bool new_line = false); + + /*! \brief Push inplace call and maybe assign Doc*/ + void InplaceStart(const String& callee, Optional assign_to, + Optional caller = NullOpt); + void InplaceStart(const String& callee, const String& assign_to = "", const String& caller = ""); + + /*! \brief End inplace call*/ + void InplaceEnd(); /*! \brief Push nested expr to last Doc*/ void PopNest(const String& key = ""); @@ -164,10 +203,19 @@ class BaseStack { void ConditionEnd(); /*! \brief Push for to cache and start for block*/ - void ForStart(const String& lhs, const String& rhs); + template + void ForStart(const LT& lhs, const RT& rhs) { + PushDoc(ForDoc(DocUtils::ToDoc(lhs), DocUtils::ToDoc(rhs), Array())); + BlockStart(); + } /*! \brief Push for range to cache and start for block*/ - void ForStart(const String& lhs, size_t start, size_t end); + template + void ForStart(const String& lhs, const ST& start, const ET& end) { + Array range{DocUtils::ToDoc(start), DocUtils::ToDoc(end)}; + PushDoc(ForDoc(IdDoc(lhs), TupleDoc(range), Array())); + BlockStart(); + } /*! \brief End a for block*/ void ForEnd(); @@ -178,6 +226,15 @@ class BaseStack { /*! \brief End a while block*/ void WhileEnd(); + /*! \brief Push switch to cache and start switch block*/ + void SwitchStart(const String& predicate); + + /*! \brief Add new case to switch*/ + void SwitchCase(const String& predicate = ""); + + /*! \brief Push switch to cached*/ + void SwitchEnd(); + /*! \brief Start a new block*/ void BlockStart(); @@ -185,7 +242,7 @@ class BaseStack { void BlockEnd(bool block_docs = true); /*! \brief Start a new scope*/ - void ScopeStart(const String& scope_def, const String& scope_ref = ""); + void ScopeStart(const String& scope_def = "", const String& scope_ref = ""); /*! \brief End a scope*/ void ScopeEnd(); @@ -220,148 +277,234 @@ class BaseStack { std::stack> blocks_; }; -#define COMMON_WRAPPERS(Stack) \ - Stack& line(const Doc& doc) { \ - Line(doc); \ - return *this; \ - } \ - Stack& line(const String& line = "") { \ - Line(line); \ - return *this; \ - } \ - Stack& comment(const String& comment) { \ - Comment(comment); \ - return *this; \ - } \ - template \ - Stack& assign(const String& lhs, const T& rhs, const String& annotation = "") { \ - Assign(lhs, rhs, annotation); \ - return *this; \ - } \ - Stack& declare(const String& type, const String& variable, size_t len = 0, \ - bool use_constructor = true) { \ - Declare(type, variable, len, use_constructor); \ - return *this; \ - } \ - template \ - Stack& declare_arg(const T& value) { \ - DeclareArg(value); \ - return *this; \ - } \ - Stack& class_def(const String& class_name) { \ - ClassDef(class_name); \ - return *this; \ - } \ - Stack& class_decorator(const String& decorator) { \ - ClassDecorator(decorator); \ - return *this; \ - } \ - Stack& class_start() { \ - ClassStart(); \ - return *this; \ - } \ - Stack& class_end() { \ - ClassEnd(); \ - return *this; \ - } \ - Stack& func_def(const String& func_name, const String& ret_type = "") { \ - FuncDef(func_name, ret_type); \ - return *this; \ - } \ - Stack& func_arg(const String& arg, const String& annotation = "", const String& value = "") { \ - FuncArg(arg, annotation, value); \ - return *this; \ - } \ - Stack& func_decorator(const String& decorator) { \ - FuncDecorator(decorator); \ - return *this; \ - } \ - Stack& func_start() { \ - FuncStart(); \ - return *this; \ - } \ - Stack& func_end(const String& ret_val = "") { \ - FuncEnd(ret_val); \ - return *this; \ - } \ - Stack& func_call(const String& callee, Optional assign_to, \ - Optional caller = NullOpt) { \ - FuncCall(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& func_call(const String& callee, const String& assign_to = "", \ - const String& caller = "") { \ - FuncCall(callee, assign_to, caller); \ - return *this; \ - } \ - Stack& method_call(const String& callee) { \ - MethodCall(callee); \ - return *this; \ - } \ - Stack& pop_nest(const String& key = "") { \ - PopNest(key); \ - return *this; \ - } \ - template \ - Stack& call_arg(T value, const String& key = "") { \ - CallArg(value, key); \ - return *this; \ - } \ - Stack& call_arg(const ExprDoc& value, const String& key = "") { \ - CallArg(value, key); \ - return *this; \ - } \ - Stack& call_arg(const Array& values) { \ - CallArg(values); \ - return *this; \ - } \ - Stack& cond_if(const String& predicate) { \ - ConditionIf(predicate); \ - return *this; \ - } \ - Stack& cond_else() { \ - ConditionElse(); \ - return *this; \ - } \ - Stack& cond_end() { \ - ConditionEnd(); \ - return *this; \ - } \ - Stack& for_start(const String& lhs, const String& rhs) { \ - ForStart(lhs, rhs); \ - return *this; \ - } \ - Stack& for_start(const String& lhs, size_t start, size_t end) { \ - ForStart(lhs, start, end); \ - return *this; \ - } \ - Stack& for_end() { \ - ForEnd(); \ - return *this; \ - } \ - Stack& while_start(const String& predicate) { \ - WhileStart(predicate); \ - return *this; \ - } \ - Stack& while_end() { \ - WhileEnd(); \ - return *this; \ - } \ - Stack& block_start() { \ - BlockStart(); \ - return *this; \ - } \ - Stack& block_end(bool block_docs = true) { \ - BlockEnd(block_docs); \ - return *this; \ - } \ - Stack& scope_start(const String& scope_def, const String& scope_ref = "") { \ - ScopeStart(scope_def, scope_ref); \ - return *this; \ - } \ - Stack& scope_end() { \ - ScopeEnd(); \ - return *this; \ +#define COMMON_WRAPPERS(Stack) \ + Stack& line(const Doc& doc) { \ + Line(doc); \ + return *this; \ + } \ + Stack& line(const String& line = "") { \ + Line(line); \ + return *this; \ + } \ + Stack& comment(const String& comment, bool attach = false) { \ + Comment(comment, attach); \ + return *this; \ + } \ + template \ + Stack& assign(const LT& lhs, const RT& rhs, const String& annotation = "") { \ + Assign(lhs, rhs, annotation); \ + return *this; \ + } \ + Stack& declare(const String& type, const String& variable, size_t len = 0, \ + bool use_constructor = true) { \ + Declare(type, variable, len, use_constructor); \ + return *this; \ + } \ + template \ + Stack& declare_arg(const T& value) { \ + DeclareArg(value); \ + return *this; \ + } \ + Stack& class_def(const String& class_name) { \ + ClassDef(class_name); \ + return *this; \ + } \ + Stack& class_decorator(const String& decorator) { \ + ClassDecorator(decorator); \ + return *this; \ + } \ + Stack& class_start() { \ + ClassStart(); \ + return *this; \ + } \ + Stack& class_end() { \ + ClassEnd(); \ + return *this; \ + } \ + Stack& struct_start(const String& struct_name) { \ + StructStart(struct_name); \ + return *this; \ + } \ + Stack& struct_end() { \ + StructEnd(); \ + return *this; \ + } \ + Stack& func_def(const String& func_name, const String& ret_type = "") { \ + FuncDef(func_name, ret_type); \ + return *this; \ + } \ + Stack& func_arg(const String& arg, const String& annotation = "", const String& value = "") { \ + FuncArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& func_decorator(const String& decorator) { \ + FuncDecorator(decorator); \ + return *this; \ + } \ + Stack& func_start() { \ + FuncStart(); \ + return *this; \ + } \ + Stack& func_end() { \ + FuncEnd(); \ + return *this; \ + } \ + template \ + Stack& func_end(const T& ret_val) { \ + FuncEnd(ret_val); \ + return *this; \ + } \ + Stack& func_call(const String& callee, Optional assign_to, \ + Optional caller = NullOpt) { \ + FuncCall(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& func_call(const String& callee, const String& assign_to = "", \ + const String& caller = "") { \ + FuncCall(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& method_call(const String& callee, bool new_line = false) { \ + MethodCall(callee, new_line); \ + return *this; \ + } \ + Stack& inplace_start(const String& callee, Optional assign_to, \ + Optional caller = NullOpt) { \ + InplaceStart(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& inplace_start(const String& callee, const String& assign_to = "", \ + const String& caller = "") { \ + InplaceStart(callee, assign_to, caller); \ + return *this; \ + } \ + Stack& inplace_end() { \ + InplaceEnd(); \ + return *this; \ + } \ + Stack& constructor_def(const String& func_name) { \ + ConstructorDef(func_name); \ + return *this; \ + } \ + Stack& constructor_arg(const String& arg, const String& annotation = "", \ + const String& value = "") { \ + ConstructorArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& constructor_start() { \ + ConstructorStart(); \ + return *this; \ + } \ + Stack& constructor_end() { \ + ConstructorEnd(); \ + return *this; \ + } \ + Stack& lambda_def(const String& lambda_name) { \ + LambdaDef(lambda_name); \ + return *this; \ + } \ + Stack& lambda_arg(const String& arg, const String& annotation = "", const String& value = "") { \ + LambdaArg(arg, annotation, value); \ + return *this; \ + } \ + Stack& lambda_ref(const String& ref) { \ + LambdaRef(ref); \ + return *this; \ + } \ + Stack& lambda_start() { \ + LambdaStart(); \ + return *this; \ + } \ + Stack& lambda_end(const String& ret_val = "") { \ + LambdaEnd(ret_val); \ + return *this; \ + } \ + Stack& lambda_end(const ExprDoc& ret_val) { \ + LambdaEnd(ret_val); \ + return *this; \ + } \ + Stack& pop_nest(const String& key = "") { \ + PopNest(key); \ + return *this; \ + } \ + template \ + Stack& call_arg(T value, const String& key = "") { \ + CallArg(value, key); \ + return *this; \ + } \ + Stack& call_arg(const ExprDoc& value, const String& key = "") { \ + CallArg(value, key); \ + return *this; \ + } \ + Stack& call_arg(const Array& values) { \ + CallArg(values); \ + return *this; \ + } \ + Stack& cond_if(const String& predicate) { \ + ConditionIf(predicate); \ + return *this; \ + } \ + Stack& cond_else() { \ + ConditionElse(); \ + return *this; \ + } \ + Stack& cond_end() { \ + ConditionEnd(); \ + return *this; \ + } \ + template \ + Stack& for_start(const LT& lhs, const RT& rhs) { \ + ForStart(lhs, rhs); \ + return *this; \ + } \ + template \ + Stack& for_start(const String& lhs, const ST& start, const ET& end) { \ + ForStart(lhs, start, end); \ + return *this; \ + } \ + Stack& for_start(const String& lhs, const String& start, const String& end) { \ + ForStart(lhs, start, end); \ + return *this; \ + } \ + Stack& for_end() { \ + ForEnd(); \ + return *this; \ + } \ + Stack& while_start(const String& predicate) { \ + WhileStart(predicate); \ + return *this; \ + } \ + Stack& while_end() { \ + WhileEnd(); \ + return *this; \ + } \ + Stack& switch_start(const String& predicate) { \ + SwitchStart(predicate); \ + return *this; \ + } \ + Stack& switch_case(const String& predicate = "") { \ + SwitchCase(predicate); \ + return *this; \ + } \ + Stack& switch_end() { \ + SwitchEnd(); \ + return *this; \ + } \ + Stack& block_start() { \ + BlockStart(); \ + return *this; \ + } \ + Stack& block_end(bool block_docs = true) { \ + BlockEnd(block_docs); \ + return *this; \ + } \ + Stack& scope_start(const String& scope_def = "", const String& scope_ref = "") { \ + ScopeStart(scope_def, scope_ref); \ + return *this; \ + } \ + Stack& scope_end() { \ + ScopeEnd(); \ + return *this; \ } /*! @@ -428,7 +571,7 @@ class OpCodeStack : public BaseStack { std::string attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { const String& valid_key = key == "msc::auto" ? attr_key : key; - return call_arg(DocUtils::ToStrDoc(attr_val), valid_key); + return call_arg(DocUtils::ToStr(attr_val), valid_key); } return *this; } @@ -440,7 +583,7 @@ class OpCodeStack : public BaseStack { std::vector attr_val; if (codegen_->node()->GetAttr(attr_key, &attr_val)) { const String& valid_key = key == "msc::auto" ? attr_key : key; - return call_arg(DocUtils::ToListDoc(attr_val, allow_empty), valid_key); + return call_arg(DocUtils::ToList(attr_val, allow_empty), valid_key); } return *this; } @@ -457,7 +600,7 @@ class OpCodeStack : public BaseStack { inputs.push_back(codegen_->IdxInput(i, true)); } if (as_list) { - return call_arg(DocUtils::ToListDoc(inputs), key); + return call_arg(DocUtils::ToList(inputs), key); } else { return call_arg(DocUtils::ToDocList(inputs)); } @@ -481,7 +624,7 @@ class OpCodeStack : public BaseStack { const String& name = "msc::auto") { const String& valid_key = key == "msc::auto" ? "name" : key; const String& valid_name = name == "msc::auto" ? codegen_->node()->name : name; - return call_arg(DocUtils::ToStrDoc(valid_name), valid_key); + return call_arg(DocUtils::ToStr(valid_name), valid_key); return *this; } diff --git a/src/contrib/msc/core/codegen/py_codegen.h b/src/contrib/msc/core/codegen/py_codegen.h index 040b6df11ff5..6b3120affe32 100644 --- a/src/contrib/msc/core/codegen/py_codegen.h +++ b/src/contrib/msc/core/codegen/py_codegen.h @@ -103,7 +103,7 @@ class PyCodeGen : public BaseCodeGen { .func_arg("dtype", "str") .func_start() .func_call("os.path.join", "path") - .call_arg(DocUtils::ToStrDoc(this->config()->baseline_folder)) + .call_arg(DocUtils::ToStr(this->config()->baseline_folder)) .call_arg("name + \".bin\"") .cond_if("os.path.isfile(path)") .func_call("np.fromfile", "data") @@ -129,21 +129,27 @@ class PyCodeGen : public BaseCodeGen { .assign("golden", "{}"); for (const auto& i : this->graph()->input_names) { const auto& input = this->graph()->FindTensor(i); - this->stack_.func_call("load_data", "inputs[\"" + input->alias + "\"]") - .call_arg(DocUtils::ToStrDoc(input->alias)) - .call_arg(DocUtils::ToListDoc(input->shape, true)) - .call_arg(DocUtils::ToStrDoc(runtime::DLDataType2String(input->dtype))); + this->stack_ + .func_call("load_data", DocUtils::ToIndex("inputs", DocUtils::ToStr(input->alias))) + .call_arg(DocUtils::ToStr(input->alias)) + .call_arg(DocUtils::ToList(input->shape, true)) + .call_arg(DocUtils::ToStr(runtime::DLDataType2String(input->dtype))); } for (const auto& o : this->graph()->output_names) { const auto& output = this->graph()->FindTensor(o); - this->stack_.func_call("load_data", "golden[\"" + output->alias + "\"]") - .call_arg(DocUtils::ToStrDoc(output->alias)) - .call_arg(DocUtils::ToListDoc(output->shape, true)) - .call_arg(DocUtils::ToStrDoc(runtime::DLDataType2String(output->dtype))); + this->stack_ + .func_call("load_data", DocUtils::ToIndex("golden", DocUtils::ToStr(output->alias))) + .call_arg(DocUtils::ToStr(output->alias)) + .call_arg(DocUtils::ToList(output->shape, true)) + .call_arg(DocUtils::ToStr(runtime::DLDataType2String(output->dtype))); } this->stack_.comment("Build and inference the graph"); CodeGenInference(); - this->stack_.line("msc_utils.compare_arrays(golden, outputs, verbose=\"detail\")").cond_end(); + this->stack_.func_call("msc_utils.compare_arrays") + .call_arg("golden") + .call_arg("outputs") + .call_arg(DocUtils::ToStr("detail"), "verbose") + .cond_end(); } /*! \brief Stack the docs for the node*/ @@ -155,19 +161,19 @@ class PyCodeGen : public BaseCodeGen { const auto& input = node->InputAt(i); this->stack_.func_call("msc_tools.process_tensor", this->IdxInputBase(node, i, true)) .call_arg(this->IdxInputBase(node, i, false)) - .call_arg(DocUtils::ToStrDoc(input->name)) - .call_arg(DocUtils::ToStrDoc(node->name)) - .call_arg(DocUtils::ToStrDoc(this->config()->tools_scope)) - .call_arg(DocUtils::ToStrDoc(this->config()->tools_tag)); + .call_arg(DocUtils::ToStr(input->name)) + .call_arg(DocUtils::ToStr(node->name)) + .call_arg(DocUtils::ToStr(this->config()->tools_scope)) + .call_arg(DocUtils::ToStr(this->config()->tools_tag)); } for (const auto& pair : node->weights) { this->stack_ .func_call("msc_tools.process_tensor", this->IdxWeightBase(node, pair.first, true)) .call_arg(this->IdxWeightBase(node, pair.first, false)) - .call_arg(DocUtils::ToStrDoc(pair.second->name)) - .call_arg(DocUtils::ToStrDoc(node->name)) - .call_arg(DocUtils::ToStrDoc(this->config()->tools_scope)) - .call_arg(DocUtils::ToStrDoc(this->config()->tools_tag)); + .call_arg(DocUtils::ToStr(pair.second->name)) + .call_arg(DocUtils::ToStr(node->name)) + .call_arg(DocUtils::ToStr(this->config()->tools_scope)) + .call_arg(DocUtils::ToStr(this->config()->tools_tag)); } } for (const auto& d : this->GetOpCodes(node)) { @@ -180,10 +186,10 @@ class PyCodeGen : public BaseCodeGen { if (graph_outputs_.count(node->OutputAt(index))) { this->stack_.func_call("msc_tools.process_tensor", this->IdxOutputBase(node, index, true)) .call_arg(this->IdxOutputBase(node, index, false)) - .call_arg(DocUtils::ToStrDoc(node->OutputAt(index)->name)) - .call_arg(DocUtils::ToStrDoc("exit")) - .call_arg(DocUtils::ToStrDoc(this->config()->tools_scope)) - .call_arg(DocUtils::ToStrDoc(this->config()->tools_tag)); + .call_arg(DocUtils::ToStr(node->OutputAt(index)->name)) + .call_arg(DocUtils::ToStr("exit")) + .call_arg(DocUtils::ToStr(this->config()->tools_scope)) + .call_arg(DocUtils::ToStr(this->config()->tools_tag)); } } } diff --git a/src/contrib/msc/core/printer/cpp_printer.cc b/src/contrib/msc/core/printer/cpp_printer.cc index 728f609c00df..f162f5db1e85 100644 --- a/src/contrib/msc/core/printer/cpp_printer.cc +++ b/src/contrib/msc/core/printer/cpp_printer.cc @@ -54,10 +54,20 @@ void CppPrinter::PrintTypedDoc(const IndexDoc& doc) { void CppPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { PrintDoc(doc->value, false); - if (!doc->value->IsInstance()) { - output_ << "."; + if (StringUtils::EndsWith(doc->name, DocSymbol::NextLine())) { + const auto& v_name = StringUtils::Replace(doc->name, DocSymbol::NextLine(), ""); + if (!doc->value->IsInstance()) { + IncreaseIndent(); + PrintDoc(IdDoc(".")); + DecreaseIndent(); + } + output_ << v_name; + } else { + if (!doc->value->IsInstance()) { + output_ << "."; + } + output_ << doc->name; } - output_ << doc->name; } void CppPrinter::PrintTypedDoc(const CallDoc& doc) { @@ -79,8 +89,10 @@ void CppPrinter::PrintTypedDoc(const CallDoc& doc) { void CppPrinter::PrintTypedDoc(const AssignDoc& doc) { ICHECK(doc->lhs.defined()) << "lhs should be given for assign"; if (doc->annotation.defined()) { - PrintDoc(doc->annotation.value(), false); - output_ << " "; + if (!IsEmptyDoc(doc->annotation.value())) { + PrintDoc(doc->annotation.value(), false); + output_ << " "; + } } PrintDoc(doc->lhs, false); if (doc->rhs.defined()) { @@ -148,7 +160,7 @@ void CppPrinter::PrintTypedDoc(const ForDoc& doc) { void CppPrinter::PrintTypedDoc(const ScopeDoc& doc) { MaybePrintComment(doc, true); ICHECK(doc->rhs.defined()) << "rhs should be given for scope"; - PrintDoc(doc->rhs); + PrintDoc(doc->rhs, false); PrintIndentedBlock(doc->body); } @@ -158,20 +170,28 @@ void CppPrinter::PrintTypedDoc(const FunctionDoc& doc) { ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them."; } if (doc->return_type.defined()) { - PrintDoc(doc->return_type.value(), false); + if (!IsEmptyDoc(doc->return_type.value())) { + PrintDoc(doc->return_type.value(), false); + output_ << " "; + } } else { - output_ << "void"; + output_ << "void "; } - output_ << " "; PrintDoc(doc->name, false); output_ << "("; PrintJoinedDocs(doc->args, ", "); output_ << ")"; + if (doc->decorators.size() > 0) { + output_ << " "; + PrintJoinedDocs(doc->decorators, " "); + } if (doc->body.size() > 0) { output_ << " {"; PrintIndentedBlock(doc->body); if (doc->return_type.defined()) { - Endline(); + if (!IsEmptyDoc(doc->return_type.value())) { + Endline(); + } } NewLine(); output_ << "}"; @@ -189,9 +209,11 @@ void CppPrinter::PrintTypedDoc(const ClassDoc& doc) { for (const StmtDoc& d : doc->body) { PrintDoc(d); } - NewLine(false); output_ << "}"; Endline(); + output_ << " // class "; + PrintDoc(doc->name, false); + NewLine(false); } void CppPrinter::PrintTypedDoc(const CommentDoc& doc) { @@ -230,6 +252,101 @@ void CppPrinter::PrintTypedDoc(const StrictListDoc& doc) { } } +void CppPrinter::PrintTypedDoc(const StructDoc& doc) { + MaybePrintComment(doc, true); + output_ << "struct "; + PrintDoc(doc->name, false); + output_ << " {"; + IncreaseIndent(); + for (const StmtDoc& d : doc->body) { + PrintDoc(d); + } + DecreaseIndent(); + NewLine(false); + output_ << "}"; + Endline(); + output_ << " // struct "; + PrintDoc(doc->name, false); + NewLine(false); +} + +void CppPrinter::PrintTypedDoc(const ConstructorDoc& doc) { + MaybePrintComment(doc, true); + for (const AssignDoc& arg_doc : doc->args) { + ICHECK(arg_doc->comment == nullptr) << "Constructor arg cannot have comment attached to them."; + } + PrintDoc(doc->name, false); + output_ << "("; + PrintJoinedDocs(doc->args, ", "); + output_ << ")"; + if (doc->body.size() > 0) { + output_ << " {"; + PrintIndentedBlock(doc->body); + NewLine(); + output_ << "}"; + } else { + Endline(); + } + NewLine(false); +} + +void CppPrinter::PrintTypedDoc(const LambdaDoc& doc) { + MaybePrintComment(doc, true); + for (const AssignDoc& arg_doc : doc->args) { + ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them."; + } + output_ << "auto "; + PrintDoc(doc->name, false); + output_ << " = ["; + PrintJoinedDocs(doc->refs, ", "); + output_ << "]("; + PrintJoinedDocs(doc->args, ", "); + output_ << ")"; + if (doc->body.size() > 0) { + output_ << " {"; + PrintIndentedBlock(doc->body); + Endline(); + NewLine(); + output_ << "};"; + } else { + Endline(); + } + NewLine(false); +} + +void CppPrinter::PrintTypedDoc(const SwitchDoc& doc) { + MaybePrintComment(doc, true); + ICHECK_EQ(doc->predicates.size(), doc->branchs.size()) + << "predicates " << doc->predicates.size() << " mismatch with branchs " + << doc->branchs.size(); + for (size_t i = 0; i < doc->predicates.size(); i++) { + if (i == 0) { + output_ << "if ("; + } else { + NewLine(); + output_ << "} else if ("; + } + PrintDoc(doc->predicates[i], false); + output_ << ") {"; + PrintIndentedBlock(doc->branchs[i]); + } + if (!doc->default_branch.empty()) { + NewLine(); + output_ << "} else {"; + PrintIndentedBlock(doc->default_branch); + } + NewLine(); + output_ << "}"; +} + +bool CppPrinter::IsEmptyDoc(const ExprDoc& doc) { + if (!doc->IsInstance()) { + return false; + } + const auto& id_doc = Downcast(doc); + return id_doc->name == DocSymbol::Empty(); +} + void CppPrinter::PrintIndentedBlock(const Array& docs) { IncreaseIndent(); for (const StmtDoc& d : docs) { diff --git a/src/contrib/msc/core/printer/cpp_printer.h b/src/contrib/msc/core/printer/cpp_printer.h index 870ff517f61a..bdd25acdebed 100644 --- a/src/contrib/msc/core/printer/cpp_printer.h +++ b/src/contrib/msc/core/printer/cpp_printer.h @@ -28,7 +28,9 @@ #include #include +#include "../utils.h" #include "msc_base_printer.h" +#include "print_utils.h" namespace tvm { namespace contrib { @@ -51,51 +53,63 @@ class CppPrinter : public MSCBasePrinter { } protected: - /*! * \brief Print a LiteralDoc to python format*/ + /*! * \brief Print a LiteralDoc to cpp format*/ void PrintTypedDoc(const LiteralDoc& doc) final; - /*! \brief Virtual method to print an IndexDoc*/ + /*! * \brief Print a IndexDoc to cpp format*/ void PrintTypedDoc(const IndexDoc& doc) final; - /*! * \brief Print a AttrAccessDoc to python format*/ + /*! * \brief Print a AttrAccessDoc to cpp format*/ void PrintTypedDoc(const AttrAccessDoc& doc) final; - /*! * \brief Print a CallDoc to python format*/ + /*! * \brief Print a CallDoc to cpp format*/ void PrintTypedDoc(const CallDoc& doc) final; - /*! * \brief Print a AssignDoc to python format*/ + /*! * \brief Print a AssignDoc to cpp format*/ void PrintTypedDoc(const AssignDoc& doc) final; - /*! * \brief Print a IfDoc to python format*/ + /*! * \brief Print a IfDoc to cpp format*/ void PrintTypedDoc(const IfDoc& doc) final; - /*! * \brief Print a WhileDoc to python format*/ + /*! * \brief Print a WhileDoc to cpp format*/ void PrintTypedDoc(const WhileDoc& doc) final; - /*! \brief Virtual method to print a ForDoc*/ + /*! * \brief Print a ForDoc to cpp format*/ void PrintTypedDoc(const ForDoc& doc) final; - /*! * \brief Print a ScopeDoc to python format*/ + /*! * \brief Print a ScopeDoc to cpp format*/ void PrintTypedDoc(const ScopeDoc& doc) final; - /*! * \brief Print a FunctionDoc to python format*/ + /*! * \brief Print a FunctionDoc to cpp format*/ void PrintTypedDoc(const FunctionDoc& doc) final; - /*! * \brief Print a ClassDoc to python format*/ + /*! * \brief Print a ClassDoc to cpp format*/ void PrintTypedDoc(const ClassDoc& doc) final; - /*! * \brief Print a CommentDoc to python format*/ + /*! * \brief Print a CommentDoc to cpp format*/ void PrintTypedDoc(const CommentDoc& doc) final; - /*! \brief Virtual method to print a DeclareDoc*/ + /*! * \brief Print a DeclareDoc to cpp format*/ void PrintTypedDoc(const DeclareDoc& doc) final; - /*! \brief Virtual method to print a PointerDoc*/ + /*! * \brief Print a PointerDoc to cpp format*/ void PrintTypedDoc(const PointerDoc& doc) final; - /*! \brief Virtual method to print a StrictListDoc*/ + /*! * \brief Print a StrictListDoc to cpp format*/ void PrintTypedDoc(const StrictListDoc& doc) final; + /*! * \brief Print a StructDoc to cpp format*/ + void PrintTypedDoc(const StructDoc& doc) final; + + /*! * \brief Print a ConstructorDoc to cpp format*/ + void PrintTypedDoc(const ConstructorDoc& doc) final; + + /*! * \brief Print a LambdaDoc to cpp format*/ + void PrintTypedDoc(const LambdaDoc& doc) final; + + /*! * \brief Print a SwitchDoc to cpp format*/ + void PrintTypedDoc(const SwitchDoc& doc) final; + private: /*! \brief endline scopes*/ std::vector endlines_; @@ -129,6 +143,9 @@ class CppPrinter : public MSCBasePrinter { } } + /*! \brief Check if the doc is empty doc*/ + bool IsEmptyDoc(const ExprDoc& doc); + /*! \brief Print block with indent*/ void PrintIndentedBlock(const Array& docs); }; diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index ba434976ffa1..289c1b79fd66 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -82,6 +82,14 @@ void MSCBasePrinter::PrintDoc(const Doc& doc, bool new_line) { PrintTypedDoc(doc_node.value()); } else if (auto doc_node = doc.as()) { PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); + } else if (auto doc_node = doc.as()) { + PrintTypedDoc(doc_node.value()); } else { LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); throw; diff --git a/src/contrib/msc/core/printer/msc_base_printer.h b/src/contrib/msc/core/printer/msc_base_printer.h index 17cb218bfcba..af369a530dae 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.h +++ b/src/contrib/msc/core/printer/msc_base_printer.h @@ -183,6 +183,20 @@ class MSCBasePrinter { LOG(FATAL) << "PointerDoc is not implemented"; } + /*! \brief Virtual method to print a StructDoc*/ + virtual void PrintTypedDoc(const StructDoc& doc) { LOG(FATAL) << "StructDoc is not implemented"; } + + /*! \brief Virtual method to print a ConstructorDoc*/ + virtual void PrintTypedDoc(const ConstructorDoc& doc) { + LOG(FATAL) << "ConstructorDoc is not implemented"; + } + + /*! \brief Virtual method to print a SwitchDoc*/ + virtual void PrintTypedDoc(const SwitchDoc& doc) { LOG(FATAL) << "SwitchDoc is not implemented"; } + + /*! \brief Virtual method to print a LambdaDoc*/ + virtual void PrintTypedDoc(const LambdaDoc& doc) { LOG(FATAL) << "LambdaDoc is not implemented"; } + /*! \brief Print docs to joined doc */ template void PrintJoinedDocs(const Array& docs, const String& separator = ", ") { diff --git a/src/contrib/msc/core/printer/msc_doc.cc b/src/contrib/msc/core/printer/msc_doc.cc index d5c94675266e..5497b7f9fe0a 100644 --- a/src/contrib/msc/core/printer/msc_doc.cc +++ b/src/contrib/msc/core/printer/msc_doc.cc @@ -52,6 +52,40 @@ PointerDoc::PointerDoc(String name) { this->data_ = std::move(n); } +StructDoc::StructDoc(IdDoc name, Array decorators, Array body) { + ObjectPtr n = make_object(); + n->name = name; + n->decorators = decorators; + n->body = body; + this->data_ = std::move(n); +} + +ConstructorDoc::ConstructorDoc(IdDoc name, Array args, Array body) { + ObjectPtr n = make_object(); + n->name = name; + n->args = args; + n->body = body; + this->data_ = std::move(n); +} + +SwitchDoc::SwitchDoc(Array predicates, Array> branchs, + Array default_branch) { + ObjectPtr n = make_object(); + n->predicates = predicates; + n->branchs = branchs; + n->default_branch = default_branch; + this->data_ = std::move(n); +} + +LambdaDoc::LambdaDoc(IdDoc name, Array args, Array refs, Array body) { + ObjectPtr n = make_object(); + n->name = name; + n->args = args; + n->refs = refs; + n->body = body; + this->data_ = std::move(n); +} + } // namespace msc } // namespace contrib } // namespace tvm diff --git a/src/contrib/msc/core/printer/msc_doc.h b/src/contrib/msc/core/printer/msc_doc.h index 440d2079ffca..3b83d6e22e0a 100644 --- a/src/contrib/msc/core/printer/msc_doc.h +++ b/src/contrib/msc/core/printer/msc_doc.h @@ -153,8 +153,192 @@ class PointerDoc : public ExprDoc { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PointerDoc, ExprDoc, PointerDocNode); }; +/*! + * \brief Doc that represents struct definition. + * + * \sa StructDoc + */ +class StructDocNode : public StmtDocNode { + public: + /*! \brief The name of class. */ + IdDoc name{nullptr}; + /*! \brief Decorators of class. */ + Array decorators; + /*! \brief The body of class. */ + Array body; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("decorators", &decorators); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.StructDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(StructDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of StructDocNode. + * + * \sa StructDocNode + */ +class StructDoc : public StmtDoc { + public: + /*! + * \brief Constructor of StructDoc. + * \param name The name of class. + * \param decorators The decorator of class. + * \param body The body of class. + */ + explicit StructDoc(IdDoc name, Array decorators, Array body); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StructDoc, StmtDoc, StructDocNode); +}; + +/*! + * \brief Doc that represents constructor definition. + * + * \sa ConstructorDoc + */ +class ConstructorDocNode : public StmtDocNode { + public: + /*! \brief The name of function. */ + IdDoc name{nullptr}; + /*! + * \brief The arguments of function. + * + * The `lhs` means argument name, + * `annotation` means argument type, + * and `rhs` means default value. + */ + Array args; + /*! \brief The body of function. */ + Array body; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("args", &args); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.ConstructorDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of ConstructorDocNode. + * + * \sa ConstructorDocNode + */ +class ConstructorDoc : public StmtDoc { + public: + /*! + * \brief Constructor of ConstructorDoc. + * \param name The name of function.. + * \param args The arguments of function. + * \param body The body of function. + */ + explicit ConstructorDoc(IdDoc name, Array args, Array body); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ConstructorDoc, StmtDoc, ConstructorDocNode); +}; + +/*! + * \brief Doc that represent switch statement. + * + * \sa SwitchDoc + */ +class SwitchDocNode : public StmtDocNode { + public: + /*! \brief The predicates of the switch statement. */ + Array predicates; + /*! \brief The branchs of the switch statement. */ + Array> branchs; + /*! \brief The default_branch of the switch statement. */ + Array default_branch; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("predicates", &predicates); + v->Visit("branchs", &branchs); + v->Visit("default_branch", &default_branch); + } + + static constexpr const char* _type_key = "script.printer.SwitchDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(SwitchDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of SwitchDocNode. + * + * \sa SwitchDocNode + */ +class SwitchDoc : public StmtDoc { + public: + /*! + * \brief Constructor of SwitchDoc. + * \param predicates The predicates of the switch statement. + * \param branchs The branchs of the switch statement. + * \param default_branch The default_branch of the switch statement. + */ + explicit SwitchDoc(Array predicates, Array> branchs, + Array default_branch); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SwitchDoc, StmtDoc, SwitchDocNode); +}; + +/*! + * \brief Doc that represents lambda definition. + * + * \sa LambdaDoc + */ +class LambdaDocNode : public StmtDocNode { + public: + /*! \brief The name of lambda. */ + IdDoc name{nullptr}; + /*! + * \brief The arguments of lambda. + * + * The `lhs` means argument name, + * `annotation` means argument type, + * and `rhs` means default value. + */ + Array args; + /*! \brief References of lambda. */ + Array refs; + /*! \brief The body of lambda. */ + Array body; + + void VisitAttrs(AttrVisitor* v) { + StmtDocNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("args", &args); + v->Visit("refs", &refs); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.LambdaDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(LambdaDocNode, StmtDocNode); +}; + +/*! + * \brief Reference type of LambdaDocNode. + * + * \sa LambdaDoc + */ +class LambdaDoc : public StmtDoc { + public: + /*! + * \brief Constructor of LambdaDoc. + * \param name The name of lambda. + * \param args The arguments of lambda. + * \param refs The references of lambda. + * \param body The body of lambda. + */ + explicit LambdaDoc(IdDoc name, Array args, Array refs, Array body); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LambdaDoc, StmtDoc, LambdaDocNode); +}; + } // namespace msc } // namespace contrib } // namespace tvm - #endif // TVM_CONTRIB_MSC_CORE_PRINTER_MSC_DOC_H_ diff --git a/src/contrib/msc/core/printer/print_utils.cc b/src/contrib/msc/core/printer/print_utils.cc index 6086e3ffa3c0..95df8da85cb2 100644 --- a/src/contrib/msc/core/printer/print_utils.cc +++ b/src/contrib/msc/core/printer/print_utils.cc @@ -28,6 +28,10 @@ namespace tvm { namespace contrib { namespace msc { +const String DocSymbol::Empty() { return "::EMPTY"; } + +const String DocSymbol::NextLine() { return "::NEXT_LINE"; } + const ExprDoc DocUtils::ToDoc(int64_t val) { return LiteralDoc::Int(val, NullOpt); } const ExprDoc DocUtils::ToDoc(int val) { return ToDoc(static_cast(val)); } @@ -52,28 +56,35 @@ const ExprDoc DocUtils::ToDoc(bool val) { return LiteralDoc::Boolean(val, NullOp const ExprDoc DocUtils::ToDoc(const ExprDoc& val) { return val; } -const ExprDoc DocUtils::ToStrDoc(const String& val) { return LiteralDoc::Str(val, NullOpt); } +const ExprDoc DocUtils::ToStr(const String& val) { return LiteralDoc::Str(val, NullOpt); } -const PointerDoc DocUtils::ToPtrDoc(const String& val) { return PointerDoc(val); } +const PointerDoc DocUtils::ToPtr(const String& val) { return PointerDoc(val); } -const DeclareDoc DocUtils::ToDeclareDoc(const String& type, const String& variable, size_t len, - bool use_constructor) { - Optional type_doc; - if (type.size() == 0) { - type_doc = NullOpt; - } else { - type_doc = IdDoc(type); +const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { + if (values.size() > 0 || allow_empty) { + Array elements; + for (const auto& v : values) { + elements.push_back(ToStr(v)); + } + return StrictListDoc(ListDoc(elements), allow_empty); } - if (len == 0) { - return DeclareDoc(type_doc, IdDoc(variable), Array(), use_constructor); + return StrictListDoc(ListDoc(), false); +} + +const StrictListDoc DocUtils::ToStrList(const std::vector& values, bool allow_empty) { + std::vector v_values; + for (const auto& v : values) { + v_values.push_back(v); } - Array doc_indices{DocUtils::ToDoc(len)}; - return DeclareDoc(type_doc, IndexDoc(IdDoc(variable), doc_indices), Array(), - use_constructor); + return ToStrList(v_values, allow_empty); } -const AttrAccessDoc DocUtils::ToAttrAccessDoc(const String& value, const String& name) { - return AttrAccessDoc(IdDoc(value), name); +const StrictListDoc DocUtils::ToStrList(const Array& values, bool allow_empty) { + std::vector v_values; + for (const auto& v : values) { + v_values.push_back(v); + } + return ToStrList(v_values, allow_empty); } const Array DocUtils::ToStmts(const Array& docs) { diff --git a/src/contrib/msc/core/printer/print_utils.h b/src/contrib/msc/core/printer/print_utils.h index acae1dbd4d7f..17d4b0b077a6 100644 --- a/src/contrib/msc/core/printer/print_utils.h +++ b/src/contrib/msc/core/printer/print_utils.h @@ -26,6 +26,7 @@ #include +#include #include #include "msc_doc.h" @@ -36,6 +37,19 @@ namespace msc { using namespace tvm::script::printer; +/*! + * \brief Symbols for Doc. + */ + +class DocSymbol { + public: + /*! * \brief The empty symbol*/ + TVM_DLL static const String Empty(); + + /*! * \brief The next line symbol*/ + TVM_DLL static const String NextLine(); +}; + /*! * \brief Utils for Doc. */ @@ -57,21 +71,68 @@ class DocUtils { TVM_DLL static const ExprDoc ToDoc(const String& val); TVM_DLL static const ExprDoc ToDoc(bool val); TVM_DLL static const ExprDoc ToDoc(const ExprDoc& val); - TVM_DLL static const ExprDoc ToStrDoc(const String& val); - TVM_DLL static const PointerDoc ToPtrDoc(const String& val); + TVM_DLL static const ExprDoc ToStr(const String& val); + TVM_DLL static const PointerDoc ToPtr(const String& val); /*! * \brief Change object to DeclareDoc. * \return The DeclareDoc. */ - TVM_DLL static const DeclareDoc ToDeclareDoc(const String& type, const String& variable, - size_t len = 0, bool use_constructor = true); + template + TVM_DLL static const DeclareDoc ToDeclare(const String& type, const T& variable, size_t len = 0, + bool use_constructor = true) { + Optional type_doc; + if (type.size() == 0) { + type_doc = NullOpt; + } else { + type_doc = IdDoc(type); + } + if (len == 0) { + return DeclareDoc(type_doc, ToDoc(variable), Array(), use_constructor); + } + Array doc_indices{DocUtils::ToDoc(len)}; + return DeclareDoc(type_doc, IndexDoc(ToDoc(variable), doc_indices), Array(), + use_constructor); + } + + /*! + * \brief Change object to AssignDoc. + * \return The AssignDoc. + */ + template + TVM_DLL static const AssignDoc ToAssign(const LT& lhs, const RT& rhs, + const String& annotation = "") { + if (annotation.size() == 0) { + return AssignDoc(ToDoc(lhs), ToDoc(rhs), NullOpt); + } + return AssignDoc(ToDoc(lhs), ToDoc(rhs), IdDoc(annotation)); + } + template + TVM_DLL static const AssignDoc ToAssign(const T& lhs, const String& rhs, + const String& annotation = "") { + Optional rhs_doc; + if (rhs.size() > 0) { + rhs_doc = IdDoc(rhs); + } else { + rhs_doc = NullOpt; + } + Optional annotation_doc; + if (annotation.size() > 0) { + annotation_doc = IdDoc(annotation); + } else { + annotation_doc = NullOpt; + } + return AssignDoc(ToDoc(lhs), rhs_doc, annotation_doc); + } /*! * \brief Change object to AttrAccessDoc. * \return The AttrAccessDoc. */ - TVM_DLL static const AttrAccessDoc ToAttrAccessDoc(const String& value, const String& name); + template + TVM_DLL static const AttrAccessDoc ToAttrAccess(const T& value, const String& name) { + return AttrAccessDoc(ToDoc(value), name); + } /*! * \brief Change object to List of Docs. @@ -87,11 +148,11 @@ class DocUtils { } template TVM_DLL static const Array ToDocList(const Array& values) { - Array elements; + std::vector v_values; for (const auto& v : values) { - elements.push_back(ToDoc(v)); + v_values.push_back(v); } - return elements; + return ToDocList(v_values); } /*! @@ -99,53 +160,58 @@ class DocUtils { * \return The ListDoc. */ template - TVM_DLL static const StrictListDoc ToListDoc(const std::vector& values, - bool allow_empty = false) { + TVM_DLL static const StrictListDoc ToList(const std::vector& values, + bool allow_empty = false) { if (values.size() > 0 || allow_empty) { return StrictListDoc(ListDoc(ToDocList(values)), allow_empty); } return StrictListDoc(ListDoc(), false); } template - TVM_DLL static const StrictListDoc ToListDoc(const Array& values, bool allow_empty = false) { - if (values.size() > 0 || allow_empty) { - return StrictListDoc(ListDoc(ToDocList(values)), allow_empty); + TVM_DLL static const StrictListDoc ToList(const Array& values, bool allow_empty = false) { + std::vector v_values; + for (const auto& v : values) { + v_values.push_back(v); } - return StrictListDoc(ListDoc(), false); + return ToList(v_values, allow_empty); } /*! - * \brief Change object to IndexDoc. + * \brief Change object to ListDoc for string elemenets. * \return The ListDoc. */ - template - TVM_DLL static const IndexDoc ToIndexDoc(const String& value, const std::vector& indices) { + TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, + bool allow_empty = false); + TVM_DLL static const StrictListDoc ToStrList(const std::vector& values, + bool allow_empty = false); + TVM_DLL static const StrictListDoc ToStrList(const Array& values, + bool allow_empty = false); + + /*! + * \brief Change object to IndexDoc. + * \return The IndexDoc. + */ + template + TVM_DLL static const IndexDoc ToIndex(const VT& value, const IT& index) { + Array doc_indices; + doc_indices.push_back(ToDoc(index)); + return IndexDoc(ToDoc(value), doc_indices); + } + template + TVM_DLL static const IndexDoc ToIndices(const VT& value, const std::vector& indices) { Array doc_indices; for (const auto& i : indices) { doc_indices.push_back(ToDoc(i)); } - return IndexDoc(IdDoc(value), doc_indices); + return IndexDoc(ToDoc(value), doc_indices); } - template - TVM_DLL static const IndexDoc ToIndexDoc(const String& value, const Array& indices) { + template + TVM_DLL static const IndexDoc ToIndices(const VT& value, const Array& indices) { Array doc_indices; for (const auto& i : indices) { doc_indices.push_back(ToDoc(i)); } - return IndexDoc(IdDoc(value), doc_indices); - } - - /*! - * \brief Change object to AssignDoc. - * \return The AssignDoc. - */ - template - TVM_DLL static const AssignDoc ToAssignDoc(const String& lhs, const T& rhs, - const String& annotation = "") { - if (annotation.size() == 0) { - return AssignDoc(IdDoc(lhs), ToDoc(rhs), NullOpt); - } - return AssignDoc(IdDoc(lhs), ToDoc(rhs), IdDoc(annotation)); + return IndexDoc(ToDoc(value), doc_indices); } /*! diff --git a/src/contrib/msc/core/printer/python_printer.cc b/src/contrib/msc/core/printer/python_printer.cc index e272c04e984b..f1a13c7fd04f 100644 --- a/src/contrib/msc/core/printer/python_printer.cc +++ b/src/contrib/msc/core/printer/python_printer.cc @@ -23,6 +23,8 @@ #include "python_printer.h" +#include "../utils.h" + namespace tvm { namespace contrib { namespace msc { @@ -118,6 +120,28 @@ void PythonPrinter::PrintTypedDoc(const IfDoc& doc) { } } +void PythonPrinter::PrintTypedDoc(const ForDoc& doc) { + MaybePrintComment(doc, true); + if (doc->rhs->IsInstance()) { + const auto& tuple = Downcast(doc->rhs); + ICHECK_EQ(tuple->elements.size(), 2) << "For with tuple should has 2 elements"; + output_ << "for "; + PrintDoc(doc->lhs, false); + output_ << " in range("; + PrintDoc(tuple->elements[0], false); + output_ << ", "; + PrintDoc(tuple->elements[1], false); + output_ << "):"; + } else { + output_ << "for "; + PrintDoc(doc->lhs, false); + output_ << " in "; + PrintDoc(doc->rhs, false); + output_ << ":"; + } + PrintIndentedBlock(doc->body); +} + void PythonPrinter::PrintTypedDoc(const ScopeDoc& doc) { MaybePrintComment(doc, true); output_ << "with "; @@ -152,7 +176,11 @@ void PythonPrinter::PrintTypedDoc(const FunctionDoc& doc) { output_ << ":"; - MaybePrintComment(doc, true); + if (doc->comment.defined()) { + IncreaseIndent(); + MaybePrintComment(doc, true); + DecreaseIndent(); + } PrintIndentedBlock(doc->body); NewLine(false); } @@ -182,6 +210,44 @@ void PythonPrinter::PrintTypedDoc(const StrictListDoc& doc) { } } +void PythonPrinter::PrintTypedDoc(const SwitchDoc& doc) { + MaybePrintComment(doc, true); + ICHECK_EQ(doc->predicates.size(), doc->branchs.size()) + << "predicates " << doc->predicates.size() << " mismatch with branchs " + << doc->branchs.size(); + for (size_t i = 0; i < doc->predicates.size(); i++) { + if (i == 0) { + output_ << "if "; + } else { + NewLine(); + output_ << "elif "; + } + PrintDoc(doc->predicates[i], false); + output_ << ":"; + PrintIndentedBlock(doc->branchs[i]); + } + if (!doc->default_branch.empty()) { + NewLine(); + output_ << "else:"; + PrintIndentedBlock(doc->default_branch); + } +} + +void PythonPrinter::MaybePrintComment(const StmtDoc& stmt, bool multi_lines) { + if (stmt->comment.defined() && multi_lines) { + NewLine(); + output_ << "\"\"\""; + for (const auto& l : StringUtils::Split(stmt->comment.value(), "\n")) { + PrintDoc(IdDoc(l)); + } + NewLine(); + output_ << "\"\"\""; + NewLine(); + } else { + MSCBasePrinter::MaybePrintComment(stmt, multi_lines); + } +} + void PythonPrinter::PrintIndentedBlock(const Array& docs) { IncreaseIndent(); for (const StmtDoc& d : docs) { diff --git a/src/contrib/msc/core/printer/python_printer.h b/src/contrib/msc/core/printer/python_printer.h index ac053ee276fb..31f380bc87be 100644 --- a/src/contrib/msc/core/printer/python_printer.h +++ b/src/contrib/msc/core/printer/python_printer.h @@ -66,6 +66,9 @@ class PythonPrinter : public MSCBasePrinter { /*! * \brief Print a IfDoc to python format*/ void PrintTypedDoc(const IfDoc& doc) final; + /*! * \brief Print a ForDoc to python format*/ + void PrintTypedDoc(const ForDoc& doc) final; + /*! * \brief Print a ScopeDoc to python format*/ void PrintTypedDoc(const ScopeDoc& doc) final; @@ -78,9 +81,15 @@ class PythonPrinter : public MSCBasePrinter { /*! * \brief Print a CommentDoc to python format*/ void PrintTypedDoc(const CommentDoc& doc) final; - /*! \brief Virtual method to print a StrictListDoc*/ + /*! * \brief Print a StrictListDoc to python format*/ void PrintTypedDoc(const StrictListDoc& doc) final; + /*! * \brief Print a SwitchDoc to python format*/ + void PrintTypedDoc(const SwitchDoc& doc) final; + + /*! \brief Print comment for stmt in python format*/ + void MaybePrintComment(const StmtDoc& stmt, bool multi_lines = false) final; + private: /*! \brief Print block with indent*/ void PrintIndentedBlock(const Array& docs); diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc index 4ec2a60314d6..4c9eeb2ec176 100644 --- a/src/contrib/msc/core/transform/layout_utils.cc +++ b/src/contrib/msc/core/transform/layout_utils.cc @@ -30,6 +30,32 @@ namespace tvm { namespace contrib { namespace msc { +NLayout LayoutUtils::InferNLayout(const Expr& expr, const VarLayoutMap& var_layout_map) { + if (expr->IsInstance() && var_layout_map.count(Downcast(expr))) { + return tvm::relax::GetNLayout(var_layout_map, expr); + } + return GetNLayout(expr); +} + +LayoutDecision LayoutUtils::InferLayoutDecision(const Expr& expr, + const VarLayoutMap& var_layout_map) { + const auto& nlayout = InferNLayout(expr, var_layout_map); + ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; + return nlayout.LeafValue(); +} + +LayoutDecision LayoutUtils::InferLayoutDecisionAt(const Expr& expr, + const VarLayoutMap& var_layout_map, + size_t index) { + const auto& nlayouts = InferNLayout(expr, var_layout_map); + if (nlayouts.IsLeaf()) { + return index == 0 ? nlayouts.LeafValue() : LayoutDecision(""); + } + const auto& nlayout = nlayouts.NestedArray()[0]; + ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr; + return nlayout.LeafValue(); +} + bool LayoutUtils::LayoutInfered(const Expr& expr) { const String& layout = SpanUtils::GetAttr(expr->span, "layout"); return layout.size() > 0; diff --git a/src/contrib/msc/core/transform/layout_utils.h b/src/contrib/msc/core/transform/layout_utils.h index b9de832838c5..7748f217d6ec 100644 --- a/src/contrib/msc/core/transform/layout_utils.h +++ b/src/contrib/msc/core/transform/layout_utils.h @@ -45,6 +45,27 @@ using namespace tvm::relax; */ class LayoutUtils { public: + /*! + * \brief Infer NLayout. + * \return The NLayout. + */ + TVM_DLL static NLayout InferNLayout(const Expr& expr, const VarLayoutMap& var_layout_map); + + /*! + * \brief Infer LayoutDecision. + * \return The LayoutDecision. + */ + TVM_DLL static LayoutDecision InferLayoutDecision(const Expr& expr, + const VarLayoutMap& var_layout_map); + + /*! + * \brief Infer LayoutDecision at given pos. + * \return The LayoutDecision. + */ + TVM_DLL static LayoutDecision InferLayoutDecisionAt(const Expr& expr, + const VarLayoutMap& var_layout_map, + size_t index = 0); + /*! * \brief Check if the layout is infered. * \return Whether the layout is infered. diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index d2023b886a0f..affb653a230d 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -34,30 +34,6 @@ namespace relax { using namespace tvm::contrib::msc; -NLayout InferNLayout(const Expr& expr, const VarLayoutMap& var_layout_map) { - if (expr->IsInstance() && var_layout_map.count(Downcast(expr))) { - return GetNLayout(var_layout_map, expr); - } - return LayoutUtils::GetNLayout(expr); -} - -LayoutDecision InferLayoutDecision(const Expr& expr, const VarLayoutMap& var_layout_map) { - const auto& nlayout = InferNLayout(expr, var_layout_map); - ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; - return nlayout.LeafValue(); -} - -LayoutDecision InferLayoutDecisionAt(const Expr& expr, const VarLayoutMap& var_layout_map, - size_t index = 0) { - const auto& nlayouts = InferNLayout(expr, var_layout_map); - if (nlayouts.IsLeaf()) { - return index == 0 ? nlayouts.LeafValue() : LayoutDecision(""); - } - const auto& nlayout = nlayouts.NestedArray()[0]; - ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr; - return nlayout.LeafValue(); -} - std::tuple AccumulateMatch(const std::vector& in_shape, const std::vector& out_shape, size_t in_start, size_t out_start) { @@ -228,7 +204,7 @@ InferLayoutOutput ForwardInferLayoutCommon(const Call& call, Array input_layouts; LayoutDecision layout_hint; for (const auto& arg : call->args) { - const auto& in_layout = InferLayoutDecision(arg, var_layout_map); + const auto& in_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); if (in_layout->layout.defined()) { layout_hint = in_layout; } @@ -286,7 +262,7 @@ InferLayoutOutput ForwardInferLayoutInplace(const Call& call, InferLayoutOutput ForwardInferLayoutArgMaxMin(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); } @@ -318,7 +294,7 @@ InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, if (input_shape.size() == 0) { return InferLayoutOutput(); } - LayoutDecision in_layout = InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision in_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!in_layout->layout.defined()) { if (input_shape.size() == 4) { in_layout = LayoutDecision("NCHW"); @@ -334,7 +310,7 @@ InferLayoutOutput ForwardInferLayoutBatchNorm(const Call& call, InferLayoutOutput ForkwardInferLayoutExpandDims(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); } @@ -362,7 +338,7 @@ InferLayoutOutput ForwardInferLayoutNormalize(const Call& call, if (input_shape.size() == 0) { return InferLayoutOutput(); } - LayoutDecision in_layout = InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision in_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!in_layout->layout.defined()) { if (input_shape.size() == 4) { in_layout = LayoutDecision("NCHW"); @@ -386,7 +362,7 @@ InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, if (a_shape.size() == 0) { return InferLayoutOutput(); } - LayoutDecision a_layout = InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision a_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!a_layout->layout.defined()) { if (a_shape.size() == 4) { a_layout = LayoutDecision("NCHW"); @@ -408,7 +384,7 @@ InferLayoutOutput ForwardInferLayoutMatmul(const Call& call, InferLayoutOutput ForwardInferLayoutPermute(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); } @@ -430,7 +406,7 @@ InferLayoutOutput ForwardInferLayoutPermute(const Call& call, InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); } @@ -458,7 +434,7 @@ InferLayoutOutput ForwardInferLayoutReduceAxis(const Call& call, InferLayoutOutput ForwardInferLayoutReshape(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); } @@ -492,7 +468,7 @@ InferLayoutOutput ForwardInferLayoutReshape(const Call& call, InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = InferLayoutDecision(call->args[0], var_layout_map); + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[0], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); } @@ -525,7 +501,7 @@ InferLayoutOutput ForwardInferLayoutSqueeze(const Call& call, InferLayoutOutput ForwardInferLayoutTake(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision input_layout = InferLayoutDecision(call->args[1], var_layout_map); + LayoutDecision input_layout = LayoutUtils::InferLayoutDecision(call->args[1], var_layout_map); if (!input_layout->layout.defined()) { return InferLayoutOutput(); } @@ -631,7 +607,7 @@ TVM_REGISTER_OP("relax.nn.layer_norm") InferLayoutOutput BackwardInferLayoutCommon(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - NLayout output_layout = InferNLayout(call, var_layout_map); + NLayout output_layout = LayoutUtils::InferNLayout(call, var_layout_map); LayoutDecision layout_hint; if (output_layout.IsLeaf()) { layout_hint = output_layout.LeafValue(); @@ -647,7 +623,7 @@ InferLayoutOutput BackwardInferLayoutCommon(const Call& call, } Array input_layouts; for (const auto& arg : call->args) { - const auto& saved_layout = InferLayoutDecision(arg, var_layout_map); + const auto& saved_layout = LayoutUtils::InferLayoutDecision(arg, var_layout_map); if (saved_layout->layout.defined()) { input_layouts.push_back(saved_layout); } else { @@ -692,7 +668,7 @@ InferLayoutOutput BackwardInferLayoutInplace(const Call& call, InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -715,7 +691,7 @@ InferLayoutOutput BackwardInferLayoutArgMaxMin(const Call& call, InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 0); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecisionAt(call, var_layout_map, 0); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -727,7 +703,7 @@ InferLayoutOutput BackwardInferLayoutBatchNorm(const Call& call, InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -749,7 +725,7 @@ InferLayoutOutput BackwardInferLayoutExpandDims(const Call& call, InferLayoutOutput BackwardInferLayoutNormalize(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecisionAt(call, var_layout_map, 0); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecisionAt(call, var_layout_map, 0); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -760,7 +736,7 @@ InferLayoutOutput BackwardInferLayoutNormalize(const Call& call, InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -782,7 +758,7 @@ InferLayoutOutput BackwardInferLayoutMatmul(const Call& call, InferLayoutOutput BackwardInferLayoutPermute(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -813,7 +789,7 @@ InferLayoutOutput BackwardInferLayoutPermute(const Call& call, InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -838,7 +814,7 @@ InferLayoutOutput BackwardInferLayoutReduceAxis(const Call& call, InferLayoutOutput BackwardInferLayoutReshape(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -872,7 +848,7 @@ InferLayoutOutput BackwardInferLayoutReshape(const Call& call, InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -905,7 +881,7 @@ InferLayoutOutput BackwardInferLayoutSqueeze(const Call& call, InferLayoutOutput BackwardInferLayoutTake(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -915,7 +891,7 @@ InferLayoutOutput BackwardInferLayoutTake(const Call& call, InferLayoutOutput BackwardInferLayoutTupleInputs(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - LayoutDecision output_layout = InferLayoutDecision(call, var_layout_map); + LayoutDecision output_layout = LayoutUtils::InferLayoutDecision(call, var_layout_map); if (!output_layout->layout.defined()) { return InferLayoutOutput(); } @@ -1196,7 +1172,7 @@ class LayoutInfer : public ExprVisitor { if (IsNestedTensor(binding->var)) { Array input_layouts; for (const auto& field : val->fields) { - input_layouts.push_back(InferLayoutDecision(field, var_layout_map_)); + input_layouts.push_back(LayoutUtils::InferLayoutDecision(field, var_layout_map_)); } SetExprLayout(binding->var, input_layouts); } @@ -1205,8 +1181,8 @@ class LayoutInfer : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { ExprVisitor::VisitBinding_(binding, val); RecordExpr(binding->var, GetRef(val)); - const auto& out_layout = - InferLayoutDecisionAt(GetRef(val)->tuple, var_layout_map_, val->index); + const auto& out_layout = LayoutUtils::InferLayoutDecisionAt(GetRef(val)->tuple, + var_layout_map_, val->index); SetExprLayout(binding->var, out_layout); } diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index 663de50b0b8b..bf5f36b4605a 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -23,6 +23,7 @@ #include "utils.h" +#include #include namespace tvm { namespace contrib { @@ -207,6 +208,18 @@ const String StringUtils::GetClosureOnce(const String& src_string, const String& return val; } +const String StringUtils::Upper(const String& src_string) { + std::string str = std::string(src_string); + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + return str; +} + +const String StringUtils::Lower(const String& src_string) { + std::string str = std::string(src_string); + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + return str; +} + const String StringUtils::ToString(const runtime::ObjectRef& obj) { String obj_string; if (!obj.defined()) { @@ -224,6 +237,8 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { obj_string = obj_string + ","; } } + } else if (const auto* n = obj.as()) { + obj_string = ToString(n->value); } else { std::ostringstream obj_des; obj_des << obj; diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 95444c354d13..43eb7bd33d9e 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -39,6 +39,9 @@ using Expr = tvm::RelayExpr; using RelaxCall = tvm::relax::Call; using RelayCall = tvm::relay::Call; +/*! + * \brief Utils for Common. + */ class CommonUtils { public: /*! @@ -127,6 +130,18 @@ class StringUtils { TVM_DLL static const String GetClosureOnce(const String& src_string, const String& left, const String& right, bool from_left = true); + /*! + * \brief Change string to upper. + * \return The String. + */ + TVM_DLL static const String Upper(const String& src_string); + + /*! + * \brief Change string to lower. + * \return The String. + */ + TVM_DLL static const String Lower(const String& src_string); + /*! * \brief Change Object to String. * \return The String. @@ -194,6 +209,29 @@ class ArrayUtils { } return new_array; } + + template + TVM_DLL static const Array> Product(const Array>& arrays) { + Array> p_arrays; + if (arrays.size() == 1) { + for (const auto& a : arrays[0]) { + p_arrays.push_back(Array{a}); + } + return p_arrays; + } + Array> sub_arrays; + for (size_t i = 0; i < arrays.size() - 1; i++) { + sub_arrays.push_back(arrays[i]); + } + for (const auto& p_array : Product(sub_arrays)) { + for (const auto& a : arrays[arrays.size() - 1]) { + Array sub_array = p_array; + sub_array.push_back(a); + p_arrays.push_back(sub_array); + } + } + return p_arrays; + } }; /*! diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index b99b1a136334..9e437f705c34 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -43,7 +43,9 @@ void TensorflowCodeGen::CodeGenHelper() { .cond_if("name in weights") .func_call("tf_v1.get_variable", "var") .call_arg("name") - .call_arg("weights[name].asnumpy()", "initializer") + .inplace_start("asnumpy", DocUtils::ToDoc("initializer"), + DocUtils::ToIndex("weights", "name")) + .inplace_end() .cond_else() .func_call("tf_v1.get_variable", "var") .call_arg("name") @@ -69,9 +71,9 @@ void TensorflowCodeGen::CodeGenGraph() { } for (const auto& pair : node->weights) { stack_.func_call("get_variable", IdxWeightBase(node, pair.first, false)) - .call_arg(DocUtils::ToStrDoc(pair.second->name)) - .call_arg(DocUtils::ToListDoc(pair.second->shape, true)) - .call_arg(DocUtils::ToStrDoc(pair.second->DTypeName())) + .call_arg(DocUtils::ToStr(pair.second->name)) + .call_arg(DocUtils::ToList(pair.second->shape, true)) + .call_arg(DocUtils::ToStr(pair.second->DTypeName())) .call_arg("weights"); } } @@ -92,7 +94,7 @@ void TensorflowCodeGen::CodeGenGraph() { if (idx_outputs.size() == 1) { stack_.assign("outputs", idx_outputs[0]); } else { - stack_.assign("outputs", DocUtils::ToListDoc(idx_outputs)); + stack_.assign("outputs", DocUtils::ToList(idx_outputs)); } stack_.func_end("outputs"); } @@ -101,7 +103,7 @@ void TensorflowCodeGen::CodeGenInference() { stack_.comment("Load weights") .scope_start("open(\"" + graph()->name + "_params.bin\", \"rb\")", "f") .func_call("tvm.runtime.load_param_dict", "params") - .func_call("f.read") + .func_call("read", "", "f") .pop_nest() .scope_end() .comment("Build Graph") @@ -109,9 +111,9 @@ void TensorflowCodeGen::CodeGenInference() { for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); stack_.func_call("tf_v1.placeholder", IdxNodeBase(producer)) - .call_arg(DocUtils::ToStrDoc(i->DTypeName())) - .call_arg(DocUtils::ToListDoc(i->shape)) - .call_arg(DocUtils::ToStrDoc(i->alias)); + .call_arg(DocUtils::ToStr(i->DTypeName())) + .call_arg(DocUtils::ToList(i->shape)) + .call_arg(DocUtils::ToStr(i->alias)); } stack_.func_call(graph()->name, "outs"); for (const auto& i : graph()->GetInputs()) { @@ -121,13 +123,14 @@ void TensorflowCodeGen::CodeGenInference() { stack_.call_arg("params").assign("feed_dict", "{}"); for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); - stack_.assign("feed_dict[" + IdxNodeBase(producer) + "]", "inputs[\"" + i->alias + "\"]"); + stack_.assign(DocUtils::ToIndex("feed_dict", IdxNodeBase(producer)), + DocUtils::ToIndex("inputs", DocUtils::ToStr(i->alias))); } stack_.scope_start("tf_v1.Session()", "sess") - .func_call("sess.run") + .func_call("run", "", "sess") .func_call("ops.variables.global_variables_initializer") .pop_nest() - .func_call("sess.run", "outputs") + .func_call("run", "outputs", "sess") .call_arg("outs") .call_arg("feed_dict", "feed_dict") .scope_end() @@ -149,7 +152,7 @@ const Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { TVM_REGISTER_GLOBAL("msc.framework.tensorflow.GetTensorflowSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, - const String print_config) -> Map { + const String& print_config) -> Map { TensorflowCodeGen codegen = TensorflowCodeGen(graph, codegen_config); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc index 56dba21ac8a2..9c8f0aeb860a 100644 --- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc +++ b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc @@ -167,20 +167,17 @@ class TFV1BatchnormCodeGen : public TFV1OpCode { .op_arg("scale") .op_arg("center") .op_arg("momentum") - .op_arg("epsilon") - .call_arg("tf_v1.constant_initializer(weights[\"" + node()->WeightAt("gamma")->name + - "\"].asnumpy())", - "gamma_initializer") - .call_arg("tf_v1.constant_initializer(weights[\"" + node()->WeightAt("beta")->name + - "\"].asnumpy())", - "beta_initializer") - .call_arg("tf_v1.constant_initializer(weights[\"" + node()->WeightAt("mean")->name + - "\"].asnumpy())", - "moving_mean_initializer") - .call_arg("tf_v1.constant_initializer(weights[\"" + node()->WeightAt("var")->name + - "\"].asnumpy())", - "moving_variance_initializer") - .op_name_arg(); + .op_arg("epsilon"); + Array weight_names{"gamma", "beta", "mean", "var"}; + Array init_names{"gamma", "beta", "moving_mean", "moving_variance"}; + for (size_t i = 0; i < weight_names.size(); i++) { + const auto& w_doc = DocUtils::ToStr(node()->WeightAt(weight_names[i])->name); + stack_.inplace_start("tf_v1.constant_initializer", init_names[i] + "_initializer") + .inplace_start("asnumpy", NullOpt, DocUtils::ToIndex("weights", w_doc)) + .inplace_end() + .inplace_end(); + } + stack_.op_name_arg(); } }; @@ -252,13 +249,13 @@ class TFV1ConvCodeGen : public TFV1OpCode { } stack_.op_input_arg() .op_weight_arg("weight") - .call_arg(DocUtils::ToListDoc(strides), "strides") - .call_arg(DocUtils::ToListDoc(dilation), "dilations") + .call_arg(DocUtils::ToList(strides), "strides") + .call_arg(DocUtils::ToList(dilation), "dilations") .op_str_arg("data_layout", "data_format"); if (pair.first.size() > 0) { - stack_.call_arg(DocUtils::ToStrDoc(pair.first), "padding"); + stack_.call_arg(DocUtils::ToStr(pair.first), "padding"); } else if (pair.second.size() > 0) { - stack_.call_arg(DocUtils::ToListDoc(pair.second), "padding"); + stack_.call_arg(DocUtils::ToList(pair.second), "padding"); } else { LOG_FATAL << "Can not parse padding for " << node(); } @@ -290,7 +287,7 @@ class TFV1EinsumCodeGen : public TFV1OpCode { const auto& producer = node()->ProducerOf(0); stack_.op_call().op_str_arg("subscripts", ""); if (node()->inputs.size() == 1 && producer->optype == "tuple") { - stack_.call_arg(DocUtils::ToIndexDoc(IdxInput(), std::vector{0})); + stack_.call_arg(DocUtils::ToIndex(IdxInput(), 0)); } else { stack_.op_inputs_arg(false); } @@ -340,8 +337,8 @@ class TFV1PadCodeGen : public TFV1OpCode { ICHECK(val_producer->optype == "constant" && val_producer->HasAttr("scalar")); stack_.op_call() .op_input_arg() - .call_arg(DocUtils::ToListDoc(pad_width), "paddings") - .call_arg(DocUtils::ToStrDoc(mode), "mode") + .call_arg(DocUtils::ToList(pad_width), "paddings") + .call_arg(DocUtils::ToStr(mode), "mode") .call_arg(val_producer->GetTypeAttr("scalar"), "constant_values") .op_name_arg(); } @@ -364,13 +361,13 @@ class TFV1Pool2dCodeGen : public TFV1OpCode { stack_.op_call() .op_input_arg() .op_list_arg("pool_size", "window_shape") - .call_arg(DocUtils::ToStrDoc(pooling_type), "pooling_type") + .call_arg(DocUtils::ToStr(pooling_type), "pooling_type") .op_list_arg("dilation", "dilation_rate") .op_list_arg("strides"); if (pair.first.size() > 0) { - stack_.call_arg(DocUtils::ToStrDoc(pair.first), "padding"); + stack_.call_arg(DocUtils::ToStr(pair.first), "padding"); } else if (pair.second.size() > 0) { - stack_.call_arg(DocUtils::ToListDoc(pair.second), "padding"); + stack_.call_arg(DocUtils::ToList(pair.second), "padding"); } else { LOG_FATAL << "Can not parse padding for " << node(); } @@ -389,7 +386,7 @@ class TFV1PermuteDimsCodeGen : public TFV1OpCode { axes.push_back(i - 1); } } - stack_.op_call().op_input_arg().call_arg(DocUtils::ToListDoc(axes)).op_name_arg(); + stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(axes)).op_name_arg(); } }; @@ -454,7 +451,7 @@ class TFV1SplitCodeGen : public TFV1OpCode { for (size_t i = 0; i < node()->outputs.size(); i++) { indices.push_back(node()->OutputAt(i)->DimAt(axis)->value); } - stack_.call_arg(DocUtils::ToListDoc(indices), "num_or_size_splits") + stack_.call_arg(DocUtils::ToList(indices), "num_or_size_splits") .op_arg("axis") .op_name_arg(); } diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 4cba1bdce3d8..c3d659015c18 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -68,7 +68,7 @@ void TensorRTCodeGen::CodeGenClassDeclare() { // private scope stack_.scope_start("private:").declare("std::map", "mWeights").scope_end(); // end class declare - stack_.class_end().line(); + stack_.class_end(); // declare test function stack_.func_def("test_" + graph()->name, "bool") .func_arg("engine", "std::shared_ptr&") @@ -83,15 +83,15 @@ void TensorRTCodeGen::CodeGenClassDefine() { auto malloc_buffer = [this](const MSCTensor& tensor) { const String& idx_var = "idx_" + IdxTensor(tensor); this->stack_ - .func_call("getBindingIndex", DocUtils::ToDeclareDoc("int", idx_var), - DocUtils::ToPtrDoc("engine")) - .call_arg(DocUtils::ToStrDoc(tensor->name)) + .func_call("getBindingIndex", DocUtils::ToDeclare("int", idx_var), + DocUtils::ToPtr("engine")) + .call_arg(DocUtils::ToStr(tensor->name)) .func_call("CHECK") .func_call("cudaMalloc") - .call_arg("&gpu_buffers[" + idx_var + "]") + .call_arg(DocUtils::ToIndex("&gpu_buffers", idx_var)) .call_arg(GetTensorBytes(tensor)) .pop_nest() - .func_call("malloc", "cpu_buffers[" + idx_var + "]") + .func_call("malloc", DocUtils::ToIndex("cpu_buffers", idx_var)) .call_arg(GetTensorBytes(tensor)); }; stack_.line("#include \"" + graph()->name + ".h\"").line(); @@ -111,7 +111,8 @@ void TensorRTCodeGen::CodeGenClassDefine() { before_build_codes_ = (*pf)(GetStepCtx(), "before_build", graph()->name, config()->tools_tag); } if (graph()->weight_holders.size() > 0) { - stack_.assign("mWeights", "TRTUtils::LoadWeights(\"" + graph()->name + ".wts\")"); + stack_.func_call("TRTUtils::LoadWeights", "mWeights") + .call_arg(DocUtils::ToStr(graph()->name + ".wts")); } // build layers for (const auto& n : graph()->node_names) { @@ -122,18 +123,18 @@ void TensorRTCodeGen::CodeGenClassDefine() { stack_.comment("Mark outputs"); for (const auto& o : graph()->GetOutputs()) { const auto& pair = graph()->FindProducerAndIdx(o); - stack_.func_call("markOutput", NullOpt, DocUtils::ToPtrDoc("network")) + stack_.func_call("markOutput", NullOpt, DocUtils::ToPtr("network")) .call_arg("*" + IdxOutputBase(pair.first, pair.second)); } // mark batch_size stack_.comment("Mark batch size"); - stack_.func_call("createOptimizationProfile", DocUtils::ToDeclareDoc("auto", "profile"), - DocUtils::ToPtrDoc("builder")); + stack_.func_call("createOptimizationProfile", DocUtils::ToDeclare("auto", "profile"), + DocUtils::ToPtr("builder")); Array batch_flags{"MIN", "MAX", "OPT"}; for (const auto& i : graph()->GetInputs()) { for (const auto& f : batch_flags) { - stack_.func_call("setDimensions", NullOpt, DocUtils::ToPtrDoc("profile")) - .call_arg(DocUtils::ToStrDoc(i->name)) + stack_.func_call("setDimensions", NullOpt, DocUtils::ToPtr("profile")) + .call_arg(DocUtils::ToStr(i->name)) .call_arg("OptProfileSelector::k" + f) .call_arg(ToDims(i->shape)); } @@ -141,52 +142,52 @@ void TensorRTCodeGen::CodeGenClassDefine() { // set max workspace stack_.comment("Set max worksapce"); if (CompareVersion(6, 0, 0) >= 0) { - stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtrDoc("config")) + stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtr("config")) .call_arg(config()->max_workspace); } else { - stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtrDoc("builder")) + stack_.func_call("setMaxWorkspaceSize", NullOpt, DocUtils::ToPtr("builder")) .call_arg(config()->max_workspace); } // set data type if (config()->precision == "float16") { stack_.comment("Set network precision") .cond_if("!builder->platformHasFastFp16()") - .func_call("logger.log") + .func_call("log", "", "logger") .call_arg("ILogger::Severity::kINTERNAL_ERROR") - .call_arg(DocUtils::ToStrDoc("platform do not support float16, fallback to float32")) + .call_arg(DocUtils::ToStr("platform do not support float16, fallback to float32")) .cond_else() - .func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + .func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kFP16"); if (config()->precision_mode == "strict") { - stack_.func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kSTRICT_TYPES"); } - stack_.func_call("logger.log") + stack_.func_call("log", "", "logger") .call_arg("ILogger::Severity::kINFO") - .call_arg(DocUtils::ToStrDoc("use float16 to build the engine")) + .call_arg(DocUtils::ToStr("use float16 to build the engine")) .cond_end(); } else if (config()->precision == "int8") { stack_.comment("Set network precision") .cond_if("!builder->platformHasFastInt8()") - .func_call("logger.log") + .func_call("log", "", "logger") .call_arg("ILogger::Severity::kINTERNAL_ERROR") - .call_arg(DocUtils::ToStrDoc("platform do not support int8, fallback to float32")) + .call_arg(DocUtils::ToStr("platform do not support int8, fallback to float32")) .cond_else() - .func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + .func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kINT8"); if (config()->precision_mode == "strict") { - stack_.func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kSTRICT_TYPES"); } else if (config()->precision_mode == "prefer") { - stack_.func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kPREFER_PRECISION_CONSTRAINTS"); } else if (config()->precision_mode == "obey") { - stack_.func_call("setFlag", NullOpt, DocUtils::ToPtrDoc("config")) + stack_.func_call("setFlag", NullOpt, DocUtils::ToPtr("config")) .call_arg("BuilderFlag::kOBEY_PRECISION_CONSTRAINTS"); } - stack_.func_call("logger.log") + stack_.func_call("log", "", "logger") .call_arg("ILogger::Severity::kINFO") - .call_arg(DocUtils::ToStrDoc("use int8 to build the engine")) + .call_arg(DocUtils::ToStr("use int8 to build the engine")) .cond_end(); } // save codegen after build @@ -204,8 +205,8 @@ void TensorRTCodeGen::CodeGenClassDefine() { .func_arg("logger", "TRTLogger&") .func_start(); stack_.comment("Create context") - .func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "context")) - .func_call("createExecutionContext", NullOpt, DocUtils::ToPtrDoc("engine")) + .func_call("TRTPtr", DocUtils::ToDeclare("auto", "context")) + .func_call("createExecutionContext", NullOpt, DocUtils::ToPtr("engine")) .pop_nest(); ReturnOnFail("context", "Failed to create the context"); // prepare variables @@ -237,8 +238,8 @@ void TensorRTCodeGen::CodeGenClassDefine() { for (const auto& i : graph()->GetInputs()) { stack_.func_call("CHECK") .func_call("cudaMemcpyAsync") - .call_arg("gpu_buffers[idx_" + IdxTensor(i) + "]") - .call_arg("cpu_buffers[idx_" + IdxTensor(i) + "]") + .call_arg(DocUtils::ToIndex("gpu_buffers", "idx_" + IdxTensor(i))) + .call_arg(DocUtils::ToIndex("cpu_buffers", "idx_" + IdxTensor(i))) .call_arg(GetTensorBytes(i)) .call_arg("cudaMemcpyHostToDevice") .call_arg("stream") @@ -248,7 +249,7 @@ void TensorRTCodeGen::CodeGenClassDefine() { stack_.func_call("cudaStreamSynchronize") .call_arg("stream") .comment("enquque with gpu buffers") - .func_call("enqueueV2", NullOpt, DocUtils::ToPtrDoc("context")) + .func_call("enqueueV2", NullOpt, DocUtils::ToPtr("context")) .call_arg("gpu_buffers") .call_arg("stream") .call_arg("nullptr") @@ -258,7 +259,7 @@ void TensorRTCodeGen::CodeGenClassDefine() { stack_.func_call("CHECK") .func_call("cudaMemcpyAsync") .call_arg("output_" + IdxTensor(o)) - .call_arg("gpu_buffers[idx_" + IdxTensor(o) + "]") + .call_arg(DocUtils::ToIndex("gpu_buffers", "idx_" + IdxTensor(o))) .call_arg(GetTensorBytes(o)) .call_arg("cudaMemcpyDeviceToHost") .call_arg("stream") @@ -281,10 +282,10 @@ void TensorRTCodeGen::CodeGenClassDefine() { .for_start("i", 0, binding_num) .func_call("CHECK") .func_call("cudaFree") - .call_arg("gpu_buffers[i]") + .call_arg(DocUtils::ToIndex("gpu_buffers", "i")) .pop_nest() .func_call("free") - .call_arg("cpu_buffers[i]") + .call_arg(DocUtils::ToIndex("cpu_buffers", "i")) .for_end(); // end define test method stack_.func_end("true"); @@ -302,7 +303,7 @@ void TensorRTCodeGen::CodeGenMain() { .func_arg("argv", "char**") .func_start() .declare("TRTLogger", "logger") - .func_call("logger.setLogSeverity"); + .func_call("setLogSeverity", "", "logger"); if (config()->log_level == 0) { stack_.call_arg("ILogger::Severity::kINFO"); } else if (config()->log_level == 1) { @@ -324,7 +325,7 @@ void TensorRTCodeGen::CodeGenMain() { .cond_if("!FileUtils::FileExist(\"" + graph()->name + ".trt\")"); // create builder stack_.comment("Create TensorRT tools") - .func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "builder")) + .func_call("TRTPtr", DocUtils::ToDeclare("auto", "builder")) .func_call("createInferBuilder") .call_arg("logger") .pop_nest(); @@ -335,19 +336,19 @@ void TensorRTCodeGen::CodeGenMain() { .assign("flags", "1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)", "uint32_t") - .func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "network")) - .func_call("createNetworkV2", NullOpt, DocUtils::ToPtrDoc("builder")) + .func_call("TRTPtr", DocUtils::ToDeclare("auto", "network")) + .func_call("createNetworkV2", NullOpt, DocUtils::ToPtr("builder")) .call_arg("flags") .pop_nest(); } else { - stack_.func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "network")) - .func_call("createNetwork", NullOpt, DocUtils::ToPtrDoc("builder")) + stack_.func_call("TRTPtr", DocUtils::ToDeclare("auto", "network")) + .func_call("createNetwork", NullOpt, DocUtils::ToPtr("builder")) .pop_nest(); } ReturnOnFail("network", "Failed to create network"); // create config - stack_.func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "config")) - .func_call("createBuilderConfig", NullOpt, DocUtils::ToPtrDoc("builder")) + stack_.func_call("TRTPtr", DocUtils::ToDeclare("auto", "config")) + .func_call("createBuilderConfig", NullOpt, DocUtils::ToPtr("builder")) .pop_nest(); ReturnOnFail("config", "Failed to create config"); // add codegen before build @@ -357,7 +358,7 @@ void TensorRTCodeGen::CodeGenMain() { // build model stack_.comment("Build model") .declare(graph()->name, "model") - .func_call("model.Build", "pass") + .func_call("Build", "pass", "model") .call_arg("builder") .call_arg("network"); if (CompareVersion(6, 0, 0) >= 0) { @@ -381,12 +382,12 @@ void TensorRTCodeGen::CodeGenMain() { .assign("profile_verbose", "ProfilingVerbosity::kNONE") .cond_end() .cond_end() - .func_call("setProfilingVerbosity", NullOpt, DocUtils::ToPtrDoc("config")) + .func_call("setProfilingVerbosity", NullOpt, DocUtils::ToPtr("config")) .call_arg("profile_verbose"); // Serialize engine stack_.comment("Serialize engine") .func_call("TRTUtils::SerializeEngineToFile", "pass") - .call_arg(DocUtils::ToStrDoc(graph()->name + ".trt")) + .call_arg(DocUtils::ToStr(graph()->name + ".trt")) .call_arg("builder") .call_arg("network"); if (CompareVersion(6, 0, 0) >= 0) { @@ -400,21 +401,21 @@ void TensorRTCodeGen::CodeGenMain() { stack_.comment("Deserialize engine") .declare("std::shared_ptr", "engine") .func_call("TRTUtils::DeserializeEngineFromFile", "pass") - .call_arg(DocUtils::ToStrDoc(graph()->name + ".trt")) + .call_arg(DocUtils::ToStr(graph()->name + ".trt")) .call_arg("engine") .call_arg("logger"); ReturnOnFail("pass", "Failed to deserialize the engine"); // dump info by inspector stack_.comment("Dump info by inspector") .cond_if("profile_level > 0") - .func_call("TRTPtr", DocUtils::ToDeclareDoc("auto", "inspector")) - .func_call("createEngineInspector", NullOpt, DocUtils::ToPtrDoc("engine")) + .func_call("TRTPtr", DocUtils::ToDeclare("auto", "inspector")) + .func_call("createEngineInspector", NullOpt, DocUtils::ToPtr("engine")) .pop_nest() - .func_call("getEngineInformation", DocUtils::ToDeclareDoc("std::string", "result"), - DocUtils::ToPtrDoc("inspector")) + .func_call("getEngineInformation", DocUtils::ToDeclare("std::string", "result"), + DocUtils::ToPtr("inspector")) .call_arg("LayerInformationFormat::kJSON") .declare("std::ofstream", "os") - .declare_arg(DocUtils::ToStrDoc(graph()->name + "_info.json")) + .declare_arg(DocUtils::ToStr(graph()->name + "_info.json")) .declare_arg("std::ofstream::trunc") .line("os << result << std::flush;") .cond_end(); @@ -422,7 +423,7 @@ void TensorRTCodeGen::CodeGenMain() { if (config()->test_iter > 0) { stack_.comment("Prepare dataset") .declare("DatasetReader", "reader") - .declare_arg(DocUtils::ToStrDoc(config()->dataset)) + .declare_arg(DocUtils::ToStr(config()->dataset)) .declare_arg(config()->test_iter); stack_.comment("Test engine by datas") .func_call("test_" + graph()->name, "pass") @@ -438,21 +439,21 @@ void TensorRTCodeGen::CodeGenCmake() { stack_.line("cmake_minimum_required(VERSION " + config()->cmake_version + " FATAL_ERROR)") .line("project(" + graph()->name + ")") .line("find_package(CUDA)") - .line("find_path(TENSORRT_INCLUDE_DIR NvInfer.h HINTS " + config()->tensorrt_root + + .line("find_path(TRT_INCLUDE_DIR NvInfer.h HINTS " + config()->tensorrt_root + " PATH_SUFFIXES include)") - .line("find_library(TENSORRT_LIB_DIR nvinfer HINTS " + config()->tensorrt_root + + .line("find_library(TRT_LIBS nvinfer HINTS " + config()->tensorrt_root + " PATH_SUFFIXES lib)") .line( - "message(STATUS \"Build project with TENSORRT_INCLUDE_DIR ${TENSORRT_INCLUDE_DIR} and " - "TENSORRT_LIB_DIR " - "${TENSORRT_LIB_DIR}\")") + "message(STATUS \"Build project with TRT_INCLUDE_DIR ${TRT_INCLUDE_DIR} and " + "TRT_LIBS " + "${TRT_LIBS}\")") .line("add_definitions(-DTRT_MAJOR=" + std::to_string(config()->version[0]) + ")") .line("add_definitions(-DTRT_MINOR=" + std::to_string(config()->version[1]) + ")") .line("add_definitions(-DTRT_PATCH=" + std::to_string(config()->version[2]) + ")") .line("file(GLOB_RECURSE TRT_SRCS *.cc)") .line("cuda_add_executable(" + graph()->name + " ${TRT_SRCS})") - .line("target_include_directories(" + graph()->name + " PUBLIC ${TENSORRT_INCLUDE_DIR})") - .line("target_link_libraries(" + graph()->name + " ${TENSORRT_LIB_DIR})"); + .line("target_include_directories(" + graph()->name + " PUBLIC ${TRT_INCLUDE_DIR})") + .line("target_link_libraries(" + graph()->name + " ${TRT_LIBS})"); } const String TensorRTCodeGen::IdxTensor(const MSCTensor& tensor) { @@ -489,7 +490,7 @@ void TensorRTCodeGen::ReturnOnFail(const String& flag, const String& err) { stack_.cond_if("!" + flag) .func_call("logger.log") .call_arg("ILogger::Severity::kERROR") - .call_arg(DocUtils::ToStrDoc(err)) + .call_arg(DocUtils::ToStr(err)) .line("return -1;") .cond_end(); } @@ -551,7 +552,7 @@ const Map TensorRTCodeGen::GetStepCtx() { TVM_REGISTER_GLOBAL("msc.framework.tensorrt.GetTensorRTSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, - const String print_config) -> Map { + const String& print_config) -> Map { TensorRTCodeGen codegen = TensorRTCodeGen(graph, codegen_config); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc index fc6217c31de1..2cacd11907ff 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -36,17 +36,17 @@ const Array TensorRTOpCode::GetDocs() { CodeGenBuild(); if (node()->optype == "tuple") { for (size_t i = 0; i < node()->outputs.size(); i++) { - stack_.func_call("setName", NullOpt, DocUtils::ToPtrDoc(IdxOutput(i))) - .call_arg(DocUtils::ToStrDoc(node()->OutputAt(i)->name)); + stack_.func_call("setName", NullOpt, DocUtils::ToPtr(IdxOutput(i))) + .call_arg(DocUtils::ToStr(node()->OutputAt(i)->name)); } } else if (node()->optype == "get_item") { - stack_.func_call("setName", NullOpt, DocUtils::ToPtrDoc(IdxNode())) - .call_arg(DocUtils::ToStrDoc(node()->OutputAt(0)->name)); + stack_.func_call("setName", NullOpt, DocUtils::ToPtr(IdxNode())) + .call_arg(DocUtils::ToStr(node()->OutputAt(0)->name)); } else if (node()->optype != "input") { - SetLayerByValue("Name", DocUtils::ToStrDoc(node()->name)); + SetLayerByValue("Name", DocUtils::ToStr(node()->name)); for (size_t i = 0; i < node()->outputs.size(); i++) { - stack_.func_call("setName", NullOpt, DocUtils::ToPtrDoc(IdxOutput(i))) - .call_arg(DocUtils::ToStrDoc(node()->OutputAt(i)->name)); + stack_.func_call("setName", NullOpt, DocUtils::ToPtr(IdxOutput(i))) + .call_arg(DocUtils::ToStr(node()->OutputAt(i)->name)); } } return stack_.GetDocs(); @@ -155,29 +155,29 @@ const size_t TensorRTOpCode::AttrToAxis(const String& key, size_t ndim) { template void TensorRTOpCode::SetLayerByAttr(const String& method, const String& key) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())).op_arg(key, ""); + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())).op_arg(key, ""); } template void TensorRTOpCode::SetLayerByValue(const String& method, const T& value) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())).call_arg(value); + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())).call_arg(value); } void TensorRTOpCode::SetLayerByDimsAttr(const String& method, const String& key, bool use_ndim) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())) + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())) .call_arg(AttrToDims(key, use_ndim)); } template void TensorRTOpCode::SetLayerByDimsValue(const String& method, const std::vector& value, bool use_ndim) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())) + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())) .call_arg(ToDims(value, use_ndim)); } void TensorRTOpCode::SetLayerByDimsValue(const String& method, const Array& value, bool use_ndim) { - stack_.func_call("set" + method, NullOpt, DocUtils::ToPtrDoc(IdxNode())) + stack_.func_call("set" + method, NullOpt, DocUtils::ToPtr(IdxNode())) .call_arg(ToDims(value, use_ndim)); } @@ -267,7 +267,7 @@ class TensorRTAstypeCodeGen : public TensorRTOpCode { void CodeGenBuild() final { stack_.op_call() .op_input_arg() - .func_call("setOutput", NullOpt, DocUtils::ToPtrDoc(IdxNode())) + .func_call("setOutput", NullOpt, DocUtils::ToPtr(IdxNode())) .call_arg(0) .op_dtype_arg(node()->OutputAt(0)->dtype); } @@ -298,7 +298,10 @@ class TensorRTConcatCodeGen : public TensorRTOpCode { const auto& producer = node()->ProducerOf(0); ICHECK(node()->parents.size() == 1 && producer->optype == "tuple") << "Concat expect parent as tuple, get " << node(); - stack_.op_call().call_arg(IdxNodeBase(producer) + ".data()").call_arg(producer->inputs.size()); + stack_.op_call() + .inplace_start("data", "", IdxNodeBase(producer)) + .inplace_end() + .call_arg(producer->inputs.size()); SetLayerByValue("Axis", AttrToAxis()); } }; @@ -386,7 +389,7 @@ class TensorRTInputCodeGen : public TensorRTOpCode { void CodeGenBuild() final { const auto& output = node()->OutputAt(0); stack_.op_call() - .call_arg(DocUtils::ToStrDoc(output->name)) + .call_arg(DocUtils::ToStr(output->name)) .op_dtype_arg(output->dtype) .call_arg(ToDims(output->shape)); } @@ -405,7 +408,7 @@ class TensorRTLinearCodeGen : public TensorRTOpCode { if (use_bias_) { stack_.op_weight_arg("bias"); } else { - stack_.call_arg("mWeights[\"" + node()->name + ".bias\"]"); + stack_.call_arg(DocUtils::ToIndex("mWeights", DocUtils::ToStr(node()->name + ".bias"))); } } diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index 6d21f590d0ba..012f0311b26d 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -92,7 +92,7 @@ void TorchCodeGen::CodeGenGraph() { if (idx_outputs.size() == 1) { stack_.assign("outputs", idx_outputs[0]); } else { - stack_.assign("outputs", DocUtils::ToListDoc(idx_outputs)); + stack_.assign("outputs", DocUtils::ToList(idx_outputs)); } stack_.func_end("outputs"); stack_.class_end(); @@ -103,16 +103,16 @@ void TorchCodeGen::CodeGenInference() { .func_call(graph()->name, "model") .comment("Load weights") .func_call("torch.load", "weights") - .call_arg(DocUtils::ToStrDoc(graph()->name + ".pth")) - .func_call("model.load_state_dict") + .call_arg(DocUtils::ToStr(graph()->name + ".pth")) + .func_call("load_state_dict", "", "model") .call_arg("weights"); if (config()->test_device == "gpu") { - stack_.func_call("model.to").func_call("torch.device").call_arg("cuda").pop_nest(); + stack_.func_call("to", "", "model").func_call("torch.device").call_arg("cuda").pop_nest(); } for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); stack_.func_call("torch.from_numpy", IdxNodeBase(producer)) - .call_arg("inputs[\"" + i->alias + "\"]"); + .call_arg(DocUtils::ToIndex("inputs", DocUtils::ToStr(i->alias))); } stack_.func_call("model", "outputs"); for (const auto& i : graph()->GetInputs()) { @@ -139,7 +139,7 @@ const Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { TVM_REGISTER_GLOBAL("msc.framework.torch.GetTorchSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, - const String print_config) -> Map { + const String& print_config) -> Map { TorchCodeGen codegen = TorchCodeGen(graph, codegen_config); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index 341181bc2afe..5086678758f2 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -73,7 +73,7 @@ const StrictListDoc TorchOpCode::GetPadding(const String& key) { } else { LOG_FATAL << "Unexpected padding node" << node(); } - return DocUtils::ToListDoc(padding); + return DocUtils::ToList(padding); } #define TORCH_OP_CODEGEN_METHODS(TypeName) \ @@ -175,9 +175,9 @@ class TorchBatchNormCodeGen : public TorchOpCode { .op_weight_arg("var") .op_weight_arg("gamma") .op_weight_arg("beta") - .call_arg(DocUtils::ToAttrAccessDoc(module_ref(), "training")) - .call_arg(DocUtils::ToAttrAccessDoc(module_ref(), "momentum")) - .call_arg(DocUtils::ToAttrAccessDoc(module_ref(), "eps")); + .call_arg(DocUtils::ToAttrAccess(module_ref(), "training")) + .call_arg(DocUtils::ToAttrAccess(module_ref(), "momentum")) + .call_arg(DocUtils::ToAttrAccess(module_ref(), "eps")); } else { TorchOpCode::CodeGenForward(); } @@ -244,7 +244,7 @@ class TorchConvCodeGen : public TorchOpCode { stack_.op_call() .call_arg(weight->DimAt("I"), "in_channels") .call_arg(weight->DimAt("O"), "out_channels") - .call_arg(DocUtils::ToListDoc(kernel_size), "kernel_size") + .call_arg(DocUtils::ToList(kernel_size), "kernel_size") .op_list_arg("strides", "stride") .call_arg(GetPadding(), "padding") .op_list_arg("dilation") @@ -260,10 +260,10 @@ class TorchConvCodeGen : public TorchOpCode { } else { stack_.call_arg("None"); } - stack_.call_arg(DocUtils::ToAttrAccessDoc(module_ref(), "stride")) - .call_arg(DocUtils::ToAttrAccessDoc(module_ref(), "padding")) - .call_arg(DocUtils::ToAttrAccessDoc(module_ref(), "dilation")) - .call_arg(DocUtils::ToAttrAccessDoc(module_ref(), "groups")); + stack_.call_arg(DocUtils::ToAttrAccess(module_ref(), "stride")) + .call_arg(DocUtils::ToAttrAccess(module_ref(), "padding")) + .call_arg(DocUtils::ToAttrAccess(module_ref(), "dilation")) + .call_arg(DocUtils::ToAttrAccess(module_ref(), "groups")); } else { TorchOpCode::CodeGenForward(); } @@ -365,7 +365,7 @@ class TorchLayerNormCodeGen : public TorchOpCode { normalized_shape.push_back(node()->InputAt(0)->DimAt(a)); } stack_.op_call() - .call_arg(DocUtils::ToListDoc(normalized_shape), "normalized_shape") + .call_arg(DocUtils::ToList(normalized_shape), "normalized_shape") .op_arg("epsilon", "eps"); } }; @@ -437,7 +437,7 @@ class TorchPermuteDimsCodeGen : public TorchOpCode { axes.push_back(i - 1); } } - stack_.op_call().op_input_arg().call_arg(DocUtils::ToListDoc(axes)); + stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(axes)); } }; @@ -475,7 +475,7 @@ class TorchReduceAxesCodeGen : public TorchOpCode { has_axes = true; } if (has_axes) { - stack_.call_arg(DocUtils::ToListDoc(axes), "dim").op_arg("keepdims", "keepdim"); + stack_.call_arg(DocUtils::ToList(axes), "dim").op_arg("keepdims", "keepdim"); } } }; @@ -497,7 +497,7 @@ class TorchRepeatCodeGen : public TorchOpCode { } stack_.assign(IdxNode(), IdxInput()) .method_call("repeat") - .call_arg(DocUtils::ToListDoc(repeats), ""); + .call_arg(DocUtils::ToList(repeats), ""); } }; @@ -514,7 +514,7 @@ class TorchReshapeCodeGen : public TorchOpCode { shape.Set(batch_dim, Integer(-1)); } } - stack_.op_call().op_input_arg().call_arg(DocUtils::ToListDoc(shape)); + stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(shape)); } }; @@ -530,7 +530,7 @@ class TorchResize2dCodeGen : public TorchOpCode { } else { LOG(FATAL) << "Unexpected resize2d method " << method; } - stack_.op_call().op_input_arg().op_list_arg("size").call_arg(DocUtils::ToStrDoc(v_method), + stack_.op_call().op_input_arg().op_list_arg("size").call_arg(DocUtils::ToStr(v_method), "mode"); } }; @@ -563,8 +563,7 @@ class TorchSplitCodeGen : public TorchOpCode { for (size_t i = 0; i < node()->outputs.size(); i++) { indices.push_back(node()->OutputAt(i)->DimAt(axis)->value); } - stack_.call_arg(DocUtils::ToListDoc(indices), "split_size_or_sections") - .op_arg("axis", "dim"); + stack_.call_arg(DocUtils::ToList(indices), "split_size_or_sections").op_arg("axis", "dim"); } }; @@ -591,7 +590,7 @@ class TorchStridedSliceCodeGen : public TorchOpCode { slice.push_back(":"); } } - stack_.assign(IdxNode(), DocUtils::ToIndexDoc(IdxInput(), slice)); + stack_.assign(IdxNode(), DocUtils::ToIndices(IdxInput(), slice)); } }; diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index c8956ca399ff..0b65240401fc 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -40,7 +40,7 @@ void RelaxCodeGen::CodeGenGraph() { stack_.func_arg(idx_input, "relax.Var"); idx_inputs.push_back(idx_input); } - stack_.func_start().assign("inputs", DocUtils::ToListDoc(idx_inputs, true)); + stack_.func_start().assign("inputs", DocUtils::ToList(idx_inputs, true)); // define weights stack_.comment("Define the weights"); for (const auto& n : graph()->node_names) { @@ -48,21 +48,21 @@ void RelaxCodeGen::CodeGenGraph() { for (const auto& pair : node->weights) { const auto& idx_weight = IdxWeightBase(node, pair.first, false); stack_.func_call("relax.Var", idx_weight) - .call_arg(DocUtils::ToStrDoc(pair.second->name)) + .call_arg(DocUtils::ToStr(pair.second->name)) .func_call("relax.TensorStructInfo") - .call_arg(DocUtils::ToListDoc(pair.second->shape, true), "") - .call_arg(DocUtils::ToStrDoc(pair.second->DTypeName())) + .call_arg(DocUtils::ToList(pair.second->shape, true), "") + .call_arg(DocUtils::ToStr(pair.second->DTypeName())) .pop_nest() - .func_call("inputs.append") + .func_call("append", "", "inputs") .call_arg(idx_weight); } } stack_.comment("Define the module"); - stack_.assign("block_builder", "relax.BlockBuilder()") + stack_.func_call("relax.BlockBuilder", "block_builder") .scope_start("block_builder.function(name=\"" + graph()->name + "\", params=inputs.copy())"); if (config()->use_tools) { stack_.func_call("msc_tools.execute_step") - .call_arg(DocUtils::ToStrDoc("before_build")) + .call_arg(DocUtils::ToStr("before_build")) .call_arg("block_builder"); } for (const auto& n : graph()->node_names) { @@ -99,44 +99,42 @@ void RelaxCodeGen::CodeGenGraph() { const auto& t_output = IdxOutputBase(e, o_idx, true); tuple_outputs.push_back(t_output); } - stack_.func_call("relax.Tuple", idx_exit).call_arg(DocUtils::ToListDoc(tuple_outputs)); - stack_.func_call("block_builder.emit", idx_exit).call_arg(idx_exit); - stack_.call_arg(DocUtils::ToStrDoc(e->name + "_exit"), "name_hint"); + stack_.func_call("relax.Tuple", idx_exit).call_arg(DocUtils::ToList(tuple_outputs)); + stack_.func_call("emit", idx_exit, "block_builder").call_arg(idx_exit); + stack_.call_arg(DocUtils::ToStr(e->name + "_exit"), "name_hint"); } } - stack_.func_call("block_builder.emit_output", idx_exit).call_arg(idx_exit); + stack_.func_call("emit_output", idx_exit, "block_builder").call_arg(idx_exit); idx_exits.push_back(idx_exit); } stack_.scope_end(); if (config()->use_tools) { - stack_.func_call("msc_tools.execute_step", "output") - .call_arg(DocUtils::ToStrDoc("after_build")); + stack_.func_call("msc_tools.execute_step", "output").call_arg(DocUtils::ToStr("after_build")); if (idx_exits.size() == 1) { stack_.call_arg(idx_exits[0]); } else { - stack_.call_arg(DocUtils::ToListDoc(idx_exits)); + stack_.call_arg(DocUtils::ToList(idx_exits)); } } - stack_.func_call("block_builder.emit_func_output"); + stack_.func_call("emit_func_output", "", "block_builder"); if (config()->use_tools) { stack_.call_arg("output"); } else if (idx_exits.size() == 1) { stack_.call_arg(idx_exits[0]); } else { - stack_.call_arg(DocUtils::ToListDoc(idx_exits)); + stack_.call_arg(DocUtils::ToList(idx_exits)); } - - stack_.scope_end().assign("mod", "block_builder.get()").func_end("mod"); + stack_.scope_end().func_call("finalize", "mod", "block_builder").func_end("mod"); } void RelaxCodeGen::CodeGenInference() { for (const auto& i : graph()->GetInputs()) { const auto& producer = graph()->FindProducer(i); stack_.func_call("relax.Var", IdxNodeBase(producer)) - .call_arg(DocUtils::ToStrDoc(i->alias)) + .call_arg(DocUtils::ToStr(i->alias)) .func_call("relax.TensorStructInfo") - .call_arg(DocUtils::ToListDoc(i->shape)) - .call_arg(DocUtils::ToStrDoc(i->DTypeName())) + .call_arg(DocUtils::ToList(i->shape)) + .call_arg(DocUtils::ToStr(i->DTypeName())) .pop_nest(); } stack_.comment("Build Module").func_call(graph()->name, "mod"); @@ -155,15 +153,16 @@ void RelaxCodeGen::CodeGenInference() { stack_.comment("Load weights") .scope_start("open(\"" + graph()->name + "_params.bin\", \"rb\")", "f") .func_call("tvm.runtime.load_param_dict", "params") - .call_arg("f.read()") + .inplace_start("read", "", "f") + .inplace_end() .scope_end() .func_call("tvm.relax.transform.BindParams", "bind_params") - .call_arg(DocUtils::ToStrDoc("main")) + .call_arg(DocUtils::ToStr("main")) .call_arg("params") .func_call("bind_params", "mod") .call_arg("mod") .func_call("tvm.target.Target", "target") - .call_arg(DocUtils::ToStrDoc(target)) + .call_arg(DocUtils::ToStr(target)) .func_call("tvm.relax.transform.LegalizeOps()", "mod") .call_arg("mod") .scope_start("tvm.transform.PassContext(opt_level=3)") @@ -174,9 +173,10 @@ void RelaxCodeGen::CodeGenInference() { .call_arg("ex") .call_arg(device) .scope_end() - .func_call("vm[\"main\"]", "outputs"); + .assign("f_main", DocUtils::ToIndex("vm", DocUtils::ToStr("main"))) + .func_call("f_main", "outputs"); for (const auto& i : graph()->GetInputs()) { - stack_.call_arg("inputs[\"" + i->alias + "\"]"); + stack_.call_arg(DocUtils::ToIndex("inputs", DocUtils::ToStr(i->alias))); } } @@ -195,7 +195,7 @@ const Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { TVM_REGISTER_GLOBAL("msc.framework.tvm.GetRelaxSources") .set_body_typed([](const MSCGraph& graph, const String& codegen_config, - const String print_config) -> Map { + const String& print_config) -> Map { RelaxCodeGen codegen = RelaxCodeGen(graph, codegen_config); return codegen.GetSources(print_config); }); diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index 3bd40cbfd79f..3867083b90ca 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -49,7 +49,7 @@ const Array RelaxOpCode::GetDocs() { void RelaxOpCode::BuilderEmit(const String& ret, const String& name) { stack_.func_call("block_builder.emit", ret).call_arg(ret); if (name.size() > 0) { - stack_.call_arg(DocUtils::ToStrDoc(name), "name_hint"); + stack_.call_arg(DocUtils::ToStr(name), "name_hint"); } } @@ -60,9 +60,9 @@ const ExprDoc RelaxOpCode::GetOutDtype(const String& key, int input_idx) { } std::string out_dtype; if (!node()->GetAttr(key, &out_dtype) && config()->from_relay) { - return DocUtils::ToStrDoc(node()->OutputAt(0)->DTypeName()); + return DocUtils::ToStr(node()->OutputAt(0)->DTypeName()); } - return DocUtils::ToStrDoc(out_dtype); + return DocUtils::ToStr(out_dtype); } const std::vector RelaxOpCode::GetAxes(const String& key) { @@ -132,7 +132,7 @@ class RelaxAxesCodeGen : public RelaxOpCode { protected: void CodeGenBuild() final { const String& key = node()->HasAttr("axes") ? "axes" : "axis"; - stack_.op_call().op_input_arg().call_arg(DocUtils::ToListDoc(GetAxes(key)), key); + stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(GetAxes(key)), key); } }; @@ -154,7 +154,7 @@ class RelaxBatchMatmulCodeGen : public RelaxOpCode { axes.push_back(node()->InputAt(0)->Ndim() - 2); stack_.op_call("relax.op.permute_dims", IdxInput(0)) .op_input_arg() - .call_arg(DocUtils::ToListDoc(axes)); + .call_arg(DocUtils::ToList(axes)); BuilderEmit(IdxInput(0)); stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); } else if (!transpose_a && transpose_b) { @@ -166,7 +166,7 @@ class RelaxBatchMatmulCodeGen : public RelaxOpCode { axes.push_back(node()->InputAt(1)->Ndim() - 2); stack_.op_call("relax.op.permute_dims", IdxInput(1)) .op_input_arg(1) - .call_arg(DocUtils::ToListDoc(axes)); + .call_arg(DocUtils::ToList(axes)); BuilderEmit(IdxInput(1)); stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); } else { @@ -179,7 +179,7 @@ class RelaxBatchMatmulCodeGen : public RelaxOpCode { axes.push_back(node()->InputAt(idx)->Ndim() - 2); stack_.op_call("relax.op.permute_dims", IdxInput(idx)) .op_input_arg(idx) - .call_arg(DocUtils::ToListDoc(axes)); + .call_arg(DocUtils::ToList(axes)); BuilderEmit(IdxInput(idx)); } stack_.op_call().op_inputs_arg(false).op_str_arg("out_dtype"); @@ -222,7 +222,7 @@ class RelaxBiasAddCodeGen : public RelaxOpCode { } stack_.op_call("relax.op.reshape", IdxInput(1)) .op_input_arg(1) - .call_arg(DocUtils::ToListDoc(expand_shape), "shape"); + .call_arg(DocUtils::ToList(expand_shape), "shape"); BuilderEmit(IdxInput(1)); stack_.op_call().op_inputs_arg(false); } @@ -299,7 +299,7 @@ class RelaxConvCodeGen : public RelaxOpCode { BuilderEmit(IdxNode()); stack_.func_call("relax.op.reshape", "expand_bias") .op_weight_arg("bias") - .call_arg(DocUtils::ToListDoc(expand_shape), "shape"); + .call_arg(DocUtils::ToList(expand_shape), "shape"); BuilderEmit("expand_bias"); stack_.func_call("relax.op.add", IdxNode()).call_arg(IdxNode()).call_arg("expand_bias"); } @@ -355,7 +355,7 @@ class RelaxStridedSliceCodeGen : public RelaxOpCode { } stack_.op_call() .op_input_arg() - .call_arg(DocUtils::ToListDoc(axes), "axes") + .call_arg(DocUtils::ToList(axes), "axes") .op_list_arg("begin") .op_list_arg("end") .op_list_arg("strides"); @@ -371,13 +371,13 @@ class RelaxEmbeddingCodeGen : public RelaxOpCode { if (input->DTypeName() != "int32") { stack_.op_call("relax.op.astype", IdxInput()) .op_input_arg() - .call_arg(DocUtils::ToStrDoc("int32")); + .call_arg(DocUtils::ToStr("int32")); BuilderEmit(IdxInput()); } if (input->Ndim() > 1) { stack_.op_call("relax.op.reshape", IdxInput()) .op_input_arg() - .call_arg(DocUtils::ToListDoc(std::vector{-1}), "shape"); + .call_arg(DocUtils::ToList(std::vector{-1}), "shape"); BuilderEmit(IdxInput()); } stack_.op_call().op_weight_arg("weight").op_input_arg().op_arg("axis"); @@ -385,7 +385,7 @@ class RelaxEmbeddingCodeGen : public RelaxOpCode { BuilderEmit(IdxNode()); stack_.op_call("relax.op.reshape") .op_output_arg() - .call_arg(DocUtils::ToListDoc(node()->OutputAt(0)->shape)); + .call_arg(DocUtils::ToList(node()->OutputAt(0)->shape)); } } }; @@ -421,7 +421,7 @@ class RelaxGroupNormCodeGen : public RelaxOpCode { for (size_t i = 2; i < node()->InputAt(0)->Ndim(); i++) { axes.push_back(i); } - stack_.op_arg("axis", "channel_axis").call_arg(DocUtils::ToListDoc(axes), "axes"); + stack_.op_arg("axis", "channel_axis").call_arg(DocUtils::ToList(axes), "axes"); } else { stack_.op_arg("channel_axis").op_list_arg("axes"); } @@ -526,7 +526,7 @@ class RelaxPermuteDimsCodeGen : public RelaxOpCode { axes.push_back(i - 1); } } - stack_.op_call().op_input_arg().call_arg(DocUtils::ToListDoc(axes), "axes"); + stack_.op_call().op_input_arg().call_arg(DocUtils::ToList(axes), "axes"); } }; @@ -540,7 +540,7 @@ class RelaxReduceAxisCodeGen : public RelaxOpCode { stack_.op_call().op_input_arg(); std::vector axes = GetAxes("axis"); if (as_list_) { - stack_.call_arg(DocUtils::ToListDoc(axes), "axis"); + stack_.call_arg(DocUtils::ToList(axes), "axis"); } else if (axes.size() > 0) { stack_.call_arg(axes[0], "axis"); } @@ -590,7 +590,7 @@ class RelaxResize2dCodeGen : public RelaxOpCode { .func_call("relax.ShapeExpr") .op_list_arg("size", "values") .pop_nest() - .call_arg(DocUtils::ToListDoc(roi_list)) + .call_arg(DocUtils::ToList(roi_list)) .op_str_arg("layout") .op_str_arg("method") .op_str_arg("coordinate_transformation_mode") diff --git a/tests/python/contrib/test_msc/test_manager.py b/tests/python/contrib/test_msc/test_manager.py index c3e1583ef291..04379af89a20 100644 --- a/tests/python/contrib/test_msc/test_manager.py +++ b/tests/python/contrib/test_msc/test_manager.py @@ -17,6 +17,7 @@ """ Test Managers in MSC. """ +import json import pytest import torch @@ -31,10 +32,13 @@ ) -def _get_config(model_type, compile_type, inputs, outputs, atol=1e-2, rtol=1e-2): +def _get_config(model_type, compile_type, inputs, outputs, atol=1e-1, rtol=1e-1): """Get msc config""" + + path = "test_manager_{}_{}".format(model_type, compile_type) return { - "workspace": msc_utils.msc_dir(), + "workspace": msc_utils.msc_dir(path), + "verbose": "critical", "model_type": model_type, "inputs": inputs, "outputs": outputs, @@ -53,11 +57,12 @@ def _get_config(model_type, compile_type, inputs, outputs, atol=1e-2, rtol=1e-2) def _get_torch_model(name, is_training=False): """Get model from torch vision""" + # pylint: disable=import-outside-toplevel try: import torchvision - model = getattr(torchvision.models, name)(pretrained=True) + model = getattr(torchvision.models, name)() if is_training: model = model.train() else: @@ -70,6 +75,7 @@ def _get_torch_model(name, is_training=False): def _get_tf_graph(): """Get graph from tensorflow""" + # pylint: disable=import-outside-toplevel try: from tvm.contrib.msc.framework.tensorflow import tf_v1 @@ -89,7 +95,23 @@ def _get_tf_graph(): return None -def _test_from_torch(compile_type, expected_info, is_training=False, atol=1e-2, rtol=1e-2): +def _check_manager(manager, expected_info): + """Check the manager results""" + + model_info = manager.runner.model_info + passed, err = True, "" + if not manager.report["success"]: + passed = False + err = "Failed to run pipe for {} -> {}".format(manager.model_type, manager.compile_type) + if not msc_utils.dict_equal(model_info, expected_info): + passed = False + err = "Model info {} mismatch with expected {}".format(model_info, expected_info) + manager.destory() + if not passed: + raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2))) + + +def _test_from_torch(compile_type, expected_info, is_training=False, atol=1e-1, rtol=1e-1): torch_model = _get_torch_model("resnet50", is_training) if torch_model: if torch.cuda.is_available(): @@ -103,13 +125,8 @@ def _test_from_torch(compile_type, expected_info, is_training=False, atol=1e-2, rtol=rtol, ) manager = MSCManager(torch_model, config) - report = manager.run_pipe() - assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type) - model_info = manager.runner.model_info - assert msc_utils.dict_equal( - model_info, expected_info - ), "Model info {} mismatch with expected {}".format(model_info, expected_info) - manager.destory() + manager.run_pipe() + _check_manager(manager, expected_info) def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2): @@ -125,13 +142,8 @@ def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2): ) config["compile"]["profile"]["check"]["err_rate"] = -1 manager = MSCManager(graphdef, config) - report = manager.run_pipe() - assert report["success"], "Failed to run pipe for tensorflow -> {}".format(compile_type) - model_info = manager.runner.model_info - assert msc_utils.dict_equal( - model_info, expected_info - ), "Model info {} mismatch with expected {}".format(model_info, expected_info) - manager.destory() + manager.run_pipe() + _check_manager(manager, expected_info) def test_tvm_manager(): diff --git a/tests/python/contrib/test_msc/test_runner.py b/tests/python/contrib/test_msc/test_runner.py index cbc33d452846..e3d5bcf24503 100644 --- a/tests/python/contrib/test_msc/test_runner.py +++ b/tests/python/contrib/test_msc/test_runner.py @@ -41,11 +41,12 @@ def _get_torch_model(name, is_training=False): """Get model from torch vision""" + # pylint: disable=import-outside-toplevel try: import torchvision - model = getattr(torchvision.models, name)(pretrained=True) + model = getattr(torchvision.models, name)() if is_training: model = model.train() else: @@ -77,14 +78,15 @@ def _get_tf_graph(): return None, None -def _test_from_torch(runner_cls, device, is_training=False, atol=1e-3, rtol=1e-3): +def _test_from_torch(runner_cls, device, is_training=False, atol=1e-1, rtol=1e-1): """Test runner from torch model""" torch_model = _get_torch_model("resnet50", is_training) if torch_model: - workspace = msc_utils.set_workspace(msc_utils.msc_dir()) + path = "test_runner_torch_{}_{}".format(runner_cls.__name__, device) + workspace = msc_utils.set_workspace(msc_utils.msc_dir(path)) log_path = workspace.relpath("MSC_LOG", keep_history=False) - msc_utils.set_global_logger("info", log_path) + msc_utils.set_global_logger("critical", log_path) input_info = [([1, 3, 224, 224], "float32")] datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info] torch_datas = [torch.from_numpy(d) for d in datas] @@ -96,9 +98,9 @@ def _test_from_torch(runner_cls, device, is_training=False, atol=1e-3, rtol=1e-3 runner.build() outputs = runner.run(datas, ret_type="list") golden = [msc_utils.cast_array(golden)] + workspace.destory() for gol_r, out_r in zip(golden, outputs): tvm.testing.assert_allclose(gol_r, out_r, atol=atol, rtol=rtol) - workspace.destory() def test_tvm_runner_cpu(): @@ -107,9 +109,9 @@ def test_tvm_runner_cpu(): _test_from_torch(TVMRunner, "cpu", is_training=True) -@tvm.testing.requires_gpu -def test_tvm_runner_gpu(): - """Test runner for tvm on gpu""" +@tvm.testing.requires_cuda +def test_tvm_runner_cuda(): + """Test runner for tvm on cuda""" _test_from_torch(TVMRunner, "cuda", is_training=True) @@ -120,18 +122,18 @@ def test_torch_runner_cpu(): _test_from_torch(TorchRunner, "cpu") -@tvm.testing.requires_gpu -def test_torch_runner_gpu(): +@tvm.testing.requires_cuda +def test_torch_runner_cuda(): """Test runner for torch on cuda""" - _test_from_torch(TorchRunner, "cuda", atol=1e-2, rtol=1e-2) + _test_from_torch(TorchRunner, "cuda", atol=1e-1, rtol=1e-1) @requires_tensorrt def test_tensorrt_runner(): """Test runner for tensorrt""" - _test_from_torch(TensorRTRunner, "cuda", atol=1e-2, rtol=1e-2) + _test_from_torch(TensorRTRunner, "cuda", atol=1e-1, rtol=1e-1) def test_tensorflow_runner(): @@ -139,9 +141,10 @@ def test_tensorflow_runner(): tf_graph, graph_def = _get_tf_graph() if tf_graph and graph_def: - workspace = msc_utils.set_workspace(msc_utils.msc_dir()) + path = "test_runner_tf" + workspace = msc_utils.set_workspace(msc_utils.msc_dir(path)) log_path = workspace.relpath("MSC_LOG", keep_history=False) - msc_utils.set_global_logger("info", log_path) + msc_utils.set_global_logger("critical", log_path) data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32") out_name = "MobilenetV2/Predictions/Reshape_1:0" # get golden @@ -153,9 +156,9 @@ def test_tensorflow_runner(): runner = TensorflowRunner(mod) runner.build() outputs = runner.run([data], ret_type="list") + workspace.destory() for gol_r, out_r in zip(golden, outputs): tvm.testing.assert_allclose(gol_r, out_r, atol=1e-3, rtol=1e-3) - workspace.destory() if __name__ == "__main__": diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index 6adf70605bfd..8fa9e5cf10cc 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -17,9 +17,8 @@ """ Test Tools in MSC. """ -import os +import json import pytest - import torch import tvm.testing @@ -45,8 +44,11 @@ def _get_config( optimize_type=None, ): """Get msc config""" + + path = "_".join(["test_tools", model_type, compile_type] + list(tools_config.keys())) return { - "workspace": msc_utils.msc_dir(), + "workspace": msc_utils.msc_dir(path), + "verbose": "critical", "model_type": model_type, "inputs": inputs, "outputs": outputs, @@ -68,31 +70,15 @@ def _get_config( } -def get_tool_config(tool_type, use_distill=False, use_gym=False): +def get_tool_config(tool_type, use_distill=False): """Get config for the tool""" + config = {} if tool_type == ToolType.PRUNER: config = { "plan_file": "msc_pruner.json", "strategys": [{"method": "per_channel", "density": 0.8}], } - if use_gym: - config["gym_configs"] = [ - { - "env": { - "executors": { - "action_space": { - "method": "action_prune_density", - "start": 0.4, - "end": 0.8, - "step": 0.4, - } - }, - "max_tasks": 3, - }, - "agent": {"agent_type": "search.grid", "executors": {}}, - } - ] elif tool_type == ToolType.QUANTIZER: # pylint: disable=import-outside-toplevel from tvm.contrib.msc.core.tools.quantize import QuantizeStage @@ -130,23 +116,6 @@ def get_tool_config(tool_type, use_distill=False, use_gym=False): }, ], } - if use_gym: - config["gym_configs"] = [ - { - "env": { - "executors": { - "action_space": { - "method": "action_quantize_scale", - "start": 0.8, - "end": 1.2, - "step": 0.2, - } - }, - "max_tasks": 3, - }, - "agent": {"agent_type": "search.grid", "executors": {}}, - } - ] elif tool_type == ToolType.TRACKER: config = { "plan_file": "msc_tracker.json", @@ -178,11 +147,12 @@ def get_tool_config(tool_type, use_distill=False, use_gym=False): def _get_torch_model(name, is_training=False): """Get model from torch vision""" + # pylint: disable=import-outside-toplevel try: import torchvision - model = getattr(torchvision.models, name)(pretrained=True) + model = getattr(torchvision.models, name)() if is_training: model = model.train() else: @@ -193,13 +163,29 @@ def _get_torch_model(name, is_training=False): return None +def _check_manager(manager, expected_info): + """Check the manager results""" + + model_info = manager.runner.model_info + passed, err = True, "" + if not manager.report["success"]: + passed = False + err = "Failed to run pipe for {} -> {}".format(manager.model_type, manager.compile_type) + if not msc_utils.dict_equal(model_info, expected_info): + passed = False + err = "Model info {} mismatch with expected {}".format(model_info, expected_info) + manager.destory() + if not passed: + raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2))) + + def _test_from_torch( compile_type, tools_config, expected_info, is_training=False, - atol=1e-2, - rtol=1e-2, + atol=1e-1, + rtol=1e-1, optimize_type=None, ): torch_model = _get_torch_model("resnet50", is_training) @@ -217,21 +203,13 @@ def _test_from_torch( optimize_type=optimize_type, ) manager = MSCManager(torch_model, config) - report = manager.run_pipe() - model_info = manager.runner.model_info - for t_type, config in tools_config.items(): - assert os.path.isfile( - msc_utils.get_config_dir().relpath(config["plan_file"]) - ), "Failed to find plan of " + str(t_type) - manager.destory() - assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type) - assert msc_utils.dict_equal( - model_info, expected_info - ), "Model info {} mismatch with expected {}".format(model_info, expected_info) + manager.run_pipe() + _check_manager(manager, expected_info) def get_model_info(compile_type): """Get the model info""" + if compile_type == MSCFramework.TVM: return { "inputs": [ @@ -273,6 +251,7 @@ def test_tvm_tool(tool_type): ) +@tvm.testing.requires_cuda @pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER]) def test_tvm_distill(tool_type): """Test tools for tvm with distiller""" @@ -283,16 +262,6 @@ def test_tvm_distill(tool_type): ) -@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER]) -def test_tvm_gym(tool_type): - """Test tools for tvm with distiller""" - - tool_config = get_tool_config(tool_type, use_gym=True) - _test_from_torch( - MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM), is_training=True - ) - - @requires_tensorrt @pytest.mark.parametrize( "tool_type", @@ -312,8 +281,8 @@ def test_tensorrt_tool(tool_type): tool_config, get_model_info(MSCFramework.TENSORRT), is_training=False, - atol=5e-2, - rtol=5e-2, + atol=1e-1, + rtol=1e-1, optimize_type=optimize_type, ) @@ -329,16 +298,5 @@ def test_tensorrt_distill(tool_type): ) -@requires_tensorrt -@pytest.mark.parametrize("tool_type", [ToolType.PRUNER]) -def test_tensorrt_gym(tool_type): - """Test tools for tensorrt with gym""" - - tool_config = get_tool_config(tool_type, use_gym=True) - _test_from_torch( - MSCFramework.TENSORRT, tool_config, get_model_info(MSCFramework.TENSORRT), is_training=False - ) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_transform.py b/tests/python/contrib/test_msc/test_transform.py new file mode 100644 index 000000000000..37a3ad3fcd7e --- /dev/null +++ b/tests/python/contrib/test_msc/test_transform.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test MSC basic Pass. """ + +import tvm.testing +from tvm.relax.frontend.torch import from_fx +from tvm.relax import PyExprVisitor + +from tvm.relay import testing +from tvm.relay.expr_functor import ExprVisitor +from tvm.relay.build_module import bind_params_by_name + +from tvm.contrib.msc.core import _ffi_api +from tvm.contrib.msc.core import transform as msc_transform + + +def test_relax_layout(): + """Test SetExprLayout for relax""" + + # pylint: disable=import-outside-toplevel + try: + import torch + import torchvision + from torch import fx + except: # pylint: disable=bare-except + print("please install pytorch python package") + return + + class RelaxLayoutChecker(PyExprVisitor): + """Check if name as span attribute is setted.""" + + def check(self, expr): + self._missing_exprs = [] + if isinstance(expr, tvm.relax.Expr): + self.visit_expr(expr) + elif isinstance(expr, tvm.relax.BindingBlock): + self.visit_binding_block(expr) + assert len(self._missing_exprs) == 0, "Missing {} layouts".format( + len(self._missing_exprs) + ) + + def visit_var_binding_(self, binding) -> None: + super().visit_var_binding_(binding) + layout = _ffi_api.SpanGetAttr(binding.value.span, "layout") + if not layout: + self._missing_exprs.append(binding.value) + + def visit_constant_(self, op) -> None: + super().visit_constant_(op) + layout = _ffi_api.SpanGetAttr(op.span, "layout") + if not layout: + self._missing_exprs.append(op) + + torch_model = torchvision.models.resnet50() + graph_model = fx.symbolic_trace(torch_model) + input_info = [([1, 3, 224, 224], "float32")] + with torch.no_grad(): + mod = from_fx(graph_model, input_info) + mod = msc_transform.SetExprLayout()(mod) + RelaxLayoutChecker().check(mod) + + +def test_relay_name(): + """Test SetExprName for relay""" + + class RelayNameChecker(ExprVisitor): + """Check if name as span attribute is setted.""" + + def check(self, expr): + self._missing_exprs = [] + super().visit(expr) + assert len(self._missing_exprs) == 0, "Missing {} names".format( + len(self._missing_exprs) + ) + + def visit_constant(self, expr): + super().visit_constant(expr) + name = _ffi_api.SpanGetAttr(expr.span, "name") + if not name: + self._missing_exprs.append(expr) + + def visit_call(self, expr): + super().visit_call(expr) + name = _ffi_api.SpanGetAttr(expr.span, "name") + if not name: + self._missing_exprs.append(expr) + + mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") + mod["main"] = bind_params_by_name(mod["main"], params) + mod = msc_transform.SetExprName(as_relax=False)(mod) + RelayNameChecker().check(mod["main"]) + + +def test_relax(): + """Test SetExprName for relax""" + + # pylint: disable=import-outside-toplevel + try: + import torch + import torchvision + from torch import fx + except: # pylint: disable=bare-except + print("please install pytorch python package") + return + + class RelaxNameChecker(PyExprVisitor): + """Check if name as span attribute is setted.""" + + def check(self, expr): + self._missing_exprs = [] + if isinstance(expr, tvm.relax.Expr): + self.visit_expr(expr) + elif isinstance(expr, tvm.relax.BindingBlock): + self.visit_binding_block(expr) + assert len(self._missing_exprs) == 0, "Missing {} names".format( + len(self._missing_exprs) + ) + + def visit_var_binding_(self, binding) -> None: + super().visit_var_binding_(binding) + name = _ffi_api.SpanGetAttr(binding.value.span, "name") + if not name: + self._missing_exprs.append(binding.value) + + def visit_constant_(self, op) -> None: + super().visit_constant_(op) + name = _ffi_api.SpanGetAttr(op.span, "name") + if not name: + self._missing_exprs.append(op) + + torch_model = torchvision.models.resnet50() + graph_model = fx.symbolic_trace(torch_model) + input_info = [([1, 3, 224, 224], "float32")] + with torch.no_grad(): + mod = from_fx(graph_model, input_info) + mod = msc_transform.SetExprName()(mod) + RelaxNameChecker().check(mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_layout.py b/tests/python/contrib/test_msc/test_transform_set_expr_layout.py deleted file mode 100644 index 523dc930abf9..000000000000 --- a/tests/python/contrib/test_msc/test_transform_set_expr_layout.py +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Test SetExprLayout Pass. """ - -import tvm.testing -from tvm.relax.frontend.torch import from_fx -from tvm.relax import PyExprVisitor -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import transform as msc_transform - - -class RelaxChecker(PyExprVisitor): - """Check if name as span attribute is setted.""" - - def check(self, expr): - self._missing_exprs = [] - if isinstance(expr, tvm.relax.Expr): - self.visit_expr(expr) - elif isinstance(expr, tvm.relax.BindingBlock): - self.visit_binding_block(expr) - assert len(self._missing_exprs) == 0, "Missing {} layouts".format(len(self._missing_exprs)) - - def visit_var_binding_(self, binding) -> None: - super().visit_var_binding_(binding) - layout = _ffi_api.SpanGetAttr(binding.value.span, "layout") - if not layout: - self._missing_exprs.append(binding.value) - - def visit_constant_(self, op) -> None: - super().visit_constant_(op) - layout = _ffi_api.SpanGetAttr(op.span, "layout") - if not layout: - self._missing_exprs.append(op) - - -def test_relax(): - """Test SetExprLayout for relax""" - - # pylint: disable=import-outside-toplevel - try: - import torch - import torchvision - from torch import fx - except: # pylint: disable=bare-except - print("please install pytorch python package") - return - - torch_model = torchvision.models.resnet50() - graph_model = fx.symbolic_trace(torch_model) - input_info = [([1, 3, 224, 224], "float32")] - with torch.no_grad(): - mod = from_fx(graph_model, input_info) - mod = msc_transform.SetExprLayout()(mod) - RelaxChecker().check(mod) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_msc/test_transform_set_expr_name.py b/tests/python/contrib/test_msc/test_transform_set_expr_name.py deleted file mode 100644 index 426860145c9b..000000000000 --- a/tests/python/contrib/test_msc/test_transform_set_expr_name.py +++ /dev/null @@ -1,108 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Test SetExprName Pass. """ - -import tvm.testing -from tvm.relay import testing -from tvm.relay.expr_functor import ExprVisitor -from tvm.relay.build_module import bind_params_by_name - -from tvm.relax.frontend.torch import from_fx -from tvm.relax import PyExprVisitor - -from tvm.contrib.msc.core import _ffi_api -from tvm.contrib.msc.core import transform as msc_transform - - -class RelayChecker(ExprVisitor): - """Check if name as span attribute is setted.""" - - def check(self, expr): - self._missing_exprs = [] - super().visit(expr) - assert len(self._missing_exprs) == 0, "Missing {} names".format(len(self._missing_exprs)) - - def visit_constant(self, expr): - super().visit_constant(expr) - name = _ffi_api.SpanGetAttr(expr.span, "name") - if not name: - self._missing_exprs.append(expr) - - def visit_call(self, expr): - super().visit_call(expr) - name = _ffi_api.SpanGetAttr(expr.span, "name") - if not name: - self._missing_exprs.append(expr) - - -class RelaxChecker(PyExprVisitor): - """Check if name as span attribute is setted.""" - - def check(self, expr): - self._missing_exprs = [] - if isinstance(expr, tvm.relax.Expr): - self.visit_expr(expr) - elif isinstance(expr, tvm.relax.BindingBlock): - self.visit_binding_block(expr) - assert len(self._missing_exprs) == 0, "Missing {} names".format(len(self._missing_exprs)) - - def visit_var_binding_(self, binding) -> None: - super().visit_var_binding_(binding) - name = _ffi_api.SpanGetAttr(binding.value.span, "name") - if not name: - self._missing_exprs.append(binding.value) - - def visit_constant_(self, op) -> None: - super().visit_constant_(op) - name = _ffi_api.SpanGetAttr(op.span, "name") - if not name: - self._missing_exprs.append(op) - - -def test_relay(): - """Test SetExprName for relay""" - - mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") - mod["main"] = bind_params_by_name(mod["main"], params) - mod = msc_transform.SetExprName(as_relax=False)(mod) - RelayChecker().check(mod["main"]) - - -def test_relax(): - """Test SetExprName for relax""" - - # pylint: disable=import-outside-toplevel - try: - import torch - import torchvision - from torch import fx - except: # pylint: disable=bare-except - print("please install pytorch python package") - return - - torch_model = torchvision.models.resnet50() - graph_model = fx.symbolic_trace(torch_model) - input_info = [([1, 3, 224, 224], "float32")] - with torch.no_grad(): - mod = from_fx(graph_model, input_info) - mod = msc_transform.SetExprName()(mod) - RelaxChecker().check(mod) - - -if __name__ == "__main__": - tvm.testing.main()