diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1a063c2fa806..611f4348d55e 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1209,9 +1209,7 @@ def _impl_v13(cls, bb, inputs, attr, params): if isinstance(axis, (tuple, type(None))): out_data = _np.squeeze(data.data.numpy(), axis) else: - raise NotImplementedError( - "Squeeze with symbolic axes not supported" - ) + raise NotImplementedError("Squeeze with symbolic axes not supported") return relax.const(out_data, data.struct_info.dtype) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 050f6ca933aa..9faa441138fc 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -52,16 +52,14 @@ def generate_random_inputs( shape = [] for dim in i.type.tensor_type.shape.dim: shape.append(dim.dim_value) - + input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type) return input_values -def generate_random_value( - shape, elem_type -) -> np.ndarray: - +def generate_random_value(shape, elem_type) -> np.ndarray: + # Extract datatype for the input. if elem_type: dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type]) @@ -81,6 +79,7 @@ def generate_random_value( return random_value + def check_correctness( model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None, @@ -170,7 +169,7 @@ def _check_output(tvm_out, ort_out): # Check that number of outputs match. assert len(tvm_output) == len(ort_output), "Unequal number of outputs" - for (tvm_out, ort_out) in zip(tvm_output, ort_output): + for tvm_out, ort_out in zip(tvm_output, ort_output): # TODO Allow configurable tolerance. if ort_out is not None: _check_output(tvm_out, ort_out) @@ -227,6 +226,7 @@ def verify_unary( model = helper.make_model(graph, producer_name="elemwise_test") check_correctness(model, opset=opset) + def verify_unary_dynamic_shape( op_name, shape, @@ -246,7 +246,7 @@ def verify_unary_dynamic_shape( ], outputs=[helper.make_tensor_value_info("y", output_dtype, shape)], ) - + model = helper.make_model(graph, producer_name="elemwise_test") inputs = {"x": generate_random_value(shape_instance, input_dtype)} check_correctness(model, inputs, opset=opset) @@ -1045,11 +1045,14 @@ def test_squeeze(axis): model = helper.make_model(graph, producer_name="squeeze_test") check_correctness(model, opset=13) + @pytest.mark.parametrize("axis", [[0, 2], None]) def test_squeeze_constant(axis): shape = [1, 32, 1, 32] - constant= make_constant_node("x", onnx.TensorProto.FLOAT, shape, rg.standard_normal(size=shape).astype("float32")) - if axis: + constant = make_constant_node( + "x", onnx.TensorProto.FLOAT, shape, rg.standard_normal(size=shape).astype("float32") + ) + if axis: squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) else: squeeze_node = helper.make_node("Squeeze", ["x"], ["y"]) @@ -1069,11 +1072,12 @@ def test_squeeze_constant(axis): model = helper.make_model(graph, producer_name="squeeze_test") check_correctness(model, opset=13) + @pytest.mark.parametrize("axis", [[0]]) @pytest.mark.parametrize("A", [8, 16, 32]) @pytest.mark.parametrize("B", [8, 16, 32]) def test_dynamic_squeeze(axis, A, B): - + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) shape = [1, "A", "B"] @@ -1092,13 +1096,14 @@ def test_dynamic_squeeze(axis, A, B): ) model = helper.make_model(graph, producer_name="squeeze_test") - inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")} + inputs = {"x": rg.standard_normal(size=[1, A, B]).astype("float32")} check_correctness(model, inputs, opset=13) + @pytest.mark.parametrize("axis", [[0]]) @pytest.mark.parametrize("A", [8, 16, 32]) def test_dynamic_shape_squeeze(axis, A): - + shape_node = helper.make_node("Shape", ["x"], ["y"]) squeeze_node = helper.make_node("Squeeze", ["y", "axes"], ["z"]) shape = ["A"] @@ -1118,9 +1123,10 @@ def test_dynamic_shape_squeeze(axis, A): ) model = helper.make_model(graph, producer_name="squeeze_test") - inputs = {"x": rg.standard_normal(size=[A]).astype("float32")} + inputs = {"x": rg.standard_normal(size=[A]).astype("float32")} check_correctness(model, inputs, opset=13) + def test_const(): shape = [32, 32] const_node = helper.make_node( @@ -1655,8 +1661,11 @@ def verify_slice(data_shape, output_shape, starts, ends, axes=None, steps=None): # steps=[-1, -3, -2], # ) + def test_slice_dynamic_shape(): - def verify_slice(data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None): + def verify_slice( + data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None + ): if isinstance(starts, list): starts = np.array(starts, "int64") if isinstance(ends, list): @@ -1678,10 +1687,10 @@ def verify_slice(data_shape, data_instance_shape, output_shape, starts, ends, ax if steps is not None: initializer.append(helper.make_tensor("steps", TensorProto.INT64, steps.shape, steps)) slice_inputs.append("steps") - + shape_node = helper.make_node("Shape", inputs=["x"], outputs=["y"]) slice_node = helper.make_node("Slice", inputs=slice_inputs, outputs=["z"]) - + graph = helper.make_graph( [shape_node, slice_node], "slice_test", @@ -1966,7 +1975,9 @@ def verify_split(indata_shape, outdata_shapes, split, axis=0, pass_split=True, o if pass_split: if opset >= 13: np_split = np.array(split).astype(np.int64) - split_constant= make_constant_node("split", onnx.TensorProto.INT64, list(np_split.shape), np_split) + split_constant = make_constant_node( + "split", onnx.TensorProto.INT64, list(np_split.shape), np_split + ) input_names.append("split") node = helper.make_node( @@ -2398,8 +2409,8 @@ def test_flatten(): def test_flatten_dynamic(): verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 0}) - verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": -1}) - verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 2}) + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": -1}) + verify_unary_dynamic_shape("Flatten", [1, "A", "B", 32], [1, 3, 32, 32], attrs={"axis": 2}) def test_onehot():