Skip to content

Commit

Permalink
[OpenCLML] More ops and network coverage.
Browse files Browse the repository at this point in the history
     * Enabled Resnet & InceptionV3
     * FP16 support added for CLML ops.

Co-authored-by: Krishna Raju <quic_kvegiraj@quicinc.com>
Co-authored-by: Shwetank Singh <quic_shwesing@quicinc.com>
  • Loading branch information
3 people committed Sep 13, 2022
1 parent a23b71c commit 82aae33
Show file tree
Hide file tree
Showing 6 changed files with 526 additions and 108 deletions.
35 changes: 33 additions & 2 deletions python/tvm/relay/op/contrib/clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item
from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item, is_tuple
from .register import register_pattern_table
from ..strategy.generic import is_depthwise_conv2d

Expand Down Expand Up @@ -135,13 +135,15 @@ def conv_pattern():
"""Create a convolution pattern."""
pattern = is_op("nn.conv2d")(wildcard(), is_constant())
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
pattern = pattern.optional(
lambda x: is_op("nn.batch_norm")(
x, is_constant(), is_constant(), is_constant(), is_constant()
)
)
pattern = pattern.optional(is_tuple_get_item)
pattern = pattern.optional(is_op("nn.relu"))
pattern = pattern.optional(is_op("clip"))
return pattern

def batch_norm_pattern():
Expand All @@ -152,10 +154,24 @@ def batch_norm_pattern():
pattern = is_tuple_get_item(pattern)
return pattern

def concat_pattern():
"""Create a concat pattern.
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the concat pattern.
"""
pattern = is_tuple(None)
pattern = is_op("concatenate")(pattern)

return pattern

def dense_pattern():
"""Create a dense pattern."""
pattern = is_op("nn.dense")(wildcard(), is_constant())
pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
return pattern

def pad_pattern():
Expand All @@ -172,6 +188,13 @@ def check_conv(extract):
call = call.args[0]
if isinstance(call, tvm.relay.expr.TupleGetItem):
call = call.tuple_value
elif call.op.name == "clip":
if call.attrs["a_min"] != 0.0 or call.attrs["a_max"] != 6.0:
return False
call = call.args[0]
if isinstance(call, tvm.relay.expr.TupleGetItem):
call = call.tuple_value

while call.op.name != "nn.conv2d":
call = call.args[0]
attrs, args = call.attrs, call.args
Expand All @@ -194,6 +217,7 @@ def check_conv(extract):
("clml.conv2d", conv_pattern(), check_conv),
("clml.dense", dense_pattern()),
("clml.pad", pad_pattern()),
("clml.concat", concat_pattern()),
("clml.batch_norm", batch_norm_pattern()),
]

Expand All @@ -207,11 +231,18 @@ def _func_wrapper(expr):


_register_external_op_helper("clip")
_register_external_op_helper("relu")
_register_external_op_helper("nn.relu")
_register_external_op_helper("nn.global_avg_pool2d")
_register_external_op_helper("nn.global_max_pool2d")
_register_external_op_helper("nn.avg_pool2d")
_register_external_op_helper("nn.max_pool2d")
_register_external_op_helper("nn.softmax")
_register_external_op_helper("reshape")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")
_register_external_op_helper("minimum")
_register_external_op_helper("maximum")


class OpAttrContext(object):
Expand Down
37 changes: 37 additions & 0 deletions src/relay/backend/contrib/clml/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
json_node = CreateDenseJSONNode(cn);
} else if (name == "clml.pad") {
json_node = CreatePadJSONNode(cn);
} else if (name == "clml.concat") {
json_node = CreateConcatJSONNode(cn);
} else {
LOG(FATAL) << "Unrecognized CLML pattern: " << name;
}
Expand Down Expand Up @@ -148,6 +150,15 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
} else {
current_call = current_call->args[0].as<CallNode>();
}
} else if (backend::IsOp(current_call, "clip")) {
nodes.activation = current_call;
nodes.act_type = "relu6";
if (current_call->args[0].as<TupleGetItemNode>()) {
auto tuple_item = current_call->args[0].as<TupleGetItemNode>();
current_call = tuple_item->tuple.as<CallNode>();
} else {
current_call = current_call->args[0].as<CallNode>();
}
}
if (backend::IsOp(current_call, "nn.batch_norm")) {
nodes.bn = current_call;
Expand Down Expand Up @@ -279,6 +290,32 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
return json_node;
}

/*!
* \brief Create a JSON representation of a Concat operator.
*
* \param cn The call to be represented.
* \return A JSON representation of a specific operator.
*/
std::shared_ptr<JSONGraphNode> CreateConcatJSONNode(const CallNode* cn) {
const auto* fn = cn->op.as<FunctionNode>();
ICHECK(fn);
const auto* concat = fn->body.as<CallNode>();

ICHECK(backend::IsOp(concat, "concatenate"));
const auto* concat_op = concat->op.as<OpNode>();
ICHECK(concat_op);
const std::string name = concat_op->name;

std::vector<JSONGraphNodeEntry> inputs;
for (auto arg : cn->args) {
inputs.push_back(VisitExpr(arg)[0]);
}

auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
SetCallNodeAttribute(json_node, concat);
return json_node;
}

/*!
* \brief Create a JSON representation of a Dense operator.
*
Expand Down
Loading

0 comments on commit 82aae33

Please sign in to comment.