Skip to content

Commit

Permalink
[VTA] Fix vta rpc server, refactor launch cond to not depend on sys.a…
Browse files Browse the repository at this point in the history
…rgv (apache#8671)
  • Loading branch information
tqchen authored and mehrdadh committed Aug 11, 2021
1 parent 4222a3f commit 154ba7f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 67 deletions.
6 changes: 4 additions & 2 deletions vta/python/vta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
configure the hardware environment and access remote device through RPC.
"""
import sys
import tvm._ffi.base

from .autotvm import module_loader
from .bitstream import get_bitstream_path, download_bitstream
Expand All @@ -29,8 +30,9 @@

__version__ = "0.1.0"


# do not from tvm import topi when running vta.exec.rpc_server
# to maintain minimum dependency on the board
if sys.argv[0] not in ("-c", "-m"):
# in lib tvm runtime only mode
if not tvm._ffi.base._RUNTIME_ONLY:
from . import top
from .build_module import build_config, lower, build
65 changes: 64 additions & 1 deletion vta/python/vta/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
# pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime."""
import tvm
from tvm.ir import register_intrin_lowering
from . import transform
from .environment import get_env
from .environment import get_env, Environment


def EarlyRewrite():
Expand Down Expand Up @@ -134,3 +135,65 @@ def build(*args, **kwargs):

tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle")
tvm.ir.register_op_attr("tir.vta.command_handle", "TCallEffectKind", tvm.tir.CallEffectKind.Opaque)

# The memory information for the compiler
@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
def mem_info_inp_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.INP_ELEM_BITS,
max_simd_bits=spec.INP_ELEM_BITS,
max_num_bits=spec.INP_BUFF_SIZE * 8,
head_address=None,
)


@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
def mem_info_wgt_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.WGT_ELEM_BITS,
max_simd_bits=spec.WGT_ELEM_BITS,
max_num_bits=spec.WGT_BUFF_SIZE * 8,
head_address=None,
)


@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_acc_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.ACC_ELEM_BITS,
max_simd_bits=spec.ACC_ELEM_BITS,
max_num_bits=spec.ACC_BUFF_SIZE * 8,
head_address=None,
)


# TVM Op related registration
@register_intrin_lowering("tir.vta.coproc_sync", "default")
def coproc_sync(op):
_ = op
return tvm.tir.call_extern(
"int32",
"VTASynchronize",
get_env().dev.command_handle,
tvm.runtime.const(1 << 31, dtype="uint32"),
)


@register_intrin_lowering("tir.vta.coproc_dep_push", "default")
def coproc_dep_push(op):
return tvm.tir.call_extern(
"int32", "VTADepPush", get_env().dev.command_handle, op.args[0], op.args[1]
)


@register_intrin_lowering("tir.vta.coproc_dep_pop", "default")
def coproc_dep_pop(op):
return tvm.tir.call_extern(
"int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1]
)
64 changes: 0 additions & 64 deletions vta/python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import copy
import tvm
from tvm import te
from tvm.ir import register_intrin_lowering
from . import intrin


Expand Down Expand Up @@ -255,69 +254,6 @@ def get_env():
return Environment.current


# The memory information for the compiler
@tvm.register_func("tvm.info.mem.%s" % Environment.inp_scope)
def mem_info_inp_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.INP_ELEM_BITS,
max_simd_bits=spec.INP_ELEM_BITS,
max_num_bits=spec.INP_BUFF_SIZE * 8,
head_address=None,
)


@tvm.register_func("tvm.info.mem.%s" % Environment.wgt_scope)
def mem_info_wgt_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.WGT_ELEM_BITS,
max_simd_bits=spec.WGT_ELEM_BITS,
max_num_bits=spec.WGT_BUFF_SIZE * 8,
head_address=None,
)


@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_acc_buffer():
spec = get_env()
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=spec.ACC_ELEM_BITS,
max_simd_bits=spec.ACC_ELEM_BITS,
max_num_bits=spec.ACC_BUFF_SIZE * 8,
head_address=None,
)


# TVM Op related registration
@register_intrin_lowering("tir.vta.coproc_sync", "default")
def coproc_sync(op):
_ = op
return tvm.tir.call_extern(
"int32",
"VTASynchronize",
get_env().dev.command_handle,
tvm.runtime.const(1 << 31, dtype="uint32"),
)


@register_intrin_lowering("tir.vta.coproc_dep_push", "default")
def coproc_dep_push(op):
return tvm.tir.call_extern(
"int32", "VTADepPush", get_env().dev.command_handle, op.args[0], op.args[1]
)


@register_intrin_lowering("tir.vta.coproc_dep_pop", "default")
def coproc_dep_pop(op):
return tvm.tir.call_extern(
"int32", "VTADepPop", get_env().dev.command_handle, op.args[0], op.args[1]
)


def _init_env():
"""Initialize the default global env"""
config_path = os.path.join(get_vta_hw_path(), "config/vta_config.json")
Expand Down

0 comments on commit 154ba7f

Please sign in to comment.