Skip to content

Commit

Permalink
rebase and add tests for expand dims and squeeze
Browse files Browse the repository at this point in the history
Change-Id: Ic6a9fd77b61368720328bfe82032490bcc66152c
  • Loading branch information
lhutton1 committed Feb 8, 2022
1 parent 4b87076 commit 9adacaa
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 1 deletion.
27 changes: 26 additions & 1 deletion tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,28 @@ def create_model():
_compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape,axis", [((2,), 0), ((1, 3, 3), 2)])
def test_tflite_expand_dims(accel_type, ifm_shape, axis):
@tf.function
def expand_dims_func(x):
return tf.expand_dims(x, axis=axis)

_compare_tvm_with_tflite(expand_dims_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,axis", [((1, 1, 2, 1), 0), ((1, 3, 3, 1), 3), ((1, 1, 2, 1), None)]
)
def test_tflite_squeeze(accel_type, ifm_shape, axis):
@tf.function
def squeeze_func(x):
return tf.squeeze(x, axis=axis)

_compare_tvm_with_tflite(squeeze_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,size",
Expand Down Expand Up @@ -1115,7 +1137,10 @@ def test_tflite_pack(accel_type, ifm_shapes, axis):
def pack_func(*inputs):
return tf.stack(inputs, axis=axis)

_compare_tvm_with_tflite(pack_func, ifm_shapes, accel_type)
# TODO(lhutton1) For now output is not bit exact with TFLite.
# This is because TFLite reference kernels are not being used.
# For this, TFLite will need upgrading to 2.6.
_compare_tvm_with_tflite(pack_func, ifm_shapes, accel_type, output_tolerance=1)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
Expand Down
141 changes: 141 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,147 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize("ifm_shape,axis", [((2,), 0), ((1, 3, 3), 2)])
def test_tflite_expand_dims(ifm_shape, axis):
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def tf_function(self, x):
return tf.expand_dims(x, axis=axis)

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, tf.float32)
)

def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()

return tflite_model

def verify(ext_func):
op = ext_func.body
expected_shape = list(ifm_shape)
expected_shape.insert(axis, 1)

# Check IFM
assert list(op.args[0].checked_type.shape) == list(ifm_shape)
assert op.args[0].checked_type.dtype == dtype

# Check OFM
assert list(op.checked_type.shape) == expected_shape
assert op.checked_type.dtype == dtype

# Check op
assert op.op.name == "reshape"

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, _ = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)
mod = ethosu.partition_for_ethosu(mod)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.ExpandDimsRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.ReshapeRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)
mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[
"tvmgen_default_ethos_u_main_0"
]
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize(
"ifm_shape,axis", [((1, 1, 2, 1), 0), ((1, 3, 3, 1), 3), ((1, 1, 2, 1), None)]
)
def test_tflite_squeeze(ifm_shape, axis):
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def tf_function(self, x):
return tf.squeeze(x, axis=axis)

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, tf.float32)
)

def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()

return tflite_model

def verify(ext_func):
op = ext_func.body
expected_shape = list(ifm_shape)
if isinstance(axis, int):
expected_shape = ifm_shape[:axis] + ifm_shape[axis + 1 :]
else:
expected_shape = list(filter(lambda a: a != 1, expected_shape))

# Check IFM
assert list(op.args[0].checked_type.shape) == list(ifm_shape)
assert op.args[0].checked_type.dtype == dtype

# Check OFM
assert list(op.checked_type.shape) == list(expected_shape)
assert op.checked_type.dtype == dtype

# Check op
assert op.op.name == "reshape"

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, _ = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)
mod = ethosu.partition_for_ethosu(mod)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.SqueezeRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.ReshapeRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)
mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[
"tvmgen_default_ethos_u_main_0"
]
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize(
"ifm_shape,size",
[
Expand Down

0 comments on commit 9adacaa

Please sign in to comment.