Skip to content

Commit

Permalink
Addressing comments - 3
Browse files Browse the repository at this point in the history
Change-Id: Id25d1382c30d6d0a0013b5e8986fb8cd886666dc
  • Loading branch information
Giuseppe Rossini committed Apr 19, 2021
1 parent 7fef775 commit d5f2e81
Show file tree
Hide file tree
Showing 32 changed files with 342 additions and 433 deletions.
4 changes: 2 additions & 2 deletions include/tvm/runtime/crt/page_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ struct MemoryManagerInterface {
* \param page_size_bytes_log2 log2 of the page size, in bytes.
* \return kTvmErrorNoError on success.
*/
tvm_crt_error_t MemoryManagerCreate(MemoryManagerInterface** manager, uint8_t* memory_pool,
size_t memory_pool_size_bytes, size_t page_size_bytes_log2);
tvm_crt_error_t PageMemoryManagerCreate(MemoryManagerInterface** manager, uint8_t* memory_pool,
size_t memory_pool_size_bytes, size_t page_size_bytes_log2);

#ifdef __cplusplus
} // extern "C"
Expand Down
25 changes: 25 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,16 @@ TVM_DLL const Op& tvm_stack_make_array();
*/
TVM_DLL const Op& tvm_call_packed();

/*!
* \brief See pesudo code
*
* int tvm_call_packed(fname, TVMValue* args) {
* (*fname)(args, type_code_of(args), len(args));
* return 0;
* }
*/
TVM_DLL const Op& tvm_call_cpacked();

/*!
* \brief See pesudo code
*
Expand Down Expand Up @@ -386,6 +396,21 @@ TVM_DLL const Op& tvm_thread_context();
*/
TVM_DLL const Op& tvm_call_packed_lowered();

/*!
* \brief Lowered version of call c-packed, the space of value and
* type codes are explicitly allocated.
*
* int tvm_call_packed_lowered(fname,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* fname(TVMArgs(value_stack[begin:end], tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
* }
*/
TVM_DLL const Op& tvm_call_cpacked_lowered();

/*!
* \brief Lowered version of trace intrinsic, the space of value and
* type codes are explicitly allocated. The return value is the
Expand Down
24 changes: 4 additions & 20 deletions python/tvm/relay/backend/executor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ def get_lib(self):
""" Return the generated library"""
raise NotImplementedError

@abstractmethod
def get_internal_repr(self):
""" Return the internal representation used to execute the network"""
raise NotImplementedError

def __getitem__(self, item):
print(item)
return self.module.__getitem__(item)
Expand Down Expand Up @@ -85,9 +80,8 @@ def __next__(self):
class AOTExecutorFactoryModule(ExecutorFactoryModule):
"""AOT executor factory module.
Parameters
Attributes
----------
runner_function : the PrimFunc containing of the TIR main executor function.
target : tvm.Target
The Target used to build this module.
libmod : tvm.Module
Expand All @@ -98,29 +92,19 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
The parameters of module
"""

def __init__(self, ir_mod, target, runner_function, libmod, libmod_name, params):
assert isinstance(runner_function, tir.PrimFunc)
args = []
for k, v in params.items():
args.append(k)
args.append(ndarray.array(v))

def __init__(self, ir_mod, target, libmod, libmod_name, params):
self.ir_mod = ir_mod
self.target = target
self.runner_func = runner_function
self.lib = libmod
self.libmod_name = libmod_name
self.params = params
self.iter_cnt = 0

# Sometimes we want to get params explicitly.
# For example, we want to save its params value to
# an independent file.
def get_params(self):
return self.params

def get_internal_repr(self):
return self.runner_func
return None

def get_lib(self):
return self.lib
Expand All @@ -130,7 +114,7 @@ class GraphExecutorFactoryModule(ExecutorFactoryModule):
"""Graph executor factory module.
This is a module of graph executor factory
Parameters
Attributes
----------
graph_json_str : the json graph to be deployed in json format output by graph compiler.
The graph can contain operator(tvm_op) that points to the name of
Expand Down
22 changes: 9 additions & 13 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class BuildModule(object):
def __init__(self):
self.mod = _build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_runner_function = self.mod["get_runner_function"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._optimize = self.mod["optimize"]
Expand Down Expand Up @@ -114,7 +113,7 @@ def build(self, mod, target=None, target_host=None, params=None):
Returns
-------
factory_module : tvm.relay.backend.executor_factory.ExecutorFactoryModule
The runtime factory for the TVM graph executor.
The runtime factory for the TVM executor.
"""
target = _update_target(target)
target, target_host = Target.check_and_update_host_consist(
Expand All @@ -141,11 +140,7 @@ def build(self, mod, target=None, target_host=None, params=None):
# Get artifacts
mod = self.get_module()
params = self.get_params()
internal_repr = (
self._get_runner_function()
if self.get_executor_type() == "aot"
else self.get_graph_json()
)
internal_repr = self.get_graph_json() if self.get_executor_type() == "graph" else None

return internal_repr, mod, params

Expand Down Expand Up @@ -255,10 +250,11 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
Returns
-------
graph_json : str
The string representation of the graph. When using the
graph-executor this is represents json string that can be
accepted by the executor.
internal_repr : str or tir.PrimFunc
The internal representation the executor uses to execute the
network. Can be a string representing the json graph (if we are
building for graph executor) or the PrimFunc representing the
AOT runner function
mod : tvm.Module
The module containing necessary libraries.
Expand Down Expand Up @@ -303,14 +299,14 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"

if bld_mod.get_executor_type() == "aot":
executor_factory = _executor_factory.AOTExecutorFactoryModule(
ir_mod, target, internal_repr, runtime_mod, mod_name, params
ir_mod, target, runtime_mod, mod_name, params
)
elif bld_mod.get_executor_type() == "graph":
executor_factory = _executor_factory.GraphExecutorFactoryModule(
ir_mod, target, internal_repr, runtime_mod, mod_name, params
)
else:
assert False, "Executor not supported"
assert False, "Executor " + bld_mod.get_executor_type() + " not supported"

return executor_factory

Expand Down
Loading

0 comments on commit d5f2e81

Please sign in to comment.