Skip to content

Commit

Permalink
Add ONNX LinearRegressor operator support (#10477)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xingyu Zhou authored Mar 10, 2022
1 parent 1f60529 commit 48793f3
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
31 changes: 31 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3444,6 +3444,35 @@ def body_fn(*loop_inputs):
return outputs


class LinearRegressor(OnnxOpConverter):
"""Operator converter for LinearRegressor."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
coefficients = attr.get("coefficients", 0)
data_shape = infer_shape(data)
targets = attr.get("targets", 1)
coefficients = _expr.const(list(coefficients), dtype="float32")
coefficients_shape = infer_shape(coefficients)

coefficients = _op.reshape(coefficients, (targets, coefficients_shape[0] // targets))
if coefficients_shape[0] // targets < data_shape[-1]:
data = _op.split(data, [coefficients_shape[0] // targets], -1)[0]

mm_out = _op.nn.dense(data, coefficients)

if "intercepts" in attr:
intercepts = attr.get("intercepts", 0)
intercepts = _expr.const(list(intercepts), dtype="float32")

if targets == 1:
return _op.nn.bias_add(mm_out, intercepts, axis=-1)
return get_relay_op("add")(mm_out, intercepts)

return mm_out


class NonMaxSuppression(OnnxOpConverter):
"""Operator converter for NonMaxSuppression."""

Expand Down Expand Up @@ -4770,6 +4799,8 @@ def _get_convert_map(opset):
"Adam": Adam.get_converter(opset),
"Momentum": Momentum.get_converter(opset),
"Scan": Scan.get_converter(opset),
# ML
"LinearRegressor": LinearRegressor.get_converter(opset),
}


Expand Down
44 changes: 44 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6275,6 +6275,49 @@ def verify_scan(
verify_scan(input_shapes, output_shapes, 3, [-4, -1, -2], [1] * 3, [-3, -2], [1] * 2, 9)


@tvm.testing.parametrize_targets
def test_LinearRegressor(target, dev):
def verify_LinearRegressor(a_shape, c_shape, i_shape, targets=1, batch=1):
a_array = np.random.uniform(size=a_shape).astype("float32")
out_shape = (batch, targets)

coefficients = np.random.uniform(size=c_shape).astype("float32")
intercepts = np.random.uniform(size=i_shape).astype("float32")

mul_node = helper.make_node(
"LinearRegressor",
["a"],
["out"],
coefficients=coefficients,
intercepts=intercepts,
targets=targets,
domain="ai.onnx.ml",
)

graph = helper.make_graph(
[mul_node],
"LinearRegressor_test",
inputs=[
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)],
)
model = helper.make_model(
graph,
producer_name="LinearRegressor_test",
opset_imports=[
onnx.helper.make_opsetid("ai.onnx.ml", 1),
],
)
verify_with_ort_with_inputs(model, [a_array], target=target, dev=dev)

verify_LinearRegressor((1, 3), (3), (1))
verify_LinearRegressor((2, 10), (10), (1), batch=2)
verify_LinearRegressor((1, 3), (30), (10), targets=10)
verify_LinearRegressor((10, 3), (30), (10), targets=10, batch=10)
verify_LinearRegressor((1, 4), (3), (1))


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -6371,3 +6414,4 @@ def verify_scan(
test_random_uniform_like()
test_random_normal()
test_random_normal_like()
test_LinearRegressor()

0 comments on commit 48793f3

Please sign in to comment.