Skip to content

Commit

Permalink
[Core][Build] Move build module transformations and utilities to C++ (a…
Browse files Browse the repository at this point in the history
…pache#9103)

* Initial investigation

* More progress!

* More progress / notes

* rewrite build_for_device mostly in c++

* More progress

* Initial split of transformations applied to device and host as post split action from mixed module

* Combine duplicate passes after spliting mod on aot and vm flows

* Minor cleanup

* Move target mangling to driver_api.cc

* Move more build utlities to cpp driver api

* [Build][WIP] Moving build utilities to C++ from Python

* [Build] Remove comments

* [lint] Pass black

* More formating

* Move more build functionality into cpp

* Remove comments

* Remove unused defs and imports

* Address PR comments

* More PR comments

* More comments

* More comments

* Add comments on the new split function

* Fix PR comments on clarity

* Test CI

* Fix format

* Refactor build

* Expose splitted composite passes to python

* Format files

* Test fix

* Fix for annotating entry funcs on code targeting CPU

* Prevent entry funcs to be annotated when compiling for CPU with C runtime enabled

* Guard for aot executor entry

* Sphix format

* Sanity fix

* Sphinx fix

Co-authored-by: electriclilies <lilyorthsmith@gmail.com>
  • Loading branch information
2 people authored and ylc committed Jan 7, 2022
1 parent 9ad6004 commit 2870a82
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 205 deletions.
30 changes: 30 additions & 0 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#define TVM_DRIVER_DRIVER_API_H_

#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/support/with.h>
#include <tvm/target/target.h>
Expand All @@ -43,6 +44,34 @@
#include <vector>

namespace tvm {
using tvm::transform::Pass;

/*!
* \brief Configures and returns the composite Pass for the fused module (pre split) that contains
* device and host code.
* \param mixed_mod The original mixed module.
* \param target The device Target.
* \return The composite Pass for the fused module.
// */
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target);

/*!
* \brief Configures and returns the composite Pass for the device Target after device/host from
* mixed module.
* \param mixed_mod The optimized mixed module.
* \param target The device Target.
* \return The composite Pass for the device module.
*/
TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target);

/*!
* \brief Configures and returns the composite Pass for the host Target after device/host from mixed
* module.
* \param mixed_mod The optimized mixed module.
* \param target_host The host Target.
* \return The composite Pass for the host module.
*/
TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host);

/*!
* \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList)
Expand Down Expand Up @@ -136,6 +165,7 @@ TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target&
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const Target& target_host);

} // namespace tvm

#endif // TVM_DRIVER_DRIVER_API_H_
125 changes: 10 additions & 115 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,23 @@
# under the License.

# pylint: disable=invalid-name
"""The build utils in python.
"""
"""The build utils in python."""

from typing import Union, Optional, List, Mapping
import warnings

import tvm.tir

from tvm.runtime import Module
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.ir import CallingConv
from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.target import codegen
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var
from tvm.driver import _ffi_api as _driver_ffi

from . import _ffi_api as ffi

Expand Down Expand Up @@ -104,8 +100,8 @@ def lower(
args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
The argument lists to the function for TE schedule.
It should be None if we want to lower TensorIR.
It should be None if we want to lower TensorIR.
name : str
The name of the result function.
Expand All @@ -132,98 +128,6 @@ def lower(
raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp))


def _build_for_device(input_mod, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
input_mod : IRModule
The schedule to be built.
target : str or :any:`tvm.target.Target`
The target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
The host compilation target.
Returns
-------
fhost : IRModule
The host IRModule.
mdev : tvm.module
A module that contains device code.
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
device_type = ndarray.device(target.kind.name, 0).device_type

mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)

opt_mixed = [
tvm.tir.transform.VerifyMemory(),
tvm.tir.transform.MergeDynamicSharedMemoryAllocations(),
]
if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]

if PassContext.current().config.get("tir.detect_global_barrier", False):
opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [
tvm.tir.transform.ThreadSync("shared"),
tvm.tir.transform.ThreadSync("warp"),
tvm.tir.transform.InferFragment(),
tvm.tir.transform.LowerThreadAllreduce(),
tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.SplitHostDevice(),
]
mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)

# device optimizations
opt_device = tvm.transform.Sequential(
[
tvm.tir.transform.Filter(
lambda f: "calling_conv" in f.attrs
and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.Simplify(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
]
)
mod_dev = opt_device(mod_mixed)

# host optimizations
opt_host = tvm.transform.Sequential(
[
tvm.tir.transform.Filter(
lambda f: "calling_conv" not in f.attrs
or f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH
),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target_host)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerCustomDatatypes(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall(),
]
)
mod_host = opt_host(mod_mixed)

if device_type == ndarray.cpu(0).device_type and target_host == target:
assert len(mod_dev.functions) == 0
if "gpu" in target.keys and len(mod_dev.functions) == 0:
warnings.warn(
"Specified target %s, but cannot find device code, did you do " "bind?" % target
)

rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
return mod_host, rt_mod_dev


def build(
inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
Expand All @@ -237,7 +141,8 @@ def build(
Parameters
----------
inputs : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]]
inputs : Union[tvm.te.schedule.Schedule,
tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]]
The input to be built
args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
Expand All @@ -253,7 +158,7 @@ def build(
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
otherwise a stackvm interpreter is used.
name : Optional[str]
The name of result function.
Expand Down Expand Up @@ -350,21 +255,11 @@ def build(
target_input_mod, target_host
)

mod_host_all = tvm.IRModule({})

device_modules = []
for tar, input_mod in target_input_mod.items():
mod_host, mdev = _build_for_device(input_mod, tar, target_host)
mod_host_all.update(mod_host)
device_modules.append(mdev)

# Generate a unified host module.
rt_mod_host = codegen.build_module(mod_host_all, target_host)
rt_mod_host = _driver_ffi.finalize_module(target_input_mod, target_host)

# Import all modules.
for mdev in device_modules:
if mdev:
rt_mod_host.import_module(mdev)
target_input_mod, target_host = Target.check_and_update_host_consist(
target_input_mod, target_host
)

if not isinstance(target_host, Target):
target_host = Target(target_host)
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def build(
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
otherwise a stackvm interpreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
Expand Down Expand Up @@ -303,7 +303,7 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
otherwise a stackvm interpreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
Expand Down Expand Up @@ -452,7 +452,7 @@ def bind_params_by_name(func, params):
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
This executor is used for debug and testing purpoes.
This executor is used for debug and testing purposes.
Parameters
----------
Expand Down
Loading

0 comments on commit 2870a82

Please sign in to comment.