Skip to content

Commit

Permalink
[CLML] Version compatibility and various test cases (#13670)
Browse files Browse the repository at this point in the history
* [CLML][TEST] Codegen test cases for ops

Codegen verification test cases for all the ops (convolution, concat, pad, pool ..etc.)
that are supported by clml BYOC path.

Fix depthwise conv2d issue with layout

* * lint errors

* * version compatilibility changes.

* * review comments

* * Make the adreno container compatible w/ and w/o CLML SDK availability

Co-authored-by: Siva Rama Krishna Reddy B <sivb@qti.qualcomm.com>
  • Loading branch information
srkreddy1238 and Siva Rama Krishna Reddy B authored Jan 3, 2023
1 parent e5a7f5f commit b6851f3
Show file tree
Hide file tree
Showing 9 changed files with 482 additions and 92 deletions.
16 changes: 15 additions & 1 deletion cmake/modules/contrib/CLML.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,21 @@ if(USE_CLML)
if(NOT USE_CLML_GRAPH_EXECUTOR)
list(APPEND COMPILER_SRCS ${CLML_RUNTIME_MODULE})
endif()
message(STATUS "Build with CLML support...")
message(STATUS "Build with CLML support : " ${USE_CLML})
if (NOT USE_CLML STREQUAL "ON")
set(CLML_VERSION_HEADER "${USE_CLML}/CL/cl_qcom_ml_ops.h")
if(EXISTS ${CLML_VERSION_HEADER})
file(READ ${CLML_VERSION_HEADER} ver)
string(REGEX MATCH "CL_QCOM_ML_OPS_H_MAJOR_VERSION ([0-9]*)" _ ${ver})
set(CLML_VERSION_MAJOR ${CMAKE_MATCH_1})
else()
set(CLML_VERSION_MAJOR "2")
endif()
else()
set(CLML_VERSION_MAJOR "2")
endif()
add_definitions(-DTVM_CLML_VERSION=${CLML_VERSION_MAJOR})
message(STATUS "CLML SDK Version :" ${CLML_VERSION_MAJOR})
endif()

if(USE_CLML_GRAPH_EXECUTOR)
Expand Down
58 changes: 32 additions & 26 deletions python/tvm/relay/op/contrib/clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
from ..strategy.generic import is_depthwise_conv2d


def clml_sdk_version():
"""Utility function to get clml version version"""

return tvm.support.libinfo().get("TVM_CLML_VERSION", 2)


def is_clml_runtime_enabled():
"""Check if the CLML graph runtime is present.
Expand Down Expand Up @@ -92,38 +98,35 @@ def preprocess_module(mod):
preprocessed_mod : The processed module.
"""

def convert_layout_conv2d(conv2d_function):
def convert_conv(attrs, inputs, tinfos, desired_layouts):
new_attrs = dict(attrs)
data_info = tinfos[0]
weight_info = tinfos[1]
desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
new_attrs["data_layout"] = desired_data_layout
new_attrs["kernel_layout"] = desired_kernel_layout

if is_depthwise_conv2d(
data_info.shape,
attrs["data_layout"],
weight_info.shape,
attrs["kernel_layout"],
attrs["groups"],
):
dkl = desired_kernel_layout
new_attrs["kernel_layout"] = dkl[1] + dkl[0] + dkl[2] + dkl[3]
return conv2d_function(*inputs, **new_attrs)

return convert_conv

with OpAttrContext(
"nn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.nn.conv2d)
):
def alter_conv(attrs, inputs, tinfos, out_type):
new_attrs = dict(attrs)
data_info = tinfos[0]
weight_info = tinfos[1]
(desired_data_layout, desired_kernel_layout) = ("NCHW", "OIHW")
new_attrs["data_layout"] = desired_data_layout
new_attrs["kernel_layout"] = desired_kernel_layout

if is_depthwise_conv2d(
data_info.shape,
attrs["data_layout"],
weight_info.shape,
attrs["kernel_layout"],
attrs["groups"],
):
dkl = desired_kernel_layout
new_attrs["kernel_layout"] = dkl[1] + dkl[0] + dkl[2] + dkl[3]
return relay.nn.conv2d(*inputs, **new_attrs)

with OpAttrContext("nn.conv2d", "FTVMAlterOpLayout", alter_conv):
seq = tvm.transform.Sequential(
[
transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]}),
transform.AlterOpLayout(),
transform.FoldConstant(),
]
)
preprocessed_mod = seq(mod)
with tvm.transform.PassContext(opt_level=3):
preprocessed_mod = seq(mod)
return preprocessed_mod


Expand Down Expand Up @@ -275,6 +278,9 @@ def check_default_op(extract):
("clml.add", is_op("add")(wildcard(), wildcard()), check_binary_op),
("clml.subtract", is_op("subtract")(wildcard(), wildcard()), check_binary_op),
("clml.multiply", is_op("multiply")(wildcard(), wildcard()), check_binary_op),
("clml.divide", is_op("divide")(wildcard(), wildcard()), check_binary_op),
("clml.minimum", is_op("minimum")(wildcard(), wildcard()), check_binary_op),
("clml.maximum", is_op("maximum")(wildcard(), wildcard()), check_binary_op),
("clml.softmax", is_op("nn.softmax")(wildcard()), check_softmax_op),
("clml.reshape", is_op("reshape")(wildcard()), check_default_op),
("clml.avg_pool2d", is_op("nn.avg_pool2d")(wildcard()), check_default_op),
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/clml/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
const auto* dense = fn->body.as<CallNode>();
const CallNode* bias = nullptr;

if (backend::IsOp(dense, "add")) {
if (backend::IsOp(dense, "add") || backend::IsOp(dense, "nn.bias_add")) {
bias = dense;
dense = dense->args[0].as<CallNode>();
}
Expand Down
38 changes: 26 additions & 12 deletions src/runtime/contrib/clml/clml_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,25 @@ class CLMLRuntime : public JSONRuntimeBase {
ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result;

for (cl_uint i = 0; i < numVersions; ++i) {
#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 2
if (majorVersions[i] == 2) {
LOG(WARNING) << "CLML Version Selected:" << majorVersions[i] << " : " << majorVersions[i];
h_ClmlIntf = clGetMLInterfaceV2QCOM(0);
ICHECK(h_ClmlIntf != NULL) << "clGetMLInterfaceV2QCOM:" << result;
LOG(WARNING) << "CLML Target version:" << majorVersions[i];
break;
}
#endif
#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 3
if (majorVersions[i] == 3) {
h_ClmlIntf = clGetMLInterfaceV3QCOM(0);
LOG(WARNING) << "CLML Target version:" << majorVersions[i];
break;
}
#endif
}
ICHECK(h_ClmlIntf != NULL)
<< "clGetMLInterfaceVxQCOM:" << result
<< " Perhaps there is mispatch between CLML SDK version to target supported version:"
<< majorVersions[numVersions - 1];
char* tune_flag;
if ((tune_flag = getenv("CLML_IS_TUNNING_RUN")))
this->is_tuning_run = std::stoi(tune_flag);
Expand Down Expand Up @@ -400,7 +412,7 @@ class CLMLRuntime : public JSONRuntimeBase {
this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
this->layer_.func_outs.push_back(out);
} else if ("add" == op_name || "subtract" == op_name || "multiply" == op_name ||
"minimum" == op_name || "maximum" == op_name) {
"minimum" == op_name || "maximum" == op_name || "divide" == op_name) {
auto out = CreateBinaryLayer(&layer_, node);
this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
this->layer_.func_outs.push_back(out);
Expand Down Expand Up @@ -523,16 +535,15 @@ class CLMLRuntime : public JSONRuntimeBase {
}

cl_ml_tensor_qcom DeviceMakeCLMLTensor(
void* pClmlIntf, cl_context context, tensor_dims_t dims,
cl_context context, tensor_dims_t dims,
cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
cl_channel_type dtype = CL_FLOAT) {
cl_ml_tensor_qcom tensor;
cl_int result = CL_OUT_OF_RESOURCES;

cl_ml_tensor_desc_qcom desc = {
dtype, layout, dims.n, dims.c, dims.h, dims.w, 0, CL_TENSOR_DIMENSIONS_4D_QCOM, { 0 }};
CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast<CLMLInterfaceV2QCOM*>(pClmlIntf);
result = clmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &tensor);
result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &tensor);
ICHECK(tensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result;
(void)result;
return tensor;
Expand All @@ -544,9 +555,8 @@ class CLMLRuntime : public JSONRuntimeBase {
cl_int result = CL_OUT_OF_HOST_MEMORY;
cl_mem buffer = NULL;

CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast<CLMLInterfaceV2QCOM*>(pClmlIntf);
result =
clmlIntf->clGetMLTensorMemorySizeQCOM(workspace->context, pTensorMemDesc->tensor, &size);
h_ClmlIntf->clGetMLTensorMemorySizeQCOM(workspace->context, pTensorMemDesc->tensor, &size);
ICHECK(result == CL_SUCCESS) << "clGetMLTensorMemorySizeQCOM:" << result;

buffer = clCreateBuffer(workspace->context, CL_MEM_READ_WRITE, size, NULL, &result);
Expand Down Expand Up @@ -612,8 +622,7 @@ class CLMLRuntime : public JSONRuntimeBase {
cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);

auto tensor_dsc = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
tensor_dsc->tensor =
DeviceMakeCLMLTensor(h_ClmlIntf, workspace->context, dims, layout, cl_dtype);
tensor_dsc->tensor = DeviceMakeCLMLTensor(workspace->context, dims, layout, cl_dtype);
return tensor_dsc;
}

Expand Down Expand Up @@ -901,7 +910,6 @@ class CLMLRuntime : public JSONRuntimeBase {
auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
cl_dtype);
auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
auto in_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]);

std::vector<std::string> windows = node.GetAttr<std::vector<std::string>>("pool_size");
std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides");
Expand Down Expand Up @@ -1103,7 +1111,6 @@ class CLMLRuntime : public JSONRuntimeBase {
cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
int inputSize = input_.size();
int axis = std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
cl_ml_tensor_qcom* concatInputs = new cl_ml_tensor_qcom[inputSize];
for (int i = 0; i < inputSize; i++) {
Expand Down Expand Up @@ -1236,6 +1243,8 @@ class CLMLRuntime : public JSONRuntimeBase {
binary_op = CL_TENSOR_OP_SUB_QCOM;
else if (op_name == "multiply")
binary_op = CL_TENSOR_OP_MUL_QCOM;
else if (op_name == "divide")
binary_op = CL_TENSOR_OP_DIV_QCOM;
else if (op_name == "minimum")
binary_op = CL_TENSOR_OP_MIN_QCOM;
else if (op_name == "maximum")
Expand All @@ -1260,7 +1269,12 @@ class CLMLRuntime : public JSONRuntimeBase {

CachedLayer layer_;
// CLML Context
#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 2
CLMLInterfaceV2QCOM* h_ClmlIntf = NULL;
#endif
#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 3
CLMLInterfaceV3QCOM* h_ClmlIntf = NULL;
#endif
cl::OpenCLWorkspace* workspace = NULL;
cl::OpenCLThreadEntry* tentry = NULL;
cl_ml_tuningcache_qcom tuning_cache = NULL;
Expand Down
58 changes: 47 additions & 11 deletions tests/python/contrib/test_clml/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ class Device:
Configuration for CLML tests.
Check tests/python/contrib/clml/ for the presence of an test_config.json file.
This file can be used to override the default configuration here which will attempt to run the Arm
Compute Library runtime tests locally if the runtime is available. Changing the configuration
will allow these runtime tests to be offloaded to a remote Arm device via a tracker for example.
This file can be used to override the default configuration here which will attempt to run the
Open CLML runtime tests locally if the runtime is available. Changing the configuration
will allow these runtime tests to be offloaded to a remote Snapdragon device via a tracker for example.
Notes
-----
Expand Down Expand Up @@ -101,6 +101,25 @@ def _get_remote(cls):
return device


def get_cpu_op_count(mod):
"""Traverse graph counting ops offloaded to TVM."""

class Counter(tvm.relay.ExprVisitor):
def __init__(self):
super().__init__()
self.count = 0

def visit_call(self, call):
if isinstance(call.op, tvm.ir.Op):
self.count += 1

super().visit_call(call)

c = Counter()
c.visit(mod["main"])
return c.count


def skip_codegen_test():
"""Skip test if it requires the CLML codegen and it's not present."""
if not tvm.get_global_func("relay.ext.clml", True):
Expand Down Expand Up @@ -130,7 +149,6 @@ def build_and_run(

try:
libm = build_module(mod, device.target, device.target_host, params, enable_clml, tune_log)

clml_modules = extract_clml_modules(libm)
for mod in clml_modules:
source = mod.get_source("json")
Expand All @@ -155,9 +173,9 @@ def build_and_run(
for _ in range(no_runs):
gen_module.run()
out.append([gen_module.get_output(i) for i in range(outputs)])
time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1)
cost = time_f().mean
print("%g secs/iteration\n" % cost)
# time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1)
# cost = time_f().mean
# print("%g secs/iteration\n" % cost)
return out


Expand All @@ -181,16 +199,34 @@ def extract_clml_modules(module):


def verify_codegen(
module,
mod,
known_good_codegen,
device,
params,
num_clml_modules=1,
tvm_ops=0,
target="llvm -mtriple=aarch64-linux-gnu",
):
"""Check clml codegen against a known good output."""
module = build_module(module, target, tvm_ops=tvm_ops, clml_partitions=num_clml_modules)
clml_modules = extract_clml_modules(module)
if isinstance(mod, tvm.relay.expr.Call):
mod = tvm.IRModule.from_expr(mod)
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
mod = clml.partition_for_clml(mod, params)
tvm_op_count = get_cpu_op_count(mod)
assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format(
tvm_op_count, tvm_ops
)
partition_count = 0
for global_var in mod.get_global_vars():
if "clml" in global_var.name_hint:
partition_count += 1

assert (
num_clml_modules == partition_count
), "Got {} Open CLML partitions, expected {}".format(partition_count, num_clml_modules)
relay.backend.te_compiler.get().clear()

module = relay.build(mod, target=device.target, target_host=device.target_host, params=params)
clml_modules = extract_clml_modules(module)
assert len(clml_modules) == num_clml_modules, (
f"The number of CLML modules produced ({len(clml_modules)}) does not "
f"match the expected value ({num_clml_modules})."
Expand Down
15 changes: 4 additions & 11 deletions tests/python/contrib/test_clml/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,8 @@ def get_model():
mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5
)

# test
print("OpenCL:", outputs[0].asnumpy().shape)
print("CLML:", outputs[1].asnumpy().shape)

opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
clml_sort = np.argsort(outputs[0].asnumpy()).flatten()

tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5)


Expand Down Expand Up @@ -134,7 +129,6 @@ def get_model():

opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
clml_sort = np.argsort(outputs[0].asnumpy()).flatten()

tvm.testing.assert_allclose(opencl_sort[:5], clml_sort[:5], rtol=1e-5, atol=1e-5)


Expand Down Expand Up @@ -176,11 +170,10 @@ def get_model():
mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5
)

# test
print("OpenCL:", outputs[0].asnumpy().shape)
print("CLML:", outputs[1].asnumpy().shape)

opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
clml_sort = np.argsort(outputs[0].asnumpy()).flatten()

tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit b6851f3

Please sign in to comment.