Skip to content

Commit

Permalink
[microNPU][ETHOSU] MatMul legalization support (#15780)
Browse files Browse the repository at this point in the history
NPU has a restriction that weights must be constant, so the matrix multiplication operation was expressed using split, elementwise multiplication, reduce sum, concatenations operations.
  • Loading branch information
Aleksei-grovety authored Oct 13, 2023
1 parent 71caa19 commit fdfd16c
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 1 deletion.
98 changes: 97 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,101 @@ def callback(self, pre, post, node_map):
return ethosu_fc


class MatMulRewriter(DFPatternCallback):
"""Legalize matrix multiplication to an NPU operator"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.MatMulParams.composite_name})
)(wildcard(), wildcard())

def callback(self, pre, post, node_map):
params = ethosu_patterns.MatMulParams(post.op.body)
ifm = post.args[0]
ifm2 = post.args[1]
lut = relay.const([], dtype="int8")
activation_map = {"clip": "CLIP"}
if params.activation:
activation = activation_map[params.activation.op.name]
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0

# Reshape ifm to NHWC
ifm = relay.reshape(ifm, (1, 1, *params.ifm.shape))
# Split the second matrix to get columns
columns = list(relay.op.split(ifm2, params.ofm.shape[-1], axis=0))

res_columns = []
for column in columns:
ifm2 = relay.reshape(column, (1, 1, 1, params.ifm.shape[-1]))
# Multiplying the first matrix by a column
ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
ifm=ifm,
ifm2=ifm2,
lut=lut,
operator_type="MUL",
ifm_zero_point=int(params.ifm.q_params.zero_point),
ifm_scale=0.0,
ifm2_zero_point=int(params.weights.q_params.zero_point),
ifm2_scale=0.0,
ofm_scale=0.0,
ofm_zero_point=0,
ifm_channels=params.ifm.shape[-1],
ifm2_channels=params.ifm.shape[-1],
reversed_operands=False,
ofm_dtype="int32",
)

# Use reduce sum to get result column
reduce_sum = ethosu_ops.ethosu_pooling(
ifm=ethosu_binary_elementwise,
lut=lut,
pooling_type="SUM",
ifm_zero_point=0,
ifm_scale=float(params.weights.q_params.scale_f32)
* float(params.ifm.q_params.scale_f32),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=0,
pool_shape=(1, 1),
ofm_channels=1,
ofm_dtype="int32",
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
rounding_mode="NATURAL",
)

# Convert tensor dtype from int32 to int8
scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int32"), dtype="int32")
reduce_sum = ethosu_ops.ethosu_binary_elementwise(
ifm=reduce_sum,
ifm2=scalar_tensor,
lut=lut,
operator_type="MUL",
ifm_scale=0.0,
ifm_zero_point=0,
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int8",
)

res_columns.append(reduce_sum)

# Concatenate result columns
concat = relay.op.concatenate(relay.Tuple(res_columns), axis=3)
return relay.reshape(concat, params.ofm.shape)


class PadRewriter(DFPatternCallback):
"""Convert ethos-u.pad2d composite function to ethosu_depthwise_conv2d
operator"""
Expand Down Expand Up @@ -1546,12 +1641,13 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
"""
rewriters = [
PartitionedSplitRewriter(),
FullyConnectedRewriter(),
MatMulRewriter(),
SplitRewriter(),
ChannelPadRewriter(),
Conv2DRewriter(),
Conv2DTransposeRewriter(),
DepthwiseConv2DRewriter(),
FullyConnectedRewriter(),
MaxPoolingRewriter(),
AvgPoolingRewriter(),
PadRewriter(),
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,44 @@ def qnn_fc_pattern():
return optional_clip


class MatMulParams(FullyConnectedParams):
"""
This class will parse a call to an ethos-u.matmul composite
function and extract the parameter information.
"""

composite_name = "ethos-u.matmul"

@requires_vela
def __init__(self, func_body):
FullyConnectedParams.__init__(self, func_body)

def is_valid(self) -> bool:
"""
Checks whether matrix multiplication has compatible attributes with HW
"""

if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
return False
if not len(self.ifm.shape) == 2:
return False
if not len(self.ofm.shape) == 2:
return False
# The weights must be transposed
if self.ifm.shape[1] != self.weights.shape[1]:
return False
return True


def matmul_pattern():
dense = is_op("qnn.dense")(
wildcard(), wildcard(), is_constant(), is_constant(), is_constant(), is_constant()
)
req = is_op("qnn.requantize")(dense, is_constant(), is_constant(), is_constant(), is_constant())
optional_clip = req.optional(is_op("clip"))
return optional_clip


class HardSwishParams:
"""
This class will parse a call to a ethos-u.hard_swish composite function
Expand Down Expand Up @@ -2185,6 +2223,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
qnn_fc_pattern(),
lambda pat: FullyConnectedParams(pat).is_valid(),
),
(
MatMulParams.composite_name,
matmul_pattern(),
lambda pat: MatMulParams(pat).is_valid(),
),
(
MaxPool2DParams.composite_name,
qnn_maxpool2d_pattern(),
Expand Down
24 changes: 24 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,30 @@ def fully_connected(x):
)


@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
@pytest.mark.parametrize("ifm_shape", [(1, 16), (4, 8)])
@pytest.mark.parametrize("ofm_channels", [8, 32])
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
def test_tflite_matmul(
accel_type,
ifm_shape,
ofm_channels,
activation_function,
):
np.random.seed(0)

@tf.function
def matmul(x, y):
x = tf.matmul(x, y, transpose_b=True)
if activation_function == "RELU":
x = tf.nn.relu(x)
return x

infra.compare_tvm_with_tflite(
matmul, [ifm_shape, [ofm_channels, ifm_shape[-1]]], accel_type, enable_cascader=False
)


@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
def test_tflite_subtract_sigmoid(accel_type):
np.random.seed(0)
Expand Down
117 changes: 117 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3806,5 +3806,122 @@ def representative_dataset():
assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d"


def test_tflite_matmul():
ifm_shape = [1, 4]
ifm2_shape = [2, 4]
ifm_shapes = [ifm_shape, ifm2_shape]
ofm_shape = [ifm_shape[0], ifm2_shape[0]]
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def matmul(self, x, y):
res = tf.matmul(x, y, transpose_b=True)
return res

model = Model()
concrete_func = model.matmul.get_concrete_function(
*[tf.TensorSpec(shape, tf.float32) for shape in ifm_shapes]
)
# Convert the model
def representative_dataset():
for _ in range(100):
datas = [np.random.rand(*shape) for shape in ifm_shapes]
yield [data.astype(np.float32) for data in datas]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

def verify(ext_func):
ofm = ext_func.body
ops = []

def _visit(stmt):
if isinstance(stmt, relay.expr.Call):
ops.append(stmt)

relay.analysis.post_order_visit(ofm, _visit)
ofm_checked_type = ofm.checked_type
ofm_channels = ofm_shape[-1]

# check IFM
ifm = ops[1].checked_type
assert list(ifm.shape) == ifm_shape
assert str(ifm.dtype) == dtype

# check IFM2
ifm2 = ops[3].checked_type
assert list(ifm2.shape) == ifm2_shape
assert str(ifm2.dtype) == dtype

# check split
split = ops[4]
split_checked_types = list(split.checked_type.fields)
assert split.op.name == "split"
assert split.attrs.axis == 0
assert int(split.attrs.indices_or_sections) == ofm_channels
for split_checked_type in split_checked_types:
assert list(split_checked_type.shape) == ifm_shape
assert str(split_checked_type.dtype) == dtype

# check MUL
mul_ops = [ops[6], ops[10]]
for mul_op in mul_ops:
assert mul_op.op.name == "contrib.ethosu.binary_elementwise"
assert mul_op.attrs.operator_type == "MUL"
assert mul_op.attrs.ofm_dtype == "int32"

# check reduce sum
reduce_sum_ops = [ops[7], ops[11]]
for reduce_sum_op in reduce_sum_ops:
assert reduce_sum_op.op.name == "contrib.ethosu.pooling"
assert reduce_sum_op.attrs.pooling_type == "SUM"
assert list(reduce_sum_op.checked_type.shape) == [1, 1, 1, 1]

# check concatenation
concatenation = ofm.args[0]
concatenation_shape = concatenation.checked_type.shape
assert concatenation.op.name == "concatenate"
assert list(concatenation_shape) == [1, 1, 1, ofm_channels]

# check OFM
assert ofm.op.name == "reshape"
assert list(ofm_checked_type.shape) == ofm_shape
assert str(ofm_checked_type.dtype) == dtype

matmul_pattern_table = [
(
ethosu.MatMulParams.composite_name,
ethosu.matmul_pattern(),
lambda pat: ethosu.MatMulParams(pat).is_valid(),
)
]

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(ifm_shapes)},
dtype_dict={("ifm" + str(i)): dtype for i, _ in enumerate(ifm_shapes)},
)

mod["main"] = bind_params_by_name(mod["main"], params)
mod = partition_ethosu_by_table(mod, matmul_pattern_table)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.MatMulRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)

verify(mod["tvmgen_default_ethos_u_main_0"])


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit fdfd16c

Please sign in to comment.