Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Change-Id: I041fda14f3bf9975f3518ba8a4e3ab43ba98403d
  • Loading branch information
lhutton1 committed Jul 14, 2020
1 parent 1ccd355 commit 699b943
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 74 deletions.
9 changes: 5 additions & 4 deletions docs/deploy/arm_compute_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ a performance boost on such devices.

Building with ACL support
-------------------------

The current implementation has two separate build options in cmake. The reason for this split is
because ACL cannot be used on an x86 machine. However, we still want to be able compile an ACL
runtime module on an x86 machine.
Expand All @@ -51,7 +52,7 @@ relay graph can be input. The ACL integration will only pick supported operators
whilst the rest will be computed via TVM. (For this example we will use a single
max_pool2d operator).

..code:: python
.. code:: python
import tvm
from tvm import relay
Expand Down Expand Up @@ -79,7 +80,7 @@ Annotate and partition the graph for ACL.

Build the Relay graph.

..code:: python
.. code:: python
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon"
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
Expand All @@ -88,7 +89,7 @@ Build the Relay graph.
Export the module.

..code:: python
.. code:: python
lib_path = '~/lib_acl.so'
cross_compile = 'aarch64-linux-gnu-c++'
Expand All @@ -98,7 +99,7 @@ Export the module.
Run Inference. This must be on an Arm device. If compiling on x86 device and running on aarch64
consider using the RPC mechanism.

..code:: python
.. code:: python
tvm.runtime.load_module('lib_acl.so')
gen_module = tvm.contrib.graph_runtime.create(json, lib, ctx)
Expand Down
1 change: 1 addition & 0 deletions docs/deploy/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@ target device without relying on RPC. see the following resources on how to do s
android
integrate
hls
arm_compute_lib
8 changes: 4 additions & 4 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .register import register_pattern_table


def is_arm_compute_runtime_present():
def is_arm_compute_runtime_enabled():
"""Check if the ACL graph runtime is present.
Returns
Expand All @@ -43,7 +43,7 @@ def partition_for_arm_compute_lib(mod, params=None):
----------
mod : Module
The module to run passes on.
params : dict[str, NDArray]
params : Optional[Dict[str, NDArray]]
Constant input parameters.
Returns
Expand All @@ -53,15 +53,15 @@ def partition_for_arm_compute_lib(mod, params=None):
if params:
mod['main'] = bind_params_by_name(mod['main'], params)

seq = tvm.transform.Sequential([transform.MergeComposite(pattern_table()),
seq = tvm.transform.Sequential([transform.MergeComposite(arm_compute_lib_pattern_table()),
transform.AnnotateTarget('arm_compute_lib'),
transform.PartitionGraph()])

return seq(mod)


@register_pattern_table("arm_compute_lib")
def pattern_table():
def arm_compute_lib_pattern_table():
"""Get the ACL pattern table."""

def conv_pattern():
Expand Down
22 changes: 17 additions & 5 deletions src/relay/backend/contrib/arm_compute_lib/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ std::vector<JSONGraphNodeEntry> ACLJSONSerializer::VisitExpr_(const CallNode* cn
std::shared_ptr<JSONGraphNode> json_node;

if (cn->op.as<OpNode>()) {
json_node = CreateOp(cn);
json_node = CreateOpJSONNode(cn);
} else if (const auto* fn = cn->op.as<FunctionNode>()) {
auto comp = fn->GetAttr<String>(attr::kComposite);
CHECK(comp.defined()) << "Arm Compute Library JSON runtime only supports composite functions.";
name = comp.value();
if (name == "arm_compute_lib.conv2d") {
json_node = CreateCompositeConvolution(cn);
json_node = CreateCompositeConvJSONNode(cn);
} else {
LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name;
}
Expand All @@ -65,7 +65,7 @@ std::vector<JSONGraphNodeEntry> ACLJSONSerializer::VisitExpr_(const ConstantNode
return JSONSerializer::VisitExpr_(cn);
}

std::shared_ptr<JSONGraphNode> ACLJSONSerializer::CreateOp(const CallNode* cn) {
std::shared_ptr<JSONGraphNode> ACLJSONSerializer::CreateOpJSONNode(const CallNode* cn) {
const auto* op = cn->op.as<OpNode>();
CHECK(op);
const std::string name = op->name;
Expand All @@ -81,7 +81,7 @@ std::shared_ptr<JSONGraphNode> ACLJSONSerializer::CreateOp(const CallNode* cn) {
return json_node;
}

std::shared_ptr<JSONGraphNode> ACLJSONSerializer::CreateCompositeConvolution(const CallNode* cn) {
std::shared_ptr<JSONGraphNode> ACLJSONSerializer::CreateCompositeConvJSONNode(const CallNode* cn) {
const std::string name = "arm_compute_lib.conv2d";
const CallNode* pad = nullptr;
const CallNode* conv;
Expand Down Expand Up @@ -162,7 +162,7 @@ IRModule PreProcessModule(const IRModule& mod) {
runtime::Module ACLCompiler(const ObjectRef& ref) {
CHECK(ref->IsInstance<FunctionNode>()) << "The input ref is expected to be a Relay function.";
Function func = Downcast<Function>(ref);
std::string func_name = GetExtSymbol(func);
std::string func_name = backend::GetExtSymbol(func);

IRModule mod;
mod->Add(GlobalVar(func_name), func);
Expand All @@ -182,6 +182,18 @@ runtime::Module ACLCompiler(const ObjectRef& ref) {
return lib;
}

TVM_REGISTER_GLOBAL("relay.ext.arm_compute_lib").set_body_typed(ACLCompiler);

inline constexpr bool IsACLRuntimeEnabled() {
#if TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
return true;
#else
return false;
#endif
}

TVM_REGISTER_GLOBAL("relay.op.is_arm_compute_runtime_enabled").set_body_typed(IsACLRuntimeEnabled);

} // namespace arm_compute_lib
} // namespace contrib
} // namespace relay
Expand Down
29 changes: 3 additions & 26 deletions src/relay/backend/contrib/arm_compute_lib/codegen_acl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
* \param call The call to be represented.
* \return A JSON representation of a specific operator.
*/
std::shared_ptr<JSONGraphNode> CreateOp(const CallNode* cn);
std::shared_ptr<JSONGraphNode> CreateCompositeConvolution(const CallNode* cn);
std::shared_ptr<JSONGraphNode> CreateOpJSONNode(const CallNode* cn);
std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn);

/* \brief Transposed constant tensors to serialize. Arm Compute Library expects constant tensors
* in OHWI format. */
Expand Down Expand Up @@ -106,34 +106,11 @@ IRModule PreProcessModule(const IRModule& mod);
*/
runtime::Module ACLCompiler(const ObjectRef& ref);

/*!
* \brief Get the external symbol of the Relay function name.
*
* \param func The provided function.
*
* \return An external symbol.
*/
std::string GetExtSymbol(const Function& func) {
const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node.value());
}

TVM_REGISTER_GLOBAL("relay.ext.arm_compute_lib").set_body_typed(ACLCompiler);

/*!
* \brief Check whether ACL graph runtime is used.
* \return True if ACL graph runtime is enabled, False if not.
*/
inline constexpr bool IsACLRuntimeEnabled() {
#if TVM_GRAPH_RUNTIME_ACL
return true;
#else
return false;
#endif
}

TVM_REGISTER_GLOBAL("relay.op.is_arm_compute_runtime_enabled").set_body_typed(IsACLRuntimeEnabled);
inline constexpr bool IsACLRuntimeEnabled();

} // namespace arm_compute_lib
} // namespace contrib
Expand Down
13 changes: 0 additions & 13 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,6 @@ class CSourceModuleCodegenBase {
* \return A runtime module.
*/
virtual runtime::Module CreateCSourceModule(const ObjectRef& ref) = 0;

/*!
* \brief Get the external symbol of the Relay function name.
*
* \param func The provided function.
*
* \return An external symbol.
*/
std::string GetExtSymbol(const Function& func) const {
const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node.value());
}
};

// The base class to generate the declaration functions in C.
Expand Down
13 changes: 0 additions & 13 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,19 +468,6 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
return AddNode(node, GetRef<Expr>(cn));
}
};

/*!
* \brief Get the external symbol of the Relay function name.
*
* \param func The provided function.
*
* \return An external symbol.
*/
std::string GetExtSymbol(const Function& func) {
const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node.value());
}
#endif

/*!
Expand Down
12 changes: 12 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,18 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
return GetRootCall(next_call, depth - 1, expected_op_names);
}

/*!
* \brief Get the external symbol of the Relay function name.
*
* \param func The provided function.
* \return An external symbol.
*/
inline std::string GetExtSymbol(const Function& func) {
const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node.value());
}

} // namespace backend
} // namespace relay
} // namespace tvm
Expand Down
9 changes: 6 additions & 3 deletions src/runtime/contrib/arm_compute_lib/acl_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class ACLRuntime : public JSONRuntimeBase {
this->layer_.function->run();
#else
LOG(FATAL) << "Cannot call run on Arm Compute Library module without runtime enabled. "
<< "Please build with USE_ACL_GRAPH_RUNTIME.";
<< "Please build with USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME.";
#endif
}

Expand Down Expand Up @@ -260,9 +260,14 @@ class ACLRuntime : public JSONRuntimeBase {
}
}

bool found_kernel_node = false;
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (found_kernel_node) {
LOG(FATAL) << "Arm Compute Library runtime module only supports one kernel node per function.";
}
if (node.GetOpType() == "kernel") {
found_kernel_node = true;
auto op_name = node.GetOpName();
if ("nn.conv2d" == op_name || "arm_compute_lib.conv2d" == op_name) {
CreateConvolution2DLayer(&layer_, node, mm);
Expand All @@ -274,8 +279,6 @@ class ACLRuntime : public JSONRuntimeBase {
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
// Only expect one op for the time being
break;
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/runtime/contrib/arm_compute_lib/acl_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeMemoryManager() {

arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
const std::vector<std::string>& stride) {
int pad_0, pad_1, pad_2, pad_3;

int pad_0 = 0, pad_1 = 0, pad_2 = 0, pad_3 = 0;
int stride_0 = std::stoi(stride[0]), stride_1 = std::stoi(stride[1]);
size_t size = pad.size();
if (size == 1) {
int pad_v = std::stoi(pad[0]);
Expand All @@ -103,11 +103,11 @@ arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
pad_3 = std::stoi(pad[2]);
} else {
LOG(FATAL) << "Unsupported padding dimensions";
return arm_compute::PadStrideInfo();
}

return arm_compute::PadStrideInfo(std::stoi(stride[0]), std::stoi(stride[1]), pad_0, pad_1, pad_2,
pad_3, arm_compute::DimensionRoundingType::FLOOR);
return arm_compute::PadStrideInfo(stride_0, stride_1,
pad_0, pad_1, pad_2, pad_3,
arm_compute::DimensionRoundingType::FLOOR);
}

} // namespace arm_compute_lib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def skip_runtime_test():
return True

# Remote device is in use or ACL runtime not present
if not Device.use_remote and not arm_compute_lib.is_arm_compute_runtime_present():
if not Device.use_remote and not arm_compute_lib.is_arm_compute_runtime_enabled():
print("Skip because runtime isn't present or a remote device isn't being used.")
return True

Expand Down

0 comments on commit 699b943

Please sign in to comment.