Skip to content

Commit

Permalink
[BYORTL][Verilator] update ops and add MobileNet (apache#7972)
Browse files Browse the repository at this point in the history
* update

* update vta submodule

* cpp fmt

* python fmt

* skip if tflite is not available

* fmt

* change assertion

* update comment
  • Loading branch information
vegaluisjose authored May 18, 2021
1 parent c510c2b commit dbd076a
Show file tree
Hide file tree
Showing 8 changed files with 554 additions and 103 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
5 changes: 4 additions & 1 deletion src/runtime/contrib/verilator/verilator_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ namespace tvm {
namespace runtime {
namespace contrib {

extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* data, int* weight, int* out,
extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* left, int* right, int* out,
int p_h_, int p_w_);

extern "C" TVM_DLL void verilator_bias_add(VerilatorHandle handle, int* data, int* bias, int* out,
int p_n_, int p_c_, int p_h_, int p_w_);

} // namespace contrib
} // namespace runtime
} // namespace tvm
Expand Down
19 changes: 12 additions & 7 deletions src/runtime/contrib/verilator/verilator_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ VerilatorRuntime::~VerilatorRuntime() {
auto dealloc = reinterpret_cast<VerilatorDeallocFunc>(lib_->GetSymbol("VerilatorDealloc"));
ICHECK(dealloc != nullptr);
dealloc(device_);
delete lib_;
lib_->~VerilatorLibrary();
}

void VerilatorRuntime::SetLibrary(const std::string& lib_path) { lib_path_ = lib_path; }
Expand All @@ -100,15 +100,14 @@ void VerilatorRuntime::Init(const Array<NDArray>& consts) {
ICHECK(reset != nullptr);
read_ = reinterpret_cast<VerilatorReadFunc>(lib_->GetSymbol("VerilatorRead"));
ICHECK(read_ != nullptr);
add_op_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));

// alloc verilator device
device_ = alloc();

// enable profiler
if (prof_enable_) prof_ = VerilatorProfiler::ThreadLocal();

// reset verilator device.
// reset verilator device
reset(device_, reset_cycles_);

CHECK_EQ(consts.size(), const_idx_.size())
Expand Down Expand Up @@ -136,11 +135,17 @@ void VerilatorRuntime::Run() {
if (node.GetOpType() == "kernel") {
CHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
auto entry = node.GetInputs()[0];
auto shape = node.GetOpShape()[entry.index_];
if ("add" == op_name) {
auto entry = node.GetInputs()[0];
auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
ICHECK(add_op_ != nullptr);
add_op_(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
auto add = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
ICHECK(add != nullptr);
add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
} else if ("nn.bias_add" == op_name) {
auto bias_add =
reinterpret_cast<VerilatorBiasAddFunc>(lib_->GetSymbol("verilator_bias_add"));
ICHECK(bias_add != nullptr);
bias_add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[3], shape[1], shape[2]);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand Down
5 changes: 2 additions & 3 deletions src/runtime/contrib/verilator/verilator_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ using namespace tvm::runtime::json;
typedef VerilatorHandle (*VerilatorAllocFunc)();
typedef void (*VerilatorDeallocFunc)(VerilatorHandle);
typedef void (*VerilatorResetFunc)(VerilatorHandle, int);
typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
typedef int (*VerilatorReadFunc)(VerilatorHandle, int, int);
typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
typedef void (*VerilatorBiasAddFunc)(VerilatorHandle, int*, int*, int*, int, int, int, int);

class VerilatorLibrary : public Library {
public:
Expand Down Expand Up @@ -122,8 +123,6 @@ class VerilatorRuntime : public JSONRuntimeBase {
VerilatorProfiler* prof_{nullptr};
/*! \brief the verilator read function */
VerilatorReadFunc read_{nullptr};
/*! \brief the verilator add op function */
VerilatorAddFunc add_op_{nullptr};
/*! \brief the verilator reset cycles */
int reset_cycles_{1};
/*! \brief the verilator profiler status */
Expand Down
128 changes: 104 additions & 24 deletions tests/python/contrib/test_verilator/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import sys
import subprocess as sp
import json

import tvm
from tvm import relay
Expand Down Expand Up @@ -48,6 +49,10 @@ def _func_wrapper(expr):
return _func_wrapper


_register_verilator_op("add")
_register_verilator_op("nn.bias_add")


def skip_test():
"""Skip test if it requires the Verilator codegen and it's not present."""
if not tvm.get_global_func("relay.ext.verilator", True):
Expand All @@ -59,8 +64,33 @@ def skip_test():
return False


def clear_stats():
"""Clear profiler statistics."""
f = tvm.get_global_func("verilator.profiler_clear", True)
if f:
f()


def stats():
"""Get profiler statistics."""

x = tvm.get_global_func("verilator.profiler_status")()
return json.loads(x)


def offload(mod):
"""Offload ops based on the registered ops"""
"""Offload ops based on the registered ops
Paramters
---------
mod : Module
The input module.
Returns
-------
mod : Module
The output module with offloaded ops.
"""

backend = "verilator"
mod = transform.AnnotateTarget([backend])(mod)
Expand All @@ -69,7 +99,7 @@ def offload(mod):


def verilator_app_path():
"""Find verilator hardware app path"""
"""Create verilator hardware app path."""

cur_dir = os.path.dirname(os.path.realpath(__file__))
return os.path.join(
Expand All @@ -82,37 +112,87 @@ def verilator_app_path():
"vta-hw",
"apps",
"verilator",
"add",
)


def compile_hardware():
"""Compile hardware into shared library"""
def compile_hardware(lanes):
"""Compile hardware into shared library
Paramters
---------
lanes : Int
The number of vector lanes.
Returns
-------
path : Str
The path of the shared library.
"""
lib_name = "libverilator_{}".format(lanes)
lib_name_ext = "{}.so".format(lib_name)
lib = os.path.join(verilator_app_path(), lib_name_ext)
if not os.path.isfile(lib):
opt_lib_name = "LIB_NAME={}".format(lib_name)
opt_lanes = "LANES={}".format(lanes)
cmd = []
cmd.append("make")
cmd.append("--directory")
cmd.append(verilator_app_path())
cmd.append(opt_lib_name)
cmd.append(opt_lanes)
sp.run(cmd, check=True, stdout=sp.DEVNULL)
return lib


cmd = []
cmd.append("make")
cmd.append("--directory")
cmd.append(verilator_app_path())
sp.run(cmd, check=True)
def compiler_opts(lib):
"""Create compiler options
Paramters
---------
lib : Str
The path of the hardware shared library.
def compile_module(mod):
"""Compile Relay module and hardware library"""
Returns
-------
opts : Dict
The compiler options.
"""
opts = {
"lib_path": lib,
"profiler_enable": True,
"profiler_cycle_counter_id": 0,
}
return opts

lib = os.path.join(verilator_app_path(), "libverilator.so")
if not os.path.isfile(lib):
compile_hardware()

opts = {"lib_path": lib}
def run_module(inp, mod, params=None, opts=None):
"""Compile Relay module and hardware library
with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}):
exe = relay.vm.compile(mod, target="llvm", params=None)
code, lib = exe.save()
return runtime.vm.Executable.load_exec(code, lib)
Paramters
---------
inp : Data
The input data.
mod : Module
The relay module.
def run_module(exe, inputs):
"""Run Relay module"""
params : Parameters
The model Parameters.
dev = tvm.cpu()
vm = runtime.vm.VirtualMachine(exe, dev)
return vm.run(**inputs)
opts : Dict
The compiler
Returns
-------
out : Data
The output data.
"""

with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}):
lib = relay.vm.compile(mod, target="llvm", params=params)
code, lib = lib.save()
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
out = vm.run(**inp)
return out
Loading

0 comments on commit dbd076a

Please sign in to comment.