Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Onnx] Add SoftmaxCrossEntropyLoss #8906

Merged
merged 72 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
6963dc9
nll loss v1
AndrewZhaoLuo Aug 25, 2021
d6f420f
add converter
AndrewZhaoLuo Aug 25, 2021
83998d0
decode strings in byte form
AndrewZhaoLuo Aug 25, 2021
6c7ec71
decode variable length inputs
AndrewZhaoLuo Aug 25, 2021
1fbc3b7
make shapes correct
Aug 25, 2021
0cec344
unsqueeze
AndrewZhaoLuo Aug 25, 2021
1bd6573
fix
Aug 26, 2021
173054d
proper weight handling
Aug 26, 2021
e69997f
simplify if statement
AndrewZhaoLuo Aug 27, 2021
8949e5f
fix tests
Aug 31, 2021
ed36b75
add comment about tests
Aug 31, 2021
409b8a3
delete extra file
Aug 31, 2021
86229c2
lint
Aug 31, 2021
2a7b2b7
so cool
Aug 31, 2021
0ddfd30
Update CI Lint Image Version (#8841)
mehrdadh Aug 25, 2021
9f501cc
[BUG] ToBasicBlockNormalForm immutability (#8778)
ganler Aug 25, 2021
c757bd2
[GRAPH EXECUTOR,VM] Add benchmarking function to graph executor and v…
Aug 26, 2021
39ca27f
Apply CPPLint to CRT Tests (#8844)
Mousius Aug 26, 2021
3e1187d
[Relay][TOPI] Support of depthwise conv2d NHWC for Mali/Bifrost. (#8584)
AnastasiaStulova Aug 26, 2021
a221761
Support for CMSIS-NN in Corstone300 Makefile (#8831)
ashutosh-arm Aug 26, 2021
c2d1940
[microtvm][Zephyr] Increase timeout to fix flaky tests (#8846)
mehrdadh Aug 26, 2021
73c9f07
[AMP] Bump up tolerance on flaky test (#8850)
AndrewZhaoLuo Aug 26, 2021
2231f0d
[Hexagon] Rework tvm.target.hexagon() interface (#8823)
Aug 26, 2021
9447946
[Pattern matching] Add an option to rewrite the graph only once (#8843)
ekalda Aug 26, 2021
8ca142b
update gpu and cpu (#8853)
mehrdadh Aug 26, 2021
9a68712
VTA cmake change to include Verilator header for building tsim librar…
aasorokiin Aug 26, 2021
c105a1b
[FIX] Bug fix for a floormod rewrite simplify rule (#8852)
jcf94 Aug 26, 2021
d833218
move rust lint script (#8726)
mehrdadh Aug 26, 2021
3862ce6
[AMP] Disallow fp16 conversion for summation-like ops (#8810)
masahi Aug 26, 2021
934b4e5
[TOPI] [Relay] Sparse Conv2d Implementation for 3x3 kernels (#8605)
Tantalus13A98B5F Aug 27, 2021
382b194
extend repeat_interleave op for relay.Expr (#8839)
vvchernov Aug 27, 2021
b5aaa39
Change AOT from ExprVisitor to MixedModeVisitor (#8856)
Mousius Aug 27, 2021
8633399
Add a PaddlePaddle Frontend (#8645)
jiangjiajun Aug 27, 2021
7cce940
[Runtime] add set_output_zero_copy (#8497)
sunjiweiswift Aug 27, 2021
b3d6d78
[Hexagon] Change declaration order of unique_ptr objects to fix crash…
Aug 27, 2021
2501640
[Graph Executor, VM] Add end to end benchmarking of models (#8858)
Aug 27, 2021
d19b66c
[UnitTests] Expose TVM pytest helpers as plugin (#8532)
Lunderberg Aug 27, 2021
5328acb
Remove AOT Executor header from Arduino project (#8857)
Mousius Aug 27, 2021
fb29996
[Community] @mdw-octoml -> Reviewer (#8868)
yzhliu Aug 28, 2021
bbe2998
[TIR] Fix opaque access in buffer locator pass and match_buffer in re…
Hzfengsy Aug 28, 2021
3d43446
[Autoscheduler] Configurable workload keys (#8862)
AndrewZhaoLuo Aug 28, 2021
a5cb1a9
[Tutorial][Executor] Fix the usage of executors in tutorials (#8586)
ganler Aug 28, 2021
d5c699c
[Frontend][Onnx] Simplify onnx input since name accesses are not reli…
Aug 28, 2021
91effea
[TIR] GetBlockReadWriteRegion (#8875)
MasterJH5574 Aug 29, 2021
aad9a88
[RISCV] Add support for llvm parameter -mabi (-target-abi) (#8860)
apivovarov Aug 30, 2021
0ee4c63
[Community] @manupa-arm -> Committer (#8870)
tmoreau89 Aug 30, 2021
6b87fb5
[RPC] Fix ios_rpc build (#8864)
echuraev Aug 31, 2021
b1d9d11
[Vulkan][Target] Added the driver name to the vulkan target string. (…
Lunderberg Aug 31, 2021
dc9b2d7
[ONNX][TOPI] Support select_last_index for argmin/max (#8816)
AndrewZhaoLuo Aug 31, 2021
f6d6229
refactor optimize GEMM on CPU tutorial (#8825)
adstraw Aug 31, 2021
aca2844
Change target string to Target object in the TE compiler and interpre…
electriclilies Aug 31, 2021
6a417c2
[TensorIR][M2a] CacheRead/Write (#8863)
Hzfengsy Aug 31, 2021
1fbb74a
[CI] make pre-commit hooks to run on every push instead of every comm…
mikepapadim Aug 31, 2021
ff8e138
[TVMScript] Fix printing ForNode annotations (#8891)
vinx13 Sep 1, 2021
3655e8e
[1/10] CMSIS-NN graph partitioner for softmax (#8653)
ashutosh-arm Sep 1, 2021
60014c8
[microTVM][RVM] Add Arduino RVM (#8748)
guberti Sep 1, 2021
d5c4113
sce loss example
Sep 1, 2021
3371073
Merge branch 'main' into aluo/onnx/sceloss
Sep 1, 2021
4e43297
Merge branch 'main' into aluo/onnx/sceloss
AndrewZhaoLuo Sep 2, 2021
7d70d63
add comments, remove other tests
Sep 3, 2021
4f138f3
lint
AndrewZhaoLuo Sep 3, 2021
ccc57d4
lint
AndrewZhaoLuo Sep 3, 2021
75da5cc
jostle
Sep 6, 2021
95bdb82
lint up
AndrewZhaoLuo Sep 6, 2021
1c2806c
Merge branch 'aluo/onnx/sceloss' of github.com:AndrewZhaoLuo/tvm into…
AndrewZhaoLuo Sep 6, 2021
d677eb8
jostle
AndrewZhaoLuo Sep 6, 2021
6d3aeb8
Merge branch 'main' into aluo/onnx/sceloss
AndrewZhaoLuo Sep 16, 2021
812acf3
uncomment some tests
AndrewZhaoLuo Sep 16, 2021
4418709
proper return
AndrewZhaoLuo Sep 16, 2021
09e6069
clean up
AndrewZhaoLuo Sep 16, 2021
7d6a90a
lint
AndrewZhaoLuo Sep 16, 2021
d6eace4
minor merge errors
Sep 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 75 additions & 24 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tvm
from tvm import relay
from tvm.ir import IRModule
from tvm.relay.op.tensor import log
from tvm.topi.utils import get_const_tuple

from ... import nd as _nd
Expand Down Expand Up @@ -1839,6 +1840,13 @@ def _impl_v13(cls, inputs, attr, params):
class LogSoftmax(OnnxOpConverter):
"""Operator converter for Softmax."""

@classmethod
def run_calculation(cls, x, axes):
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
s = _op.sum(e, axes, keepdims=True)
return x - m - _op.log(s)

@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axis", 1)
Expand All @@ -1847,10 +1855,7 @@ def _impl_v1(cls, inputs, attr, params):
axis += ndim
axes = list(range(axis, ndim))
x = inputs[0]
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
s = _op.sum(e, axes, keepdims=True)
return x - m - _op.log(s)
return cls.run_calculation(inputs[0], axes)

@classmethod
def _impl_v13(cls, inputs, attr, params):
Expand All @@ -1859,11 +1864,7 @@ def _impl_v13(cls, inputs, attr, params):
if axis < 0:
axis += ndim
axes = [axis]
x = inputs[0]
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
s = _op.sum(e, axes, keepdims=True)
return x - m - _op.log(s)
return cls.run_calculation(inputs[0], axes)


class Hardmax(OnnxOpConverter):
Expand Down Expand Up @@ -3462,23 +3463,14 @@ class NegativeLogLikelihoodLoss(OnnxOpConverter):
VALID_REDUCTIONS = {"mean", "sum", "none"}

@classmethod
def _impl_v13(cls, inputs, attr, params):
ignore_index = attr.get("ignore_index", None)
reduction = attr.get("reduction", b"mean").decode("utf-8")

if reduction not in cls.VALID_REDUCTIONS:
raise ValueError(
f"Unknown reduction type {reduction}, choices are {cls.VALID_REDUCTIONS}"
)

input_tensor, target_tensor = inputs[0], inputs[1]
if len(inputs) == 3:
weight_tensor = inputs[2]
else:
def run_calculation(
cls, input_tensor, target_tensor, weight_tensor=None, ignore_index=None, reduction="none"
):
if weight_tensor is None:
channels = infer_shape(input_tensor)[1]
weight_tensor = relay.ones(
[channels],
dtype=input_tensor.type_annotation.dtype,
dtype=infer_type(input_tensor).checked_type.dtype,
)

loss = -relay.gather(input_tensor, axis=1, indices=relay.expand_dims(target_tensor, 1))
Expand All @@ -3503,15 +3495,73 @@ def _impl_v13(cls, inputs, attr, params):
select_weights *= relay.cast_like(mask_tensor, select_weights)

weight_total = relay.sum(select_weights)
return loss, weight_total

@classmethod
def _impl_v13(cls, inputs, attr, params):
ignore_index = attr.get("ignore_index", None)
reduction = attr.get("reduction", b"mean").decode("utf-8")

if reduction not in cls.VALID_REDUCTIONS:
raise ValueError(
f"Unknown reduction type {reduction}, choices are {cls.VALID_REDUCTIONS}"
)

input_tensor, target_tensor = inputs[0], inputs[1]
if len(inputs) == 3:
weight_tensor = inputs[2]
else:
weight_tensor = None

loss, weight_total = cls.run_calculation(
input_tensor,
target_tensor,
weight_tensor=weight_tensor,
ignore_index=ignore_index,
reduction=reduction,
)
if reduction == "mean":
return relay.sum(loss) / weight_total
if reduction == "sum":
return relay.sum(loss)
return relay.sum(loss), weight_total
# Case reduction == 'none'
return loss


class SoftmaxCrossEntropyLoss(OnnxOpConverter):
"""Operator converter for SCE_loss"""

@classmethod
def _impl_v13(cls, inputs, attr, params):
ignore_index = attr.get("ignore_index", None)
reduction = attr.get("reduction", b"mean").decode("utf-8")
input_tensor, target_tensor = inputs[0], inputs[1]
if len(inputs) == 3:
weight_tensor = inputs[2]
else:
weight_tensor = None

get_log_prob = attr["tvm_custom"]["num_outputs"] == 2
log_softmax_tensor = LogSoftmax.run_calculation(input_tensor, axes=[1])

loss, weight_total = NegativeLogLikelihoodLoss.run_calculation(
log_softmax_tensor,
target_tensor,
weight_tensor,
ignore_index=ignore_index,
reduction="none",
)

if reduction == "mean":
loss = relay.sum(loss) / weight_total
elif reduction == "sum":
loss = relay.sum(loss)

if get_log_prob:
return relay.TupleWrapper(relay.Tuple((loss, log_softmax_tensor)), 2)
return loss


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -3696,6 +3746,7 @@ def _get_convert_map(opset):
"RandomUniform": RandomUniform.get_converter(opset),
# Loss functions
"NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset),
"SoftmaxCrossEntropyLoss": SoftmaxCrossEntropyLoss.get_converter(opset),
}


Expand Down
47 changes: 13 additions & 34 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4803,74 +4803,47 @@ def verify_eyelike(indata):
"test_round",
"test_scan9_sum",
"test_scan_sum",
"test_sce_NCd1_mean_weight_negative_ii",
# With reduce_sum supported fully, these expanded tests should pass
"test_sce_NCd1_mean_weight_negative_ii_expanded",
"test_sce_NCd1_mean_weight_negative_ii_log_prob",
"test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded",
"test_sce_NCd1d2d3_none_no_weight_negative_ii",
"test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded",
"test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob",
"test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded",
"test_sce_NCd1d2d3_sum_weight_high_ii",
"test_sce_NCd1d2d3_sum_weight_high_ii_expanded",
"test_sce_NCd1d2d3_sum_weight_high_ii_log_prob",
"test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded",
"test_sce_NCd1d2d3d4d5_mean_weight",
"test_sce_NCd1d2d3d4d5_mean_weight_expanded",
"test_sce_NCd1d2d3d4d5_mean_weight_log_prob",
"test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded",
"test_sce_NCd1d2d3d4d5_none_no_weight",
"test_sce_NCd1d2d3d4d5_none_no_weight_expanded",
"test_sce_NCd1d2d3d4d5_none_no_weight_log_prob",
"test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded",
"test_sce_mean",
"test_sce_mean_3d",
"test_sce_mean_3d_expanded",
"test_sce_mean_3d_log_prob",
"test_sce_mean_3d_log_prob_expanded",
"test_sce_mean_expanded",
"test_sce_mean_log_prob",
"test_sce_mean_log_prob_expanded",
"test_sce_mean_no_weight_ii",
"test_sce_mean_no_weight_ii_3d",
"test_sce_mean_no_weight_ii_3d_expanded",
"test_sce_mean_no_weight_ii_3d_log_prob",
"test_sce_mean_no_weight_ii_3d_log_prob_expanded",
"test_sce_mean_no_weight_ii_4d",
"test_sce_mean_no_weight_ii_4d_expanded",
"test_sce_mean_no_weight_ii_4d_log_prob",
"test_sce_mean_no_weight_ii_4d_log_prob_expanded",
"test_sce_mean_no_weight_ii_expanded",
"test_sce_mean_no_weight_ii_log_prob",
"test_sce_mean_no_weight_ii_log_prob_expanded",
"test_sce_mean_weight",
"test_sce_mean_weight_expanded",
"test_sce_mean_weight_ii",
"test_sce_mean_weight_ii_3d",
"test_sce_mean_weight_ii_3d_expanded",
"test_sce_mean_weight_ii_3d_log_prob",
"test_sce_mean_weight_ii_3d_log_prob_expanded",
"test_sce_mean_weight_ii_4d",
"test_sce_mean_weight_ii_4d_expanded",
"test_sce_mean_weight_ii_4d_log_prob",
"test_sce_mean_weight_ii_4d_log_prob_expanded",
"test_sce_mean_weight_ii_expanded",
"test_sce_mean_weight_ii_log_prob",
"test_sce_mean_weight_ii_log_prob_expanded",
"test_sce_mean_weight_log_prob",
"test_sce_mean_weight_log_prob_expanded",
"test_sce_none",
"test_sce_none_expanded",
"test_sce_none_log_prob",
"test_sce_none_log_prob_expanded",
"test_sce_none_weights",
"test_sce_none_weights_expanded",
"test_sce_none_weights_log_prob",
"test_sce_none_weights_log_prob_expanded",
"test_sce_sum",
"test_sce_sum_expanded",
"test_sce_sum_log_prob",
"test_sce_sum_log_prob_expanded",
# These sce tests seems to have the same issue as the nll_loss test
# referenced here: https://github.com/apache/tvm/issues/8918 and produce NaNs sometimes
"test_sce_NCd1d2d3_none_no_weight_negative_ii",
"test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob",
"test_sce_NCd1d2d3_sum_weight_high_ii",
"test_sce_NCd1d2d3_sum_weight_high_ii_log_prob",
"test_sequence_insert_at_back",
"test_sequence_insert_at_front",
"test_simple_rnn_defaults",
Expand Down Expand Up @@ -4952,6 +4925,12 @@ def test_onnx_nodes(target, dev, onnx_test):
# for some reason the ONNX test crops the
# roialign results to 4 decimal places
atol = 1e-4

if "_sce_" in test_dir:
# complicated loss functions like SoftmaxCrossEntropy can have minor variations
# in accuracy depending on implementation
atol = 1e-4

onnx_model = onnx.load(test_dir + "/model.onnx")
inputs = []
outputs = []
Expand Down