Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Build] Move build module transformations and utilities to C++ #9103

Merged
merged 39 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c8eace9
Initial investigation
electriclilies Sep 10, 2021
8d6b228
More progress!
electriclilies Sep 15, 2021
ba8836e
More progress / notes
electriclilies Sep 15, 2021
b5bb9e8
rewrite build_for_device mostly in c++
electriclilies Sep 17, 2021
f9372db
More progress
electriclilies Sep 17, 2021
322f9f1
Initial split of transformations applied to device and host as post s…
mikepapadim Sep 20, 2021
0c4a01d
Combine duplicate passes after spliting mod on aot and vm flows
mikepapadim Sep 20, 2021
73640e8
Minor cleanup
mikepapadim Sep 20, 2021
6904123
Fix merge conflicts
mikepapadim Sep 21, 2021
d0ba8b8
Move target mangling to driver_api.cc
mikepapadim Sep 22, 2021
01b4ce3
Move more build utlities to cpp driver api
mikepapadim Sep 22, 2021
6176155
[Build][WIP] Moving build utilities to C++ from Python
mikepapadim Sep 24, 2021
df0c75d
Merge branch 'main' of https://github.com/apache/tvm into build_incre…
mikepapadim Sep 24, 2021
09aaf88
[Build] Remove comments
mikepapadim Sep 24, 2021
311632b
[lint] Pass black
mikepapadim Sep 24, 2021
0c28839
More formating
mikepapadim Sep 24, 2021
5008b75
Move more build functionality into cpp
mikepapadim Sep 24, 2021
f73791a
Remove comments
mikepapadim Sep 24, 2021
ba98e6f
Remove unused defs and imports
mikepapadim Sep 24, 2021
f515c6f
Address PR comments
mikepapadim Sep 25, 2021
6b366c3
More PR comments
mikepapadim Sep 25, 2021
1e24b25
More comments
mikepapadim Sep 25, 2021
af8c8e3
More comments
mikepapadim Sep 25, 2021
57b8039
Add comments on the new split function
mikepapadim Sep 26, 2021
0c4bf6d
Fix PR comments on clarity
mikepapadim Oct 4, 2021
41cf6f3
Test CI
mikepapadim Oct 7, 2021
6e4f751
Merge branch 'main' of https://github.com/apache/tvm into build_incre…
mikepapadim Oct 7, 2021
b7f27d0
Fix format
mikepapadim Oct 7, 2021
e1658b5
Refactor build
mikepapadim Oct 8, 2021
a71a0af
Expose splitted composite passes to python
mikepapadim Oct 8, 2021
0ca60e9
Format files
mikepapadim Oct 8, 2021
49cbe53
Test fix
mikepapadim Oct 12, 2021
cd33d95
Fix merge conflicts
mikepapadim Oct 12, 2021
6e1203e
Fix for annotating entry funcs on code targeting CPU
mikepapadim Oct 12, 2021
2c0305b
Prevent entry funcs to be annotated when compiling for CPU with C run…
mikepapadim Oct 12, 2021
b955799
Guard for aot executor entry
mikepapadim Oct 12, 2021
aed0765
Sphix format
mikepapadim Oct 12, 2021
c3505fa
Sanity fix
mikepapadim Oct 13, 2021
2162c85
Sphinx fix
mikepapadim Oct 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#define TVM_DRIVER_DRIVER_API_H_

#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/support/with.h>
#include <tvm/target/target.h>
Expand All @@ -43,6 +44,34 @@
#include <vector>

namespace tvm {
using tvm::transform::Pass;

/*!
* \brief Configures and returns the composite Pass for the fused module (pre split) that contains
* device and host code.
* \param mixed_mod The original mixed module.
* \param target The device Target.
* \return The composite Pass for the fused module.
// */
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target);

/*!
* \brief Configures and returns the composite Pass for the device Target after device/host from
* mixed module.
* \param mixed_mod The optimized mixed module.
* \param target The device Target.
* \return The composite Pass for the device module.
*/
TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target);

/*!
* \brief Configures and returns the composite Pass for the host Target after device/host from mixed
* module.
* \param mixed_mod The optimized mixed module.
* \param target_host The host Target.
* \return The composite Pass for the host module.
*/
TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host);

/*!
* \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList)
Expand Down Expand Up @@ -136,6 +165,7 @@ TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target&
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const Target& target_host);

} // namespace tvm

#endif // TVM_DRIVER_DRIVER_API_H_
138 changes: 11 additions & 127 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,20 @@
"""

from typing import Union, Optional, List, Mapping
import warnings

import tvm.tir

from tvm.runtime import Module
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.ir import CallingConv
from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.target import codegen
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var
from tvm.driver import _ffi_api as _driver_ffi

from . import _ffi_api as ffi

Expand Down Expand Up @@ -123,6 +120,7 @@ def lower(
m : IRModule
The result IRModule
"""
# TODO(@mikepapadim) introduce ffi.relay.lower_te_pass()
if isinstance(inp, IRModule):
return ffi.lower_module(inp, simple_mode)
if isinstance(inp, PrimFunc):
Expand All @@ -132,98 +130,6 @@ def lower(
raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp))


def _build_for_device(input_mod, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.

Parameters
----------
input_mod : IRModule
The schedule to be built.

target : str or :any:`tvm.target.Target`
The target and option of the compilation.

target_host : str or :any:`tvm.target.Target`
The host compilation target.

Returns
-------
fhost : IRModule
The host IRModule.

mdev : tvm.module
A module that contains device code.
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
device_type = ndarray.device(target.kind.name, 0).device_type

mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)

opt_mixed = [
tvm.tir.transform.VerifyMemory(),
tvm.tir.transform.MergeDynamicSharedMemoryAllocations(),
]
if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]

if PassContext.current().config.get("tir.detect_global_barrier", False):
opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [
tvm.tir.transform.ThreadSync("shared"),
tvm.tir.transform.ThreadSync("warp"),
tvm.tir.transform.InferFragment(),
tvm.tir.transform.LowerThreadAllreduce(),
tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.SplitHostDevice(),
]
mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)

# device optimizations
opt_device = tvm.transform.Sequential(
[
tvm.tir.transform.Filter(
lambda f: "calling_conv" in f.attrs
and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
]
)
mod_dev = opt_device(mod_mixed)

# host optimizations
opt_host = tvm.transform.Sequential(
[
tvm.tir.transform.Filter(
lambda f: "calling_conv" not in f.attrs
or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall(),
]
)
mod_host = opt_host(mod_mixed)

if device_type == ndarray.cpu(0).device_type and target_host == target:
assert len(mod_dev.functions) == 0
if "gpu" in target.keys and len(mod_dev.functions) == 0:
warnings.warn(
"Specified target %s, but cannot find device code, did you do " "bind?" % target
)

rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
return mod_host, rt_mod_dev


def build(
inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
Expand All @@ -234,59 +140,47 @@ def build(
):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.

mikepapadim marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]]
inputs : Union[tvm.te.schedule.Schedule,
tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]]
The input to be built

args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
The argument lists to the function.

target : Optional[Union[str, Target]]
The target and option of the compilation.

target_host : Optional[Union[str, Target]]
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.

otherwise a stackvm interpreter is used.
name : Optional[str]
The name of result function.

binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]]
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.

Returns
-------
ret : tvm.module
A module that combines both host and device code.

Examples
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
1. it is an IRModule.

.. code-block:: python

n = 2
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
m = tvm.lower(s, [A, B, C], name="test_add")
rt_mod = tvm.build(m, target="llvm")

2. it is a dict of compilation target to IRModule.

.. code-block:: python

n = 2
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
Expand All @@ -297,11 +191,11 @@ def build(
m1 = tvm.lower(s1, [A, B, C], name="test_add1")
m2 = tvm.lower(s2, [A, B, C], name="test_add2")
rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")

Note
----
See the note on :any:`tvm.target` on target string format.
"""

if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
Expand All @@ -318,7 +212,7 @@ def build(
f"Inputs must be Schedule, IRModule or dict of target to IRModule, "
f"but got {type(inputs)}."
)

# starts here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe sweep for leftover comments

if not isinstance(inputs, (dict, container.Map)):
target = Target.current() if target is None else target
target = target if target else "llvm"
Expand Down Expand Up @@ -350,21 +244,11 @@ def build(
target_input_mod, target_host
)

mod_host_all = tvm.IRModule({})

device_modules = []
for tar, input_mod in target_input_mod.items():
mod_host, mdev = _build_for_device(input_mod, tar, target_host)
mod_host_all.update(mod_host)
device_modules.append(mdev)
rt_mod_host = _driver_ffi.finalize_module(target_input_mod, target_host)

# Generate a unified host module.
rt_mod_host = codegen.build_module(mod_host_all, target_host)

# Import all modules.
for mdev in device_modules:
if mdev:
rt_mod_host.import_module(mdev)
target_input_mod, target_host = Target.check_and_update_host_consist(
target_input_mod, target_host
)

if not isinstance(target_host, Target):
target_host = Target(target_host)
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def build(
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
otherwise a stackvm interpreter is used.

params : dict of str to NDArray
Input parameters to the graph that do not change
Expand Down Expand Up @@ -303,7 +303,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
otherwise a stackvm interpreter is used.

params : dict of str to NDArray
Input parameters to the graph that do not change
Expand Down Expand Up @@ -452,7 +452,7 @@ def bind_params_by_name(func, params):
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.

This executor is used for debug and testing purpoes.
This executor is used for debug and testing purposes.

Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/target/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def build_module(mod, target):
The corressponding module.
"""
target = Target(target) if isinstance(target, str) else target
return _ffi_api.Build(mod, target)
return _ffi_api.Codegen(mod, target)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also rename the containing function, which is just a wrapper (which would really be renaming the contianing function to codegen but adding a backwards-compat build_module which warns about deprecation (there are some examples using warnings module i believe).



def llvm_lookup_intrinsic_id(name):
Expand Down
Loading