Skip to content

Commit

Permalink
[Hexagon] Add mobilenet test (apache#11104)
Browse files Browse the repository at this point in the history
* Add mobilenet test on Hexagon

* Address comments

* fix import and remove extra function
  • Loading branch information
mehrdadh authored and Sergey Shtin committed May 17, 2022
1 parent 0c8549f commit f99ef9c
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 40 deletions.
29 changes: 29 additions & 0 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,35 @@ def get_graph_executor(
graph_mod = self.load_module(module_name, session)
return tvm.contrib.graph_executor.create(graph_json, graph_mod, session.device)

def get_graph_debug_executor(
self,
graph_json: str,
module_name: Union[str, pathlib.Path],
session: Session,
dump_root: Union[str, pathlib.Path] = None,
):
"""Create a local GraphModuleDebug which consumes a remote libmod.
Parameters
----------
graph_json : str
The string with the graph JSON.
module_name : str or pathlib.Path
Remote module filename. Same restrictions apply as in load_module().
session : Session
Remote session. The session must be established (via __enter__)
prior to calling this function.
Returns
-------
GraphModuleDebug :
Runtime debug graph module that can be used to debug the graph.
"""
graph_mod = self.load_module(module_name, session)
return tvm.contrib.debugger.debug_executor.create(
graph_json, graph_mod, session.device, dump_root=str(dump_root)
)

def get_aot_executor(self, module_name: Union[str, pathlib.Path], session: Session):
"""Create a local AoTModule which consumes a remote libmod.
Expand Down
75 changes: 48 additions & 27 deletions python/tvm/relay/op/strategy/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from .generic import *
from .. import op as _op


# --- Op strategy registration


Expand All @@ -44,27 +43,49 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
data_layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
groups = attrs.groups
data, kernel = inputs
layout = attrs.data_layout

if groups == 1:
if data_layout == "NHWC" and kernel_layout == "HWIO":
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_topi_schedule(topi.hexagon.schedule_conv2d_nhwc),
name="conv2d_nhwc.hexagon",
)
elif data_layout == "NCHW" and kernel_layout == "OIHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nchw),
wrap_topi_schedule(topi.hexagon.schedule_conv2d_nchw),
name="conv2d_nchw.hexagon",
)
else:
raise RuntimeError(
f"Unsupported layouts: data_layout:{data_layout}, kernel_layout:{kernel_layout}, "
f"groups:{attrs.groups}"
)
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.hexagon.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.hexagon.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.generic",
)
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
raise RuntimeError(f"Unsupported group_conv2d layout {layout}")

if data_layout == "NHWC" and kernel_layout == "HWIO":
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_topi_schedule(topi.hexagon.schedule_conv2d_nhwc),
name="conv2d_nhwc.hexagon",
)
return strategy

if data_layout == "NCHW" and kernel_layout == "OIHW":
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nchw),
wrap_topi_schedule(topi.hexagon.schedule_conv2d_nchw),
name="conv2d_nchw.hexagon",
)
return strategy

raise RuntimeError(
f"Unsupported layouts: data_layout:{data_layout}, kernel_layout:{kernel_layout}, "
f"groups:{attrs.groups}"
)
return strategy


@dense_strategy.register("hexagon")
Expand Down Expand Up @@ -101,16 +122,16 @@ def schedule_adaptive_pool_hexagon(attrs, outs, target):
return topi.hexagon.schedule_adaptive_pool(outs)


@schedule_concatenate.register("hexagon")
def schedule_concatenate_hexagon(attrs, outs, target):
"""Schedule concatenate ops for Hexagon"""
@schedule_injective.register("hexagon")
def schedule_injective_hexagon(attrs, outs, target):
"""Schedule injective ops for Hexagon"""
with target:
return topi.hexagon.schedule_injective(outs)


@schedule_injective.register("hexagon")
def schedule_injective_hexagon(attrs, outs, target):
"""Schedule injective ops for Hexagon"""
@schedule_concatenate.register("hexagon")
def schedule_concatenate_hexagon(attrs, outs, target):
"""Schedule concatenate ops for Hexagon"""
with target:
return topi.hexagon.schedule_injective(outs)

Expand Down
8 changes: 8 additions & 0 deletions python/tvm/topi/hexagon/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@ def schedule_conv2d(outs, layout="NHWC"):
return schedule_conv2d_nchw(outs)

raise ValueError(f"Unexpected layout={layout}")


def schedule_depthwise_conv2d_nchw(outs):
return schedule_conv2d_nchw(outs)


def schedule_depthwise_conv2d_nhwc(out):
return schedule_conv2d_nhwc(out)
8 changes: 8 additions & 0 deletions python/tvm/topi/hexagon/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@ def schedule_injective(outs):

def schedule_softmax(outs):
return schedule_injective(outs)


def schedule_elemwise(outs):
return schedule_injective(outs)


def schedule_broadcast(outs):
return schedule_injective(outs)
13 changes: 0 additions & 13 deletions tests/python/contrib/test_hexagon/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.

import os
import sys
import pytest
import numpy as np
Expand All @@ -24,7 +23,6 @@
from tvm import te
from tvm import relay
from tvm.relay.backend import Executor, Runtime
import tvm.contrib.hexagon as hexagon

from .conftest import requires_hexagon_toolchain

Expand Down Expand Up @@ -256,17 +254,6 @@ def test_graph_executor_multiple_conv2d(hexagon_session):
tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5)


def _workaround_create_aot_shared():
# The C codegen uses TVM/RT functions directly. On Hexagon it should use
# functions pointers via __TVMxyz variables. This workaround makes the
# runtime symbols visible to the compiled shared library.
extra_link_flags = os.environ.get("HEXAGON_SHARED_LINK_FLAGS")
extra_options = str(extra_link_flags).split() if extra_link_flags else []
return lambda so_name, files, hexagon_arch, options: hexagon.create_aot_shared(
so_name, files, hexagon_arch, options=extra_options + options
)


@requires_hexagon_toolchain
def test_aot_executor(hexagon_session, aot_host_target, aot_target):
dtype = "float32"
Expand Down
85 changes: 85 additions & 0 deletions tests/python/contrib/test_hexagon/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import sys
import pytest
import numpy as np

import tvm.testing
from tvm import te
from tvm import relay
from tvm.relay.backend import Executor, Runtime

from .conftest import requires_hexagon_toolchain


@requires_hexagon_toolchain
def test_mobilenet(hexagon_session):
import onnx

dtype = "float32"
model_url = "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx"
model_path = tvm.contrib.download.download_testdata(
model_url, "mobilenetv2-7.onnx", module="onnx"
)
onnx_model = onnx.load(model_path)

target_hexagon = tvm.target.hexagon("v68")
target_llvm = tvm.target.Target("llvm")
runtime = Runtime("cpp")
executor = Executor("graph", {"link-params": True})

data_in = np.random.rand(1, 3, 224, 224).astype(dtype=dtype)

input_name = "input"
shape_dict = {input_name: data_in.shape}
relay_mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
inputs = {input_name: data_in}

with tvm.transform.PassContext(opt_level=3):
hexagon_lowered = tvm.relay.build(
relay_mod,
tvm.target.Target(target_hexagon, host=target_hexagon),
runtime=runtime,
executor=executor,
params=params,
)

llvm_lowered = tvm.relay.build(
relay_mod,
tvm.target.Target(target_llvm, host=target_llvm),
runtime=runtime,
executor=executor,
params=params,
)

graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered)
graph_mod.set_input(**inputs)
graph_mod.run()
hexagon_output = graph_mod.get_output(0).numpy()

llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
llvm_graph_mod.set_input(**inputs)
llvm_graph_mod.run()
expected_output = llvm_graph_mod.get_output(0).numpy()

tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit f99ef9c

Please sign in to comment.