From fe4b85cf1778b7fe95b6a5611c7767879ef112ef Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Tue, 17 Aug 2021 16:41:42 -0700 Subject: [PATCH] [Relay] Refactor Interpreter to treat lowering as IRModule->IRModule rewrite. (#8597) * This continues the work outlined in the RFC https://discuss.tvm.apache.org/t/rfc-relay-tecompiler-rewrite-existing-compile-engine-to-match-updated-compiler-flow/9233 This gets about halfway there for the Interpreter: * Remove direct access to TECompiler from interpreter, and instead call tec::LowerTEExpr when 'preparing' a module and expression for evaluation. * Make clear there's no phase distinction between create_interpreter and evaluate on the Python side -- both must be prepared together as a single IRModule. * But in return make sure the result of evaluate on the Python side is a packed func ready to directly apply 'simple' arguments to an already interpreted closure. * The interpreter builds and caches primitive TIR functions (and their corresponding dynamic shape functions) as packed funcs as they are encountered. * Cleanup uses of interpreter for constant folding on the C++ side. Future work: * Fold LoweredModule into IRModule so tec::LowerTEExpr is just another pass. * Get rid of the implicit caching of lowered functions in TECompiler. * Make calling convention from Relay to TIR explicit, and remove all the function attribute hackery currently needed so the interpreter can correctly invoke lowered functions as it encounters them. * Make TECompiler private. Though could do this now it will make migrating the VM and AOT uses of CompilerEngine harder. Force a gc between sphinx-gallery items to reclaim GPU memory. (#8722) GPU memory is only released once the PackedFunc for evaling the model is gced by Python. In CI we're noticing intermittent 'CUDA: Out of memory' failures while processing the tutorials, and tracing showed there was no gc happening between items. Not confident this will solve the problem but worth a try. * Get rid of logs spam. --- docs/conf.py | 6 +- docs/langref/relay_pattern.rst | 2 +- include/tvm/ir/module.h | 43 +- include/tvm/relay/interpreter.h | 70 +- python/tvm/relay/analysis/analysis.py | 3 +- python/tvm/relay/backend/interpreter.py | 81 +- python/tvm/relay/build_module.py | 3 + python/tvm/relay/frontend/common.py | 5 +- python/tvm/relay/testing/__init__.py | 8 +- src/ir/module.cc | 51 +- src/ir/transform.cc | 2 +- src/relay/backend/interpreter.cc | 822 ++++++++++++------ src/relay/backend/te_compiler.cc | 320 +++++-- src/relay/backend/te_compiler_cache.cc | 31 +- src/relay/backend/te_compiler_cache.h | 1 + src/relay/transforms/fold_constant.cc | 27 +- src/relay/transforms/partial_eval.cc | 24 +- src/runtime/cuda/cuda_device_api.cc | 18 + src/tir/analysis/verify_memory.cc | 3 + tests/crt/aot_executor_test.cc | 10 +- tests/crt/framing_test.cc | 12 +- tests/crt/memory_test.cc | 2 +- tests/crt/session_test.cc | 14 +- tests/python/contrib/test_onnx.py | 5 +- tests/python/contrib/test_onnx_model.py | 7 +- tests/python/contrib/test_tensorrt.py | 63 +- .../test_vitis_ai_runtime_cpu_part.py | 6 +- tests/python/frontend/mxnet/test_forward.py | 296 ++++--- tests/python/frontend/onnx/test_forward.py | 20 +- tests/python/frontend/pytorch/test_forward.py | 3 +- tests/python/frontend/pytorch/test_lstm.py | 6 +- .../frontend/tensorflow/test_control_flow.py | 3 +- .../frontend/tensorflow/test_debugging.py | 3 +- .../frontend/tensorflow/test_forward.py | 5 +- .../python/frontend/tensorflow/test_no_op.py | 3 +- tests/python/frontend/tflite/test_forward.py | 5 +- .../relay/dyn/test_dynamic_op_level10.py | 22 +- .../relay/dyn/test_dynamic_op_level2.py | 6 +- .../relay/dyn/test_dynamic_op_level3.py | 5 +- .../relay/dyn/test_dynamic_op_level4.py | 5 +- .../relay/dyn/test_dynamic_op_level5.py | 5 +- .../relay/dyn/test_dynamic_op_level6.py | 5 +- tests/python/relay/test_adt.py | 77 +- tests/python/relay/test_any.py | 13 +- .../relay/test_backend_graph_executor.py | 11 +- .../python/relay/test_backend_interpreter.py | 112 ++- tests/python/relay/test_debug.py | 6 +- tests/python/relay/test_external_codegen.py | 8 +- tests/python/relay/test_memory_passes.py | 8 +- tests/python/relay/test_op_grad_level1.py | 10 +- tests/python/relay/test_op_grad_level2.py | 15 +- tests/python/relay/test_op_grad_level3.py | 10 +- tests/python/relay/test_op_level1.py | 61 +- tests/python/relay/test_op_level10.py | 65 +- tests/python/relay/test_op_level2.py | 136 +-- tests/python/relay/test_op_level3.py | 163 ++-- tests/python/relay/test_op_level4.py | 52 +- tests/python/relay/test_op_level5.py | 135 +-- tests/python/relay/test_op_level6.py | 15 +- tests/python/relay/test_op_qnn_add.py | 30 +- tests/python/relay/test_op_qnn_concatenate.py | 23 +- tests/python/relay/test_op_qnn_mul.py | 25 +- tests/python/relay/test_op_qnn_subtract.py | 5 +- .../python/relay/test_pass_alter_op_layout.py | 12 +- .../python/relay/test_pass_annotate_target.py | 10 +- tests/python/relay/test_pass_auto_quantize.py | 5 +- .../relay/test_pass_defunctionalization.py | 25 +- .../relay/test_pass_dynamic_to_static.py | 10 +- .../test_pass_fake_quantization_to_integer.py | 16 +- .../relay/test_pass_fold_explicit_padding.py | 10 +- tests/python/relay/test_pass_fuse_ops.py | 6 +- tests/python/relay/test_pass_gradient.py | 42 +- .../relay/test_pass_lazy_gradient_init.py | 39 +- tests/python/relay/test_pass_manager.py | 32 +- tests/python/relay/test_pass_partial_eval.py | 14 +- .../python/relay/test_pass_partition_graph.py | 15 +- .../relay/test_pass_to_a_normal_form.py | 5 +- .../test_pass_to_basic_block_normal_form.py | 14 +- tests/python/relay/test_pass_to_cps.py | 3 +- .../relay/test_pass_to_graph_normal_form.py | 4 +- tests/python/relay/test_tensor_array.py | 5 +- tests/python/relay/test_to_mixed_precision.py | 3 +- tests/python/relay/test_vm.py | 9 +- tests/python/relay/test_vm_serialization.py | 3 +- .../python/topi/python/test_topi_transform.py | 14 +- .../python/unittest/test_custom_datatypes.py | 8 +- .../python/unittest/test_runtime_container.py | 3 +- .../unittest/test_target_codegen_vulkan.py | 3 +- tutorials/dev/bring_your_own_datatypes.py | 22 +- tutorials/frontend/deploy_quantized.py | 4 +- tutorials/frontend/from_keras.py | 12 +- tutorials/frontend/from_onnx.py | 4 +- 92 files changed, 2033 insertions(+), 1330 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 67e4f2c4098a..eaa17abef5de 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -318,10 +318,8 @@ def __call__(self, filename): # collecting TVM packed function closures for any device memory to also be released. This # is not a good setup for machines with lots of CPU ram but constrained GPU ram, so force # a gc after each example. -def force_gc(gallery_cong, fname): - print("(Forcing Python gc after '{}' to avoid lag in reclaiming CUDA memory)".format(fname)) +def force_gc(gallery_conf, fname): gc.collect() - print("(Remaining garbage: {})".format(gc.garbage)) sphinx_gallery_conf = { @@ -341,7 +339,7 @@ def force_gc(gallery_cong, fname): "download_all_examples": False, "min_reported_time": 60, "expected_failing_examples": [], - "reset_modules": (force_gc, "matplotlib", "seaborn"), + "reset_modules": ("matplotlib", "seaborn", force_gc), } autodoc_default_options = { diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index b74c58921d3f..68e77ecfa43e 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -406,7 +406,7 @@ Either match the first pattern or the second pattern. Domination ********** -Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern. +Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node between the child and the pattern matches the path pattern. Function Pattern **************** diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 9ca27ec3b661..fefb08f878ef 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -36,6 +36,7 @@ #include #include #include +#include #include namespace tvm { @@ -307,6 +308,14 @@ class IRModuleNode : public Object { /*! \brief Helper function for registering a typedef's constructors */ void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type); + /*! + * \brief Returns a version of \p name which is unique amongst all function definitions in module. + * + * \param name The original name. + * \return Updated name which is unique. + */ + String GetUniqueName(const String& name); + /*! \brief A map from string names to global variables that * ensures global uniqueness. */ @@ -361,16 +370,38 @@ class IRModule : public ObjectRef { } /*! - * \brief Construct a module from a standalone expression. + * \brief Constructs a module from a standalone expression \p expr. + * + * If \p expr is a function it will be bound directly. Otherwise a function over the free + * variables of \p expr (possibly none) with \p expr as body is created and bound. + * + * The function is bound to, in preference order: + * - The "global_symbol" attribute of \p expr, if it is a function with that attribute. + * - 'main' + * - A unique name derived from 'main' if 'main' is already bound in \p global_funcs. * - * Allows one to optionally pass a global function map and - * map of type definitions as well. + * Additional global functions and type definitions may be included in the result module. + * + * See also \p FromExpr. * * \param expr The expression to set as the main function to the module. - * \param global_funcs The global function map. - * \param type_definitions Map of global type definitions + * \param global_funcs The global function map. Default empty. + * \param type_definitions The global type definition map. Default empty. + * \param import_set Set of external modules already imported. Default empty. + * + * \returns A module with \p expr set as the main function, and the global var to which + * \p expr was bound (typcially 'main'). * - * \returns A module with expr set as the main function. + * TODO(mbs): Does import_set and the bound global var need to be exposed via ffi? + */ + static std::pair FromExprInContext( + const RelayExpr& expr, const Map& global_funcs = {}, + const Map& type_definitions = {}, + std::unordered_set import_set = {}); + + /*! + * \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no + * imports. */ TVM_DLL static IRModule FromExpr(const RelayExpr& expr, const Map& global_funcs = {}, diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 93a56cede77b..eed6d0ffc1e4 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -40,31 +40,11 @@ #include #include +#include + namespace tvm { namespace relay { -/*! - *\brief Create a Interpreter function that can - * evaluate an expression and produce a value. - * - * The resulting value can be passed to Python, making it easy to use - * for testing and debugging. - * - * The interpreter interprets the program fragments not supported by the - * TVM runtime, although the interpreter is naively implemented it uses - * TVM operators for evaluating all operators. - * - * Our intent is that this will never be the most efficient implementation of - * Relay's semantics, but a readable and clear one. - * - * \param mod The function module. - * \param device The primary device that the interepreter runs on. - * \param target Compiler target flag to compile the functions on the context. - * \return A function that takes in an expression and returns a value. - */ -runtime::TypedPackedFunc CreateInterpreter(IRModule mod, Device device, - Target target); - /*! \brief The container type of Closures used by the interpreter. */ class InterpreterClosureObj : public runtime::ClosureObj { public: @@ -164,6 +144,52 @@ class ConstructorValue : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj); }; +/*! + * \brief Returns a packed function over Relay expressions which will evaluate \p expr + * applied to those arguments, where \p expr is w.r.t. the definitions in \p mod. + * + * This function is intended to support the Python 'debug' executor. + * + * The given \p expr should have function type. The given \p mod may be empty or + * undefined if \p expr is self-contained. Relay arguments passed to the result + * packed function must be constants, references, or constructors/tuples over such. + * As much work as possible is done while constructing the result packed function, and + * that function may be reasonably efficiently applied multiple times without redoing + * unnecessary work. + * + * Primitives are lowered and compiled to packed functions for execution on \p device + * with properties given by \p target. All other Relay constructs are interpreted. + * + * The interpreter is intended to be a 'reference' implementation of the Relay semantics + * for testing and interactive use. It is not intended to be particularly efficient. + * + * \param mod A module containing definitions which can be referenced from + * \p expr. May be empty or undefined. + * \param expr An expression of function type to evaluate. May reference definitions from \p mod. + * \param device The device on which all primitives will be executed. + * \param target The compiler target flag for compiling primitives. + * \return A packed function that takes an array of Relay expressions and returns the + * result of applying \p expr to those arguments. + */ +TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, Device device, + Target target); + +/*! + * \brief Evaluates \p expr and returns its result. + * + * This function is intended to support TVM constant evaluation. + * + * \param expr An expression to evaluate. + * \param type_definitions Global type definitions which \p expr may references. + * \param import_set Already imported external modules. + * \param device The device on which all primitives will be executed. + * \param target The compiler target flag for compiling primitives. + * @return The object representing the result. + */ +ObjectRef Eval(Expr expr, Map type_definitions, + std::unordered_set import_set, Device device, Target target); + } // namespace relay } // namespace tvm + #endif // TVM_RELAY_INTERPRETER_H_ diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index a8f1a993552e..c7b6c60849a1 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -433,8 +433,7 @@ def get_calibration_data(mod, data): mod = _ffi_api.get_calibrate_module(mod) mod = transform.Inline()(mod) - ref_ex = build_module.create_executor("graph", mod=mod, device=cpu(0)) - ref_res = ref_ex.evaluate()(**data) + ref_res = build_module.create_executor("graph", mod=mod, device=cpu(0)).evaluate()(**data) calib_data = {} for gvar, indices in output_map.items(): diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 81edf74a0a03..819e5eda41f5 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -22,10 +22,9 @@ import tvm._ffi from tvm.runtime import container, Object -from tvm.ir import IRModule from . import _backend -from .. import _make, analysis, transform +from .. import _make, analysis from ... import nd from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const from ..function import Function @@ -178,6 +177,7 @@ def evaluate(self, expr=None, binds=None): return self._make_executor(expr) # normal expression evaluated by running a function. + # TODO(mbs): This should really be type rather than syntax driven. func = Function([], expr) return self._make_executor(func)() @@ -196,6 +196,23 @@ class Interpreter(Executor): target : tvm.Target The target option to build the function. + + CAUTION: Despite the API the module is prepared upon each call to evaluate + rather than once in create_executor. + That is: + .. code-block:: python + + executor = relay.create_executor(kind="debug", mod=module) + a = executor.evaluate(expr)(args1) + b = executor.evaluate(expr)(args2) + + will prepare all the bindings in module twice. For efficiency, try to hoist + calls to evaluate as high as possible, preferably immediately after create_executor: + .. code-block:: python + + func = relay.create_executor(kind="debug", mod=module).evaluate(expr) + a = func(args1) + b = func(args2) """ def __init__(self, mod, device, target): @@ -203,58 +220,30 @@ def __init__(self, mod, device, target): self.device = device self.target = target - def optimize(self): - """Optimize functions in a module. - - Returns - ------- - opt_mod : tvm.IRModule - The optimized module. - """ - seq = tvm.transform.Sequential( - [ - # tvm.parser.AnnotateSpans(), - transform.SimplifyInference(), - transform.FuseOps(0), - transform.ToANormalForm(), - transform.InferType(), - ] - ) - mod = seq(self.mod) - return mod - def _make_executor(self, expr=None): if expr is None or isinstance(expr, GlobalVar): assert self.mod is not None - _intrp = _backend.CreateInterpreter(self.optimize(), self.device, self.target) + if expr is None: + # A missing expr denotes 'main' in the given module. + expr = self.mod.get_global_var("main") - def _interp_wrapper(*args, **kwargs): - if expr is None: - args = self._convert_args(self.mod["main"], args, kwargs) + # Evaluate expr to a packed function we can efficiently re-apply + # to Relay arguments. + func = _backend.EvalFunction(self.mod, expr, self.device, self.target) + + def _apply_args(*args, **kwargs): + if isinstance(expr, GlobalVar): + # When expanding args, look inside the actual global definition so kwargs + # can be matched. + args = self._convert_args(self.mod[expr.name_hint], args, kwargs) else: args = self._convert_args(expr, args, kwargs) - + # Reflect python arguments up into Relay. relay_args = [] for arg in args: relay_args.append(_arg_to_ast(self.mod, arg)) + # Apply func to Relay args + return func(relay_args) - # Set the entry function for the module. - if expr is None: - pass - elif isinstance(expr, GlobalVar): - self.mod["main"] = self.mod[expr] - else: - assert isinstance(expr, Function) - func = Function([], Call(expr, relay_args)) - relay_args = [] - if self.mod: - self.mod["main"] = func - else: - self.mod = IRModule.from_expr(func) - - mod = self.optimize() - opt_expr = Call(mod["main"], relay_args) - return _intrp(opt_expr) - - return _interp_wrapper + return _apply_args diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d1cf1c9bea2f..c67ac1dc423d 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -511,6 +511,9 @@ def _graph_wrapper(*args, **kwargs): return _graph_wrapper +# TODO(mbs): Collapse the create_executor/evaluate phases together since a) most callers don't +# reuse the executor for multiple expressions and b) any preparation necessary for the expression +# evaluation needs to (currently) be done along with preparation for the module. def create_executor(kind="debug", mod=None, device=None, target="llvm", params=None): """Factory function to create an executor. diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 7f67ed404de9..077b942ddf01 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -545,11 +545,12 @@ def infer_value(input_val, params, mod=None): mod["main"] = _function.Function(analysis.free_vars(input_val), input_val) else: mod = IRModule.from_expr(input_val) - exc = tvm.relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm") inputs = [] for param in mod["main"].params: inputs.append(params[param.name_hint]) - result = exc.evaluate()(*inputs) + result = tvm.relay.create_executor( + "debug", mod=mod, device=tvm.cpu(), target="llvm" + ).evaluate()(*inputs) return result diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index de85ed69238a..8eb07d7b583b 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -134,10 +134,13 @@ def check_grad( test_inputs = inputs for target, dev in enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) + # Eval the backward and forward functions + # TODO(mbs): Evaluate a pair of functions so can share preparation between them. + bwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(bwd_func) + fwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(fwd_func) # Get analytic gradients. - _, grads = intrp.evaluate(bwd_func)(*inputs) + _, grads = bwd_func_compiled(*inputs) grads = [grad.numpy().astype("float64") for grad in grads] # Throw out gradients we aren't testing @@ -154,7 +157,6 @@ def check_grad( assert len(grads) > 0, "You must test at least one gradient." # Get numeric gradients for each dimension of each param, using two-sided approximation. - fwd_func_compiled = intrp.evaluate(fwd_func) approx_grads = [] for x in test_inputs: approx_grad = np.zeros(x.shape) diff --git a/src/ir/module.cc b/src/ir/module.cc index 7990b281fb04..d4129c84ccf5 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -284,6 +284,20 @@ Constructor IRModuleNode::LookupTag(const int32_t tag) { return (*it).second; } +String IRModuleNode::GetUniqueName(const String& name) { + String result = name; + int suffix = 0; + while (true) { + auto it = global_var_map_.find(result); + if (it == global_var_map_.end()) { + return result; + } + std::ostringstream os; + os << name << "_" << ++suffix; + result = os.str(); + } +} + struct Renamer : relay::ExprMutator, TypeMutator { Map defs; Map types; @@ -347,25 +361,38 @@ void IRModuleNode::Update(const IRModule& mod) { } } -IRModule IRModule::FromExpr(const RelayExpr& expr, - const tvm::Map& global_funcs, - const tvm::Map& type_definitions) { - auto mod = IRModule(global_funcs, type_definitions); - BaseFunc func; - std::string gv_name = "main"; +std::pair IRModule::FromExprInContext( + const RelayExpr& expr, const tvm::Map& global_funcs, + const tvm::Map& type_definitions, + std::unordered_set import_set) { + auto mod = IRModule(global_funcs, type_definitions, std::move(import_set)); + String gv_name; + // All global definitions must be functions. + BaseFunc func; if (auto* func_node = expr.as()) { func = GetRef(func_node); if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + // Function literal has been annotated with it's required global symbol. gv_name = opt.value(); } - } else { func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); } - auto main_gv = GlobalVar(gv_name); + + if (gv_name.empty()) { + // Bind function to 'main' (though rename if would clash with existing 'main'). + gv_name = mod->GetUniqueName("main"); + } + + GlobalVar main_gv(gv_name); mod->Add(main_gv, func); - return mod; + return {mod, main_gv}; +} + +IRModule IRModule::FromExpr(const RelayExpr& expr, const Map& global_funcs, + const Map& type_definitions) { + return FromExprInContext(expr, global_funcs, type_definitions).first; } void IRModuleNode::Import(const String& path) { @@ -465,11 +492,7 @@ TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32 return mod->LookupTag(tag); }); -TVM_REGISTER_GLOBAL("ir.Module_FromExpr") - .set_body_typed([](RelayExpr e, tvm::Map funcs, - tvm::Map type_defs) { - return IRModule::FromExpr(e, funcs, type_defs); - }); +TVM_REGISTER_GLOBAL("ir.Module_FromExpr").set_body_typed(&IRModule::FromExpr); TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { mod->Update(from); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 426bdc9c1800..4c37f0f1a6e9 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -466,7 +466,7 @@ Pass GetPass(const String& pass_name) { return (*f)(); } -// TODO(zhiics): we currenlty only sequentially execute each pass in +// TODO(zhiics): we currently only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 6ebb17e93eca..af2cbae1f72d 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -21,14 +21,17 @@ * \file src/relay/interpreter.cc * \brief An interpreter for the Relay IR. */ + #include #include +#include #include #include #include #include #include #include +#include #include #include @@ -39,9 +42,81 @@ namespace tvm { namespace relay { -using namespace runtime; +using runtime::ADT; +using runtime::ADTObj; +using runtime::NDArray; +using runtime::TVMArgsSetter; +using runtime::operator<<; + +namespace { +// TODO(mbs): Centralize. +struct PairHash { + template + std::size_t operator()(const std::pair& k) const { + return std::hash()(k.first) ^ std::hash()(k.second); + } +}; + +// Analogue of FlattenTupleType for runtime ADT vs NDArray values. +// TODO(mbs): Hoist somewhere sensible, maybe op/memory.h? +void FlattenADTAux(const ObjectRef& object_ref, std::vector* out) { + if (const NDArray::ContainerType* ndarray = object_ref.as()) { + out->push_back(GetRef(ndarray)); + } else if (const ADTObj* adt = object_ref.as()) { + for (size_t i = 0; i < adt->size; ++i) { + FlattenADTAux((*adt)[i], out); + } + } else { + LOG(FATAL) << "unsupported " << object_ref; + } +} + +std::vector FlattenADT(const ObjectRef& object_ref) { + std::vector out; + FlattenADTAux(object_ref, &out); + return out; +} -InterpreterClosure::InterpreterClosure(tvm::Map env, Function func) { +std::vector FlattenADTs(const std::vector& object_refs) { + std::vector out; + for (const auto& object_ref : object_refs) { + FlattenADTAux(object_ref, &out); + } + return out; +} + +// Analogue of ToTupleType for runtime ADT vs NDArray values. +// TODO(mbs): Hoist somewhere sensible, maybe op/memory.h? +void ToADTOrNDArrayAux(const Type& type, const std::vector& nd_arrays, int* index, + std::vector* out) { + if (type.as()) { + out->push_back(nd_arrays[*index]); + *index += 1; + } else if (const TupleTypeNode* ttn = type.as()) { + std::vector tuple_out; + for (size_t i = 0; i < ttn->fields.size(); i++) { + ToADTOrNDArrayAux(ttn->fields[i], nd_arrays, index, &tuple_out); + } + out->push_back(ADT::Tuple(tuple_out)); + } else { + LOG(FATAL) << "unsupported " << type; + } +} + +ObjectRef ToADTOrNDArray(const Type& type, const std::vector& nd_arrays) { + if (type.as() && nd_arrays.size() == 1) { + return nd_arrays[0]; + } else { + std::vector out; + int index = 0; + ToADTOrNDArrayAux(type, nd_arrays, &index, &out); + return out[0]; + } +} + +} // namespace + +InterpreterClosure::InterpreterClosure(Map env, Function func) { ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); @@ -55,7 +130,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); inline const PackedFunc& GetPackedFunc(const std::string& name) { - const PackedFunc* pf = tvm::runtime::Registry::Get(name); + const PackedFunc* pf = runtime::Registry::Get(name); ICHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; return *pf; } @@ -93,8 +168,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "RefValueObj(" << node->value << ")"; }); -ConstructorValue::ConstructorValue(int32_t tag, tvm::Array fields, - Constructor constructor) { +ConstructorValue::ConstructorValue(int32_t tag, Array fields, Constructor constructor) { ObjectPtr n = make_object(); n->tag = tag; n->fields = fields; @@ -103,7 +177,7 @@ ConstructorValue::ConstructorValue(int32_t tag, tvm::Array fields, } TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") - .set_body_typed([](int32_t tag, tvm::Array fields, Constructor constructor) { + .set_body_typed([](int32_t tag, Array fields, Constructor constructor) { return ConstructorValue(tag, fields, constructor); }); @@ -122,9 +196,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) */ struct Frame { /*! \brief The set of local variables and arguments for the frame. */ - tvm::Map locals; + Map locals; - explicit Frame(tvm::Map locals) : locals(locals) {} + explicit Frame(Map locals) : locals(locals) {} }; /*! @@ -169,8 +243,8 @@ class InterpreterState; /*! \brief A container capturing the state of the interpreter. */ class InterpreterStateObj : public Object { public: - using Frame = tvm::Map; - using Stack = tvm::Array; + using Frame = Map; + using Stack = Array; /*! \brief The current expression under evaluation. */ Expr current_expr; @@ -178,7 +252,7 @@ class InterpreterStateObj : public Object { /*! \brief The call stack of the interpreter. */ Stack stack; - void VisitAttrs(tvm::AttrVisitor* v) { + void VisitAttrs(AttrVisitor* v) { v->Visit("current_expr", ¤t_expr); v->Visit("stack", &stack); } @@ -189,8 +263,8 @@ class InterpreterStateObj : public Object { class InterpreterState : public ObjectRef { public: - using Frame = tvm::Map; - using Stack = tvm::Array; + using Frame = Map; + using Stack = Array; InterpreterState(Expr current_expr, Stack stack); @@ -214,8 +288,13 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st class Interpreter : public ExprFunctor, PatternFunctor { public: - Interpreter(IRModule mod, Device device, Target target) - : mod_(mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {} + // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. + Interpreter(IRModule mod, Map per_target_module, Device device, Target target) + : mod_(mod), + per_target_module_(per_target_module), + device_(device), + target_(target), + debug_op_(Op::Get("debug")) {} template T WithFrame(const Frame& fr, const std::function& f) { @@ -238,8 +317,7 @@ class Interpreter : public ExprFunctor, ObjectRef VisitExpr_(const OpNode* id) override { // TODO(@jroesch): Eta-expand and return in this case. LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node " - << "in " - << "this case, eta expand"; + << "in this case, eta expand"; return ObjectRef(); } @@ -257,7 +335,7 @@ class Interpreter : public ExprFunctor, } ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) { - tvm::Map captured_mod; + Map captured_mod; Array free_vars = FreeVars(func); for (const auto& var : free_vars) { @@ -283,251 +361,301 @@ class Interpreter : public ExprFunctor, return MakeClosure(func); } - Array ComputeDynamicShape(const Function& func, const Array& args) { - CCacheKey key(func, Target("llvm")); - auto cfunc = compiler_->LowerShapeFunc(key); - size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); + /*! + * \brief Returns the packed function implementing the TIR function bound to \p tir_fn_var. + * + * \param tir_fn_var Global var for the already lowered TIR function. + * \param all_tir_fn_vars Global vars for all lowered TIR functions the above + * may reference, plus \p tir_fn_var itself. + * \param target Target for which the TIR function should be compiled. For primitives this + * will be the interpreter's target_. However for shape functions this will be the generic + * 'cpu' target, since shape functions are always executed on the host cpu. + */ + PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array& all_tir_fn_vars, + Target target) { + std::pair packed_func_key(target->str(), tir_fn_var->name_hint); + auto packed_itr = compiled_packed_funcs_.find(packed_func_key); + if (packed_itr != compiled_packed_funcs_.end()) { + // Already compiled. + return packed_itr->second; + } + + // Project out just the function(s) we need. + IRModule lowered_projected_mod; + auto mod_itr = per_target_module_.find(target->str()); + ICHECK(mod_itr != per_target_module_.end()) + << "No target module for target '" << target->str() << "'"; + const IRModule& target_module = (*mod_itr).second; + for (const auto& var : all_tir_fn_vars) { + ICHECK(target_module->ContainGlobalVar(var->name_hint)) + << "No global var for '" << var->name_hint << "' in module for target '" << target->str() + << "'"; + lowered_projected_mod->Add(var, target_module->Lookup(var->name_hint)); + } + + // Compile (aka 'build') the projected module into a runtime module of packed functions. + runtime::Module runtime_module; + if (const auto* f = runtime::Registry::Get("relay.backend.build")) { + // TODO(mbs): Cleanup hooks. + runtime_module = (*f)(lowered_projected_mod, target); + } else { + runtime_module = build(lowered_projected_mod, target, /*target_host=*/Target(nullptr)); + } + + // Extract all the packed functions. + for (const auto& var : all_tir_fn_vars) { + PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); + ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint + << "' in compiled module for target '" << target->str() << "'"; + compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func); + } + + // Return just what we need for this call. + packed_itr = compiled_packed_funcs_.find(packed_func_key); + ICHECK(packed_itr != compiled_packed_funcs_.end()) << " " << tir_fn_var->name_hint; + ICHECK_NOTNULL(packed_itr->second); + return packed_itr->second; + } + /*! + * \brief Call the dynamic shape function bound to \p prim_shape_fn_var passing the + * shapes of args, and return the resulting shapes. + * + * \param prim_shape_fn_var Global var bound to lowered shape function. + * \param all_prim_shape_fn_vars All the global vars needed to build the above, including + * the shape function itself. + * \param prim_shape_fn_states For each primitive arg, indicate whether the primitive shape + * function requires the shape of the argument and/or the actual argument tensor. + * \param num_shape_inputs The number of inputs, after accounting for both shapes vs data + * inputs and unfolding of tuple types. + * \param num_shape_outputs The number of outputs, after accounting for flattening of + * tuple types. + * \param args Arguments to the primitive this shape function is for. + * \return Expected shapes of the underlying primitive's flattened outputs. + */ + Array ComputeDynamicShape(const GlobalVar& prim_shape_fn_var, + const Array& all_prim_shape_fn_vars, + const Array& prim_shape_fn_states, + size_t num_shape_inputs, size_t num_shape_outputs, + const std::vector& args) { + ICHECK(prim_shape_fn_var.defined()); + ICHECK(prim_shape_fn_states.defined()); + ICHECK(prim_shape_fn_var->checked_type().defined()); + // The function type is that of the original primitive rather than the shape function + // itself. We currently can't express shape function types in Relay. + const FuncTypeNode* ftn = prim_shape_fn_var->checked_type().as(); + ICHECK(ftn); + // The primitive shape function states are w.r.t. the primitive's arguments in + // non-flattened form. + // TODO(mbs): Clean this up so we don't mix flattened vs original conventions. + ICHECK_EQ(prim_shape_fn_states.size(), ftn->arg_types.size()); + ICHECK_EQ(args.size(), ftn->arg_types.size()); + // num_shape_inputs will account for which primitive function arguments are dynamic, + // whether the shape and or data needs to be passed, and flattening of tuples. + // Similarly, num_shape_outputs will account for flattening of tuples. + + // Shape functions always run on the cpu + Device shape_device; + shape_device.device_type = kDLCPU; + shape_device.device_id = 0; + Target shape_target("llvm"); + + // 'Compile' the TIR shape function to appropriate callable form. + PackedFunc packed_shape_func = + TIRToPackedFunc(prim_shape_fn_var, all_prim_shape_fn_vars, shape_target); + + size_t arity = num_shape_inputs + num_shape_outputs; std::vector values(arity); std::vector codes(arity); TVMArgsSetter setter(values.data(), codes.data()); - std::vector inputs(cfunc->inputs.size()); - std::vector outputs(cfunc->outputs.size()); - - Device cpu_dev; - cpu_dev.device_type = kDLCPU; - cpu_dev.device_id = 0; - - auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) { - auto nd_array = Downcast(val); - if (need_shape) { - int64_t ndim = nd_array.Shape().size(); - NDArray shape_arr; - if (ndim == 0) { - shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_dev); - } else { - shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev); - int64_t* data = reinterpret_cast(shape_arr->data); - for (auto j = 0; j < ndim; ++j) { - data[j] = nd_array.Shape()[j]; - } - } - inputs[i] = shape_arr; - setter(i, shape_arr); - } else { - auto arr = nd_array.CopyTo(cpu_dev); - inputs[i] = arr; - setter(i, arr); - } - }; + std::vector inputs(num_shape_inputs); + std::vector outputs(num_shape_outputs); + // Collect the shapes and/or data needed by the shape function from + // the primitive's arguments. size_t arg_counter = 0; for (size_t i = 0; i < args.size(); ++i) { - auto arg = args[i]; - auto param = func->params[i]; - int state = cfunc->shape_func_param_states[i]->value; - if (arg->IsInstance()) { - if (state & kNeedInputData) { - fset_input(arg_counter++, arg, false); - } - if (state & kNeedInputShape) { - fset_input(arg_counter++, arg, true); - } - } else { - const ADT adt = Downcast(arg); + // TODO(mbs): The same need data/need shape arg state applies to everything in the + // flattened form of this arg. Does that match what lowering actually does? + int64_t state = prim_shape_fn_states[i]->value; + for (const auto& nd_array : FlattenADT(args[i])) { if (state & kNeedInputData) { - for (size_t i = 0; i < adt.size(); ++i) { - fset_input(arg_counter++, adt[i], false); - } + auto arr = nd_array.CopyTo(shape_device); + inputs[arg_counter] = arr; + setter(arg_counter, arr); + ++arg_counter; } if (state & kNeedInputShape) { - for (size_t i = 0; i < adt.size(); ++i) { - fset_input(arg_counter++, adt[i], true); + int64_t ndim = nd_array.Shape().size(); + NDArray shape_arr; + if (ndim == 0) { + shape_arr = NDArray::Empty({}, DataType::Int(64), shape_device); + } else { + shape_arr = NDArray::Empty({ndim}, DataType::Int(64), shape_device); + int64_t* data = reinterpret_cast(shape_arr->data); + for (auto j = 0; j < ndim; ++j) { + data[j] = nd_array.Shape()[j]; + } } + inputs[arg_counter] = shape_arr; + setter(arg_counter, shape_arr); + ++arg_counter; } } } - ICHECK_EQ(arg_counter, cfunc->inputs.size()) << "Shape function input sizes mismatch"; + ICHECK_EQ(arg_counter, num_shape_inputs) << "Shape function input sizes mismatch"; - auto fset_shape_output = [&](size_t i, Type val_type) { - // TODO(@icemelon): allow recursive tuple - const TensorTypeNode* rtype = val_type.as(); - ICHECK(rtype != nullptr); - int64_t ndim = rtype->shape.size(); - auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_dev); - outputs[i] = arr; - setter(arg_counter + i, arr); - }; - - auto ret_type = func->body->checked_type(); + // Prepare NDArrays to hold the output shapes. size_t out_cnt = 0; - if (auto rtype = ret_type.as()) { - out_cnt = rtype->fields.size(); - for (size_t i = 0; i < out_cnt; ++i) { - fset_shape_output(i, rtype->fields[i]); - } - } else { - out_cnt = 1; - auto tt = Downcast(ret_type); - fset_shape_output(0, tt); + for (const auto& ttype : FlattenTupleType(ftn->ret_type)) { + ICHECK(out_cnt < num_shape_outputs); + int64_t ndim = ttype->shape.size(); + auto arr = NDArray::Empty({ndim}, DataType::Int(64), shape_device); + outputs[out_cnt] = arr; + setter(arg_counter + out_cnt, arr); + ++out_cnt; } - ICHECK_EQ(cfunc->outputs.size(), out_cnt) << "Shape function output sizes mismatch"; + ICHECK_EQ(out_cnt, num_shape_outputs) << "Shape function output sizes mismatch"; - PackedFunc shape_func; - Module m; - TVMRetValue rv; - if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - m = (*f)(cfunc->funcs, cfunc->target); - } else { - m = build(cfunc->funcs, cfunc->target, Target(nullptr)); - } - shape_func = m.GetFunction(cfunc->prim_fn_var->name_hint); - shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); + // Call the dynamic shape function. + TVMRetValue rv; // ignored + packed_shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); - // Get output shapes + // Convert result tensors back to shapes. Array out_shapes; for (auto out_tensor : outputs) { int64_t* shape_data = reinterpret_cast(out_tensor->data); Shape out_shape; for (int i = 0; i < out_tensor->shape[0]; ++i) { - out_shape.push_back(tvm::Integer(shape_data[i])); + out_shape.push_back(Integer(shape_data[i])); } out_shapes.push_back(out_shape); } return out_shapes; } - ObjectRef InvokePrimitiveOp(const Function& func, const Array& args) { - const auto* call_node = func->body.as(); - - if (call_node && call_node->op == debug_op_) { - auto dattrs = call_node->attrs.as(); - auto interp_state = this->get_state(call_node->args[0]); - - if (dattrs->debug_func.defined()) { - dattrs->debug_func(interp_state); - } else { - RELAY_DEBUG_INTERP(interp_state); - } - - return args[0]; - } + /*! + * \brief Call primitive op bound to \p prim_fn_var with \p args. If necessary, evaluate dynamic + * shape function bound to \p prim_shape_fn_var to calculate shapes of result tensors. + * + * @param prim_fn_var Global bound to lowered primitive. + * @param all_prim_fn_vars All globals references by lowered primitive, plus prim_fn_var itself. + * @param prim_shape_fn_var Global bound to lowered shape function for primitive, if neeeded. + * @param all_prim_shape_fn_vars All globals references by lowered shape function, plus + * prim_shape_fn_var itself. + * @param prim_shape_fn_states Records whether shape and/or data is needed by the dynamic + * shape function (if any) for each (flattened) argument. + * @param num_shape_inputs Number of arguments to the dynamic shape function (if any). + * @param num_shape_outputs Number of outputs from the dynamic shape function (if any). + * @param args Already evaluated arguments to primitive. + * @return Result of primitive. + */ + ObjectRef InvokePrimitiveOp(const GlobalVar& prim_fn_var, const Array all_prim_fn_vars, + const GlobalVar& prim_shape_fn_var, + const Array& all_prim_shape_fn_vars, + const Array& prim_shape_fn_states, size_t num_shape_inputs, + size_t num_shape_outputs, const std::vector& args) { + ICHECK(prim_fn_var->checked_type().defined()); + const FuncTypeNode* ftn = prim_fn_var->checked_type().as(); + ICHECK(ftn); + + // 'Compile' the TIR primitive to appropriate callable form (on the desired target). + PackedFunc packed_func = TIRToPackedFunc(prim_fn_var, all_prim_fn_vars, target_); + + // Argument tuples are flattened. + std::vector arg_nd_arrays = FlattenADTs(args); + const size_t num_inputs = arg_nd_arrays.size(); + // num_inputs should equal size(concat(map(FlattenTupleType, function arg types))) + + // TVM's primitive calling convention is for the final arguments to be for output + // buffers. We must allocate space for those buffers based on the return type. + std::vector result_tensor_types = FlattenTupleType(ftn->ret_type); + const size_t arg_len = num_inputs + result_tensor_types.size(); - // Marshal the arguments. - // Handle adt input/output by flattening them. - size_t arg_len = 0; - for (size_t i = 0; i < args.size(); ++i) { - if (args[i]->IsInstance()) { - ++arg_len; - } else { - auto adt = Downcast(args[i]); - arg_len += adt.size(); - } - } - size_t num_inputs = arg_len; - if (const auto* tuple_type = func->body->checked_type().as()) { - arg_len += tuple_type->fields.size(); - } else { - ICHECK(func->body->checked_type().as()) << func->body->checked_type(); - arg_len += 1; - } std::vector values(arg_len); std::vector codes(arg_len); TVMArgsSetter setter(values.data(), codes.data()); - auto fset_input = [&](size_t i, ObjectRef val) { - const auto nd_array = Downcast(val); - setter(i, nd_array); + // Marshall the call's arguments in flattened form. + int arg_counter = 0; + for (const auto& nd_array : arg_nd_arrays) { + setter(arg_counter++, nd_array); Device arg_dev = nd_array->device; ICHECK(arg_dev.device_type == device_.device_type && arg_dev.device_id == device_.device_id) - << "Interpreter expect device to be " << device_ << ", but get " << arg_dev; - }; + << "Interpreter expect device to be " << device_ << ", but got " << arg_dev; + } - int arg_counter = 0; - for (ObjectRef arg : args) { - if (arg->IsInstance()) { - fset_input(arg_counter++, arg); - } else { - auto adt = Downcast(arg); - for (size_t i = 0; i < adt.size(); ++i) { - fset_input(arg_counter++, adt[i]); - } - } + // If necessary, retrieve concrete shapes for outputs from shape function rather + // than relying on TensorType shapes. + Array runtime_shapes; + bool is_dyn = IsDynamic(ftn->ret_type); + if (is_dyn) { + ICHECK(prim_shape_fn_var.defined()); + ICHECK(prim_shape_fn_states.defined()); + runtime_shapes = + ComputeDynamicShape(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, + num_shape_inputs, num_shape_outputs, args); + ICHECK_EQ(runtime_shapes.size(), result_tensor_types.size()); } - // TVM's calling convention is that the final argument is the output - // buffer. To preserve the illusion of being a functional language - // we need to allocate space for the output buffer based on the - // return type. - auto fset_output = [&](size_t i, Type val_type) { - const TensorTypeNode* rtype = val_type.as(); - ICHECK(rtype != nullptr); - // Allocate output tensor. - std::vector shape; - for (auto dim : rtype->shape) { + // Prepare the result tensors for the call. + TVMRetValue rv; // ignored + std::vector result_nd_arrays; + for (size_t i = 0; i < result_tensor_types.size(); ++i) { + const auto& ttype = result_tensor_types[i]; + const Shape& shape = is_dyn ? runtime_shapes[i] : ttype->shape; + // Allocate output tensor of appropriate shape. + std::vector concrete_shape; + for (const auto& dim : shape) { const auto* ivalue = tir::as_const_int(dim); ICHECK(ivalue) << "expected concrete dimensions"; - shape.push_back(ivalue[0]); + concrete_shape.push_back(ivalue[0]); } - DLDataType dtype = rtype->dtype; - NDArray nd_array = NDArray::Empty(shape, dtype, device_); + NDArray nd_array = NDArray::Empty(concrete_shape, ttype->dtype, device_); setter(num_inputs + i, nd_array); - return nd_array; - }; + result_nd_arrays.emplace_back(nd_array); + } - Array out_shapes; - auto ret_type = func->body->checked_type(); - bool is_dyn = IsDynamic(ret_type); + // Call the primitive. + packed_func.CallPacked(TVMArgs(values.data(), codes.data(), static_cast(arg_len)), &rv); - if (is_dyn) { - ICHECK(func->HasNonzeroAttr(attr::kPrimitive)); - out_shapes = ComputeDynamicShape(func, args); - } - - PackedFunc packed_func = compiler_->JIT(CCacheKey(func, target_)); - TVMRetValue rv; - if (const TupleTypeNode* rtype = func->body->checked_type().as()) { - ICHECK(!is_dyn || out_shapes.size() == rtype->fields.size()); - std::vector fields; - for (size_t i = 0; i < rtype->fields.size(); ++i) { - if (is_dyn) { - auto sh = out_shapes[i]; - auto tt = Downcast(rtype->fields[i]); - fields.push_back(fset_output(i, TensorType(sh, tt->dtype))); - } else { - fields.push_back(fset_output(i, rtype->fields[i])); - } - } - packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); - return ADT::Tuple(fields); - } else { - ObjectRef out_tensor; - if (is_dyn) { - ICHECK_EQ(out_shapes.size(), 1); - auto sh = out_shapes[0]; - auto tt = Downcast(ret_type); - out_tensor = fset_output(0, TensorType(sh, tt->dtype)); - } else { - out_tensor = fset_output(0, ret_type); - } - packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); - return out_tensor; - } + // Unflatten the results. + return ToADTOrNDArray(ftn->ret_type, result_nd_arrays); } - // Invoke the closure - ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array& args, + /*! + * \brief Invoke \p closure with \p args. If \p bind is defined then this is a recursive + * closure and \p bind should refer to itself. + */ + ObjectRef Invoke(const InterpreterClosure& closure, const Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. - if (closure->func->HasNonzeroAttr(attr::kPrimitive)) { - return InvokePrimitiveOp(closure->func, args); + Function func = closure->func; + ICHECK_EQ(func->params.size(), args.size()); + + if (func->HasNonzeroAttr(attr::kPrimitive)) { + if (const CallNode* call_node = closure->func->body.as()) { + if (call_node->op == debug_op_) { + // Special case: Calling the debug tracing function. + auto dattrs = call_node->attrs.as(); + auto interp_state = get_state(call_node->args[0]); + + if (dattrs->debug_func.defined()) { + dattrs->debug_func(interp_state); + } else { + RELAY_DEBUG_INTERP(interp_state); + } + + return args[0]; + } + } } - auto func = closure->func; - // Allocate a frame with the parameters and free variables. - tvm::Map locals; - ICHECK_EQ(func->params.size(), args.size()); + ICHECK(!func->HasNonzeroAttr(attr::kPrimitive)) + << "Calls to primitive functions should have been removed by lowering"; + // Allocate a frame with the parameters and free variables. + Map locals; for (size_t i = 0; i < func->params.size(); i++) { ICHECK_EQ(locals.count(func->params[i]), 0); locals.Set(func->params[i], args[i]); @@ -547,23 +675,63 @@ class Interpreter : public ExprFunctor, } ObjectRef VisitExpr_(const CallNode* call) final { - tvm::Array args; + std::vector args; for (auto arg : call->args) { args.push_back(Eval(arg)); } - // We should not find operators after running fusion, - // and operator lowering. - // - // We have some functions containing chunks of operators - // which will be loaded into operator map. - if (const auto* op_node = call->op.as()) { + + // We should not find calls to operators after running fusion and lowering. + if (const OpNode* op_node = call->op.as()) { LOG(FATAL) << "found " << op_node->name << "; operators should have been removed by previous passes; try " "fusing and lowering"; } - if (auto con = call->op.as()) { + + if (const ConstructorNode* con = call->op.as()) { + // Special case: ADT constructor return ConstructorValue(con->tag, args, GetRef(con)); } + + if (const GlobalVarNode* gvn = call->op.as()) { + if (const TIRCallAttrs* attrs = call->attrs.as()) { + // Special case: Call a lowered TIR function. + // TODO(mbs): Make calling convention first-class in Relay. + Array all_prim_fn_vars; + if (attrs->metadata.count("all_prim_fn_vars")) { + all_prim_fn_vars = Downcast>(attrs->metadata.at("all_prim_fn_vars")); + } + GlobalVar prim_shape_fn_var; + if (attrs->metadata.count("prim_shape_fn_var")) { + prim_shape_fn_var = Downcast(attrs->metadata.at("prim_shape_fn_var")); + } + Array all_prim_shape_fn_vars; + if (attrs->metadata.count("all_prim_shape_fn_vars")) { + all_prim_shape_fn_vars = + Downcast>(attrs->metadata.at("all_prim_shape_fn_vars")); + } + Array prim_shape_fn_states; + if (attrs->metadata.count("prim_shape_fn_states")) { + prim_shape_fn_states = + Downcast>(attrs->metadata.at("prim_shape_fn_states")); + } + size_t num_shape_inputs = 0; + if (attrs->metadata.count("prim_shape_fn_num_inputs")) { + num_shape_inputs = static_cast( + Downcast(attrs->metadata.at("prim_shape_fn_num_inputs"))->value); + } + size_t num_shape_outputs = 0; + if (attrs->metadata.count("prim_shape_fn_num_outputs")) { + num_shape_outputs = static_cast( + Downcast(attrs->metadata.at("prim_shape_fn_num_outputs"))->value); + } + + // Special case: Call TIR primitive. + return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, prim_shape_fn_var, + all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, + num_shape_outputs, args); + } + } + // Now we just evaluate and expect to find a closure. ObjectRef fn_val = Eval(call->op); if (const InterpreterClosureObj* closure_node = fn_val.as()) { @@ -700,43 +868,211 @@ class Interpreter : public ExprFunctor, } private: - // Module + // Main module. All expressions are eval'ed w.r.t. the definitions in this module. This module + // may contain calls to TIR functions bound in a per_target_module_ below. IRModule mod_; - // For simplicity we only run the interpreter on a single context. - // Context to run the interpreter on. + // Map from target key to lowered TIR functions derived from mod_. + // Note that primitives are implicitly executed on target_, while shape functions are implicitly + // executed on the default 'cpu' host. Thus this map has at most two entries. + Map per_target_module_; + // Cached packed functions for the primitives and shape functions, keyed by target and + // global var name. + std::unordered_map, PackedFunc, PairHash> + compiled_packed_funcs_; + // Unique device on which primitives (but not shape functions) will be executed. + // (For simplicity we only run the interpreter on a single device.) Device device_; - // Target parameter being used by the interpreter. + // Unique target describing how to compile for primitives (but not shape functions). Target target_; - // Object stack. + // Call stack. Stack stack_; - // TE-to-TIR lowerer (compiler). - TECompiler compiler_; - // Cache ops that need to be frequently used later to reduce lookup overhead. + // The distinguished 'debug' operator, which is handled specially. const Op& debug_op_; }; -TypedPackedFunc CreateInterpreter(IRModule mod, Device device, Target target) { - if (mod.defined()) { - transform::Sequential seq({// eta expand to support constructors in argument position - transform::EtaExpand( - /*expand_constructor=*/true, /*expand_global_var=*/false), - transform::InferType()}); +/*! + * Lowers all calls to primitives in \p mod appropriate for device and target. Returns the + * rewritten \p mod and target-specific modules containing bindings for all TIR primitive + * functions needed by the rewritten module. + */ +std::pair> Prepare(IRModule mod, Device device, Target target) { + // Run minimal transforms on module to establish invariants needed by interpreter. + transform::Sequential seq({transform::SimplifyInference(), + // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' + // attribute. + transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(), + // eta expand to support constructors in argument position + transform::EtaExpand( + /*expand_constructor=*/true, /*expand_global_var=*/false), + transform::InferType()}); + + transform::PassContext pass_ctx = transform::PassContext::Current(); + With ctx(pass_ctx); + mod = seq(mod); + + // We only have one device-specific target. + tec::TargetMap targets = {{device.device_type, target}}; + + // All calls to primitives will use the unique target. + tec::DeviceMap device_map; + + // No need for a memory plan. + backend::StaticMemoryPlan memory_plan; /*=nullptr*/ + + // Lower all primitive functions reachable from expr. + // TODO(mbs): This should be just another pass in seq above, which requires LoweredModule to + // be merged into IRModule. + LoweredModule lowered_module = + tec::LowerTE(mod, targets, device_map, memory_plan, /*module_name=*/"intrp", + [](Function func) { /* no-op */ }); + return {lowered_module.main_module, lowered_module.per_target_module}; +} + +/*! \brief Check if an expression could be changed by \p Prepare. + * + * If not we can evaluate it directly and don't need to bind it into a fresh module. + */ +class NeedsPreparationVisitor : public ExprVisitor { + public: + bool needs_preparation = false; + + private: + void VisitExpr_(const VarNode* vn) override { + // Could be prim. + needs_preparation = true; + } + // ConstantNode ok + // GlobalVarNode ok + void VisitExpr_(const OpNode* op) override { + // Could be prim. + needs_preparation = true; + } + // TupleNode recurse + void VisitExpr_(const FunctionNode* op) override { + // Could be prim. + needs_preparation = true; + } + // CallNode recurse + void VisitExpr_(const LetNode* ln) override { + // May bind prim. + needs_preparation = true; + } + // IfNode recurse + // TupleGetItemNode recurse + // RefCreateNode recurse + // RefReadNode recurse + // RefWriteNode recurse + // ConstructorNode ok + void VisitExpr_(const MatchNode* op) override { + // Needs eta-expansion. + needs_preparation = true; + } +}; - transform::PassContext pass_ctx = transform::PassContext::Current(); - tvm::With ctx(pass_ctx); - mod = seq(mod); +TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, Device device, + Target target) { + // + // Step 1: Prepare mod. + // + + // If expr is simple enough we can avoid binding it into the module and + // just eval it directly. + NeedsPreparationVisitor visitor; + visitor.VisitExpr(expr); + + Expr expr_to_eval; + IRModule mod_with_expr; // default empty + if (visitor.needs_preparation) { + GlobalVar main; + // Bind expr to a new zero-argument function so it can be prepared along with the module + // (if any). + std::pair mod_and_global; + if (mod.defined()) { + // TODO(mbs): Type inference currently assumes all global functions in modules have + // known result types, and so each global function has it's body types inferred independently + // and in arbitrary order. However, the interpreter may be called with an expression relative + // to a 'main' which has no result type annotation, and that expressions will be bound into a + // fresh global below. Type inference then fails since 'main' has unknown type. We should + // allow inference on mutually recursive global functions. To workaround, infer the type + // of mod now. Obviously that won't work if 'main' itself calls other global functions of + // partial type, but it at least maintains legacy behavior. + transform::PassContext pass_ctx = transform::PassContext::Current(); + With ctx(pass_ctx); + mod = transform::InferType()(mod); + mod_and_global = + IRModule::FromExprInContext(expr, mod->functions, mod->type_definitions, mod->Imports()); + } else { + mod_and_global = IRModule::FromExprInContext(expr); + } + mod_with_expr = mod_and_global.first; + expr_to_eval = mod_and_global.second; + } else { + if (mod.defined()) { + mod_with_expr = mod; + } + // Prepare won't change expr, so we don't need to worry about binding it into a module + // and can just eval it directly. + expr_to_eval = expr; + } + std::pair> main_and_lowered = + Prepare(mod_with_expr, device, target); + std::shared_ptr intrp = std::make_shared( + /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, + target); + + // + // Step 2: Evaluate target function to a closure. + // + ObjectRef object_ref = intrp->Eval(expr_to_eval); + if (const InterpreterClosureObj* closure_obj = object_ref.as()) { + InterpreterClosure closure = GetRef(closure_obj); + ICHECK(closure.defined()); + ICHECK(closure->func.defined()); + + return TypedPackedFunc)>([intrp, closure](Array args) { + // + // Step 3: Apply closure to arguments. + // + ICHECK_NOTNULL(intrp); + ICHECK(closure.defined()); + ICHECK(closure->func.defined()); + Array evaled_args; + for (auto arg : args) { + NeedsPreparationVisitor visitor; + visitor.VisitExpr(arg); + ICHECK(!visitor.needs_preparation) + << "attempting to apply closure to expression which needs preparation: " + << PrettyPrint(arg); + evaled_args.push_back(intrp->Eval(arg)); + } + return intrp->Invoke(closure, evaled_args); + }); + } else { + LOG(FATAL) << "expecting expression to have function type and evaluate to a closure"; + return nullptr; } +} - auto intrp = std::make_shared(mod, device, target); - auto packed = [intrp](Expr expr) { - auto f = DetectFeature(expr); - ICHECK(f.is_subset_of(FeatureSet::All() - fGraph)); - return intrp->Eval(expr); - }; - return TypedPackedFunc(packed); +ObjectRef Eval(Expr expr, Map type_definitions, + std::unordered_set import_set, Device device, Target target) { + std::pair mod_and_global = + IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); + std::pair> main_and_lowered = + Prepare(mod_and_global.first, device, target); + Interpreter intrp( + /*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, + target); + Expr expr_to_eval = main_and_lowered.first->GetGlobalVar(mod_and_global.second->name_hint); + if (expr.as() == nullptr) { + // TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr + // unless it is a function, so we must reverse that in the expression to eval. + // This should done more systematically. + expr_to_eval = Call(expr_to_eval, {}); + } + return intrp.Eval(expr_to_eval); } -TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter").set_body_typed(CreateInterpreter); +TVM_REGISTER_GLOBAL("relay.backend.EvalFunction").set_body_typed(EvalFunction); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 7840960ec268..93fcf73b17a2 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -20,17 +20,15 @@ #include "te_compiler.h" #include -#include +#include #include #include #include #include #include #include -#include #include #include -#include #include #include #include @@ -43,8 +41,6 @@ #include #include -#include "../transforms/pass_utils.h" -#include "te_compiler.h" #include "te_compiler_cache.h" #include "utils.h" @@ -101,6 +97,18 @@ class TECompilerImpl : public TECompilerNode { lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); } + + for (const auto& it : shape_func_cache_) { + auto source_func = it.first; + auto lowered_func = it.second; + auto target = source_func->target; + + if (!lowered_functions.count(target->str())) { + lowered_functions.Set(target->str(), IRModule(Map({}))); + } + + lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs); + } return lowered_functions; } @@ -304,44 +312,94 @@ std::tuple IsDeviceCopy(const Function& func) { return std::tuple(false, -1, -1); } -class LowerTensorExpr : public ExprMutator { +/*! + * \brief Rewrites call expressions to Relay functions marked as 'primitive' + * to calls to the corresponding TIR primitive for the appropriate target. + * + * \code + * let %p = fn(...) { prim_op(...) } + * ... %p(...) ... + * ==> + * (in target-specific module) def @p' = fn (...) { } + * let %p = fn(...) { prim_op(...) } + * ... @p'(...) ... + * \endcode + * + * Requires FuseOps, ToANormalForm, EtaExpand and InferType to have run. + * + * FuseOps is needed to identify and lift all prim op calls: + * \code + * ... prim_op(...) ... + * ==> + * let %p = fn(...) { prim_op(...) } + * ... %p(...) ... + * \endcode + * + * ToANormalForm is needed so we only need to consider vars as the call target. + * (However we'll also allow function literals.) + * + * EtaExpand is needed to ensures all calls to primitives are direct: + * \code + * let %p1 = fn(...) { prim_op1(...) } + * let %p2 = fn(...) { prim_op2(...) } + * let %p = if (...) { %p1 } else { %p2 } + * ... %p(...) ... + * ==> + * let %p1 = fn(...) { prim_op1(...) } + * let %p2 = fn(...) { prim_op2(...) } + * let %p = fn(...) { if (...) { %p1(...) } else { %p2(...) } } + * ... %p(...) ... + * \endcode + */ +class LowerTensorExprMutator : public ExprMutator { public: - LowerTensorExpr(const IRModule& module, const TargetMap& targets, const DeviceMap& device_ctx_map, - ProcessFn process_fn, const String& module_name, TECompiler compiler) + LowerTensorExprMutator(const IRModule& module, const TargetMap& targets, + const DeviceMap& device_ctx_map, ProcessFn process_fn, + const String& module_name, TECompiler compiler) : module_(module), targets_(targets), device_context_map_(device_ctx_map), - process_fn(process_fn), + process_fn_(process_fn), module_name_(module_name), - compiler_(compiler) {} + compiler_(compiler), + debug_op_(Op::Get("debug")) {} - Expr VisitExpr_(const CallNode* call) override { - Call expr = GetRef(call); - Function func; - - if (call->op.as()) { - func = GetRef(call->op.as()); - } else { - return ExprMutator::VisitExpr_(call); - } - - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - // Provide a callback hook which allows one-level up code generators to - // act when we process a function. - this->process_fn(func); - return ExprMutator::VisitExpr_(call); - } - - // Process inputs. - Array args; - for (size_t i = 0; i < expr->args.size(); i++) { - args.push_back(VisitExpr(expr->args[i])); + /*! + * \brief Returns the primitive function associated with \p expr, or + * nullptr if none. + */ + Function ResolveToPrimitive(Expr expr) { + if (const GlobalVarNode* gvn = expr.as()) { + BaseFunc base_func = module_->Lookup(GetRef(gvn)); + return ResolveToPrimitive(base_func); + } else if (const VarNode* vn = expr.as()) { + auto itr = primitive_functions_.find(GetRef(vn)); + return itr == primitive_functions_.end() ? Function() : itr->second; + } else if (const FunctionNode* fn = expr.as()) { + if (!fn->HasNonzeroAttr(attr::kPrimitive)) { + // Not marked as primitive by FuseOps. + return Function(); + } + if (const CallNode* cn = fn->body.as()) { + if (cn->op == debug_op_) { + // Debug 'primitives' are not lowered. + return Function(); + } + } + return GetRef(fn); } + return Function(); + } - Target target; - + /*! + * \brief Lowers the primitive function \p func to TIR for ultimate execution + * on a device with configuration \p target. Returns the global var bound + * to the TIR implementation, and attributes to attach to the call to identify it as + * a TIR call. + */ + std::pair LowerFunction(Function func, Target target) { if (func->GetAttr(attr::kCompiler).defined()) { - target = Target("ext_dev"); + // BYOC flow. CCacheKey key = CCacheKey(func, target); CachedFunc ext_func = compiler_->Lower(key, module_name_); ICHECK(ext_func.defined()) << "Lowering returned undefined function for " @@ -351,43 +409,44 @@ class LowerTensorExpr : public ExprMutator { relay::Function func_with_metadata = func; func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var); func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, "target", ext_func->target); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, ext_func->target); // Provide a callback hook which allows one-level up code generators to // act when we process a function. - this->process_fn(func_with_metadata); + this->process_fn_(func_with_metadata); - auto ret_call = Call(ext_func->prim_fn_var, args, {}); - return std::move(ret_call); + // TODO(mbs): Need TIRCallAttrs or equiv so targets know this is an extern. + // TODO(mbs): Dynamic shapes? + return {ext_func->prim_fn_var, Attrs()}; } - ICHECK_GE(device_context_map_.count(expr), 0) - << "Could not find an entry in the device context map for " << PrettyPrint(expr) - << "The memory planning was either not performed for this precise node, or there is bug " - "in the memory planner."; - - auto& device_context = this->device_context_map_[expr]; - target = GetTargetFromInteger(device_context.device_type, targets_); // Non-External Relay Function + DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n" + << PrettyPrint(func); CCacheKey key = CCacheKey(func, target); CachedFunc lowered_func = compiler_->Lower(key, module_name_); + DLOG(INFO) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; + // Collect all the lowered functions produced for this primitive function. Map prim_fns; - + Array all_prim_fn_vars; for (auto prim_fn : lowered_func->funcs->functions) { CHECK(prim_fn.second.as()) << "must be a prim fn"; prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + all_prim_fn_vars.push_back(prim_fn.first); + DLOG(INFO) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) + << "'"; } // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT relay::Function func_with_metadata = func; func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, "target", lowered_func->target); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target); // Provide a callback hook which allows one-level up code generators to // act when we process a function. - this->process_fn(func_with_metadata); + this->process_fn_(func_with_metadata); auto tir_call_attrs = make_object(); if (func->HasNonzeroAttr(attr::kReshapeOnly)) { @@ -403,27 +462,137 @@ class LowerTensorExpr : public ExprMutator { } tir_call_attrs->metadata.Set("relay_attrs", func->attrs); + tir_call_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); + + if (IsDynamic(func->ret_type)) { + // Also lower the dynamic shape function. + // Shape function keys use the underlying primitive function as their 'function', + // but the generic 'cpu' target as the target since all shape functions run + // on the host cpu irrespective of where the primitive runs. + // TODO(mbs): Cleanup target handling. + Target shape_target("llvm"); + DLOG(INFO) << "lowering to target '" << shape_target->str() + << "' for dynamic shape function for primitive"; + CCacheKey shape_key(func, shape_target); + CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); + // Capture the shape function's global var and parameters 'states' in call + // annotations so calling convention can be recovered. + // TODO(mbs): Capture all this as part of a 'call into TIR' construct once available. + // The way the shape function calling convention is derived and passed to call sites + // via the 'parameter states' could be improved. + tir_call_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); + tir_call_attrs->metadata.Set("prim_shape_fn_states", + lowered_shape_func->shape_func_param_states); + tir_call_attrs->metadata.Set("prim_shape_fn_num_inputs", + Integer(static_cast(lowered_shape_func->inputs.size()))); + tir_call_attrs->metadata.Set("prim_shape_fn_num_outputs", + Integer(static_cast(lowered_shape_func->outputs.size()))); + Array all_prim_shape_fn_vars; + for (auto prim_shape_fn : lowered_shape_func->funcs->functions) { + CHECK(prim_shape_fn.second.as()) << "must be a prim fn"; + all_prim_shape_fn_vars.push_back(prim_shape_fn.first); + } + tir_call_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); + } - Expr ret_call = Call(lowered_func->prim_fn_var, args, Attrs(tir_call_attrs)); - return std::move(ret_call); + return {lowered_func->prim_fn_var, Attrs(tir_call_attrs)}; + } + + Expr VisitExpr_(const LetNode* let) override { + Var var = Downcast(Mutate(let->var)); + Expr value = Mutate(let->value); + Function prim_func = ResolveToPrimitive(value); + if (prim_func.defined()) { + // Remember let var is bound to (possibly indirectly) to a primitive. + primitive_functions_.emplace(let->var, prim_func); + } + Expr body = Mutate(let->body); + if (prim_func.defined()) { + // Leaving let var scope. + primitive_functions_.erase(let->var); + } + if (var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body)) { + return GetRef(let); + } else { + return Let(var, value, body, let->span); + } + } + + Expr VisitExpr_(const CallNode* call) override { + Call expr = GetRef(call); + + // Look for (indirect) calls to primitives. + Function prim_func = ResolveToPrimitive(call->op); + if (!prim_func.defined()) { + // Not a call to a primitive function. + if (const FunctionNode* fn = call->op.as()) { + this->process_fn_(GetRef(fn)); + } + return ExprMutator::VisitExpr_(call); + } + + // Find the desired target device. + Target target; + if (prim_func->GetAttr(attr::kCompiler).defined()) { + // The generic 'external device' target. + target = Target("ext_dev"); + } else if (device_context_map_.empty() && targets_.size() == 1) { + // The unique target. + target = GetTargetFromInteger(kDLCPU, targets_); + } else { + // The target corresponding to the call expression's annotation. + auto itr = device_context_map_.find(expr); + ICHECK(itr != device_context_map_.end()) + << "Could not find an entry in the device context map for " << expr + << "The memory planning was either not performed for this precise node, or there is " + "bug in the memory planner."; + target = GetTargetFromInteger(itr->second.device_type, targets_); + } + + // Lower the primitive function for that target. + std::pair pair = LowerFunction(prim_func, target); + + // Similarly transform arguments. + Array args; + for (const auto& arg : call->args) { + args.push_back(VisitExpr(arg)); + } + + // Replace with direct call to lowered primitive, and attach annotations to record calling + // convention. + return Call(pair.first, args, pair.second); } IRModule module_; TargetMap targets_; DeviceMap device_context_map_; - ProcessFn process_fn; + ProcessFn process_fn_; + // Map from in-scope let-bound variables to Relay functions known to be + // primitive. We'll rewrite these to the fresh global vars bound to the lowered + // primitive function as we go. Those vars will be bound in the + // target device-type specific module we'll ultimately emit for each required + // device-type. Note that a primitive may be lowered for multiple device + // types, each which will be assigned a fresh var. + std::unordered_map + primitive_functions_; String module_name_; TECompiler compiler_; + // Cache ops that need to be frequently used later to reduce lookup overhead. + const Op& debug_op_; }; -/*! - * \brief Obtain the Target from the device type. - * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. - * - * \param dev_type - * \return Target - */ +Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + TECompiler compiler, std::function process_fn) { + runtime::TypedPackedFunc pass_func = + [=](Function func, IRModule module, PassContext ctx) { + LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn, + module_name, compiler); + return Downcast(lower_te.Mutate(func)); + }; + return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); +} + Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) { if (targets.size() == 1) { // The homogeneous execution case, return the only target. @@ -610,7 +779,7 @@ void UpdateFunctionMetadata(Function relay_func, Optional prim_fn_var = relay_func->GetAttr("prim_fn_var"); CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler."; - Optional relay_target = relay_func->GetAttr("target"); + Optional relay_target = relay_func->GetAttr(tvm::attr::kTarget); CHECK(relay_target) << "target must be set on Relay functions by the TECompiler."; for (const auto& kv : prim_fns.value()) { @@ -624,8 +793,8 @@ void UpdateFunctionMetadata(Function relay_func, // Workspace sizes Target prim_fn_target; - if (prim_fn->attrs->dict.count("target")) { - prim_fn_target = Downcast(prim_fn->attrs->dict["target"]); + if (prim_fn->attrs->dict.count(tvm::attr::kTarget)) { + prim_fn_target = Downcast(prim_fn->attrs->dict[tvm::attr::kTarget]); } else { prim_fn_target = relay_target.value(); } @@ -661,27 +830,24 @@ void UpdateFunctionMetadata(Function relay_func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } +// TODO(mbs): Make this an IRModule->IRModule pass by folding LoweredModule back into IRModule. +// Currently we rely on accumulating bindings inside the local TECompiler which we then +// host into the LoweredModule result. LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, backend::StaticMemoryPlan memory_plan, const String& module_name, std::function process_fn) { - TECompiler compiler; - - CHECK_EQ(module->functions.size(), 1) - << "There should only be one function in the module passed to LowerTE"; + DLOG(INFO) << "lowering module:\n" << PrettyPrint(module); - auto pass = CreateFunctionPass( - [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExpr lower_te(module, targets, device_context_map, process_fn, module_name, - compiler); - return Downcast(lower_te.VisitExpr(func)); - }, - 0, "LowerTensorExpr", {}); + TECompiler compiler; - // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize - backend::FunctionInfo func_info = - UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); + backend::FunctionInfo func_info; + if (memory_plan.defined()) { + // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize + func_info = UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info); + } - auto updated_module = pass(module); + auto updated_module = LowerTensorExpr(targets, device_context_map, memory_plan, module_name, + compiler, process_fn)(module); // A temporary solution until we can rewrite the auto-scheduler task extraction code to work // in a more reasonable way. diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 7cae7fcd4b09..d0e83765928a 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -41,6 +41,7 @@ #include #include +#include "../op/memory/memory.h" #include "../transforms/pass_utils.h" #include "utils.h" @@ -120,21 +121,10 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator Array fn_inputs; for (Var param : prim_func->params) { Array inputs; - if (const auto* ttype = param->checked_type().as()) { + for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); fn_inputs.push_back(tensor); inputs.push_back(tensor); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - // TODO(@icemelon): Allow recursive tuple - ICHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - fn_inputs.push_back(tensor); - inputs.push_back(tensor); - } } memo_[param] = inputs; } @@ -314,6 +304,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { + // TODO(mbs): Generalize to be equivalent to FlattenTupleType. ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; Array res = VisitExpr(field); ICHECK_EQ(res.size(), 1); @@ -372,7 +363,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array data_inputs; Array shape_inputs; - auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { + for (const auto& ttype : FlattenTupleType(param->checked_type())) { // Add data placeholder Shape shape = GetShape(ttype->shape); tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype); @@ -385,20 +376,6 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64)); shape_inputs.push_back(shape_tensor); - }; - - if (const auto* ttype = param->checked_type().as()) { - add_placeholder(ttype); - } else { - // flatten tuple of tensor type. - const auto* tuple_type = param->type_as(); - // TODO(@icemelon): Support recursive tuple - ICHECK(tuple_type); - for (Type field : tuple_type->fields) { - const auto* ttype = field.as(); - ICHECK(ttype); - add_placeholder(ttype); - } } param_data_[param] = data_inputs; param_shapes_[param] = shape_inputs; diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 1c7511ffd7d2..47ba96b2c77e 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -213,6 +213,7 @@ CachedFunc PrimFuncFor(const Function& source_func, const Target& target, CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, std::function renamer); +// TODO(mbs): Bring name uniqification under control -- this is replicated in quite a few places. std::string GetUniqueName(std::string name, std::unordered_map* name_map); // implementations diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 57603035b848..d545518c1c3c 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -229,35 +229,16 @@ class ConstantFolder : public MixedModeMutator { } // Constant evaluate an expression. Expr ConstEvaluate(Expr expr) { - std::vector passes = {transform::FuseOps(0), transform::ToANormalForm(), - transform::InferType()}; - Function func; - if (expr.as()) { - func = Downcast(expr); - } else { - // TODO(@jroesch): fix this - func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); - } - auto mod = IRModule({}, module_->type_definitions, module_->Imports()); - auto global = GlobalVar("main"); - mod->Add(global, func); - auto seq = transform::Sequential(passes); - mod = seq(mod); - auto entry_func = Downcast(mod->Lookup("main")); - expr = expr.as() == nullptr ? entry_func->body : entry_func; - - using tvm::transform::PassContext; Device dev; dev.device_type = kDLCPU; dev.device_id = 0; Target target = Target("llvm"); - // use a fresh build context - // in case we are already in a build context. + + // use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) - With fresh_build_ctx(PassContext::Create()); + With fresh_build_ctx(transform::PassContext::Create()); - FInterpreter executor = CreateInterpreter(mod, dev, target); - return ObjectToExpr(executor(expr)); + return ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), dev, target)); } // Evaluate a call to the shape_of operator for tensors with constant diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 9572faf08714..ccdd9c92cc27 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -526,6 +526,8 @@ bool StatefulOp(const Expr& e) { using FInterpreter = runtime::TypedPackedFunc; +Target CPUTarget() { return Target("llvm"); } + Device CPUDevice() { Device dev; dev.device_type = kDLCPU; @@ -533,17 +535,6 @@ Device CPUDevice() { return dev; } -FInterpreter CPUInterpreter() { - using tvm::transform::PassContext; - - Target target = Target("llvm"); - // use a fresh build context - // in case we are already in a build context. - With fresh_build_ctx(PassContext::Create()); - - return CreateInterpreter(IRModule(nullptr), CPUDevice(), target); -} - using FuncId = int; /*! @@ -904,13 +895,9 @@ class PartialEvaluator : public ExprFunctor // Constant evaluate an expression. PStatic ConstEvaluate(const Expr& expr, LetList* ll) { - std::vector passes = {transform::FuseOps(0), transform::InferType()}; - auto mod = IRModule::FromExpr(expr); - auto seq = transform::Sequential(passes); - mod = seq(mod); - auto entry_func = Downcast(mod->Lookup("main")); - auto fused_infered = expr.as() == nullptr ? entry_func->body : entry_func; - return Reify(executor_(fused_infered), ll); + // use a fresh build context in case we are already in a build context. + With fresh_build_ctx(transform::PassContext::Create()); + return Reify(Eval(expr, mod_->type_definitions, mod_->Imports(), CPUDevice(), CPUTarget()), ll); } Func ConstEvaluateFunc(const Expr& expr) { @@ -1137,7 +1124,6 @@ class PartialEvaluator : public ExprFunctor std::unordered_map fuel_map_; Store store_; Device device_ = CPUDevice(); - FInterpreter executor_ = CPUInterpreter(); }; /*! \brief Remap multiple Var sharing the same Id into the same Var. */ diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 47f038b5c612..33a87c9a2be2 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -112,9 +112,14 @@ class CUDADeviceAPI final : public DeviceAPI { ICHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; void* ret; if (dev.device_type == kDLCUDAHost) { + DLOG(INFO) << "allocating " << nbytes << "bytes on host"; CUDA_CALL(cudaMallocHost(&ret, nbytes)); } else { CUDA_CALL(cudaSetDevice(dev.device_id)); + size_t free_mem, total_mem; + CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem)); + DLOG(INFO) << "allocating " << nbytes << " bytes on device, with " << free_mem + << " bytes currently free out of " << total_mem << " bytes available"; CUDA_CALL(cudaMalloc(&ret, nbytes)); } return ret; @@ -122,9 +127,11 @@ class CUDADeviceAPI final : public DeviceAPI { void FreeDataSpace(Device dev, void* ptr) final { if (dev.device_type == kDLCUDAHost) { + DLOG(INFO) << "freeing host memory"; CUDA_CALL(cudaFreeHost(ptr)); } else { CUDA_CALL(cudaSetDevice(dev.device_id)); + DLOG(INFO) << "freeing device memory"; CUDA_CALL(cudaFree(ptr)); } } @@ -280,5 +287,16 @@ TVM_REGISTER_GLOBAL("profiling.timer.gpu").set_body_typed([](Device dev) { return Timer(make_object()); }); +TVM_DLL String GetCudaFreeMemory() { + size_t free_mem, total_mem; + CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem)); + std::stringstream ss; + ss << "Current CUDA memory is " << free_mem << " bytes free out of " << total_mem + << " bytes on device"; + return ss.str(); +} + +TVM_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory); + } // namespace runtime } // namespace tvm diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 2089ead98168..0382b8071de7 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -172,6 +172,9 @@ std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; + DLOG(INFO) << "verifying memory for target '" << target.value()->str() << "' for primitive\n" + << PrettyPrint(func); + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->kind->device_type); diff --git a/tests/crt/aot_executor_test.cc b/tests/crt/aot_executor_test.cc index ded6729d138b..e8afa133f42d 100644 --- a/tests/crt/aot_executor_test.cc +++ b/tests/crt/aot_executor_test.cc @@ -75,7 +75,7 @@ TEST(AOTRuntime, Identity) { void* outputs[] = {outputs1}; ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&identity_model, inputs, outputs)); - ASSERT_EQ(outputs1[0], 404); + ASSERT_EQ(outputs1[0], 404U); } int32_t add_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, TVMValue* out_ret_value, @@ -103,7 +103,7 @@ TEST(AOTRuntime, Add) { void* outputs[] = {outputs1}; ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&add_model, inputs, outputs)); - ASSERT_EQ(outputs1[0], 904); + ASSERT_EQ(outputs1[0], 904U); } int32_t multiple_inputs_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, @@ -135,7 +135,7 @@ TEST(AOTRuntime, MultipleInputs) { void* outputs[] = {outputs1}; ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&multiple_inputs_model, inputs, outputs)); - ASSERT_EQ(outputs1[0], 1306); + ASSERT_EQ(outputs1[0], 1306U); } int32_t multiple_outputs_run_func(TVMValue* args, int* arg_type_ids, int32_t num_args, @@ -167,8 +167,8 @@ TEST(AOTRuntime, MultipleOutputs) { void* outputs[] = {outputs1, outputs2}; ASSERT_EQ(kTvmErrorNoError, tvm_runtime_run(&multiple_outputs_model, inputs, outputs)); - ASSERT_EQ(outputs1[0], 404); - ASSERT_EQ(outputs2[0], 500); + ASSERT_EQ(outputs1[0], 404U); + ASSERT_EQ(outputs2[0], 500U); } int main(int argc, char** argv) { diff --git a/tests/crt/framing_test.cc b/tests/crt/framing_test.cc index 5ee226dc5ee7..0161afe31696 100644 --- a/tests/crt/framing_test.cc +++ b/tests/crt/framing_test.cc @@ -174,7 +174,7 @@ TEST_F(UnframerTest, PacketTooLong) { EXPECT_EQ(write_stream_.capacity(), bytes_consumed); EXPECT_EQ(kTvmErrorNoError, unframer_.Write((uint8_t*)&crc, sizeof(crc), &bytes_consumed)); - EXPECT_EQ(2, bytes_consumed); // 2, because framer is now in kFindPacketStart. + EXPECT_EQ(2UL, bytes_consumed); // 2, because framer is now in kFindPacketStart. EXPECT_FALSE(write_stream_.packet_done()); EXPECT_FALSE(write_stream_.is_valid()); EXPECT_EQ(std::string((char*)long_payload, write_stream_.capacity()), @@ -210,7 +210,7 @@ TEST_P(UnframerTestParameterized, TestByteAtATime) { EXPECT_EQ(kTvmErrorNoError, unframer_.Write(reinterpret_cast(&GetParam()->wire[i]), 1, &bytes_consumed)); - EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(1UL, bytes_consumed); EXPECT_EQ(i == wire_size - 1, write_stream_.packet_done()); } EXPECT_TRUE(write_stream_.is_valid()); @@ -247,7 +247,7 @@ TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) { unframer_.Reset(); write_stream_.Reset(); EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), 1, &bytes_consumed)); - EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(1UL, bytes_consumed); EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), wire_size, &bytes_consumed)); EXPECT_EQ(wire_size, bytes_consumed); EXPECT_TRUE(write_stream_.packet_done()); @@ -265,13 +265,13 @@ TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) { // Interrupt the packet transmission. The first byte will return no error as it is the escape // byte. EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), 1, &bytes_consumed)); - EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(1UL, bytes_consumed); EXPECT_FALSE(write_stream_.packet_done()); // Secondt byte will return a short packet error. EXPECT_EQ(kTvmErrorFramingShortPacket, unframer_.Write(&GetParam()->wire_data()[1], 1, &bytes_consumed)); - EXPECT_EQ(0, bytes_consumed); + EXPECT_EQ(0UL, bytes_consumed); EXPECT_FALSE(write_stream_.packet_done()); EXPECT_EQ(kTvmErrorNoError, @@ -291,7 +291,7 @@ TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) { // the internal state. EXPECT_EQ(kTvmErrorFramingShortPacket, unframer_.Write(GetParam()->wire_data(), wire_size, &bytes_consumed)); - EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(1UL, bytes_consumed); EXPECT_FALSE(write_stream_.packet_done()); EXPECT_EQ(kTvmErrorNoError, unframer_.Write(&GetParam()->wire_data()[1], wire_size - 1, &bytes_consumed)); diff --git a/tests/crt/memory_test.cc b/tests/crt/memory_test.cc index b531383058e6..d8465c8e8743 100644 --- a/tests/crt/memory_test.cc +++ b/tests/crt/memory_test.cc @@ -70,7 +70,7 @@ TEST_F(MemoryManagerTest, AllocFreeFifo) { } else { EXPECT_PAGE(kNumUsablePages - 1 - idx, a); } - EXPECT_EQ(interface->vleak_size, idx + 1); + EXPECT_EQ(static_cast(interface->vleak_size), idx + 1); ptrs[idx] = a; } diff --git a/tests/crt/session_test.cc b/tests/crt/session_test.cc index 9840f55dc685..14de3089fb34 100644 --- a/tests/crt/session_test.cc +++ b/tests/crt/session_test.cc @@ -158,7 +158,7 @@ TEST_F(SessionTest, NormalExchange) { bob_.WriteTo(&alice_); EXPECT_TRUE(alice_.sess.IsEstablished()); - ASSERT_EQ(alice_.messages_received.size(), 1); + ASSERT_EQ(alice_.messages_received.size(), 1UL); EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, "")); alice_.ClearBuffers(); @@ -167,7 +167,7 @@ TEST_F(SessionTest, NormalExchange) { "\xFF\xFD\b\0\0\0\x82" "f\x10hello\x90("); alice_.WriteTo(&bob_); - ASSERT_EQ(bob_.messages_received.size(), 2); + ASSERT_EQ(bob_.messages_received.size(), 2UL); EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, "")); EXPECT_EQ(bob_.messages_received[1], ReceivedMessage(MessageType::kNormal, "hello")); @@ -177,7 +177,7 @@ TEST_F(SessionTest, NormalExchange) { "\xff\xfd\b\0\0\0\x82" "f\x10ollehLv"); bob_.WriteTo(&alice_); - ASSERT_EQ(alice_.messages_received.size(), 1); + ASSERT_EQ(alice_.messages_received.size(), 1UL); EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kNormal, "olleh")); alice_.ClearBuffers(); @@ -186,13 +186,13 @@ TEST_F(SessionTest, NormalExchange) { alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast("log1"), 4); EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4"); alice_.WriteTo(&bob_); - ASSERT_EQ(bob_.messages_received.size(), 1); + ASSERT_EQ(bob_.messages_received.size(), 1UL); EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1")); bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast("zero"), 4); EXPECT_FRAMED_PACKET(bob_, "\xff\xfd\a\0\0\0\0\0\x03zero\xb2h"); bob_.WriteTo(&alice_); - ASSERT_EQ(alice_.messages_received.size(), 1); + ASSERT_EQ(alice_.messages_received.size(), 1UL); EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero")); } @@ -200,13 +200,13 @@ TEST_F(SessionTest, LogBeforeSessionStart) { alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast("log1"), 4); EXPECT_FRAMED_PACKET(alice_, "\xfe\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4"); alice_.WriteTo(&bob_); - ASSERT_EQ(bob_.messages_received.size(), 1); + ASSERT_EQ(bob_.messages_received.size(), 1UL); EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1")); bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast("zero"), 4); EXPECT_FRAMED_PACKET(bob_, "\xfe\xff\xfd\a\0\0\0\0\0\x03zero\xb2h"); bob_.WriteTo(&alice_); - ASSERT_EQ(alice_.messages_received.size(), 1); + ASSERT_EQ(alice_.messages_received.size(), 1UL); EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero")); } diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index 1c2d00aed866..121edc4b8c60 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -50,8 +50,9 @@ def run_onnx(onnx_model, input_data): def run_relay(func, data_tuple): target = "llvm" dev = tvm.device("llvm", 0) - intrp = relay.create_executor("graph", device=dev, target=target) - relay_res = intrp.evaluate(func)(*data_tuple) + relay_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + *data_tuple + ) result = [] relay_res = relay_res if isinstance(relay_res, list) else [relay_res] diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py index 84cff57d1d94..075085ff8806 100644 --- a/tests/python/contrib/test_onnx_model.py +++ b/tests/python/contrib/test_onnx_model.py @@ -59,9 +59,12 @@ def get_data(in_data_shapes, dtype="float32"): def run_relay(mod, params, in_data): target = "llvm" dev = tvm.device("llvm", 0) - intrp = relay.create_executor("graph", mod, device=dev, target=target) in_data = [tvm.nd.array(value) for value in in_data.values()] - return intrp.evaluate()(*in_data, **params).numpy() + return ( + relay.create_executor("graph", mod, device=dev, target=target) + .evaluate()(*in_data, **params) + .numpy() + ) def _verify_results(mod, params, in_data): diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 082ded704faa..f40b3368dc85 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -110,12 +110,16 @@ def run_and_verify_func(config, target="cuda"): with tvm.transform.PassContext( opt_level=3, config={"relay.ext.tensorrt.options": config} ): - exec = relay.create_executor(mode, mod=mod, device=dev, target=target) + func = relay.create_executor( + mode, mod=mod, device=dev, target=target + ).evaluate() else: with tvm.transform.PassContext(opt_level=3): - exec = relay.create_executor(mode, mod=mod, device=dev, target=target) + func = relay.create_executor( + mode, mod=mod, device=dev, target=target + ).evaluate() if not skip_runtime_test(): - result_dict[result_key] = exec.evaluate()(**input_dict, **params) + result_dict[result_key] = func(**input_dict, **params) if not skip_runtime_test(): assert_result_dict_holds(result_dict) @@ -143,12 +147,16 @@ def compile_and_run(mod, params, i_data, mode="vm", use_trt=True): with tvm.transform.PassContext( opt_level=3, config={"relay.ext.tensorrt.options": config} ): - exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda") + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() else: with tvm.transform.PassContext(opt_level=3): - exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda") + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() - res = exec.evaluate()(i_data, **params) if not skip_runtime_test() else None + res = func(i_data, **params) if not skip_runtime_test() else None return res dtype = "float32" @@ -198,16 +206,16 @@ def test_tensorrt_simple(): with tvm.transform.PassContext( opt_level=3, config={"relay.ext.tensorrt.options": config} ): - relay_exec = relay.create_executor( + func = relay.create_executor( mode, mod=mod, device=tvm.cuda(0), target="cuda" - ) + ).evaluate() else: with tvm.transform.PassContext(opt_level=3): - relay_exec = relay.create_executor( + func = relay.create_executor( mode, mod=mod, device=tvm.cuda(0), target="cuda" - ) + ).evaluate() if not skip_runtime_test(): - result_dict[result_key] = relay_exec.evaluate()(x_data, y_data, z_data) + result_dict[result_key] = func(x_data, y_data, z_data) if not skip_runtime_test(): assert_result_dict_holds(result_dict) @@ -247,9 +255,11 @@ def test_tensorrt_not_compatible(): mod, config = tensorrt.partition_for_tensorrt(mod) for mode in ["graph", "vm"]: with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}): - exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda") + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() if not skip_runtime_test(): - results = exec.evaluate()(x_data) + results = func(x_data) def test_tensorrt_serialize_graph_executor(): @@ -741,12 +751,12 @@ def test_run(x_data_list, x_shape, new_shape, should_offload_to_trt): assert are_ops_on_trt(mod, op_list=["reshape"]) == should_offload_to_trt if not skip_runtime_test(): with relay.build_config(opt_level=3): - relay_exec = relay.create_executor( + func = relay.create_executor( "vm", mod=mod, device=tvm.cpu(0), target="llvm" - ) + ).evaluate() for i, x_data in enumerate(x_data_list): - result_arr[i][use_trt] = relay_exec.evaluate()(x_data) + result_arr[i][use_trt] = func(x_data) if not skip_runtime_test(): for i in range(len(x_data_list)): @@ -1244,10 +1254,11 @@ def test_tensorrt_dynamic_batch(): if not skip_runtime_test(): with relay.build_config(opt_level=3): - relay_exec = relay.create_executor("vm", mod=mod, device=tvm.cpu(0), target="llvm") - + func = relay.create_executor( + "vm", mod=mod, device=tvm.cpu(0), target="llvm" + ).evaluate() for i, batch_size in enumerate(batches_to_test): - result_arr[i][use_trt] = relay_exec.evaluate()(x_data[:batch_size, ...]) + result_arr[i][use_trt] = func(x_data[:batch_size, ...]) if not skip_runtime_test(): for i in range(len(batches_to_test)): @@ -1280,13 +1291,11 @@ def test_tensorrt_dynamic_batch_conv(): with tvm.transform.PassContext( opt_level=3, config={"relay.ext.tensorrt.options": config} ): - relay_exec = relay.create_executor( + func = relay.create_executor( "vm", mod=mod, device=tvm.device(target), target=target - ) + ).evaluate() for i, batch_size in enumerate(batches_to_test): - result_arr[i][target][use_trt] = relay_exec.evaluate()( - x_data[:batch_size, ...], **params - ) + result_arr[i][target][use_trt] = func(x_data[:batch_size, ...], **params) if not skip_runtime_test(): for i in range(len(batches_to_test)): for target in ["llvm", "cuda"]: @@ -1434,9 +1443,11 @@ def test_empty_subgraph(): x_data = np.random.uniform(-1, 1, x_shape).astype("float32") for mode in ["graph", "vm"]: with tvm.transform.PassContext(opt_level=3): - exec = relay.create_executor(mode, mod=mod, device=tvm.cuda(0), target="cuda") + func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() if not skip_runtime_test(): - results = exec.evaluate()(x_data) + results = func(x_data) if __name__ == "__main__": diff --git a/tests/python/contrib/test_vitis_ai/test_vitis_ai_runtime_cpu_part.py b/tests/python/contrib/test_vitis_ai/test_vitis_ai_runtime_cpu_part.py index db9552c8eab2..f414d7d71fcc 100644 --- a/tests/python/contrib/test_vitis_ai/test_vitis_ai_runtime_cpu_part.py +++ b/tests/python/contrib/test_vitis_ai/test_vitis_ai_runtime_cpu_part.py @@ -59,10 +59,12 @@ def test_extern_vitis_ai_resnet18(): mod, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=1) ref_mod, params = relay.testing.resnet.get_workload(num_layers=18, batch_size=1) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)) i_data = np.random.uniform(0, 1, ishape).astype(dtype) - ref_res = ref_ex.evaluate()(i_data, **params) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( + i_data, **params + ) + verify_result( mod, {"data": i_data}, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 1b9ae38075d7..44aa93061a62 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -333,8 +333,9 @@ def test_forward_where(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, args, auxs) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(np_cond, np_x, np_y) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + np_cond, np_x, np_y + ) tvm.testing.assert_allclose(op_res.numpy(), mx_out) @@ -357,8 +358,9 @@ def verify(start, stop, step): mod, _ = relay.frontend.from_mxnet(mx_sym, {}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()() + op_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()() tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify(0, 20, None) @@ -416,8 +418,9 @@ def test_forward_broadcast_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np, b_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np, b_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -451,8 +454,9 @@ def test_forward_elemwise_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np, b_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np, b_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -500,8 +504,9 @@ def test_forward_unary_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) @@ -530,8 +535,9 @@ def test_forward_scalar_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) for op in ["maximum", "minimum"]: dtype = "float32" @@ -544,8 +550,9 @@ def test_forward_scalar_ops(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -558,8 +565,9 @@ def verify(shape, axis, begin, end): mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((3, 4), 0, 1, 2) @@ -583,8 +591,9 @@ def verify(x_shape, y_shape, axes): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": x_shape, "y": y_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np, y_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np, y_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((3, 4), (2, 3), None) @@ -617,8 +626,9 @@ def verify(shape, seq_lengths, use_seq_lengths, seq_axis): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*in_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *in_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((3, 4), [1, 2, 3, 1], True, 0) @@ -653,8 +663,9 @@ def test_forward_logistic_regression_output(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -670,8 +681,9 @@ def verify(a_shape, b_shape, transpose_b=False): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np, b_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np, b_np + ) tvm.testing.assert_allclose( op_res.numpy(), ref_res.asnumpy(), rtol=1e-05, atol=1e-05 ) @@ -689,8 +701,9 @@ def verify(shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1,)) @@ -711,8 +724,9 @@ def verify(shape, axis): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 3, 1), None) @@ -731,8 +745,9 @@ def verify(shape, axis, size): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()(x_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 2, 1), 2, 3) @@ -748,8 +763,9 @@ def verify(input_shape, shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 2, 3), (3, 2, 3)) @@ -766,8 +782,9 @@ def verify(input_shape, like_shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape, "y": like_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np, y_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np, y_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 2, 3), (3, 2, 3)) @@ -785,8 +802,9 @@ def test_forward_logical_not(): mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) @@ -801,8 +819,9 @@ def verify(val, shape, dtype): # Skip testing graph executor because this op will be optimized out # by constant folding. for kind in ["debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()() + op_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()() tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify(2, (3, 4), "float32") @@ -825,8 +844,9 @@ def verify(data_shape, weight_shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": data_shape, "w": weight_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x=x_np, w=w_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x=x_np, w=w_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((2, 2), (4, 5)) @@ -852,8 +872,9 @@ def verify(shape, indices_src, axis, mode="clip"): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape, "y": indices_np.shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np, indices_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np, indices_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((2, 2), [[[1, 0], [0, 1]]], 0) @@ -876,8 +897,9 @@ def verify(xshape, yshape, y_data, error=False): ) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data, y_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) @@ -905,8 +927,9 @@ def verify(shape, transform_type, target_shape): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) verify((4, 6), "affine", (16, 32)) @@ -925,8 +948,9 @@ def verify(data_shape, grid_shape): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data, grid) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data, grid + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) verify((4, 4, 16, 32), (4, 2, 8, 8)) @@ -988,8 +1012,9 @@ def verify( for target, dev in tvm.testing.enabled_targets(): # only test graph executor because debug runtime is too slow for kind in ["graph"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(**inputs, **params) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + **inputs, **params + ) if init_states: assert len(op_res) == len(mx_res) for i, val in enumerate(op_res): @@ -1022,11 +1047,11 @@ def verify(xshape, yshape, offset=None): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) + func = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate() if offset is None or offset == (0, 0): - op_res = intrp.evaluate()(x_data, y_data) + op_res = func(x_data, y_data) else: - op_res = intrp.evaluate()(x_data) + op_res = func(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((1, 3, 40, 40), (1, 3, 20, 20)) @@ -1045,8 +1070,9 @@ def verify(shape, axis, is_ascend, dtype="float32"): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((2, 3, 4), axis=0, is_ascend=False) @@ -1076,8 +1102,9 @@ def verify(shape, k, axis, ret_type, is_ascend=None, dtype="float32"): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) if isinstance(ref_res, list): assert len(op_res) == len(ref_res) for i, t in enumerate(op_res): @@ -1133,11 +1160,11 @@ def verify(shape, use_sequence_length, value, axis, dtype, itype): if use_sequence_length is False and kind == "graph": # Disable the test for 'graph' when it's identity. continue - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) + func = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate() if use_sequence_length: - op_res = intrp.evaluate()(data_np, valid_length_np) + op_res = func(data_np, valid_length_np) else: - op_res = intrp.evaluate()(data_np) + op_res = func(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((5, 10), True, 0.0, 0, "float32", "float32") @@ -1155,8 +1182,9 @@ def verify(shape): mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify((3, 4)) @@ -1203,8 +1231,9 @@ def verify(shape, axis=1, fix_gamma=False): # print(mod) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, gamma, beta, moving_mean, moving_var) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, gamma, beta, moving_mean, moving_var + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3) verify((2, 3, 4, 5)) @@ -1227,8 +1256,9 @@ def verify(shape, axis=1, epsilon=1e-5): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, gamma, beta) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, gamma, beta + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=2e-5, atol=1e-5) verify((2, 3, 4, 5)) @@ -1251,8 +1281,9 @@ def verify(shape, axis=-1): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, gamma, beta) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, gamma, beta + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((2, 5)) @@ -1279,8 +1310,9 @@ def verify(shape, num_groups=1): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, gamma, beta) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, gamma, beta + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 4, 2), num_groups=4) @@ -1300,8 +1332,9 @@ def verify(indices_shape, depth, on_value, off_value, dtype): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x.astype("float32")) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x.astype("float32") + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((3,), 3, 1, 0, "int32") @@ -1426,8 +1459,9 @@ def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False) mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, weight, bias) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, weight, bias + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3) verify(data_shape=(1, 1, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) @@ -1507,8 +1541,9 @@ def verify(data_shape, kernel_size, stride, pad, num_filter): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x, weight, bias) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x, weight, bias + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify(data_shape=(1, 1, 1024 * 16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4) @@ -1540,8 +1575,9 @@ def verify(a_np, b_np): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["debug", "vm"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np, b_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np, b_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3) verify(np.asarray([1.0], "float32"), np.asarray([2.0], "float32")) @@ -1559,8 +1595,9 @@ def verify(from_dtype, to_dtype): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(from_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + from_np + ) assert op_res.dtype == to_dtype, op_res.dtype tvm.testing.assert_allclose(op_res.numpy(), from_np.astype(to_dtype)) @@ -1582,8 +1619,9 @@ def verify(dtypes, cast_narrow, expected_dtype): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict, dtype_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*x_nps) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *x_nps + ) for i, res in enumerate(op_res): assert res.dtype == expected_dtype, res.dtype tvm.testing.assert_allclose(res.numpy(), x_nps[i].astype(expected_dtype)) @@ -1607,8 +1645,9 @@ def verify(x, shape, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(a_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + a_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) for dtype in ["int32", "int64"]: @@ -1648,8 +1687,9 @@ def verify(shape, blocksize=2): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 18, 3, 3), 3) @@ -1667,8 +1707,9 @@ def verify(shape, blocksize=2): mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 1, 9, 9), 3) @@ -1703,8 +1744,9 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data1, data2) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data1, data2 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify( @@ -1808,8 +1850,9 @@ def verify(data_shape, start=None, step=None, axis=None): mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()() + op_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()() tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy()) verify(data_shape=(3,), start=0.0, step=1.0) @@ -1830,8 +1873,9 @@ def verify(batch, seq_length, num_heads, head_dim): mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) verify(1, 10, 3, 16) @@ -1855,8 +1899,9 @@ def verify(batch, seq_length, num_heads, head_dim): mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape, "weight": weight_shape}) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data=data_np, weight=weight_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data=data_np, weight=weight_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) verify(1, 10, 4, 16) @@ -1912,8 +1957,9 @@ def verify( ): target += " -libs=thrust" for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 10, 6)) @@ -1951,8 +1997,9 @@ def verify(data_shape, anchor_shape, stds=[1, 1, 1, 1], clip=-1, in_format="corn mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data, anchors) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data, anchors + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) verify((1, 10, 4), (1, 10, 4)) @@ -1991,11 +2038,11 @@ def verify(data_shape, axis, use_length, length): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) + func = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate() if use_length: - op_res = intrp.evaluate()(x, length) + op_res = func(x, length) else: - op_res = intrp.evaluate()(x) + op_res = func(x) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) @@ -2031,8 +2078,7 @@ def test_forward_npi_pad(data_shape, pad_width, mode, dtype, constant_value, tar ref_res = np.pad(data_np, mode=mode, pad_width=pad_width) mx_sym = mx.sym.np.pad(data.as_np_ndarray(), mode=mode, pad_width=pad_width) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -2050,8 +2096,7 @@ def test_forward_npi_transpose(data_shape, axes, dtype, target, dev, kind): ref_res = mx.np.transpose(mx.np.array(data_np), axes=axes) mx_sym = mx.sym.np.transpose(data.as_np_ndarray(), axes=axes) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2078,8 +2123,9 @@ def test_forward_npi_concatenate(data_shape1, data_shape2, axis, dtype, target, mod, _ = relay.frontend.from_mxnet( mx_sym, shape={"data1": data_shape1, "data2": data_shape2}, dtype=dtype ) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1, data_np2) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np1, data_np2 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2106,8 +2152,9 @@ def test_forward_npi_stack(data_shape1, data_shape2, axis, dtype, target, dev, k mod, _ = relay.frontend.from_mxnet( mx_sym, shape={"data1": data_shape1, "data2": data_shape2}, dtype=dtype ) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1, data_np2) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np1, data_np2 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2121,8 +2168,7 @@ def test_forward_np_copy(data_shape, dtype, target, dev, kind): ref_res = mx.np.copy(mx.np.array(data_np)) mx_sym = mx.sym.np.copy(data.as_np_ndarray()) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2149,8 +2195,7 @@ def test_forward_npx_reshape(data_shape, out_shape, dtype, target, reverse, dev, ref_res = mx.npx.reshape(mx.np.array(data_np), newshape=out_shape, reverse=reverse) mx_sym = mx.sym.npx.reshape(data.as_np_ndarray(), newshape=out_shape, reverse=reverse) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2184,8 +2229,9 @@ def test_forward_npi_binary(data_shape, dtype, target, dev, kind): mod, _ = relay.frontend.from_mxnet( mx_sym, shape={"lhs": data_shape, "rhs": data_shape}, dtype=dtype ) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1, data_np2) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np1, data_np2 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2216,8 +2262,9 @@ def test_forward_npi_binary_scalar(data_shape, dtype, scalar, target, dev, kind) ref_res = ref_op(mx.np.array(data_np1), scalar) mx_sym = mx_op(data1.as_np_ndarray(), scalar) mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"lhs": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np1 + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2233,8 +2280,7 @@ def test_forward_npi_tanh(data_shape, dtype, target, dev, kind): ref_res = mx.np.tanh(mx.np.array(data_np1)) mx_sym = mx.sym.np.tanh(data1.as_np_ndarray()) mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np1) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np1) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2265,8 +2311,9 @@ def test_forward_npi_where_rscalar( mod, _ = relay.frontend.from_mxnet( mx_sym, shape={"condition": cond_shape, "x": data_shape}, dtype=dtypeDic ) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(cond_np, data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + cond_np, data_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.asnumpy(), rtol=1e-5) @@ -2294,8 +2341,7 @@ def test_forward_split_v2( data.as_nd_ndarray(), indices_or_sections, axis=axis, squeeze_axis=squeeze_axis ) mod, _ = relay.frontend.from_mxnet(mx_sym, {"data": data_shape}, dtype=dtype) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()(data_np) op_res_ = [] for arr in op_res: op_res_.append(arr.numpy().tolist()) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 93b9cfa07464..f3270ab11daf 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -74,8 +74,9 @@ def get_tvm_output_with_vm( if convert_to_static: mod = relay.transform.DynamicToStatic()(mod) - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) - result = ex.evaluate()(*input_data, **params) + result = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()( + *input_data, **params + ) if isinstance(result, tvm.runtime.NDArray): return result.numpy() return [r.numpy() for r in result] @@ -656,8 +657,7 @@ def test_dynamic_gather(target, dev): mod, params = relay.frontend.from_onnx(model) - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) - result = ex.evaluate()(x, **params) + result = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()(x, **params) tvm.testing.assert_allclose(out_np, result.numpy(), rtol=1e-5, atol=1e-5) @@ -1249,7 +1249,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None): def verify_simple_dynamic_model(a_shape, b_shape, target, dev): - def verify_model(ex, a_shape, b_shape): + def verify_model(model, a_shape, b_shape): a_array = np.random.uniform(size=a_shape).astype("float32") b_array = np.random.uniform(size=b_shape).astype("float32") # matmul @@ -1257,7 +1257,7 @@ def verify_model(ex, a_shape, b_shape): # relu out_np[out_np < 0] = 0 - tvm_out = ex.evaluate()(a_array, b_array).numpy() + tvm_out = model(a_array, b_array).numpy() tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5) mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) @@ -1284,10 +1284,10 @@ def verify_model(ex, a_shape, b_shape): b_anys = [relay.Any()] * len(b_shape) mod, params = relay.frontend.from_onnx(model, {"a": a_anys, "b": b_anys}) - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) - verify_model(ex, a_shape, b_shape) - verify_model(ex, [a * 2 for a in a_shape], [b * 2 for b in b_shape]) - verify_model(ex, [a * 3 for a in a_shape], [b * 3 for b in b_shape]) + model = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate() + verify_model(model, a_shape, b_shape) + verify_model(model, [a * 2 for a in a_shape], [b * 2 for b in b_shape]) + verify_model(model, [a * 3 for a in a_shape], [b * 3 for b in b_shape]) # TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e2cb51a9596a..bae7c1b5498c 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2240,8 +2240,7 @@ def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=["llv print("Running on target", tgt) dev = tvm.device(tgt, 0) - executor = relay.create_executor("vm", mod=mod, device=dev, target=tgt) - evaluator = executor.evaluate() + evaluator = relay.create_executor("vm", mod=mod, device=dev, target=tgt).evaluate() # Inference for name, inp in zip(input_names, input_data): diff --git a/tests/python/frontend/pytorch/test_lstm.py b/tests/python/frontend/pytorch/test_lstm.py index 1aa8bff4076e..25d4563ee64e 100644 --- a/tests/python/frontend/pytorch/test_lstm.py +++ b/tests/python/frontend/pytorch/test_lstm.py @@ -221,9 +221,9 @@ def assert_equal(tvm_result, torch_result): def run_and_compare(mod, params, pt_result, target, device): - executor = relay.create_executor("vm", mod=mod, device=device, target=target) - evaluator = executor.evaluate() - exec_res = evaluator(**params) + exec_res = relay.create_executor("vm", mod=mod, device=device, target=target).evaluate()( + **params + ) def flatten(nested): res = [] diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index c91661db7e36..49dc5170c52f 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -34,8 +34,7 @@ def check_equal(graph, tf_out, input_map=None): mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) if input_map is not None: params.update(input_map) - ex = relay.create_executor("vm", mod=mod) - relay_out = ex.evaluate()(**params) + relay_out = relay.create_executor("vm", mod=mod).evaluate()(**params) if isinstance(relay_out, nd.NDArray): np.testing.assert_allclose(tf_out, relay_out.numpy()) else: diff --git a/tests/python/frontend/tensorflow/test_debugging.py b/tests/python/frontend/tensorflow/test_debugging.py index 26fe171fb789..0e08840e56ee 100644 --- a/tests/python/frontend/tensorflow/test_debugging.py +++ b/tests/python/frontend/tensorflow/test_debugging.py @@ -28,8 +28,7 @@ def run_relay(graph, shape_dict=None, *vars): mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True), shape=shape_dict) - ex = relay.create_executor("debug", mod=mod) - return ex.evaluate()(*vars) + return relay.create_executor("debug", mod=mod).evaluate()(*vars) def test_assert_true(): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index e1bc79b45503..655f20949d75 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -147,7 +147,6 @@ def run_tvm_graph( ) dev = tvm.device(target, 0) if mode == "debug": - ex = relay.create_executor(mode, mod=mod, device=tvm.cpu(), target="llvm") inputs = [] for param in mod["main"].params: found = False @@ -159,7 +158,9 @@ def run_tvm_graph( # Interpreter doesn't bind constants, so still need to find in params if not found: inputs.append(tvm.nd.array(params[param.name_hint])) - result = ex.evaluate()(*inputs) + result = relay.create_executor(mode, mod=mod, device=tvm.cpu(), target="llvm").evaluate()( + *inputs + ) return vmobj_to_list(result) elif mode == "vm": with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): diff --git a/tests/python/frontend/tensorflow/test_no_op.py b/tests/python/frontend/tensorflow/test_no_op.py index 38246ea5e14f..d8bfcee9673b 100644 --- a/tests/python/frontend/tensorflow/test_no_op.py +++ b/tests/python/frontend/tensorflow/test_no_op.py @@ -26,8 +26,7 @@ def run_relay(graph): mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) - ex = relay.create_executor("debug", mod=mod) - return ex.evaluate()(**params) + return relay.create_executor("debug", mod=mod).evaluate()(**params) def test_no_op(): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ef03906db884..7b7f1b1c43b8 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -189,7 +189,6 @@ def run_tvm_graph( ) if mode in ["debug", "vm"]: - ex = relay.create_executor(mode, mod=mod, device=tvm.cpu(), target="llvm") inputs = [] for param in mod["main"].params: found = False @@ -201,7 +200,9 @@ def run_tvm_graph( # Interpreter doesn't bind constants, so still need to find in params if not found: inputs.append(tvm.nd.array(params[param.name_hint])) - result = ex.evaluate()(*inputs) + result = relay.create_executor(mode, mod=mod, device=tvm.cpu(), target="llvm").evaluate()( + *inputs + ) return vmobj_to_list(result) else: with tvm.transform.PassContext(opt_level=3): diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index ad9a0ecd4e59..0f47ce02db49 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -47,10 +47,9 @@ def verify_more_dynamic_broadcast_to(x_shape, out_shape): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate(func)( - x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type) - ) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate( + func + )(x, np.array(x_shape).astype(shape_type), np.array(out_shape).astype(shape_type)) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_more_dynamic_broadcast_to((4, 3), (3, 4, 3)) @@ -73,8 +72,9 @@ def verify_broadcast_to(x_shape, out_shape): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate(func)(x, np.array(out_shape).astype(shape_type)) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate( + func + )(x, np.array(out_shape).astype(shape_type)) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_broadcast_to((1,), (1, 1, 1)) @@ -103,8 +103,9 @@ def test_dyn_broadcast_to(): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate(func)( + x, np.array(dyn_shape).astype(shape_type) + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -136,8 +137,9 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - out_relay = intrp.evaluate()(indices_np, np.array(depth).astype("int32")) + out_relay = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()(indices_np, np.array(depth).astype("int32")) tvm.testing.assert_allclose(out_relay.numpy(), out_np) _verify((3,), 3, 1, 0, -1, "int32") diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index a6ea609be1e2..fd7ab7002806 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -60,8 +60,7 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()( + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32") ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) @@ -127,8 +126,7 @@ def verify_upsampling3d( for target, dev in enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()( + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( x_data, np.array(scale_d).astype("float32"), np.array(scale_h).astype("float32"), diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 3673f08cf8b2..d2ad5a47f15b 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -31,8 +31,9 @@ def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets() for target, dev in target_device: for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *data + ) if isinstance(op_res, tvm.runtime.container.ADT): assert len(op_res) == len( ref_res diff --git a/tests/python/relay/dyn/test_dynamic_op_level4.py b/tests/python/relay/dyn/test_dynamic_op_level4.py index f5afbd7588fd..2a4606fcf93f 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level4.py +++ b/tests/python/relay/dyn/test_dynamic_op_level4.py @@ -66,8 +66,9 @@ def verify(dshape, begin, end, strides, slice_mode="end", test_ref=True, dtype=" return for target, dev in tvm.testing.enabled_targets(): mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor("vm", mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*data) + op_res = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()( + *data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify( diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index d3459afaab06..c29ea2cd392f 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -64,8 +64,9 @@ def verify_resize2d(dshape, scale, method, layout): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data, size) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_data, size + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) for method in ["linear", "nearest_neighbor"]: diff --git a/tests/python/relay/dyn/test_dynamic_op_level6.py b/tests/python/relay/dyn/test_dynamic_op_level6.py index 03823062eab7..530c402b2947 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level6.py +++ b/tests/python/relay/dyn/test_dynamic_op_level6.py @@ -55,8 +55,9 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(np_data, np.array([k]).astype("float32")) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + np_data, np.array([k]).astype("float32") + ) if ret_type == "both": tvm.testing.assert_allclose(op_res[0].numpy(), np_values) tvm.testing.assert_allclose(op_res[1].numpy(), np_indices) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 51f46799e606..8cf31f94378e 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -34,7 +34,14 @@ def count(e): dev = tvm.device("llvm", 0) -intrp = create_executor(mod=prelude.mod, device=dev, target="llvm") + + +def eval(expr): + # CAUTION: These tests re-process the entire prelude for each test expression. + # Hoisting the create_executor won't improve that since preprocessing won't begin + # until the evaluate. + return create_executor(mod=prelude.mod, device=dev, target="llvm").evaluate(expr) + nat, z, s = prelude.mod.get_type("nat") @@ -139,7 +146,7 @@ def get_scalar(tv): # @tvm.testing.uses_gpu def test_nat_value(): assert count(make_nat_value(p, 10)) == 10 - assert count(intrp.evaluate(s(s(z())))) == 2 + assert count(eval(s(s(z())))) == 2 @tvm.testing.uses_gpu @@ -158,14 +165,14 @@ def test_nat_constructor(): @tvm.testing.uses_gpu def test_double(): assert prelude.mod[double].checked_type == relay.FuncType([nat()], nat()) - res = intrp.evaluate(double(s(z()))) + res = eval(double(s(z()))) assert count(res) == 2 @tvm.testing.uses_gpu def test_add(): assert prelude.mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) - res = intrp.evaluate(add(s(z()), s(z()))) + res = eval(add(s(z()), s(z()))) assert count(res) == 2 @@ -187,7 +194,7 @@ def test_hd_tl(): got = [] for i in range(len(expected)): - got.append(count(intrp.evaluate(hd(l)))) + got.append(count(eval(hd(l)))) l = tl(l) assert got == expected @@ -202,7 +209,7 @@ def test_nth(): for i in range(len(expected)): nth = prelude.mod.get_global_var("nth") - item = intrp.evaluate(nth(l, relay.const(i))) + item = eval(nth(l, relay.const(i))) assert get_scalar(item) == i @@ -220,7 +227,7 @@ def test_update(): got = [] for i in range(len(expected)): - got.append(count(intrp.evaluate(nth(l, relay.const(i))))) + got.append(count(eval(nth(l, relay.const(i))))) assert got == expected @@ -231,7 +238,7 @@ def test_length(): assert prelude.mod[length].checked_type == relay.FuncType( [rlist(a)], relay.scalar_type("int32"), [a] ) - res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil()))))) + res = eval(length(cons(z(), cons(z(), cons(z(), nil()))))) assert get_scalar(res) == 3 @@ -245,7 +252,7 @@ def test_map(): x = relay.Var("x") add_one = relay.Function([x], s(x)) - res = intrp.evaluate(map(add_one, cons(z(), cons(z(), nil())))) + res = eval(map(add_one, cons(z(), cons(z(), nil())))) ones = to_list(res) assert len(ones) == 2 assert count(ones[0]) == 1 and count(ones[1]) == 1 @@ -263,7 +270,7 @@ def test_foldl(): x = relay.Var("x") y = relay.Var("y") rev_dup = relay.Function([y, x], cons(x, cons(x, y))) - res = intrp.evaluate( + res = eval( foldl( rev_dup, nil(), @@ -291,7 +298,7 @@ def test_foldr(): x = relay.Var("x") y = relay.Var("y") identity = relay.Function([x, y], cons(x, y)) - res = intrp.evaluate( + res = eval( foldr( identity, nil(), @@ -316,7 +323,7 @@ def test_foldr1(): x = relay.Var("x") y = relay.Var("y") f = relay.Function([x, y], add(x, y)) - res = intrp.evaluate( + res = eval( foldr1( f, cons( @@ -334,7 +341,7 @@ def test_sum(): assert prelude.mod[sum].checked_type == relay.FuncType( [rlist(relay.scalar_type("int32"))], relay.scalar_type("int32") ) - res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil())))) + res = eval(sum(cons(relay.const(1), cons(relay.const(2), nil())))) assert get_scalar(res) == 3 @@ -345,7 +352,7 @@ def test_concat(): l1 = cons(make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), nil())) l2 = cons(make_nat_expr(prelude, 3), cons(make_nat_expr(prelude, 4), nil())) - res = intrp.evaluate(concat(l1, l2)) + res = eval(concat(l1, l2)) catted = to_list(res) assert len(catted) == 4 @@ -379,7 +386,7 @@ def test_filter(): ], ), ) - res = intrp.evaluate( + res = eval( filter( greater_than_one, cons( @@ -416,7 +423,7 @@ def test_zip(): ) l2 = cons(nil(), cons(cons(nil(), nil()), cons(cons(nil(), cons(nil(), nil())), nil()))) - res = intrp.evaluate(zip(l1, l2)) + res = eval(zip(l1, l2)) zipped = to_list(res) assert len(zipped) == 3 assert count(zipped[0][0]) == 1 @@ -428,7 +435,7 @@ def test_zip(): # test truncation l3 = cons(make_nat_expr(prelude, 4), cons(make_nat_expr(prelude, 5), nil())) - shorter_res = intrp.evaluate(zip(l3, l2)) + shorter_res = eval(zip(l3, l2)) truncated = to_list(shorter_res) assert len(truncated) == 2 assert count(truncated[0][0]) == 4 @@ -437,7 +444,7 @@ def test_zip(): assert len(to_list(truncated[1][1])) == 1 l4 = cons(nil(), nil()) - shortest_res = intrp.evaluate(zip(l3, l4)) + shortest_res = eval(zip(l3, l4)) singleton = to_list(shortest_res) assert len(singleton) == 1 assert count(singleton[0][0]) == 4 @@ -449,7 +456,7 @@ def test_rev(): a = relay.TypeVar("a") assert prelude.mod[rev].checked_type == relay.FuncType([rlist(a)], rlist(a), [a]) - res = intrp.evaluate( + res = eval( rev( cons( make_nat_expr(prelude, 1), @@ -488,7 +495,7 @@ def test_unfoldr(): ), ) - res = intrp.evaluate(unfoldr(count_down, make_nat_expr(prelude, 3))) + res = eval(unfoldr(count_down, make_nat_expr(prelude, 3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -520,7 +527,7 @@ def test_unfoldl(): ), ) - res = intrp.evaluate(unfoldl(count_down, make_nat_expr(prelude, 3))) + res = eval(unfoldl(count_down, make_nat_expr(prelude, 3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -549,7 +556,7 @@ def test_map_accumr(): make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), ) - res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals)) + res = eval(map_accumr(add_acc_to_each, z(), vals)) sum = count(res[0]) new_vals = to_list(res[1]) @@ -581,7 +588,7 @@ def test_map_accuml(): make_nat_expr(prelude, 1), cons(make_nat_expr(prelude, 2), cons(make_nat_expr(prelude, 3), nil())), ) - res = intrp.evaluate(map_accuml(add_to_acc, z(), vals)) + res = eval(map_accuml(add_to_acc, z(), vals)) sum = count(res[0]) new_vals = to_list(res[1]) @@ -609,7 +616,7 @@ def test_optional_matching(): ), ) - res = intrp.evaluate( + res = eval( foldr( condense, nil(), @@ -636,9 +643,7 @@ def test_tmap(): x = relay.Var("x") add_one = relay.Function([x], s(x)) - res = intrp.evaluate( - tmap(add_one, rose(z(), cons(rose(z(), nil()), cons(rose(z(), nil()), nil())))) - ) + res = eval(tmap(add_one, rose(z(), cons(rose(z(), nil()), cons(rose(z(), nil()), nil()))))) tree_dict = tree_to_dict(res) assert count(tree_dict["member"]) == 1 @@ -657,7 +662,7 @@ def test_size(): root = rose(z(), cons(rose(z(), nil()), cons(rose(z(), nil()), nil()))) t = rose(z(), cons(root, cons(root, cons(root, nil())))) - res = intrp.evaluate(size(t)) + res = eval(size(t)) assert get_scalar(res) == 10 @@ -666,7 +671,7 @@ def test_wildcard_match_solo(): x = relay.Var("x", nat()) copy = relay.Function([x], relay.Match(x, [relay.Clause(relay.PatternWildcard(), x)]), nat()) - res = intrp.evaluate(copy(s(s(s(z()))))) + res = eval(copy(s(s(s(z()))))) assert count(res) == 3 @@ -690,7 +695,7 @@ def test_wildcard_match_order(): nat(), ) - res = intrp.evaluate(return_zero(cons(s(z()), nil()))) + res = eval(return_zero(cons(s(z()), nil()))) # wildcard pattern is evaluated first assert count(res) == 0 @@ -744,7 +749,7 @@ def test_nested_matches(): ) final_list = cons(first_list, cons(second_list, nil())) - res = intrp.evaluate(flatten(final_list)) + res = eval(flatten(final_list)) flat = to_list(res) assert len(flat) == 6 @@ -758,8 +763,8 @@ def test_match_full_var(): v = relay.Var("v") id_func = relay.Function([x], relay.Match(x, [relay.Clause(relay.PatternVar(v), v)])) - res1 = intrp.evaluate(id_func(nil())) - res2 = intrp.evaluate(id_func(cons(z(), cons(z(), nil())))) + res1 = eval(id_func(nil())) + res2 = eval(id_func(cons(z(), cons(z(), nil())))) empty = to_list(res1) assert len(empty) == 0 @@ -794,7 +799,7 @@ def test_nested_pattern_match(): ) get_second = relay.Function([x], match) - res = intrp.evaluate(get_second(cons(s(z()), cons(s(s(z())), nil())))) + res = eval(get_second(cons(s(z()), cons(s(s(z())), nil())))) assert count(res) == 2 @@ -804,14 +809,14 @@ def test_compose(): n = relay.Var("n") inc = relay.Function([n], s(n)) x = relay.Var("x") - res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))])) + res = eval(relay.Call(compose(inc, double), [s(s(z()))])) assert count(res) == 5 @tvm.testing.uses_gpu def test_iterate(): expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(prelude, 3)]) - res = intrp.evaluate(relay.Function([], expr)()) + res = eval(relay.Function([], expr)()) assert count(res) == 12 diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index b2b251862a21..6430e6aa2116 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -58,8 +58,7 @@ def check_result( continue if kind == "debug" and (only_vm or dev.device_type != tvm.cpu().device_type): continue - ex = relay.create_executor(kind, mod=mod, device=dev, target=tgt) - result = ex.evaluate()(*args) + result = relay.create_executor(kind, mod=mod, device=dev, target=tgt).evaluate()(*args) if isinstance(result, tvm.runtime.container.ADT): result = [r.numpy() for r in result] else: @@ -851,8 +850,9 @@ def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, r mod["main"] = relay.Function([data], y.astuple()) data_np = np.random.uniform(size=static_data_shape).astype(dtype) for kind in ["vm"]: - ex = relay.create_executor(kind, mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) + result = relay.create_executor(kind, mod=mod, device=tvm.cpu(), target="llvm").evaluate()( + data_np + ) for ret, ref_ret in zip(result, ref_out_shape): assert ret.numpy().shape == ref_ret, "Shape mismatch: expect %s but got %s." % ( str(ref_ret), @@ -941,8 +941,9 @@ def verify_any_batch_matmul( for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - z = intrp.evaluate()(x_np, y_np) + z = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_np, y_np + ) tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5) diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index 7beac197fb3a..c6f2748e9ec8 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -41,10 +41,8 @@ def check_rts(expr, args, expected_result, mod=None): expected_result: The expected result of running the expression. """ - intrp = relay.create_executor("debug", mod=mod) - graph = relay.create_executor("graph", mod=mod) - eval_result = intrp.evaluate(expr)(*args) - rts_result = graph.evaluate(expr)(*args) + eval_result = relay.create_executor("debug", mod=mod).evaluate(expr)(*args) + rts_result = relay.create_executor("graph", mod=mod).evaluate(expr)(*args) tvm.testing.assert_allclose(eval_result.numpy(), rts_result.numpy()) tvm.testing.assert_allclose(eval_result.numpy(), expected_result) @@ -295,10 +293,9 @@ def test_graph_executor_nested_tuples(): out = relay.Tuple([x, relay.Tuple([y, relay.Tuple([z, w])])]) func = relay.Function([x, y, z, w], out) - exe = relay.create_executor( + f = relay.create_executor( kind="graph", mod=tvm.IRModule.from_expr(func), device=tvm.cpu(0), target="llvm" - ) - f = exe.evaluate() + ).evaluate() data = [np.random.uniform(size=(2, 3)).astype("float32") for _ in "xyzw"] out = f(*data) diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index d65bcad3364d..af2dcf32c305 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -15,27 +15,26 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm -from tvm import te -import tvm.testing +from tvm import testing from tvm import nd from tvm import relay from tvm.runtime import container from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.scope_builder import ScopeBuilder -from tvm.relay import testing, create_executor def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): # TODO(tqchen) add more types once the schedule register is fixed. for target in ["llvm"]: dev = tvm.device(target, 0) - if not tvm.testing.device_enabled(target): + if not testing.device_enabled(target): return - intrp = create_executor(mod=mod, device=dev, target=target) - result = intrp.evaluate(expr)(*args) - # use tvm.testing which also set atol - tvm.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) + func = relay.create_executor(mod=mod, device=dev, target=target).evaluate(expr) + result = func if args is None else func(*args) + # use testing which also set atol + testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) def test_tuple_value(): @@ -146,10 +145,9 @@ def test_ref(): def test_binds(): x = relay.var("x") y = relay.add(x, x) - intrp = create_executor("debug") xx = np.ones((10, 20)) - res = intrp.evaluate(y, binds={x: xx}).numpy() - tvm.testing.assert_allclose(xx + xx, res) + res = relay.create_executor().evaluate(y, binds={x: xx}).numpy() + testing.assert_allclose(xx + xx, res) def test_kwargs_params(): @@ -161,15 +159,13 @@ def test_kwargs_params(): y_data = np.random.rand(1, 10).astype("float32") z_data = np.random.rand(1, 10).astype("float32") params = {"y": y_data, "z": z_data} - intrp = create_executor("debug") - res = intrp.evaluate(f)(x_data, **params) - tvm.testing.assert_allclose(res.numpy(), x_data + y_data + z_data) + res = relay.create_executor().evaluate(f)(x_data, **params) + testing.assert_allclose(res.numpy(), x_data + y_data + z_data) def test_function_taking_adt_ref_tuple(): mod = tvm.IRModule() prelude = relay.prelude.Prelude(mod) - intrp = create_executor("debug", mod) _, cons, nil = prelude.mod.get_type("List") nil_value = ConstructorValue(nil.tag, [], nil) @@ -184,7 +180,7 @@ def test_function_taking_adt_ref_tuple(): [nd.array(np.random.rand(1, 10).astype("float32")) for _ in range(10)] ) - id_func = intrp.evaluate(prelude.id) + id_func = relay.create_executor(mod=mod).evaluate(prelude.id) res_nil = id_func(nil_value) assert res_nil.tag == nil_value.tag @@ -193,17 +189,17 @@ def test_function_taking_adt_ref_tuple(): res_cons = id_func(cons_value) assert res_cons.tag == cons_value.tag assert len(res_cons.fields) == len(cons_value.fields) - tvm.testing.assert_allclose(res_cons.fields[0].numpy(), cons_value.fields[0].numpy()) + testing.assert_allclose(res_cons.fields[0].numpy(), cons_value.fields[0].numpy()) assert isinstance(res_cons.fields[1], ConstructorValue) assert res_cons.fields[1].tag == nil.tag assert len(res_cons.fields[1].fields) == 0 res_ref = id_func(ref_value) - tvm.testing.assert_allclose(res_ref.value.numpy(), ref_value.value.numpy()) + testing.assert_allclose(res_ref.value.numpy(), ref_value.value.numpy()) res_tuple = id_func(tuple_value) for i in range(10): - tvm.testing.assert_allclose(res_tuple[i].numpy(), tuple_value[i].numpy()) + testing.assert_allclose(res_tuple[i].numpy(), tuple_value[i].numpy()) def test_tuple_passing(): @@ -222,28 +218,72 @@ def test_tuple_passing(): dev = tvm.cpu() target = tvm.target.Target("llvm") - exec = relay.create_executor(mod=mod, device=dev, target=target) - f = exec.evaluate(gv) + f = relay.create_executor(mod=mod, device=dev, target=target).evaluate(gv) # First use a Python tuple. out = f((10, 8)) - tvm.testing.assert_allclose(out.numpy(), np.array(10)) + testing.assert_allclose(out.numpy(), np.array(10)) # Second use a tuple value. value_tuple = container.tuple_object([nd.array(np.array(11)), nd.array(np.array(12))]) out = f(value_tuple) - tvm.testing.assert_allclose(out.numpy(), np.array(11)) + testing.assert_allclose(out.numpy(), np.array(11)) + + +def test_dynamic(): + n = 3 + m = 2 + x = relay.Var("x", relay.TensorType([relay.Any(), m], "float32")) + y = relay.Var("y", relay.TensorType([relay.Any(), m], "float32")) + xx = x - relay.expr.const(3.0) + yy = y * relay.expr.const(5.0) + z = relay.op.concatenate([xx, yy], axis=0) + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + x_np = np.random.uniform(size=(n, m)).astype("float32") + y_np = np.random.uniform(size=(n, m)).astype("float32") + expected = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0) + check_eval(None, [x_np, y_np], expected, mod) + + +def test_ref_global_from_expr(): + n = 3 + x = relay.Var("x", relay.TensorType([n], "float32")) + y = relay.Var("y", relay.TensorType([n], "float32")) + mod = tvm.IRModule() + mod["add"] = relay.Function([x, y], relay.add(x, y)) + x_np = np.random.uniform(size=(n,)).astype("float32") + y_np = np.random.uniform(size=(n,)).astype("float32") + expected = np.add(x_np, y_np) + expr = relay.Call(mod.get_global_var("add"), [relay.const(x_np), relay.const(y_np)]) + check_eval(expr, None, expected, mod) + + +def test_keyword_args(): + n = 3 + x = relay.Var("x", relay.TensorType([n], "float32")) + y = relay.Var("y", relay.TensorType([n], "float32")) + z = relay.add(x, y) + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + x_np = np.random.uniform(size=(n,)).astype("float32") + y_np = np.random.uniform(size=(n,)).astype("float32") + expected = np.add(x_np, y_np) + actual = relay.create_executor(mod=mod).evaluate()(y=y_np, x=x_np) + testing.assert_allclose(actual.numpy(), expected) + + +# TODO(mbs): Support? Would help reduce wasted work when we need to prepare +# multiple functions w.r.t. the same module. +@pytest.mark.skip(reason="closures are currently not directly Python callable") +def test_functional_returns(): + n = 3 + x = relay.Var("x", relay.TensorType([n], "float32")) + f = relay.Function([x], x) + t = relay.Tuple([f, f]) + c = np.random.rand(n).astype("float32") + result1, result2 = relay.create_executor().evaluate(t) + testing.assert_allclose(result1(c).numpy(), c) + testing.assert_allclose(result2(c).numpy(), c) if __name__ == "__main__": - test_id() - test_add_const() - test_equal() - test_subtract() - test_simple_loop() - test_loop() - test_binds() - test_kwargs_params() - test_ref() - test_tuple_value() - test_tuple_getitem() - test_function_taking_adt_ref_tuple() - test_tuple_passing() + pytest.main([__file__]) diff --git a/tests/python/relay/test_debug.py b/tests/python/relay/test_debug.py index c4ed657701ae..61557867f070 100644 --- a/tests/python/relay/test_debug.py +++ b/tests/python/relay/test_debug.py @@ -23,7 +23,6 @@ def test_debug(): global _test_debug_hit - ex = create_executor() x = var("x", shape=(), dtype="int32") _test_debug_hit = False @@ -32,7 +31,7 @@ def did_exec(x): _test_debug_hit = True prog = debug(x, debug_func=did_exec) - result = ex.evaluate(prog, {x: const(1, "int32")}) + result = create_executor().evaluate(prog, {x: const(1, "int32")}) assert _test_debug_hit assert result.numpy() == 1 @@ -40,7 +39,6 @@ def did_exec(x): def test_debug_with_expr(): global _test_debug_hit _test_debug_hit = False - ex = create_executor() x = var("x", shape=(), dtype="int32") _test_debug_hit = False @@ -49,6 +47,6 @@ def did_exec(x): _test_debug_hit = True prog = debug(x + x * x, debug_func=did_exec) - result = ex.evaluate(prog, {x: const(2, "int32")}) + result = create_executor().evaluate(prog, {x: const(2, "int32")}) assert _test_debug_hit assert result.numpy() == 6 diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 69e556791e5b..c05f39164531 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -322,8 +322,9 @@ def test_extern_dnnl(check_result): i_data = np.random.uniform(0, 1, ishape).astype(dtype) w_data = np.random.uniform(0, 1, w1shape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()) - ref_res = ref_ex.evaluate()(i_data, w_data, w_data) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()).evaluate()( + i_data, w_data, w_data + ) check_result( mod, {"data0": i_data, "weight0": w_data}, (1, 32, 14, 14), ref_res.numpy(), tol=1e-5 ) @@ -363,8 +364,7 @@ def test_extern_dnnl_const(check_result): i_data = np.random.uniform(0, 1, ishape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()) - ref_res = ref_ex.evaluate()(i_data) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()).evaluate()(i_data) check_result(mod, {"data0": i_data}, (1, 32, 14, 14), ref_res.numpy(), tol=1e-5) diff --git a/tests/python/relay/test_memory_passes.py b/tests/python/relay/test_memory_passes.py index 7ad72a35a1a0..bed17dbbd830 100644 --- a/tests/python/relay/test_memory_passes.py +++ b/tests/python/relay/test_memory_passes.py @@ -32,13 +32,15 @@ def check_memory_plan(func, check_fn): data = np.random.rand(*sh).astype(param.dtype) args.append(tvm.nd.array(data)) - # Compute without memory planning. + # TODO(mbs): Why does the executor need to be shared? Seems wrong. ex = relay.create_executor("vm", mod) - no_plan_result = ex.evaluate(mod["main"])(*args) + + # Compute without memory planning. + no_plan_result = ex.evaluate()(*args) # Compute with memory planning. with tvm.transform.PassContext(opt_level=1, disabled_pass=["MemoryPlan"]): - plan_result = ex.evaluate(mod["main"])(*args) + plan_result = ex.evaluate()(*args) # Compute Python result. py_res = check_fn(*[arg.numpy() for arg in args]) diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 686c0ea556c3..11099ffe50ee 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -54,8 +54,9 @@ def check_single_op(opfunc, ref, dtype): bwd_func = run_infer_type(gradient(fwd_func)) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad, _) = intrp.evaluate(bwd_func)(data, grad_in) + op_res, (op_grad, _) = relay.create_executor(device=dev, target=target).evaluate( + bwd_func + )(data, grad_in) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) for opfunc, ref in [ @@ -105,8 +106,9 @@ def check_binary_op(opfunc, ref, dtype): bwd_func = run_infer_type(gradient(fwd_func)) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad0, op_grad1) = intrp.evaluate(bwd_func)(x_data, y_data) + op_res, (op_grad0, op_grad1) = relay.create_executor( + device=dev, target=target + ).evaluate(bwd_func)(x_data, y_data) np.testing.assert_allclose(op_grad0.numpy(), ref_grad0, rtol=0.01) np.testing.assert_allclose(op_grad1.numpy(), ref_grad1, rtol=0.01) diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index c8a94683eec4..115ed48d5888 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -51,8 +51,9 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode): ) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) + op_res, (op_grad,) = relay.create_executor(device=dev, target=target).evaluate(bwd_func)( + data + ) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) @@ -100,8 +101,9 @@ def verify_avg_pool2d_grad( ) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) + op_res, (op_grad,) = relay.create_executor(device=dev, target=target).evaluate( + bwd_func + )(data) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) @@ -156,8 +158,9 @@ def verify_global_avg_pool2d_grad(x_shape): ) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) + op_res, (op_grad,) = relay.create_executor(device=dev, target=target).evaluate(bwd_func)( + data + ) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) diff --git a/tests/python/relay/test_op_grad_level3.py b/tests/python/relay/test_op_grad_level3.py index ae3fc2641a25..30d849853d87 100644 --- a/tests/python/relay/test_op_grad_level3.py +++ b/tests/python/relay/test_op_grad_level3.py @@ -41,8 +41,9 @@ def test_clip(): bwd_func = run_infer_type(gradient(fwd_func)) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - op_res, (op_grad,) = intrp.evaluate(bwd_func)(data) + op_res, (op_grad,) = relay.create_executor(device=dev, target=target).evaluate( + bwd_func + )(data) np.testing.assert_allclose(op_grad.numpy(), ref_grad, rtol=0.01) @@ -181,8 +182,9 @@ def test_zeros_ones_grad_dynamic(): bwd_func = run_infer_type(gradient(run_infer_type(fwd_func))) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(device=dev, target=target) - res, (grad,) = intrp.evaluate(bwd_func)(dyn_shape) + res, (grad,) = relay.create_executor(device=dev, target=target).evaluate(bwd_func)( + dyn_shape + ) tvm.testing.assert_allclose(res.numpy(), op_ref(dyn_shape, dtype="float32")) tvm.testing.assert_allclose(grad.numpy(), np.zeros((rank,), dtype="int32")) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index cbc3e7fbd1e5..97e10eb25a95 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -70,8 +70,9 @@ def check_single_op(opfunc, ref, dtype): and not have_fp16(tvm.cuda(0).compute_version) ): continue - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) for opfunc, ref in [ @@ -132,8 +133,9 @@ def check_binary_op(opfunc, ref, dtype): and not have_fp16(tvm.cuda(0).compute_version) ): continue - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01, atol=1e-3) for opfunc, ref in [ @@ -163,8 +165,7 @@ def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): continue data = np.random.uniform(size=dshape).astype(dtype) ref_res = data.reshape(oshape) - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) for dtype in ["float16", "float32"]: @@ -196,8 +197,9 @@ def test_bias_add(): and not have_fp16(tvm.cuda(0).compute_version) ): continue - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol) @@ -240,8 +242,9 @@ def test_softmax(): x_data = np.random.uniform(size=shape).astype(dtype) ref_res = tvm.topi.testing.softmax_python(x_data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -261,8 +264,9 @@ def test_log_softmax(): x_data = np.random.uniform(size=shape).astype(dtype) ref_res = tvm.topi.testing.log_softmax_python(x_data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data + ) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -317,11 +321,13 @@ def test_concatenate(): and not have_fp16(tvm.cuda(0).compute_version) ): continue - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data, y_data, t_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data, t_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=0.01) - op_res2 = intrp2.evaluate(func)(x_data, y_data, t_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x_data, y_data, t_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=0.01) @@ -341,8 +347,7 @@ def test_dropout(): func = relay.Function([], y) for target, dev in tvm.testing.enabled_targets(): for backend in ["debug", "graph"]: - intrp = relay.create_executor("debug", device=dev, target=target) - op_res = intrp.evaluate(func)() + op_res = relay.create_executor("debug", device=dev, target=target).evaluate(func)() tvm.testing.assert_allclose(op_res.numpy(), in_np, rtol=0.01) @@ -461,11 +466,13 @@ def test_matmul(): ref_res = np.dot(x_data.transpose(), w_data) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data, w_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, w_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data, w_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x_data, w_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -521,11 +528,13 @@ def test_dense(): ref_res = np.dot(x_data, w_data.T) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data, w_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, w_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data, w_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x_data, w_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 23012f5afc9b..f796abe5e7d7 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -40,9 +40,10 @@ def test_checkpoint(): inputs = [np.random.uniform() for _ in range(len(xs))] for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - f_res = intrp.evaluate(f)(*inputs) - f_checkpoint_res = intrp.evaluate(f_checkpoint)(*inputs) + f_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(*inputs) + f_checkpoint_res = relay.create_executor(kind, device=dev, target=target).evaluate( + f_checkpoint + )(*inputs) tvm.testing.assert_allclose(f_res.numpy(), f_checkpoint_res.numpy(), 0, 0) @@ -172,8 +173,7 @@ def test_collapse_sum_like(): ref_res = np.sum(x, 0) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x, y) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x, y) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -192,8 +192,7 @@ def test_collapse_sum_to(): ref_res = np.sum(x, 0) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -212,8 +211,7 @@ def test_broadcast_to(): ref_res = np.broadcast_to(x, shape_like) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -236,8 +234,7 @@ def test_broadcast_to_like(): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x, y) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x, y) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -281,8 +278,9 @@ def verify_slice_like(data, slice_like, axes, output, dtype="float32"): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) @@ -315,8 +313,9 @@ def verify_reverse_reshape(shape, newshape, oshape): ref_res = np.reshape(x_data, oshape) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_reverse_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2)) @@ -340,8 +339,7 @@ def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32", trans_x=Fa for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - z = intrp.evaluate(func)(x_np, y_np) + z = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x_np, y_np) tvm.testing.assert_allclose(z.numpy(), z_np, rtol=1e-5) @@ -374,8 +372,7 @@ def test_shape_of(): # Because using graph executor, this op will be optimized after # constant folding pass, here we only test with interpreter for kind in ["debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), np.array(shape).astype("int32")) @@ -390,8 +387,9 @@ def verify_ndarray_size(shape): ref_res = np.size(x_data) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify_ndarray_size((2, 3, 5)) @@ -408,8 +406,9 @@ def verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc): np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - relay_out = intrp1.evaluate(func)(np_data) + relay_out = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_data + ) tvm.testing.assert_allclose(relay_out.numpy(), np_out, rtol=1e-5, atol=1e-5) @@ -469,8 +468,9 @@ def _verify(data_shape, mask_value, axis, dtype, itype): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - out_relay = intrp.evaluate(func)(data_np, valid_length_np) + out_relay = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np, valid_length_np + ) tvm.testing.assert_allclose(out_relay.numpy(), gt_out_np) _verify((5, 10), 0.0, 1, "float32", "int32") @@ -509,8 +509,9 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - out_relay = intrp.evaluate(func)(indices_np) + out_relay = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + indices_np + ) tvm.testing.assert_allclose(out_relay.numpy(), out_np) _verify((3,), 3, 1, 0, -1, "int32") @@ -539,8 +540,9 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - out_relay = intrp.evaluate(func)(input_np, diagonal_np) + out_relay = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + input_np, diagonal_np + ) tvm.testing.assert_allclose(out_relay.numpy(), out_np) _verify((2, 2), (2,), "float32") @@ -580,8 +582,9 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 ) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np) + out_relay = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + predictions_np, targets_np, weights_np + ) tvm.testing.assert_allclose(out_relay.numpy(), out_np, rtol=1e-6, atol=1e-6) _verify((10, 5)) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index f05c5054415d..87cdc41570d0 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -98,8 +98,9 @@ def run_test_conv1d( if target in except_targets: continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) # normal conv1d @@ -226,8 +227,9 @@ def run_test_conv2d( if target in except_targets: continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-4, atol=1e-4) def compile_test_conv2d_arm_cpu( @@ -513,8 +515,9 @@ def run_test_conv3d( continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) # normal conv3d @@ -578,8 +581,9 @@ def run_test_conv3d( continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) # normal conv3d @@ -761,8 +765,9 @@ def test_conv3d_transpose_ncdhw_run(): ref_res = tvm.topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1, 0) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -804,8 +809,9 @@ def test_conv2d_transpose_nchw_run(): ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(data, kernel, 2, 1, (1, 1)) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -840,8 +846,9 @@ def test_conv2d_transpose_nhwc_run(): ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -862,8 +869,9 @@ def test_conv1d_transpose_ncw_run(): ref_res = tvm.topi.testing.conv1d_transpose_ncw_python(data, kernel, 2, 1, output_padding=(1,)) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, kernel) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -947,8 +955,7 @@ def _test_global_pool2d(opfunc, reffunc): data = np.random.uniform(size=dshape).astype(dtype) ref_res = reffunc(data, axis=(2, 3), keepdims=True) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -980,8 +987,7 @@ def _test_pool2d(opfunc, pool_type, pool_size=2, strides=2, dilation=1, padding= ceil_mode=False, ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) def _test_pool2d_int(opfunc, reffunc, dtype): @@ -1001,8 +1007,9 @@ def _test_pool2d_int(opfunc, reffunc, dtype): data = np.random.randint(low=-128, high=128, size=dshape) ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_pool2d(relay.nn.max_pool2d, "max") @@ -1039,8 +1046,7 @@ def _test_global_pool1d(opfunc, reffunc): data = np.random.uniform(size=dshape).astype(dtype) ref_res = reffunc(data, axis=(2,), keepdims=True) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -1075,8 +1081,9 @@ def _test_pool1d( ceil_mode=False, ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_pool1d(relay.nn.max_pool1d, "max") @@ -1135,8 +1142,9 @@ def _test_pool3d( ceil_mode=False, ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_pool3d(relay.nn.max_pool3d, "max") @@ -1187,8 +1195,7 @@ def test_avg_pool2d_no_count_pad(): data = a_np for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) @@ -1222,11 +1229,9 @@ def test_flatten_infer_type(): ref_res = x_data.flatten().reshape(o_shape) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -1296,8 +1301,9 @@ def _test_run(dtype): data = np.random.uniform(size=dshape).astype(dtype) ref_res = _get_numpy_pad(dshape, data, pad) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_run("float32") @@ -1320,8 +1326,9 @@ def _test_run(dtype): ref_res = _get_numpy_pad(dshape, data_arr, pad, pad_value=pad_value_arr) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor(kind="graph", device=dev, target=target) - result = intrp.evaluate(f)(data_arr, pad_value_arr) + result = relay.create_executor(kind="graph", device=dev, target=target).evaluate(f)( + data_arr, pad_value_arr + ) tvm.testing.assert_allclose(result.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_run("float32") @@ -1353,11 +1360,9 @@ def test_lrn(): ref_res = tvm.topi.testing.lrn_python(x_data, size, axis, bias, alpha, beta) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -1383,11 +1388,9 @@ def test_l2_normalize(): ref_res = tvm.topi.testing.l2_normalize_python(x_data, eps, axis) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -1408,8 +1411,7 @@ def test_batch_flatten(): data = np.random.rand(5, 10, 5).astype(t1.dtype) ref_res = batch_flatten(data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -1458,8 +1460,7 @@ def get_shape(): "align_corners" if align_corners else "asymmetric", ) for target, dev in tvm.testing.enabled_targets(): - executor = relay.create_executor("graph", device=dev, target=target) - out = executor.evaluate(func)(data) + out = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(out.numpy(), ref, rtol=1e-5, atol=1e-5) @@ -1530,8 +1531,7 @@ def get_shape(): coordinate_transformation_mode, ) for target, dev in tvm.testing.enabled_targets(): - executor = relay.create_executor("graph", device=dev, target=target) - out = executor.evaluate(func)(data) + out = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(out.numpy(), ref, rtol=1e-5, atol=1e-5) @@ -1602,7 +1602,7 @@ def _has_fast_int8_instructions(asm, target): targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] llvm_version = tvm.target.codegen.llvm_version_major() for target in targets: - if llvm_version >= 8: + if tvm.testing.device_enabled(target) and llvm_version >= 8: dtypes = ("uint8", "int8", "int32") # Sweep the input channels to check int8 robustness # Input channels should be a multiple of 4 internally. @@ -1654,7 +1654,7 @@ def _has_fast_int8_instructions(asm, target): # Check that int8 x int8 goes through legalization so that fast instructions can be picked up. for target in targets: - if llvm_version >= 8: + if tvm.testing.device_enabled(target) and llvm_version >= 8: dtypes = ("int8", "int8", "int32") # Check that both non-divisible oc and ic work asm = _compile( @@ -1676,17 +1676,18 @@ def _has_fast_int8_instructions(asm, target): # Check that a vectorized instruction is generated for older Intel # generations, because we default to NCHWc layout. target = "llvm -mcpu=core-avx2" - fast_int8_dtypes = ("uint8", "int8", "int32") - asm = _compile( - ic=16, - oc=32, - target=target, - data_layout="NCHW", - kernel_layout="OIHW", - dtypes=fast_int8_dtypes, - ) - # Check that vector int mult and add instructions are generated. - assert "vpmulld" in asm and "vpadd" in asm + if tvm.testing.device_enabled(target): + fast_int8_dtypes = ("uint8", "int8", "int32") + asm = _compile( + ic=16, + oc=32, + target=target, + data_layout="NCHW", + kernel_layout="OIHW", + dtypes=fast_int8_dtypes, + ) + # Check that vector int mult and add instructions are generated. + assert "vpmulld" in asm and "vpadd" in asm @tvm.testing.uses_gpu @@ -1797,8 +1798,9 @@ def _test_correlation( ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data1_np, data2_np) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + data1_np, data2_np + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) _test_correlation( diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 1ec4e39083c5..0958f00fa5d6 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -34,8 +34,7 @@ def test_zeros_ones(): y = op(shape=(124, 50), dtype="float64") yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((124, 50), "float64") - intrp = create_executor() - intrp_res = intrp.evaluate(y).numpy() + intrp_res = create_executor().evaluate(y).numpy() np.testing.assert_allclose(intrp_res, ref((124, 50), "float64")) @@ -60,8 +59,7 @@ def test_unary_identity(): if ref is not None: data = np.random.rand(*shape).astype("float32") - intrp = create_executor() - op_res = intrp.evaluate(y, {x: relay.const(data)}) + op_res = create_executor().evaluate(y, {x: relay.const(data)}) ref_res = ref(data) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -87,8 +85,7 @@ def test_clip(): assert yy.checked_type == relay.TensorType((10, 4), "float32") data = np.random.rand(10, 4).astype("float32") - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) ref_res = np.clip(data, 1.0, 4.0) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -105,8 +102,7 @@ def test_fixed_point_multiply(): assert yy.checked_type == relay.TensorType((10, 4), "int32") data = 23 * np.ones((10, 4)).astype("int32") - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) ref_res = np.ones((10, 4)).astype("int32") np.testing.assert_allclose(op_res.numpy(), ref_res, atol=1) @@ -118,8 +114,7 @@ def test_reinterpret(): assert yy.checked_type == relay.TensorType((1000, 4), "int32") data = np.random.randn(1000, 4).astype("float32") * 1000 - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) ref_res = data.view("int32") np.testing.assert_equal(op_res.numpy(), ref_res) @@ -155,8 +150,7 @@ def approximate_tanh(x): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((1000,), "float32") data = np.linspace(-5, 5, 1000).astype("float32") - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) def reference_sigmoid(x): return np.exp(-np.logaddexp(0, -x)) @@ -167,8 +161,7 @@ def reference_sigmoid(x): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((1000,), "float32") data = np.linspace(-5, 5, 1000).astype("float32") - intrp = create_executor() - op_res = intrp.evaluate(y, {a: relay.const(data)}) + op_res = create_executor().evaluate(y, {a: relay.const(data)}) def reference_tanh(x): return np.tanh(x) @@ -184,8 +177,7 @@ def verify_squeeze(shape, dtype, axis): np_axis = tuple(axis) if axis is not None else None data = np.random.random_sample(shape).astype(dtype) - intrp = create_executor() - op_res = intrp.evaluate(squeeze, {x: relay.const(data)}) + op_res = create_executor().evaluate(squeeze, {x: relay.const(data)}) ref_res = np.squeeze(data, axis=np_axis) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -220,8 +212,9 @@ def verify_transpose(dshape, axes): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_transpose((2, 3, 4), (0, 2, 1)) @@ -275,8 +268,9 @@ def verify_reshape(shape, newshape, oshape): ref_res = np.reshape(x_data, oshape) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_reshape((2, 3, 4), (8, 3), (8, 3)) @@ -365,8 +359,9 @@ def verify_reshape_like(shape, oshape, shape_like=None, reshape_like_kwargs={}): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_reshape_like((2, 3, 4), (1, 8, 3)) @@ -411,8 +406,9 @@ def verify_take(src_shape, indices_src, axis=None, mode="clip"): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, indices_src) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, indices_src + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_take((4,), [1]) @@ -546,8 +542,9 @@ def verify_full(fill_value, src_shape, dtype): ref_res = np.full(src_shape, fill_value) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(np.array(fill_value, dtype)) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + np.array(fill_value, dtype) + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_full(4, (1, 3, 4, 4), "int32") @@ -585,8 +582,9 @@ def verify_full_like(base, fill_value, dtype): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, np.array(fill_value, dtype)) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, np.array(fill_value, dtype) + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_full_like((1, 3, 4, 4), 4, "int32") @@ -614,11 +612,9 @@ def test_infer_type_leaky_relu(): ref_res = np.where(x_data > 0, x_data, x_data * 0.1) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -651,11 +647,13 @@ def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"): ref_res = (x_data < 0) * (x_data * a_data.reshape(1, 1, 3)) + (x_data >= 0) * x_data for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data, a_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, a_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data, a_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x_data, a_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -696,8 +694,7 @@ def verify_arange(start, stop, step): func = relay.Function([], x) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)() + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)() tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_arange(None, 20, None) @@ -735,8 +732,9 @@ def verify_meshgrid(lengths, indexing="ij"): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*input_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + *input_data + ) assert len(op_res) == len(ref_res) for i in range(len(op_res)): tvm.testing.assert_allclose(op_res[i].numpy(), ref_res[i], rtol=1e-5) @@ -793,8 +791,9 @@ def verify_tile(dshape, reps): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_tile((2, 3, 4), (3, 2, 1)) @@ -811,8 +810,7 @@ def verify_repeat(dshape, repeats, axis): ref_res = np.repeat(data, repeats, axis) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_repeat((3,), 2, 0) @@ -836,8 +834,9 @@ def verify_stack(input_expr, relay_args, ref_res, axis): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*relay_args) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + *relay_args + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) def verify_tup_lit_stack(dshapes, axis): @@ -888,8 +887,9 @@ def verify_reverse(dshape, axis): ref_res = np.flip(x_data, axis) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_reverse((2, 3, 4), 1) @@ -909,8 +909,9 @@ def verify_reverse_sequence(x_data, seq_lengths, batch_axis, seq_axis, ref_res): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") @@ -1003,8 +1004,9 @@ def verify_scatter(dshape, ishape, axis=0): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np, indices_np, updates_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) def verify_dynamic_scatter(dshape, ishape, axis=0): @@ -1024,8 +1026,9 @@ def verify_dynamic_scatter(dshape, ishape, axis=0): for target, dev in tvm.testing.enabled_targets(): for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(data_np, indices_np, updates_np) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + data_np, indices_np, updates_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_scatter((10,), (10,), 0) @@ -1277,8 +1280,9 @@ def verify_gather(data, axis, indices, ref_res): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data, indices) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data, indices + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_gather(data, axis, indices, ref_res) @@ -1304,8 +1308,9 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) @@ -1352,8 +1357,7 @@ def _verify_infiniteness_ops(relay_op, ref_op): ] = np.infty data.ravel()[np.random.choice(data.size, int(data.size * 0.5), replace=False)] = np.nan - intrp = create_executor() - op_res = intrp.evaluate(y, {x: data}) + op_res = create_executor().evaluate(y, {x: data}) ref_res = ref_op(data) np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) @@ -1387,8 +1391,9 @@ def verify_unravel_index(indices, shape, dtype): ref_res = np.unravel_index(x_data, y_data) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) for dtype in ["int64", "int32"]: @@ -1433,13 +1438,11 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ func = relay.Function(args, d) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) + f = relay.create_executor(kind, device=dev, target=target).evaluate(func) if default_value is None: - op_res = intrp.evaluate(func)(sparse_indices_data, sparse_values_data) + op_res = f(sparse_indices_data, sparse_values_data) else: - op_res = intrp.evaluate(func)( - sparse_indices_data, sparse_values_data, default_value_data - ) + op_res = f(sparse_indices_data, sparse_values_data, default_value_data) tvm.testing.assert_allclose(op_res.numpy(), xpected, rtol=1e-5) verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) # scalar @@ -1776,8 +1779,9 @@ def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets() for target, dev in target_device: for kind in ["vm"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *data + ) if isinstance(op_res, tvm.runtime.container.ADT): assert len(op_res) == len( ref_res @@ -1808,8 +1812,9 @@ def verify_adv_index(data_shape, index_shapes): func = relay.Function(inputs, out) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*np_args) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + *np_args + ) tvm.testing.assert_allclose(op_res.numpy(), np_out, rtol=1e-5) verify_adv_index((10, 5), [(3, 4), (3, 1)]) @@ -1846,8 +1851,7 @@ def assert_relay_scanop( func = relay.Function([inp], out) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data_np) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(data_np) tvm.testing.assert_allclose(op_res.numpy(), np_out, rtol=rtol, atol=atol) data = np.array([2, 3, 0]) @@ -1908,8 +1912,9 @@ def verify_scatter_nd( func = relay.Function([data, indices, updates], out) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np, indices_np, updates_np + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol) def verify_scatter_nd_with_stack( @@ -1934,8 +1939,7 @@ def verify_scatter_nd_with_stack( for a in indices_np: fargs.append(a) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*fargs) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(*fargs) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol) data = np.zeros((2, 2)).astype("int64") @@ -2020,8 +2024,9 @@ def verify_unique(n, dtype, is_dyn=False, is_sorted=False, return_counts=False): for target, dev in tvm.testing.enabled_targets(): for kind in backends: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - tvm_res = intrp.evaluate()( + tvm_res = relay.create_executor( + kind, mod=mod, device=dev, target=target + ).evaluate()( x_data ) # unique, indices, inverse_indices, num_unique, (counts) np_res = calc_numpy_unique( diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index b59325aea2f9..df77c33658de 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -50,8 +50,9 @@ def check_binary_op(opfunc, ref): func = relay.Function([x, y], z) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) for opfunc, ref in [(relay.power, np.power)]: @@ -88,8 +89,9 @@ def test_cmp_type(): func = relay.Function([x, y], z) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) @@ -113,8 +115,9 @@ def test_binary_int_broadcast_1(): ref_res = ref(x_data, y_data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) @@ -138,8 +141,9 @@ def test_binary_int_broadcast_2(): ref_res = ref(x_data, y_data) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, y_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) @@ -148,8 +152,9 @@ def test_where(): def run(func, inputs, ref_res): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(*inputs) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + *inputs + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) def verify(x_np, y_np, cond_np): @@ -258,11 +263,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) @@ -352,12 +355,10 @@ def verify_mean_var_std(funcs, shape, axis, keepdims): ref_res = ref_func(x_data, axis=axis, dtype=dtype, keepdims=keepdims) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res1[0].numpy(), ref_mean, rtol=1e-5) tvm.testing.assert_allclose(op_res1[1].numpy(), ref_res, rtol=1e-5) - op_res2 = intrp2.evaluate(func)(x_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res2[0].numpy(), ref_mean, rtol=1e-5) tvm.testing.assert_allclose(op_res2[1].numpy(), ref_res, rtol=1e-5) @@ -425,8 +426,9 @@ def verify( if not test_ref: return for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64") @@ -503,8 +505,9 @@ def verify( return for target, dev in tvm.testing.enabled_targets(): mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor("vm", mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data) + op_res = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify( @@ -562,8 +565,9 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): v_data = np.random.uniform(size=vshape).astype("float32") ref_res = tvm.topi.testing.strided_set_python(x_data, v_data, begin, end, strides) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(x_data, v_data) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x_data, v_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res) verify((3, 4, 16), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index d93de5419f56..c08b538d22e6 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -62,8 +62,9 @@ def verify_resize(dshape, scale, method, layout, coord_trans): func = relay.Function([x], z) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) for method in ["nearest_neighbor", "linear", "cubic"]: @@ -113,8 +114,9 @@ def verify_resize(dshape, scale, method, layout, coord_trans): func = relay.Function([x], z) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-4) for method in ["nearest_neighbor", "linear", "cubic"]: @@ -167,8 +169,7 @@ def verify_resize(dshape, scale, method, layout): func = relay.Function([x], z) for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4, atol=1e-6) for method in ["nearest_neighbor", "linear", "cubic"]: @@ -202,8 +203,9 @@ def verify_crop_and_resize( for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(image_data, boxes, box_indices) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + image_data, boxes, box_indices + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-3, atol=1e-04) boxes_nhwc = np.array([[0.1, 0.2, 0.8, 0.7], [0.2, 0, 1, 0.6]]).astype("float32") @@ -302,11 +304,9 @@ def verify_multibox_prior( func = relay.Function([x], z) func = run_infer_type(func) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)(data) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) sizes = (0.3, 1.5, 0.7) @@ -361,8 +361,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): func = relay.Function([x], z.astuple()) func = run_infer_type(func) for target, dev in tvm.testing.enabled_targets(): - intrp = relay.create_executor("debug", device=dev, target=target) - out = intrp.evaluate(func)(np_data) + out = relay.create_executor("debug", device=dev, target=target).evaluate(func)(np_data) tvm.testing.assert_allclose(out[0].numpy(), np_out1, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[1].numpy(), np_out2, rtol=1e-3, atol=1e-04) @@ -433,15 +432,21 @@ def verify_nms( func_indices = relay.Function([x0, x1, x2, x3], z_indices) func_indices = run_infer_type(func_indices) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data, x3_data) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + x0_data, x1_data, x2_data, x3_data + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + x0_data, x1_data, x2_data, x3_data + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-5) - op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) + op_indices_res1 = relay.create_executor("graph", device=dev, target=target).evaluate( + func_indices + )(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_indices_res1[0].numpy(), ref_indices_res, rtol=1e-5) - op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) + op_indices_res2 = relay.create_executor("debug", device=dev, target=target).evaluate( + func_indices + )(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_indices_res2[0].numpy(), ref_indices_res, rtol=1e-5) np_data = np.array( @@ -624,11 +629,13 @@ def test_default_value(): func = relay.Function([cls_prob, loc_pred, anchors], nms) func = run_infer_type(func) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds, np_anchors) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_cls_prob, np_loc_preds, np_anchors + ) tvm.testing.assert_allclose(op_res1.numpy(), expected_np_out, rtol=1e-5) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(np_cls_prob, np_loc_preds, np_anchors) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + np_cls_prob, np_loc_preds, np_anchors + ) tvm.testing.assert_allclose(op_res2.numpy(), expected_np_out, rtol=1e-5) def test_threshold(): @@ -718,11 +725,13 @@ def verify_roi_align( ) for target, dev in tvm.testing.enabled_targets(): print("test on", target) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(np_data, np_rois) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_data, np_rois + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-4) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(np_data, np_rois) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + np_data, np_rois + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-4) def verify_roi_align_nchw( @@ -813,11 +822,13 @@ def verify_roi_pool(data_shape, rois_shape, pooled_size, spatial_scale): np_data, np_rois, pooled_size=pooled_size, spatial_scale=spatial_scale ) for target, dev in tvm.testing.enabled_targets(): - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(np_data, np_rois) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_data, np_rois + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-4) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(np_data, np_rois) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + np_data, np_rois + ) tvm.testing.assert_allclose(op_res2.numpy(), ref_res, rtol=1e-4) verify_roi_pool((1, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=1.0) @@ -841,11 +852,13 @@ def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): print("Skip test because %s is not enabled." % target) continue dev = tvm.device(target, 0) - intrp1 = relay.create_executor("graph", device=dev, target=target) - op_res1 = intrp1.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info) + op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + np_cls_prob, np_bbox_pred, np_im_info + ) tvm.testing.assert_allclose(op_res1.numpy(), np_out, rtol=1e-4) - intrp2 = relay.create_executor("debug", device=dev, target=target) - op_res2 = intrp2.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info) + op_res2 = relay.create_executor("debug", device=dev, target=target).evaluate(func)( + np_cls_prob, np_bbox_pred, np_im_info + ) tvm.testing.assert_allclose(op_res2.numpy(), np_out, rtol=1e-4) attrs = { @@ -935,8 +948,9 @@ def verify_yolo_reorg(shape, stride): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) verify_yolo_reorg((1, 100, 20, 20), 10) @@ -1070,8 +1084,9 @@ def test_run(batch, in_channel, size, out_channel, deformable_groups, groups, la if target == "cuda" and layout == "NHWC": continue # Cannot run NHWC layout on cuda target, only on llvm for kind in ["graph", "debug"]: - intrp1 = relay.create_executor(kind, device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data, offset, kernel) + op_res1 = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data, offset, kernel + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) test_run(1, 4, 16, 4, 1, 1, "NCHW") @@ -1115,8 +1130,9 @@ def verify_depth_to_space(dshape, block_size, layout, mode): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4) for layout in ["NHWC", "NCHW"]: @@ -1159,8 +1175,9 @@ def verify_space_to_depth(dshape, block_size, layout): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4) for layout in ["NHWC", "NCHW"]: @@ -1215,8 +1232,9 @@ def run_test_dilation2d( for target, dev in tvm.testing.enabled_targets(): if target in except_targets: continue - intrp = relay.create_executor("graph", device=dev, target=target) - op_res = intrp.evaluate(func)(indata, kernel) + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( + indata, kernel + ) tvm.testing.assert_allclose(op_res.numpy(), out, rtol=1e-5, atol=1e-5) def _convert_data(indata, kernel, out, layout=None): @@ -1317,8 +1335,9 @@ def verify_affine_grid(num_batch, target_shape): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp1 = relay.create_executor(kind, device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data_np) + op_res1 = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) verify_affine_grid(1, (16, 32)) @@ -1344,8 +1363,9 @@ def verify_grid_sample(data_shape, grid_shape): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp1 = relay.create_executor(kind, device=dev, target=target) - op_res1 = intrp1.evaluate(func)(data_np, grid_np) + op_res1 = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + data_np, grid_np + ) tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) @@ -1371,8 +1391,9 @@ def verify_space_to_batch_nd(dshape, block_shape, paddings): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4) verify_space_to_batch_nd([3, 3, 2, 1], [3], [[0, 0]]) @@ -1398,8 +1419,9 @@ def verify_batch_to_space_nd(dshape, block_shape, crops): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-4) verify_batch_to_space_nd([4, 1, 1, 3], [2, 2], [[0, 0], [0, 0]]) @@ -1432,8 +1454,9 @@ def verify_all_class_non_max_suppression( for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - selected_indices, num_detections = intrp.evaluate(func)(boxes_np, scores_np) + selected_indices, num_detections = relay.create_executor( + kind, device=dev, target=target + ).evaluate(func)(boxes_np, scores_np) tvm_res = selected_indices.numpy()[: num_detections.numpy()[0]] np.testing.assert_equal(tvm_res, expected_indices) diff --git a/tests/python/relay/test_op_level6.py b/tests/python/relay/test_op_level6.py index f4a4dd4e6134..ea640c62dfeb 100644 --- a/tests/python/relay/test_op_level6.py +++ b/tests/python/relay/test_op_level6.py @@ -45,8 +45,9 @@ def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"): for target, dev in tvm.testing.enabled_targets(): for kind in backends: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) for is_dyn in [False, True]: @@ -80,8 +81,9 @@ def verify_argsort(shape, axis, is_ascend, dtype, is_dyn=False, in_dtype="float3 for target, dev in tvm.testing.enabled_targets(): for kind in backends: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(x_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + x_data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res.astype(dtype), rtol=1e-5) for is_dyn in [False, True]: @@ -127,8 +129,9 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype, in_dtype="float32"): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, device=dev, target=target) - op_res = intrp.evaluate(func)(np_data) + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)( + np_data + ) if ret_type == "both": tvm.testing.assert_allclose(op_res[0].numpy(), np_values) tvm.testing.assert_allclose(op_res[1].numpy(), np_indices) diff --git a/tests/python/relay/test_op_qnn_add.py b/tests/python/relay/test_op_qnn_add.py index d3a3b8ffca5f..b38ada718cc5 100644 --- a/tests/python/relay/test_op_qnn_add.py +++ b/tests/python/relay/test_op_qnn_add.py @@ -63,8 +63,9 @@ def test_tflite_same_io_qnn_params(): y_data = y_datas[i] golden_output = golden_outputs[i] - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -111,8 +112,9 @@ def test_tflite_different_io_qnn_params(): y_data = y_datas[i] golden_output = golden_outputs[i] - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -143,8 +145,9 @@ def test_saturation(): y_data = np.array((255, 255, 128, 0)).reshape((1, 4)) golden_output = np.array((255, 255, 129, 0)).reshape((1, 4)) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) # Same params, different scale @@ -169,8 +172,9 @@ def test_saturation(): y_data = np.array((255, 255, 127, 0)).reshape((1, 4)) golden_output = np.array((255, 129, 65, 0)).reshape((1, 4)) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) # Same io params, different output scale @@ -195,8 +199,9 @@ def test_saturation(): y_data = np.array((255, 255, 127, 0)).reshape((1, 4)) golden_output = np.array((255, 129, 65, 0)).reshape((1, 4)) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) # All params different @@ -221,8 +226,9 @@ def test_saturation(): y_data = np.array((0, 128, 64, 0)).reshape((1, 4)) golden_output = np.array((255, 255, 132, 0)).reshape((1, 4)) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index 12571aad0822..c5f7bf1908ce 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -51,8 +51,9 @@ def test_same_io_qnn_params(): golden_output = np.concatenate((x_data, y_data), axis=axis) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -86,8 +87,9 @@ def test_different_io_qnn_params(): golden_output = np.concatenate((x_data - 2, y_data - 3), axis=axis) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -121,8 +123,9 @@ def test_few_same_io_qnn_params(): golden_output = np.concatenate((x_data + 1, y_data), axis=axis) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -156,8 +159,9 @@ def test_same_i_qnn_params(): golden_output = np.concatenate((x_data + 1, y_data + 1), axis=axis) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) @@ -183,8 +187,7 @@ def test_call_input(): ) func = relay.Function([x], z) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data) np.testing.assert_equal(op_res.numpy(), x_data) diff --git a/tests/python/relay/test_op_qnn_mul.py b/tests/python/relay/test_op_qnn_mul.py index c4cd3244c8fe..af84f9778638 100644 --- a/tests/python/relay/test_op_qnn_mul.py +++ b/tests/python/relay/test_op_qnn_mul.py @@ -80,8 +80,9 @@ def test_tflite_same_io_qnn_params(): y_rec = recover(y_data, rhs_scale, rhs_zero_point) golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) @@ -134,8 +135,9 @@ def test_tflite_different_io_qnn_params(): y_rec = recover(y_data, rhs_scale, rhs_zero_point) golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) @@ -172,8 +174,9 @@ def test_saturation(): golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) # Same params, different scale @@ -206,8 +209,9 @@ def test_saturation(): golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) # All params different @@ -241,8 +245,9 @@ def test_saturation(): golden = generate_golden_output(x_rec, y_rec, output_scale, output_zero_point) - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), np.uint8(golden)) diff --git a/tests/python/relay/test_op_qnn_subtract.py b/tests/python/relay/test_op_qnn_subtract.py index 4f9a36757b81..f7117b559401 100644 --- a/tests/python/relay/test_op_qnn_subtract.py +++ b/tests/python/relay/test_op_qnn_subtract.py @@ -52,8 +52,9 @@ def qnn_subtract_driver(x_datas, y_datas, golden_outputs, scale_and_zp, data_dty x_data = x_datas[i] y_data = y_datas[i] golden_output = golden_outputs[i] - intrp = relay.create_executor("graph", device=tvm.cpu(0), target="llvm") - op_res = intrp.evaluate(func)(x_data, y_data) + op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)( + x_data, y_data + ) np.testing.assert_equal(op_res.numpy(), golden_output) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 1607b7df4a1e..b5702a1542a9 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -758,12 +758,16 @@ def expected(): with relay.build_config(opt_level=3): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug", "vm"]: - ex_before = relay.create_executor(kind, mod=mod_before, device=dev, target=target) - ex_new = relay.create_executor(kind, mod=mod_new, device=dev, target=target) np_data = np.random.uniform(size=(1, 32, 28, 28)).astype("float32") np_weight = np.random.uniform(size=(32, 32, 3, 3)).astype("float32") - result_before = ex_before.evaluate()(np_data, np_weight) - result_new = ex_new.evaluate()(np_data, np_weight) + f_before = relay.create_executor( + kind, mod=mod_before, device=dev, target=target + ).evaluate() + result_before = f_before(np_data, np_weight) + f_new = relay.create_executor( + kind, mod=mod_new, device=dev, target=target + ).evaluate() + result_new = f_new(np_data, np_weight) tvm.testing.assert_allclose( result_before.numpy(), result_new.numpy(), rtol=1e-5, atol=1e-5 ) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 098fb5c64e82..23ef9d11eb77 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -144,8 +144,9 @@ def test_run(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()) - ref_res = ref_ex.evaluate()(i_data, w1_data) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()).evaluate()( + i_data, w1_data + ) check_result( mod, {"data": i_data, "weight1": w1_data}, (1, 32, 14, 14), ref_res.numpy(), tol=1e-5 @@ -171,8 +172,9 @@ def test_extern_dnnl_mobilenet(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, dtype="float32") - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)) - ref_res = ref_ex.evaluate()(i_data, **params) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( + i_data, **params + ) check_result(mod, {"data": i_data}, (1, 1000), ref_res.numpy(), tol=1e-5, params=params) diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index 51b9f5f24d1d..030682148a5f 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -185,8 +185,9 @@ def verify_partition(mod, params): params = [gen_rand_tvm(param.type_annotation, 0, 1) for param in partitioned_mod["main"].params] def _eval_mod(mod): - vm = relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod) - return vm.evaluate()(*params) + return relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod).evaluate()( + *params + ) partitioned_mod_result = _eval_mod(partitioned_mod) unpartitioned_mod_result = _eval_mod(unpartitioned_mod) diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py index 57dbb82c2d0d..30f2203be0b5 100644 --- a/tests/python/relay/test_pass_defunctionalization.py +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -124,8 +124,7 @@ def to_adt_list(mod, arr): li = nil() for a in arr: li = cons(relay.const(a), li) - ex = relay.create_executor(mod=mod) - adt = ex.evaluate(li) + adt = relay.create_executor(mod=mod).evaluate(li) mod["main"] = expr return adt @@ -148,11 +147,9 @@ def @main(%l: Tensor[(5, 5), float32]) -> Tensor[(5, 5), float32] { input = np.random.rand(5, 5).astype("float32") - ex = relay.create_executor("debug", mod=mod) - defunc_ex = relay.create_executor("debug", mod=defunc_mod) + out = relay.create_executor("debug", mod=mod).evaluate()(input) - out = ex.evaluate()(input) - defunc_out = defunc_ex.evaluate()(input) + defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()(input) np.testing.assert_equal(out.numpy(), defunc_out.numpy()) @@ -182,11 +179,11 @@ def @main(%l: List[float32]) -> List[float32] { input = np.random.rand(10).astype("float32") - ex = relay.create_executor("debug", mod=mod) - defunc_ex = relay.create_executor("debug", mod=defunc_mod) + out = relay.create_executor("debug", mod=mod).evaluate(mod["main"])(to_adt_list(mod, input)) - out = ex.evaluate(mod["main"])(to_adt_list(mod, input)) - defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input)) + defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()( + to_adt_list(defunc_mod, input) + ) np.testing.assert_array_equal(to_list(mod, out), to_list(defunc_mod, defunc_out)) @@ -220,11 +217,11 @@ def @main(%l: List[int32]) -> int32 { input = np.random.randint(1, 100, 10) - ex = relay.create_executor("debug", mod=mod) - defunc_ex = relay.create_executor("debug", mod=defunc_mod) + out = relay.create_executor("debug", mod=mod).evaluate(mod["main"])(to_adt_list(mod, input)) - out = ex.evaluate(mod["main"])(to_adt_list(mod, input)) - defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input)) + defunc_out = relay.create_executor("debug", mod=defunc_mod).evaluate()( + to_adt_list(defunc_mod, input) + ) tvm.testing.assert_allclose(out.numpy(), defunc_out.numpy()) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 962b7bebb12b..836d49b3441b 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -40,8 +40,9 @@ def verify_func(func, data, ref_res, rtol=1e-5, atol=1e-7): for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(*data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + *data + ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol) @@ -181,8 +182,9 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): continue for kind in ["graph", "vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func2) - intrp = relay.create_executor(kind, mod=mod, device=dev, target=target) - op_res = intrp.evaluate()(np_data) + op_res = relay.create_executor(kind, mod=mod, device=dev, target=target).evaluate()( + np_data + ) if ret_type == "both": tvm.testing.assert_allclose(op_res[0].numpy(), np_values) tvm.testing.assert_allclose(op_res[1].numpy(), np_indices) diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index ff97b25f7e88..2bc2e4e635f0 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -29,11 +29,17 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False): mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod_int) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(*args).numpy() - - ex = relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") - result_int = ex.evaluate()(*args).numpy() + result = ( + relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + + result_int = ( + relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) if allow_rounding_error: assert np.all(np.abs(result - result_int) <= 1) diff --git a/tests/python/relay/test_pass_fold_explicit_padding.py b/tests/python/relay/test_pass_fold_explicit_padding.py index 58ba58aa06d3..effebaaf1e8b 100644 --- a/tests/python/relay/test_pass_fold_explicit_padding.py +++ b/tests/python/relay/test_pass_fold_explicit_padding.py @@ -70,12 +70,14 @@ def validate(ndim, pad_width, pad_value, pad_mode, orig_padding, layout): mod2 = tvm.IRModule.from_expr(zz) with tvm.transform.PassContext(): - ex1 = relay.create_executor("vm", mod=mod1, device=tvm.cpu(), target="llvm") - ex2 = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + func1 = relay.create_executor( + "vm", mod=mod1, device=tvm.cpu(), target="llvm" + ).evaluate() + func2 = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm").evaluate() x_np = np.random.rand(*shape).astype("float32") w_np = np.random.rand(*wshape).astype("float32") - result1 = ex1.evaluate()(x_np, w_np) - result2 = ex2.evaluate()(x_np, w_np) + result1 = func1(x_np, w_np) + result2 = func2(x_np, w_np) tvm.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-5, atol=1e-5) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 931f453f9a6d..855650f810a5 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -775,9 +775,9 @@ def test_fuse_dynamic_squeeze_slice_take(): take = relay.op.take(strided_slice, take_val, axis=0) mod = tvm.IRModule.from_expr(take) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - - result = ex.evaluate()(*input_data) + result = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm").evaluate()( + *input_data + ) np_result = np.squeeze(input_data[0][:, input_data[1][0], :], axis=0) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index cd0edf95aba7..126fcf22e823 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -45,9 +45,8 @@ def test_fo_id(): func = run_infer_type(func) back_func = run_infer_type(gradient(func, mode="first_order")) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), x.numpy()) tvm.testing.assert_allclose(grad.numpy(), np.ones_like(x.numpy())) @@ -61,9 +60,8 @@ def test_id(): func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), x.numpy()) tvm.testing.assert_allclose(grad.numpy(), np.ones_like(x.numpy())) @@ -89,9 +87,8 @@ def test_add(): func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), 2 * x.numpy()) tvm.testing.assert_allclose(grad.numpy(), 2 * np.ones_like(x.numpy())) @@ -118,9 +115,8 @@ def test_temp_add(): func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), 4 * x.numpy()) tvm.testing.assert_allclose(grad.numpy(), 4 * np.ones_like(x.numpy())) @@ -134,9 +130,8 @@ def test_sub(): func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor() x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), np.zeros_like(x.numpy())) tvm.testing.assert_allclose(grad.numpy(), np.zeros_like(x.numpy())) @@ -163,8 +158,7 @@ def test_broadcast_add(): [relay.TensorType(expected_forward.shape, dtype), relay.TupleType([t1, t2])] ), ) - ex = create_executor() - forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd) + forward, (grad_x, grad_y) = create_executor().evaluate(full_func)(x_nd, y_nd) tvm.testing.assert_allclose(forward.numpy(), expected_forward) tvm.testing.assert_allclose( grad_x.numpy(), np.ones_like(expected_forward).sum(axis=2, keepdims=True) @@ -197,8 +191,7 @@ def test_broadcast_subtract(): [relay.TensorType(expected_forward.shape, dtype), relay.TupleType([t1, t2])] ), ) - ex = create_executor() - forward, (grad_x, grad_y) = ex.evaluate(full_func)(x_nd, y_nd) + forward, (grad_x, grad_y) = create_executor().evaluate(full_func)(x_nd, y_nd) tvm.testing.assert_allclose(forward.numpy(), expected_forward) tvm.testing.assert_allclose( grad_x.numpy(), np.ones_like(expected_forward).sum(axis=2, keepdims=True) @@ -247,8 +240,7 @@ def _test_tuple(mode): y_np = y_nd.numpy() z_np = z_nd.numpy() expected_forward = x_np + y_np - z_np - ex = create_executor() - forward, (grad_x, grad_y, grad_z) = ex.evaluate(back_func)(x_nd, y_nd, z_nd) + forward, (grad_x, grad_y, grad_z) = create_executor().evaluate(back_func)(x_nd, y_nd, z_nd) tvm.testing.assert_allclose(forward.numpy(), expected_forward) tvm.testing.assert_allclose(grad_x.numpy(), np.ones_like(grad_x.numpy())) tvm.testing.assert_allclose(grad_y.numpy(), np.ones_like(grad_y.numpy())) @@ -271,8 +263,7 @@ def _test_tuple_argument(mode): xs = [rand(dtype, *shape) for _ in range(fields)] xs_np = np.array([x.numpy() for x in xs]) expected_forward = np.sum(xs_np, axis=0) - ex = create_executor() - forward, grad = ex.evaluate(back_func)(tuple(xs)) + forward, grad = create_executor().evaluate(back_func)(tuple(xs)) tvm.testing.assert_allclose(forward.numpy(), expected_forward) for field in grad[0]: tvm.testing.assert_allclose(field.numpy(), np.ones_like(field.numpy())) @@ -315,8 +306,7 @@ def test_pow(): back_func = m["main"] assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) i_nd = rand(dtype, *shape) - ex = create_executor(mod=mod) - forward, (grad_i,) = ex.evaluate(back_func)(i_nd) + forward, (grad_i,) = create_executor(mod=mod).evaluate(back_func)(i_nd) tvm.testing.assert_allclose(forward.numpy(), 8 * i_nd.numpy()) tvm.testing.assert_allclose(grad_i.numpy(), 8 * np.ones_like(grad_i.numpy())) @@ -336,8 +326,7 @@ def test_ref(): back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) x_nd = rand(dtype, *shape) - ex = create_executor() - forward, (grad_x,) = ex.evaluate(back_func)(x_nd) + forward, (grad_x,) = create_executor().evaluate(back_func)(x_nd) tvm.testing.assert_allclose(forward.numpy(), 2 * x_nd.numpy()) tvm.testing.assert_allclose(grad_x.numpy(), 2 * np.ones_like(grad_x.numpy())) @@ -358,8 +347,7 @@ def test_square_second_order(): back_back_func = run_infer_type(gradient(back_func_adjusted)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) x_nd = rand(dtype, *shape) - ex = create_executor() - forward, (grad_x,) = ex.evaluate(back_back_func)(x_nd) + forward, (grad_x,) = create_executor().evaluate(back_back_func)(x_nd) tvm.testing.assert_allclose(forward.numpy(), 2 * x_nd.numpy()) tvm.testing.assert_allclose(grad_x.numpy(), 2 * np.ones_like(grad_x.numpy())) @@ -390,9 +378,8 @@ def test_grad_tuple(): assert back_func.checked_type == relay.FuncType( [t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])]) ) - ex = create_executor() x = rand(dtype, *shape) - (forward_four, forward_two), (grad,) = ex.evaluate(back_func)(x) + (forward_four, forward_two), (grad,) = create_executor().evaluate(back_func)(x) tvm.testing.assert_allclose(forward_four.numpy(), 4 * x.numpy()) tvm.testing.assert_allclose(forward_two.numpy(), 2 * x.numpy()) tvm.testing.assert_allclose(grad.numpy(), 4 * np.ones_like(x.numpy())) @@ -463,9 +450,8 @@ def test_global_function(): m = tvm.relay.transform.InferType()(m) back_func = m[g] assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) - ex = create_executor(mod=m) x = rand(dtype, *shape) - forward, (grad,) = ex.evaluate(back_func)(x) + forward, (grad,) = create_executor(mod=m).evaluate(back_func)(x) tvm.testing.assert_allclose(forward.numpy(), 4 * x.numpy()) tvm.testing.assert_allclose(grad.numpy(), 4 * np.ones_like(x.numpy())) diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index f37856669306..a0af2205a5d0 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -65,9 +65,8 @@ def test_add(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy() + x.numpy()) @@ -92,9 +91,8 @@ def test_add_tuple(): assert mod["main"].checked_type == relay.FuncType([t], tensor_type) - ex = create_executor(mod=mod) x = (rand(dtype, *shape), rand(dtype, *shape)) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x[0].numpy() + x[1].numpy()) @@ -117,9 +115,8 @@ def test_mult(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy() * x.numpy()) @@ -143,9 +140,8 @@ def test_ret_tuple(): assert mod["main"].checked_type == relay.FuncType([t], relay.TupleType([t, t])) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(func)(x) + y = create_executor(mod=mod).evaluate(func)(x) assert_allclose(y[0].numpy(), x.numpy()) assert_allclose(y[1].numpy(), x.numpy() * 2.0) @@ -177,8 +173,7 @@ def test_add_broadcast(): expected_forward_type = relay.TensorType(expected_forward.shape, dtype) assert mod["main"].checked_type == relay.FuncType([t1, t2], expected_forward_type) - ex = create_executor(mod=mod) - forward = ex.evaluate(func)(x1_np, x2_np) + forward = create_executor(mod=mod).evaluate(func)(x1_np, x2_np) assert_allclose(forward.numpy(), expected_forward) @@ -208,9 +203,8 @@ def test_reverse_ad_identity(): [t], relay.TupleType([t, relay.TupleType([t])]) ) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - (forward), (grad,) = ex.evaluate(back_func)(x) + (forward), (grad,) = create_executor(mod=mod).evaluate(back_func)(x) assert_allclose(forward.numpy(), x.numpy()) assert_allclose(grad.numpy(), np.ones_like(x.numpy())) @@ -240,10 +234,9 @@ def test_multivar_reverse_ad(): [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) - ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape) - (forward), (grad_x, grad_y,) = ex.evaluate( + (forward), (grad_x, grad_y,) = create_executor(mod=mod).evaluate( back_func )(x, y) assert_allclose(forward.numpy(), x.numpy() * y.numpy()) @@ -305,10 +298,9 @@ def test_after_partial_eval(): [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) - ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape) - (forward), (grad_x, grad_y,) = ex.evaluate( + (forward), (grad_x, grad_y,) = create_executor(mod=mod).evaluate( back_func )(x, y) assert_allclose(forward.numpy(), x.numpy() * y.numpy()) @@ -343,10 +335,9 @@ def test_before_partial_eval(): [t, t], relay.TupleType([t, relay.TupleType([t, t])]) ) - ex = create_executor(mod=mod) x = rand(dtype, *shape) y = rand(dtype, *shape) - (forward), (grad_x, grad_y,) = ex.evaluate( + (forward), (grad_x, grad_y,) = create_executor(mod=mod).evaluate( back_func )(x, y) assert_allclose(forward.numpy(), x.numpy() * y.numpy()) @@ -372,9 +363,8 @@ def test_zeros(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy()) @@ -396,9 +386,8 @@ def test_ones(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy() + np.ones_like(x.numpy())) @@ -420,9 +409,8 @@ def test_zeros_like(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy()) @@ -444,9 +432,8 @@ def test_ones_like(): assert mod["main"].checked_type == relay.FuncType([t], t) - ex = create_executor(mod=mod) x = rand(dtype, *shape) - y = ex.evaluate(y)(x) + y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy() + np.ones_like(x.numpy())) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 7e3634f8b7db..c7926f7a3d79 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -180,11 +180,13 @@ def test_pass_run(): y_nd = get_rand(shape, dtype) ref_res = x_nd.numpy() + y_nd.numpy() for target, dev in tvm.testing.enabled_targets(): - exe1 = relay.create_executor("graph", device=dev, target=target) - exe2 = relay.create_executor("debug", device=dev, target=target) - res1 = exe1.evaluate(new_add)(x_nd, y_nd) + res1 = relay.create_executor("graph", device=dev, target=target).evaluate(new_add)( + x_nd, y_nd + ) tvm.testing.assert_allclose(res1.numpy(), ref_res, rtol=1e-5) - res2 = exe2.evaluate(new_add)(x_nd, y_nd) + res2 = relay.create_executor("debug", device=dev, target=target).evaluate(new_add)( + x_nd, y_nd + ) tvm.testing.assert_allclose(res2.numpy(), ref_res, rtol=1e-5) test_pass_registration() @@ -277,11 +279,9 @@ def test_pass_run(): x_nd = get_rand(shape, dtype) ref_res = np.log(x_nd.numpy() * 2) for target, dev in tvm.testing.enabled_targets(): - exe1 = relay.create_executor("graph", device=dev, target=target) - exe2 = relay.create_executor("debug", device=dev, target=target) - res1 = exe1.evaluate(new_log)(x_nd) + res1 = relay.create_executor("graph", device=dev, target=target).evaluate(new_log)(x_nd) tvm.testing.assert_allclose(res1.numpy(), ref_res, rtol=1e-5) - res2 = exe2.evaluate(new_log)(x_nd) + res2 = relay.create_executor("debug", device=dev, target=target).evaluate(new_log)(x_nd) tvm.testing.assert_allclose(res2.numpy(), ref_res, rtol=1e-5) test_pass_registration() @@ -439,22 +439,22 @@ def test_multiple_passes(): y_nd = get_rand(shape, dtype) ref_res = np.subtract(x_nd.numpy() * 2, y_nd.numpy() * 2) for target, dev in tvm.testing.enabled_targets(): - exe1 = relay.create_executor("graph", device=dev, target=target) - exe2 = relay.create_executor("debug", device=dev, target=target) - res1 = exe1.evaluate(new_sub)(x_nd, y_nd) + res1 = relay.create_executor("graph", device=dev, target=target).evaluate(new_sub)( + x_nd, y_nd + ) tvm.testing.assert_allclose(res1.numpy(), ref_res, rtol=1e-5) - res2 = exe2.evaluate(new_sub)(x_nd, y_nd) + res2 = relay.create_executor("debug", device=dev, target=target).evaluate(new_sub)( + x_nd, y_nd + ) tvm.testing.assert_allclose(res2.numpy(), ref_res, rtol=1e-5) # Execute the updated abs function. x_nd = get_rand((5, 10), dtype) ref_res = np.abs(x_nd.numpy() * 2) for target, dev in tvm.testing.enabled_targets(): - exe1 = relay.create_executor("graph", device=dev, target=target) - exe2 = relay.create_executor("debug", device=dev, target=target) - res1 = exe1.evaluate(new_abs)(x_nd) + res1 = relay.create_executor("graph", device=dev, target=target).evaluate(new_abs)(x_nd) tvm.testing.assert_allclose(res1.numpy(), ref_res, rtol=1e-5) - res2 = exe2.evaluate(new_abs)(x_nd) + res2 = relay.create_executor("debug", device=dev, target=target).evaluate(new_abs)(x_nd) tvm.testing.assert_allclose(res2.numpy(), ref_res, rtol=1e-5) test_pass_registration() diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 129ac047cd89..ce36abd83c40 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -31,9 +31,7 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07): dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") - - result = intrp.evaluate(expr) + result = create_executor(mod=mod, device=dev, target="llvm").evaluate(expr) np.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) @@ -144,9 +142,8 @@ def test_if_ref(): body = Let(eff, body, RefRead(r)) f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body))) pe_f = tipe(f) - ex = create_executor() - f_res = ex.evaluate(f)(const(True)) - pe_f_res = ex.evaluate(pe_f)(const(True)) + f_res = create_executor().evaluate(f)(const(True)) + pe_f_res = create_executor().evaluate(pe_f)(const(True)) np.testing.assert_allclose(f_res.numpy(), 2 * np.ones_like(f_res.numpy())) np.testing.assert_allclose(pe_f_res.numpy(), 2 * np.ones_like(pe_f_res.numpy())) @@ -168,9 +165,8 @@ def test_function_invalidate(): body = Let(r, RefCreate(const(0)), body) f = Function([d], body) pe_f = tipe(f) - ex = create_executor() - f_res = ex.evaluate(f)(const(True)) - pe_f_res = ex.evaluate(pe_f)(const(True)) + f_res = create_executor().evaluate(f)(const(True)) + pe_f_res = create_executor().evaluate(pe_f)(const(True)) np.testing.assert_allclose(f_res.numpy(), np.ones_like(f_res.numpy())) np.testing.assert_allclose(pe_f_res.numpy(), np.ones_like(pe_f_res.numpy())) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 5467589f956b..93cd6f791765 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -484,8 +484,9 @@ def get_func(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()) - ref_res = ref_ex.evaluate()(i_data, w1_data) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()).evaluate()( + i_data, w1_data + ) check_result( mod, {"data": i_data, "weight1": w1_data}, (1, 32, 14, 14), ref_res.numpy(), tol=1e-5 ) @@ -504,8 +505,9 @@ def test_extern_dnnl_mobilenet(): mod = transform.PartitionGraph()(mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)) - ref_res = ref_ex.evaluate()(i_data, **params) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( + i_data, **params + ) compile_engine.get().clear() check_result(mod, {"data": i_data}, (1, 1000), ref_res.numpy(), tol=1e-5, params=params) @@ -945,8 +947,9 @@ def test_partition_mobilenet(): def test_exec(mod, params, ref_mod, ref_params, out_shape): ishape = (1, 3, 224, 224) i_data = np.random.randn(*ishape).astype(np.float32) - ref_ex = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)) - ref_res = ref_ex.evaluate()(i_data, **ref_params) + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( + i_data, **ref_params + ) compile_engine.get().clear() mod = get_partitoned_mod(mod, params, dnnl_patterns) diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 61e5b8ea9407..cd2e5d2fd249 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -37,9 +37,7 @@ def run_opt_pass(expr, passes): def check_eval(expr, expected_result, mod=None, rtol=1e-07): dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") - - result = intrp.evaluate(expr) + result = create_executor(mod=mod, device=dev, target="llvm").evaluate(expr) np.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) @@ -151,6 +149,7 @@ def test_nat_add(): add = p.mod.get_global_var("nat_add") dev = tvm.device("llvm", 0) intrp = create_executor(mod=mod, device=dev, target="llvm") + # CAUTION: Following calls to intrp.evaluate(...) will re-prepare the prelude. assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 expr = add(s(z()), s(z())) diff --git a/tests/python/relay/test_pass_to_basic_block_normal_form.py b/tests/python/relay/test_pass_to_basic_block_normal_form.py index d345d465c53e..642cab751b79 100644 --- a/tests/python/relay/test_pass_to_basic_block_normal_form.py +++ b/tests/python/relay/test_pass_to_basic_block_normal_form.py @@ -39,9 +39,7 @@ def run_opt_pass(expr, passes): def check_eval(expr, expected_result, mod=None, rtol=1e-07): dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") - - result = intrp.evaluate(expr) + result = create_executor(mod=mod, device=dev, target="llvm").evaluate(expr) np.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) @@ -267,16 +265,20 @@ def test_nat_add(): nat, z, s = p.mod.get_type("nat") add = p.mod.get_global_var("nat_add") dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) - assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 + assert ( + count(p, create_executor(mod=mod, device=dev, target="llvm").evaluate(add(s(z()), s(z())))) + == 2 + ) expr = add(s(z()), s(z())) f = relay.GlobalVar("f") mod[f] = relay.Function([], expr) mod = transform.InferType()(mod) mod = transform.ToBasicBlockNormalForm()(mod) opt_expr = mod["f"] - assert count(p, intrp.evaluate(opt_expr.body)) == 2 + assert ( + count(p, create_executor(mod=mod, device=dev, target="llvm").evaluate(opt_expr.body)) == 2 + ) assert not Feature.fLet in detect_feature(mod[add]) check_basic_block_normal_form(opt_expr) diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py index 0cde1d9ae492..4825cc29e6e4 100644 --- a/tests/python/relay/test_pass_to_cps.py +++ b/tests/python/relay/test_pass_to_cps.py @@ -58,9 +58,8 @@ def test_recursion(): mod["main"] = to_cps(mod["main"], mod=mod) mod = relay.transform.InferType()(mod) mod["main"] = un_cps(mod["main"]) - ex = create_executor(mod=mod) i_nd = rand(dtype, *shape) - forward = ex.evaluate()(i_nd) + forward = create_executor(mod=mod).evaluate()(i_nd) tvm.testing.assert_allclose(forward.numpy(), 8 * i_nd.numpy()) diff --git a/tests/python/relay/test_pass_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py index 4f5084d83f9c..6a8c99d076e4 100644 --- a/tests/python/relay/test_pass_to_graph_normal_form.py +++ b/tests/python/relay/test_pass_to_graph_normal_form.py @@ -34,9 +34,7 @@ def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): mod = tvm.IRModule() dev = tvm.device("llvm", 0) - intrp = create_executor(mod=mod, device=dev, target="llvm") - - result = intrp.evaluate(expr)(*args) + result = create_executor(mod=mod, device=dev, target="llvm").evaluate(expr)(*args) np.testing.assert_allclose(result.numpy(), expected_result, rtol=rtol) diff --git a/tests/python/relay/test_tensor_array.py b/tests/python/relay/test_tensor_array.py index e93831bef95f..21043abb3c84 100644 --- a/tests/python/relay/test_tensor_array.py +++ b/tests/python/relay/test_tensor_array.py @@ -63,8 +63,9 @@ def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5): for target, dev in [("llvm", tvm.cpu(0))]: # testing.enabled_targets(): if kind == "debug" and dev.device_type != tvm.cpu().device_type: continue - ex = relay.create_executor(kind, mod=ta_mod, device=dev, target=target) - result = ex.evaluate()(*args) + result = relay.create_executor(kind, mod=ta_mod, device=dev, target=target).evaluate()( + *args + ) got = vmobj_to_list(ta_mod, result, dtype) tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index e9d576a333c1..5ab2eb346d8b 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -27,8 +27,7 @@ def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List: dev = tvm.device("llvm", 0) - intrp = relay.create_executor("debug", mod, device=dev, target="llvm") - result = intrp.evaluate()(**mod_params) + result = relay.create_executor("debug", mod, device=dev, target="llvm").evaluate()(**mod_params) if isinstance(result, tvm.runtime.container.ADT): result = [r.numpy() for r in result] return result diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index ad4b5c999fe7..7ae7e0eabeee 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -46,8 +46,9 @@ def check_result(args, expected_result, mod=None): The expected result of running the expression. """ for target, dev in tvm.testing.enabled_targets(): - vm = relay.create_executor("vm", device=dev, target=target, mod=mod) - rts_result = vm.evaluate()(*args) + rts_result = relay.create_executor("vm", device=dev, target=target, mod=mod).evaluate()( + *args + ) tvm.testing.assert_allclose(expected_result, rts_result.numpy()) @@ -182,8 +183,8 @@ def test_multiple_ifs(): fn = relay.Function([b], out) mod["main"] = fn dev = tvm.runtime.device("llvm", 0) - vm = relay.create_executor(device=dev, mod=mod, kind="vm") - res = vmobj_to_list(vm.evaluate()(False)) + func = relay.create_executor(device=dev, mod=mod, kind="vm").evaluate() + res = vmobj_to_list(func(False)) assert res == [1, 0] diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index ef7d9111b84c..f579f74a24ac 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -54,8 +54,7 @@ def get_serialized_output(mod, *data, params=None, target="llvm", device=tvm.cpu def run_network(mod, params, dtype="float32"): def get_vm_output(mod, data, params, target, device, dtype="float32"): - ex = relay.create_executor("vm", mod=mod, device=device) - result = ex.evaluate()(data, **params) + result = relay.create_executor("vm", mod=mod, device=device).evaluate()(data, **params) return result.numpy().astype(dtype) data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 5f8030d53af9..42d2463b8952 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -879,24 +879,24 @@ def test_transpose_unfused_schedule(target, dev): shape = (100, tvm.target.Target(target).thread_warp_size + 3) x = relay.var("x", relay.TensorType(shape, "float32")) f = relay.transpose(x) - ex = relay.create_executor( - kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x], f)), device=dev, target=target - ) r = np.random.rand(*shape) - tvm.testing.assert_allclose(ex.evaluate()(r).numpy(), np.transpose(r)) + func = relay.create_executor( + kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x], f)), device=dev, target=target + ).evaluate() + tvm.testing.assert_allclose(func(r).numpy(), np.transpose(r)) # We want to make sure schedule does not fire here, but there is no way of # inspecting which schedules were used. x = relay.var("x", relay.TensorType(shape, "float32")) y = relay.var("y", relay.TensorType(shape, "float32")) f = relay.transpose(x + y) - ex = relay.create_executor( + func = relay.create_executor( kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x, y], f)), device=dev, target=target, - ) - tvm.testing.assert_allclose(ex.evaluate()(r, r).numpy(), np.transpose(r + r)) + ).evaluate() + tvm.testing.assert_allclose(func(r, r).numpy(), np.transpose(r + r)) @tvm.testing.uses_gpu diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index 5f962ef7f74f..b135973718bc 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -90,17 +90,17 @@ def change_dtype(src, dst, module, params): def compare(module, input, src_dtype, dst_dtype, rtol, atol, params={}, target="llvm"): module = relay.transform.InferType()(module) module = relay.transform.SimplifyInference()(module) - ex = relay.create_executor("graph", mod=module) - correct = ex.evaluate()(*input, **params) + correct = relay.create_executor("graph", mod=module).evaluate()(*input, **params) module, converted_params = change_dtype(src_dtype, dst_dtype, module, params) - ex = relay.create_executor("graph", mod=module, target=target) # converts all inputs to dst_dtype x_converted = [convert_ndarray(dst_dtype, arr) for arr in input] # Vectorization is not implemented with custom datatypes with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - maybe_correct = ex.evaluate()(*x_converted, **converted_params) + maybe_correct = relay.create_executor("graph", mod=module, target=target).evaluate()( + *x_converted, **converted_params + ) # currently this only works for comparing single output maybe_correct_converted = convert_ndarray(src_dtype, maybe_correct) np.testing.assert_allclose( diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 781fd7f93886..4c72f2c6083b 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -48,8 +48,7 @@ def test_tuple_object(): fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) mod = tvm.IRModule.from_expr(fn) - exe = relay.create_executor(kind="vm", mod=mod, device=nd.cpu(), target="llvm") - f = exe.evaluate() + f = relay.create_executor(kind="vm", mod=mod, device=nd.cpu(), target="llvm").evaluate() value_tuple = _container.tuple_object([nd.array(np.array(11)), nd.array(np.array(12))]) # pass an ADT object to evaluate out = f(value_tuple) diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 01f734beb8fd..1edc5d311759 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -220,8 +220,7 @@ def do_copy(A, B, n): def check_mod(target, dev, mod, x_np, res_np): - ex = relay.create_executor("vm", mod=mod, device=dev, target=target) - res = ex.evaluate()(x_np).numpy() + res = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()(x_np).numpy() tvm.testing.assert_allclose(res, res_np, atol=1e-5) diff --git a/tutorials/dev/bring_your_own_datatypes.py b/tutorials/dev/bring_your_own_datatypes.py index 06d96e14d28c..a5e8e2898d39 100644 --- a/tutorials/dev/bring_your_own_datatypes.py +++ b/tutorials/dev/bring_your_own_datatypes.py @@ -82,9 +82,7 @@ ###################################################################### # Finally, we're ready to run the program: -ex = relay.create_executor(mod=module) - -z_output = ex.evaluate()(x_input, y_input) +z_output = relay.create_executor(mod=module).evaluate()(x_input, y_input) print("z: {}".format(z_output)) ###################################################################### @@ -135,8 +133,7 @@ # Now that we can express our program without errors, let's try running it! try: with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - ex = relay.create_executor("graph", mod=module) - z_output_myfloat = ex.evaluate()(x_input, y_input) + z_output_myfloat = relay.create_executor("graph", mod=module).evaluate()(x_input, y_input) print("z: {}".format(y_myfloat)) except tvm.TVMError as e: # Print last line of error @@ -181,8 +178,7 @@ # We can now re-try running the program: try: with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - ex = relay.create_executor("graph", mod=module) - z_output_myfloat = ex.evaluate()(x_input, y_input) + z_output_myfloat = relay.create_executor("graph", mod=module).evaluate()(x_input, y_input) print("z: {}".format(z_output_myfloat)) except tvm.TVMError as e: # Print last line of error @@ -211,8 +207,7 @@ # Now, we can run our program without errors. with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - compiled = ex.evaluate(program) - z_output_myfloat = compiled(x_input, y_input) + z_output_myfloat = relay.create_executor(mod=module).evaluate()(x_input, y_input) print("z: {}".format(z_output_myfloat)) print("x:\t\t{}".format(x_input)) @@ -262,9 +257,8 @@ def get_cat_image(): ###################################################################### # It's easy to execute MobileNet with native TVM: -ex = tvm.relay.create_executor("graph", mod=module) input = get_cat_image() -result = ex.evaluate()(input, **params).numpy() +result = tvm.relay.create_executor("graph", mod=module).evaluate()(input, **params).numpy() # print first 10 elements print(result.flatten()[:10]) @@ -311,7 +305,9 @@ def convert_ndarray(dst_dtype, array): try: # Vectorization is not implemented with custom datatypes. with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - result_myfloat = ex.evaluate(expr)(input, **params) + result_myfloat = tvm.relay.create_executor("graph", mod=module).evaluate(expr)( + input, **params + ) except tvm.TVMError as e: print(str(e).split("\n")[-1]) @@ -401,7 +397,7 @@ def convert_ndarray(dst_dtype, array): # Vectorization is not implemented with custom datatypes. with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - result_myfloat = ex.evaluate(expr)(input, **params) + result_myfloat = relay.create_executor(mod=module).evaluate(expr)(input, **params) result_myfloat = convert_ndarray(src_dtype, result_myfloat).numpy() # print first 10 elements print(result_myfloat.flatten()[:10]) diff --git a/tutorials/frontend/deploy_quantized.py b/tutorials/frontend/deploy_quantized.py index b2210b8ab69b..2d9275796eb5 100644 --- a/tutorials/frontend/deploy_quantized.py +++ b/tutorials/frontend/deploy_quantized.py @@ -146,11 +146,11 @@ def quantize(mod, params, data_aware): # ------------- # We create a Relay VM to build and execute the model. def run_inference(mod): - executor = relay.create_executor("vm", mod, dev, target) + model = relay.create_executor("vm", mod, dev, target).evaluate() val_data, batch_fn = get_val_data() for i, batch in enumerate(val_data): data, label = batch_fn(batch) - prediction = executor.evaluate()(data) + prediction = model(data) if i > 10: # only run inference on a few samples in this tutorial break diff --git a/tutorials/frontend/from_keras.py b/tutorials/frontend/from_keras.py index 1c48aff799d4..e62836d2ccfe 100644 --- a/tutorials/frontend/from_keras.py +++ b/tutorials/frontend/from_keras.py @@ -97,14 +97,20 @@ # compile the model target = "cuda" dev = tvm.cuda(0) -with tvm.transform.PassContext(opt_level=3): - executor = relay.build_module.create_executor("graph", mod, dev, target) + +# TODO(mbs): opt_level=3 causes nn.contrib_conv2d_winograd_weight_transform +# to end up in the module which fails memory validation on cuda most likely +# due to a latent bug. Note that the pass context only has an effect within +# evaluate() and is not captured by create_executor(). +with tvm.transform.PassContext(opt_level=0): + model = relay.build_module.create_executor("graph", mod, dev, target).evaluate() + ###################################################################### # Execute on TVM # --------------- dtype = "float32" -tvm_out = executor.evaluate()(tvm.nd.array(data.astype(dtype)), **params) +tvm_out = model(tvm.nd.array(data.astype(dtype)), **params) top1_tvm = np.argmax(tvm_out.numpy()[0]) ##################################################################### diff --git a/tutorials/frontend/from_onnx.py b/tutorials/frontend/from_onnx.py index 26aeb6ecaf38..890bfbac4d8a 100644 --- a/tutorials/frontend/from_onnx.py +++ b/tutorials/frontend/from_onnx.py @@ -92,13 +92,13 @@ mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) with tvm.transform.PassContext(opt_level=1): - intrp = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target) + compiled = relay.build_module.create_executor("graph", mod, tvm.cpu(0), target).evaluate() ###################################################################### # Execute on TVM # --------------------------------------------- dtype = "float32" -tvm_output = intrp.evaluate()(tvm.nd.array(x.astype(dtype)), **params).numpy() +tvm_output = compiled(tvm.nd.array(x.astype(dtype)), **params).numpy() ###################################################################### # Display results