Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
* correct mistakes in tutorial
* reshuffle runtime to use fewer macro blocks
* preprocess module using "optimize" functionality
* use new module api

Change-Id: I219488e617e5767edd7489b43b8bfce876cd24b8
  • Loading branch information
lhutton1 committed Jul 16, 2020
1 parent 1af38dd commit 04399f3
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 199 deletions.
19 changes: 11 additions & 8 deletions docs/deploy/arm_compute_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ runtime module on an x86 machine.

These flags can be used in different scenarios depending on your setup. For example, if you want
to compile ACL on an x86 machine and then run the module on a remote Arm device via RPC, you will
need to use USE_ACL=ON on the x86 machine and USE_GRAPH_RUNTIME_ACL=ON on the remote AArch64
device.
need to use USE_ARM_COMPUTE_LIB=ON on the x86 machine and USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME=ON on the remote
AArch64 device.

Usage
-----

*Note:* this section may not stay up-to-date with changes to the API.
.. note::

This section may not stay up-to-date with changes to the API.

Create a relay graph. This may be a single operator or a whole graph. The intention is that any
relay graph can be input. The ACL integration will only pick supported operators to be offloaded
Expand Down Expand Up @@ -84,7 +86,7 @@ Build the Relay graph.
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon"
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, params = relay.build(module, target=target)
lib = relay.build(module, target=target)
Export the module.
Expand All @@ -96,16 +98,17 @@ Export the module.
lib.export_library(lib_path, cc=cross_compile)
Run Inference. This must be on an Arm device. If compiling on x86 device and running on aarch64
Run Inference. This must be on an Arm device. If compiling on x86 device and running on aarch64,
consider using the RPC mechanism.

.. code:: python
tvm.runtime.load_module('lib_acl.so')
gen_module = tvm.contrib.graph_runtime.create(json, lib, ctx)
ctx = tvm.cpu(0)
loaded_lib = tvm.runtime.load_module('lib_acl.so')
gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib['default'](ctx))
d_data = np.random.uniform(0, 1, data_shape).astype(data_type)
map_inputs = {'data': d_data}
gen_module.map_inputs(**map_inputs)
gen_module.set_input(**map_inputs)
gen_module.run()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""ACL library supported operators."""
"""Arm Compute Library supported operators."""
import tvm
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
Expand Down
24 changes: 5 additions & 19 deletions src/relay/backend/contrib/arm_compute_lib/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ std::vector<JSONGraphNodeEntry> ACLJSONSerializer::VisitExpr_(const CallNode* cn
return AddNode(json_node, GetRef<Expr>(cn));
}

std::vector<JSONGraphNodeEntry> ACLJSONSerializer::VisitExpr_(const ConstantNode* cn) {
this->constants_.push_back(cn->data);
return JSONSerializer::VisitExpr_(cn);
}

std::shared_ptr<JSONGraphNode> ACLJSONSerializer::CreateOpJSONNode(const CallNode* cn) {
const auto* op = cn->op.as<OpNode>();
CHECK(op);
Expand Down Expand Up @@ -148,37 +143,28 @@ std::shared_ptr<JSONGraphNode> ACLJSONSerializer::CreateCompositeConvJSONNode(co
return json_node;
}

Array<runtime::NDArray> ACLJSONSerializer::GetParamsData() { return constants_; }

IRModule PreProcessModule(const IRModule& mod) {
IRModule preprocessed_module;
tvm::Map<String, Array<String>> desired_layouts = {
{"nn.conv2d", {String("NHWC"), String("OHWI")}}};
tvm::Map<String, Array<String>> desired_layouts = {{"nn.conv2d", {"NHWC", "OHWI"}}};
preprocessed_module = transform::ConvertLayout(desired_layouts)(mod);
preprocessed_module = transform::FoldConstant()(preprocessed_module);
return preprocessed_module;
}

TVM_REGISTER_GLOBAL("relay.ext.arm_compute_lib.optimize").set_body_typed(PreProcessModule);

runtime::Module ACLCompiler(const ObjectRef& ref) {
CHECK(ref->IsInstance<FunctionNode>()) << "The input ref is expected to be a Relay function.";
Function func = Downcast<Function>(ref);
std::string func_name = backend::GetExtSymbol(func);

IRModule mod;
mod->Add(GlobalVar(func_name), func);
mod = PreProcessModule(mod);

CHECK(mod->functions.size() == 1) << "Module should only contain single function";
Function processed_func = Downcast<Function>(mod->functions.begin().operator*().second);

ACLJSONSerializer serializer(func_name, processed_func);
ACLJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
auto param_names = serializer.GetParams();
auto param_data = serializer.GetParamsData();
const auto* pf = runtime::Registry::Get("runtime.arm_compute_lib_runtime_create");
CHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
runtime::Module lib = (*pf)(func_name, graph_json, param_names, param_data);
runtime::Module lib = (*pf)(func_name, graph_json, param_names);
return lib;
}

Expand Down
15 changes: 1 addition & 14 deletions src/relay/backend/contrib/arm_compute_lib/codegen_acl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,6 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
ACLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {}

std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) override;
std::vector<JSONGraphNodeEntry> VisitExpr_(const ConstantNode* cn) override;

/*!
* \brief Get the constant data transposed when pre-processing the
* input function.
*
* \return An array of constants
*/
Array<runtime::NDArray> GetParamsData();

private:
/*!
Expand All @@ -74,10 +65,6 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
*/
std::shared_ptr<JSONGraphNode> CreateOpJSONNode(const CallNode* cn);
std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn);

/* \brief Transposed constant tensors to serialize. Arm Compute Library expects constant tensors
* in OHWI format. */
Array<runtime::NDArray> constants_;
};

/*!
Expand All @@ -98,7 +85,7 @@ IRModule PreProcessModule(const IRModule& mod);
* one another. Each function consists of serialized JSON describing the sub-graph
* and serialized constant tensors.
*
* \note The ACL runtime module only currently supports a single operator per
* \note The ACL runtime module only supports a single operator per
* sub-graph currently.
*
* \param ref The ext_func Relay expression/module to be executed using extern ops.
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/contrib/arm_compute_lib/acl_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ ACLMemoryRegion::~ACLMemoryRegion() {
std::unique_ptr<arm_compute::IMemoryRegion> ACLMemoryRegion::extract_subregion(size_t offset,
size_t size) {
if (this->ptr_ != nullptr && (offset < _size) && (_size - offset >= size)) {
return arm_compute::support::cpp14::make_unique<ACLMemoryRegion>(
static_cast<uint8_t*>(this->ptr_) + offset, size);
return std::make_unique<ACLMemoryRegion>(static_cast<uint8_t*>(this->ptr_) + offset, size);
} else {
return nullptr;
}
Expand Down
1 change: 0 additions & 1 deletion src/runtime/contrib/arm_compute_lib/acl_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <arm_compute/runtime/IAllocator.h>
#include <arm_compute/runtime/IMemoryRegion.h>
#include <arm_compute/runtime/MemoryRegion.h>
#include <support/ToolchainSupport.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
Expand Down
Loading

0 comments on commit 04399f3

Please sign in to comment.