Skip to content

Commit

Permalink
[Frontend]Make onnx gemm tensor C optional (apache#7489)
Browse files Browse the repository at this point in the history
* Make onnx gemm tensor C optional

* fix codestyle

* add tests

* fix codestyle
  • Loading branch information
xutianming authored and Lokiiiiii committed Mar 1, 2021
1 parent e71615d commit 1b4bcbe
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
13 changes: 9 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,9 @@ class Gemm(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(len(inputs))
assert len(inputs) == 3 or len(inputs) == 2, "Gemm op take 2 or 3 inputs, {} given".format(
len(inputs)
)
# Y = alpha * A * B + beta * C
alpha = float(attr.get("alpha", 1.0))
beta = float(attr.get("beta", 1.0))
Expand All @@ -531,9 +533,12 @@ def _impl_v1(cls, inputs, attr, params):
inputs[0] *= _expr.const(alpha)
out = _op.nn.dense(inputs[0], inputs[1], units=channels)

# skip (beta * C) if zero
C_array = params[inputs[2].name_hint].asnumpy()
if (beta == 0.0) or np.array_equal(C_array, np.array([0])):
if len(inputs) == 3:
# skip (beta * C) if zero
C_array = params[inputs[2].name_hint].asnumpy()
if (beta == 0.0) or np.array_equal(C_array, np.array([0])):
return out
else:
return out
return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])

Expand Down
26 changes: 26 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,31 @@ def test_onehot():
tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)


@tvm.testing.uses_gpu
def test_gemm():
a_shape = (4, 3)
b_shape = (3, 4)
out_shape = [a_shape[0], b_shape[1]]

a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")

gemm_node = helper.make_node("Gemm", ["a", "b"], ["out"])

graph = helper.make_graph(
[gemm_node],
"gemm_test",
inputs=[
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
)

model = helper.make_model(graph, producer_name="gemm_test")
verify_with_ort_with_inputs(model, [a_array, b_array])


@tvm.testing.uses_gpu
def test_matmul():
a_shape = (4, 3)
Expand Down Expand Up @@ -4065,6 +4090,7 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
test_clip()
test_clip_min_max_as_inputs()
test_onehot()
test_gemm()
test_matmul()
test_gather()
test_gatherelements()
Expand Down

0 comments on commit 1b4bcbe

Please sign in to comment.