Skip to content

Commit

Permalink
applied black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrikPerssonInceptron committed Oct 29, 2024
1 parent 310026b commit 3d5edc9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 22 deletions.
4 changes: 1 addition & 3 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,9 +1209,7 @@ def _impl_v13(cls, bb, inputs, attr, params):
if isinstance(axis, (tuple, type(None))):
out_data = _np.squeeze(data.data.numpy(), axis)
else:
raise NotImplementedError(
"Squeeze with symbolic axes not supported"
)
raise NotImplementedError("Squeeze with symbolic axes not supported")

return relax.const(out_data, data.struct_info.dtype)

Expand Down
49 changes: 30 additions & 19 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,14 @@ def generate_random_inputs(
shape = []
for dim in i.type.tensor_type.shape.dim:
shape.append(dim.dim_value)

input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type)

return input_values


def generate_random_value(
shape, elem_type
) -> np.ndarray:

def generate_random_value(shape, elem_type) -> np.ndarray:

# Extract datatype for the input.
if elem_type:
dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
Expand All @@ -81,6 +79,7 @@ def generate_random_value(

return random_value


def check_correctness(
model: ModelProto,
inputs: Optional[Dict[str, np.ndarray]] = None,
Expand Down Expand Up @@ -170,7 +169,7 @@ def _check_output(tvm_out, ort_out):

# Check that number of outputs match.
assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
for (tvm_out, ort_out) in zip(tvm_output, ort_output):
for tvm_out, ort_out in zip(tvm_output, ort_output):
# TODO Allow configurable tolerance.
if ort_out is not None:
_check_output(tvm_out, ort_out)
Expand Down Expand Up @@ -227,6 +226,7 @@ def verify_unary(
model = helper.make_model(graph, producer_name="elemwise_test")
check_correctness(model, opset=opset)


def verify_unary_dynamic_shape(
op_name,
shape,
Expand All @@ -246,7 +246,7 @@ def verify_unary_dynamic_shape(
],
outputs=[helper.make_tensor_value_info("y", output_dtype, shape)],
)

model = helper.make_model(graph, producer_name="elemwise_test")
inputs = {"x": generate_random_value(shape_instance, input_dtype)}
check_correctness(model, inputs, opset=opset)
Expand Down Expand Up @@ -1045,11 +1045,14 @@ def test_squeeze(axis):
model = helper.make_model(graph, producer_name="squeeze_test")
check_correctness(model, opset=13)


@pytest.mark.parametrize("axis", [[0, 2], None])
def test_squeeze_constant(axis):
shape = [1, 32, 1, 32]
constant= make_constant_node("x", onnx.TensorProto.FLOAT, shape, rg.standard_normal(size=shape).astype("float32"))
if axis:
constant = make_constant_node(
"x", onnx.TensorProto.FLOAT, shape, rg.standard_normal(size=shape).astype("float32")
)
if axis:
squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
else:
squeeze_node = helper.make_node("Squeeze", ["x"], ["y"])
Expand All @@ -1069,11 +1072,12 @@ def test_squeeze_constant(axis):
model = helper.make_model(graph, producer_name="squeeze_test")
check_correctness(model, opset=13)


@pytest.mark.parametrize("axis", [[0]])
@pytest.mark.parametrize("A", [8, 16, 32])
@pytest.mark.parametrize("B", [8, 16, 32])
def test_dynamic_squeeze(axis, A, B):

squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
shape = [1, "A", "B"]

Expand All @@ -1092,13 +1096,14 @@ def test_dynamic_squeeze(axis, A, B):
)

model = helper.make_model(graph, producer_name="squeeze_test")
inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")}
inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")}
check_correctness(model, inputs, opset=13)


@pytest.mark.parametrize("axis", [[0]])
@pytest.mark.parametrize("A", [8, 16, 32])
def test_dynamic_shape_squeeze(axis, A):

shape_node = helper.make_node("Shape", ["x"], ["y"])
squeeze_node = helper.make_node("Squeeze", ["y", "axes"], ["z"])
shape = ["A"]
Expand All @@ -1118,9 +1123,10 @@ def test_dynamic_shape_squeeze(axis, A):
)

model = helper.make_model(graph, producer_name="squeeze_test")
inputs = {"x": rg.standard_normal(size=[A]).astype("float32")}
inputs = {"x": rg.standard_normal(size=[A]).astype("float32")}
check_correctness(model, inputs, opset=13)


def test_const():
shape = [32, 32]
const_node = helper.make_node(
Expand Down Expand Up @@ -1655,8 +1661,11 @@ def verify_slice(data_shape, output_shape, starts, ends, axes=None, steps=None):
# steps=[-1, -3, -2],
# )


def test_slice_dynamic_shape():
def verify_slice(data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None):
def verify_slice(
data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None
):
if isinstance(starts, list):
starts = np.array(starts, "int64")
if isinstance(ends, list):
Expand All @@ -1678,10 +1687,10 @@ def verify_slice(data_shape, data_instance_shape, output_shape, starts, ends, ax
if steps is not None:
initializer.append(helper.make_tensor("steps", TensorProto.INT64, steps.shape, steps))
slice_inputs.append("steps")

shape_node = helper.make_node("Shape", inputs=["x"], outputs=["y"])
slice_node = helper.make_node("Slice", inputs=slice_inputs, outputs=["z"])

graph = helper.make_graph(
[shape_node, slice_node],
"slice_test",
Expand Down Expand Up @@ -1966,7 +1975,9 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o
if pass_split:
if opset >= 13:
np_split = np.array(split).astype(np.int64)
split_constant= make_constant_node("split", onnx.TensorProto.INT64, list(np_split.shape), np_split)
split_constant = make_constant_node(
"split", onnx.TensorProto.INT64, list(np_split.shape), np_split
)
input_names.append("split")

node = helper.make_node(
Expand Down Expand Up @@ -2398,8 +2409,8 @@ def test_flatten():

def test_flatten_dynamic():
verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 0})
verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": -1})
verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 2})
verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": -1})
verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 2})


def test_onehot():
Expand Down

0 comments on commit 3d5edc9

Please sign in to comment.