Skip to content

Commit

Permalink
- fix AotExecutor assert failure
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Jun 18, 2022
1 parent 5de3adf commit 19b33a4
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,15 +547,15 @@ class GraphExecutor(_interpreter.Executor):
device : :py:class:`Device`
The runtime device to run the code on.
target : :py:class:`Target`
The target option to build the function.
raw_targets : Array[tvm.target.Target]
The available targets.
"""

def __init__(self, mod, device, target):
def __init__(self, mod, device, raw_targets):
assert mod is not None
self.mod = mod
self.device = device
self.target = target
self.raw_targets = raw_targets

def _make_executor(self, expr=None):
if expr:
Expand All @@ -566,7 +566,7 @@ def _make_executor(self, expr=None):
raise ValueError(
"Graph Executor only supports static graphs, got output type", ret_type
)
mod = build(self.mod, target=self.target)
mod = build(self.mod, target=self.raw_targets)
gmodule = _graph_executor.GraphModule(mod["default"](self.device))

def _unflatten(flat_iter, cur_type):
Expand Down Expand Up @@ -607,16 +607,16 @@ class AotExecutor(_interpreter.Executor):
device : :py:class:`Device`
The runtime device to run the code on.
target : :py:class:`Target`
The target option to build the function.
raw_targets : Array[tvm.target.Target]
The available targets.
"""

def __init__(self, mod, device, target):
def __init__(self, mod, device, raw_targets):
assert mod is not None
self.mod = mod
self.device = device
self.target = target
assert target.attrs.get("executor", "graph") == "aot"
self.raw_targets = raw_targets
assert raw_targets[0].attrs.get("executor", "graph") == "aot"

def _make_executor(self, expr=None):
if expr:
Expand All @@ -625,7 +625,7 @@ def _make_executor(self, expr=None):
ret_type = self.mod["main"].checked_type.ret_type
if _ty.is_dynamic(ret_type):
raise ValueError("AOT Executor only supports static graphs, got output type", ret_type)
mod = build(self.mod, target=self.target)
mod = build(self.mod, target=self.raw_targets)

# NOTE: Given AOT requires use of the "c" backend, must export/import to compile the
# generated code.
Expand Down Expand Up @@ -696,8 +696,9 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
device : :py:class:`Device`
The device to execute the code.
target : :py:class:`tvm.Target`
The corresponding context
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
params : dict of str to NDArray
Input parameters to the graph that do not change
Expand Down

0 comments on commit 19b33a4

Please sign in to comment.