Skip to content

Commit

Permalink
[Onnx] Support Negative Log Loss (#8872)
Browse files Browse the repository at this point in the history
* nll loss v1

* add converter

* decode strings in byte form

* decode variable length inputs

* make shapes correct

* unsqueeze

* proper weight handling

* simplify if statement

* fix tests

* add comment about tests

* delete extra file

* lint

* so cool

* jostle ci

Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
  • Loading branch information
AndrewZhaoLuo and Andrew Zhao Luo authored Sep 2, 2021
1 parent eaf888c commit 910b73e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 18 deletions.
59 changes: 59 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import tvm
from tvm import relay
from tvm.ir import IRModule
from tvm.topi.utils import get_const_tuple

Expand Down Expand Up @@ -3454,6 +3455,62 @@ def _impl_v1(cls, inputs, attr, params):
return vals


class NegativeLogLikelihoodLoss(OnnxOpConverter):
"""Operator converter for random_uniform"""

VALID_REDUCTIONS = {"mean", "sum", "none"}

@classmethod
def _impl_v13(cls, inputs, attr, params):
ignore_index = attr.get("ignore_index", None)
reduction = attr.get("reduction", b"mean").decode("utf-8")

if reduction not in cls.VALID_REDUCTIONS:
raise ValueError(
f"Unknown reduction type {reduction}, choices are {cls.VALID_REDUCTIONS}"
)

input_tensor, target_tensor = inputs[0], inputs[1]
if len(inputs) == 3:
weight_tensor = inputs[2]
else:
channels = infer_shape(input_tensor)[1]
weight_tensor = relay.ones(
[channels],
dtype=input_tensor.type_annotation.dtype,
)

loss = -relay.gather(input_tensor, axis=1, indices=relay.expand_dims(target_tensor, 1))
loss = relay.squeeze(loss, axis=[1])

expanded_target_tensor = relay.expand_dims(target_tensor, 0)
expanded_target_tensor = relay.nn.batch_flatten(expanded_target_tensor)
flattened_weights = relay.gather_nd(weight_tensor, expanded_target_tensor)
select_weights = relay.reshape_like(flattened_weights, loss)
loss *= select_weights

if ignore_index is not None:
# "Ignore" values whose target is the ignore_index
mask_tensor = relay.equal(
target_tensor, relay.const(ignore_index, dtype=target_tensor.type_annotation.dtype)
)
mask_tensor = relay.const(1, dtype="int8") - relay.cast(mask_tensor, "int8")
loss *= relay.cast_like(mask_tensor, loss)

# This is not explained super clearly in the onnx spec, but masked values don't
# contribute toward the final value in reduction
select_weights *= relay.cast_like(mask_tensor, select_weights)

weight_total = relay.sum(select_weights)

if reduction == "mean":
return relay.sum(loss) / weight_total
if reduction == "sum":
return relay.sum(loss)
# Case reduction == 'none'
return loss


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -3636,6 +3693,8 @@ def _get_convert_map(opset):
"ConvInteger": ConvInteger.get_converter(opset),
# Random number generation.
"RandomUniform": RandomUniform.get_converter(opset),
# Loss functions
"NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset),
}


Expand Down
19 changes: 1 addition & 18 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4735,41 +4735,24 @@ def verify_eyelike(indata):
"test_momentum_multiple",
"test_mvn",
"test_nesterov_momentum",
"test_nllloss_NC",
# When unsqueeze is fully supported, remaining nllloss tests should work:
"test_nllloss_NC_expanded",
"test_nllloss_NCd1",
"test_nllloss_NCd1_expanded",
"test_nllloss_NCd1_ii",
"test_nllloss_NCd1_ii_expanded",
"test_nllloss_NCd1_mean_weight_negative_ii",
"test_nllloss_NCd1_mean_weight_negative_ii_expanded",
"test_nllloss_NCd1_weight",
"test_nllloss_NCd1_weight_expanded",
"test_nllloss_NCd1_weight_ii",
"test_nllloss_NCd1_weight_ii_expanded",
"test_nllloss_NCd1d2",
"test_nllloss_NCd1d2_expanded",
"test_nllloss_NCd1d2_no_weight_reduction_mean_ii",
"test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded",
"test_nllloss_NCd1d2_reduction_mean",
"test_nllloss_NCd1d2_reduction_mean_expanded",
"test_nllloss_NCd1d2_reduction_sum",
"test_nllloss_NCd1d2_reduction_sum_expanded",
"test_nllloss_NCd1d2_with_weight",
"test_nllloss_NCd1d2_with_weight_expanded",
"test_nllloss_NCd1d2_with_weight_reduction_mean",
"test_nllloss_NCd1d2_with_weight_reduction_mean_expanded",
"test_nllloss_NCd1d2_with_weight_reduction_sum",
"test_nllloss_NCd1d2_with_weight_reduction_sum_expanded",
"test_nllloss_NCd1d2_with_weight_reduction_sum_ii",
"test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded",
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii",
"test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded",
"test_nllloss_NCd1d2d3_sum_weight_high_ii",
"test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded",
"test_nllloss_NCd1d2d3d4d5_mean_weight",
"test_nllloss_NCd1d2d3d4d5_mean_weight_expanded",
"test_nllloss_NCd1d2d3d4d5_none_no_weight",
"test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded",
"test_pow_types_float",
"test_pow_types_float32_int32",
Expand Down

0 comments on commit 910b73e

Please sign in to comment.