Skip to content

Commit

Permalink
[Onnx] Add SoftmaxCrossEntropyLoss (apache#8906)
Browse files Browse the repository at this point in the history
* nll loss v1

* add converter

* decode strings in byte form

* decode variable length inputs

* make shapes correct

* unsqueeze

* proper weight handling

* simplify if statement

* fix tests

* add comment about tests

* delete extra file

* lint

* so cool

* Update CI Lint Image Version (apache#8841)

* Update CI Lint Image Version

* trigger

* [BUG] ToBasicBlockNormalForm immutability (apache#8778)

* ToBasicBlockNormalForm immutability

* better comment on ToBasicBlock

* refine comment of ToBasicBlockForm

* [GRAPH EXECUTOR,VM] Add benchmarking function to graph executor and vm (apache#8807)

* [GRAPH EXECUTOR,VM] Add benchmarking function to graph executor and vm

This new benchmarking function is just a convenience function for
calling time_evaluator on the underlying module. Hopefully this should
make it easier for users to get good benchmarks of their code.

* formatting

* import order

* more test, more comments, more precision

* fix tests

* add seconds descriptions to doc

* Apply CPPLint to CRT Tests (apache#8844)

This one was a bit trickier as there was more usage of dynamic arrays and less safe casts. I've tried to minimise the changes to just those required to passing linting.

* [Relay][TOPI] Support of depthwise conv2d NHWC for Mali/Bifrost. (apache#8584)

* [Relay][TOPI] Support of depthwise conv2d NHWC for Mali/Bifrost.

Added initial tunable autotvm templates for depthwise conv2d with
NHWC layout for Mali and Bifrost.

* [Relay][TOPI] Misc fixes for depthwise conv2d Mali/Bifrost.

- Fix assert for Bifrost.
- Set reasonable default axis splits to avoid using tophub for NHWC.
- Fixed typo: arm cpu -> Mali.

* [Relay][TOPI] Fixed formatting in depthwise conv2d Mali/Bifrost.

* Support for CMSIS-NN in Corstone300 Makefile (apache#8831)

Change-Id: Ifc2305db4e11d1d15d45407287f8f0bea469100a

* [microtvm][Zephyr] Increase timeout to fix flaky tests (apache#8846)

* increase timeout

* trigger

* [AMP] Bump up tolerance on flaky test (apache#8850)

* bumpy up tol

* bumped tolerance up even more

* jostle ci

* [Hexagon] Rework tvm.target.hexagon() interface (apache#8823)

* [Hexagon] Rework tvm.target.hexagon() interface

Make the tvm.target.hexagon() function take most options as keyword
parameters. This will allow adding additional parameters without changing
the interface.

No changes are required to existing code, except for changing positional
parameters following the CPU version to keyword parameters, and updating
the names of the keyword parameters:
  sim_args  -> sim_options,
  llvm_args -> llvm_options,
although the old names will be accepted for the time being.

* formatting

* change ' to "

* Rename 'args' to 'config' for clarity

* Use 'strip' instad of 'replace'

* Restart build

* [Pattern matching] Add an option to rewrite the graph only once (apache#8843)

* [Pattern matching] Add an option to rewrite the graph only once

If the graph returned from the callback consists of the original
pattern, the rewriter will run in the loop, which is not always desired.
So this patch proposes an option to run the rewriter only once.

Change-Id: I85cf0a055b8961d52394f21c1e4d7aad0a7e1d06

* Make rewrite_once default to false

Change-Id: Idf6f01f254c403158883681e75c2a5978efbd2d0

* update gpu and cpu (apache#8853)

* VTA cmake change to include Verilator header for building tsim library (apache#8797)

* VTA cmake file require Verilator include for tsim target. VTA module.cc uses svOpenArrayHandle to send wide data through DPI

* Refactor Verialtor check conditions

* Build TSIM only for CPU target. CPU target don't use -Werror to compile with Verilator. Jenkinsfile to have tvm_multilib_tsim defined for CPU build target.

* remove build/libvta_tsim.so from non tsim targeting builds

* Revert to enable TSIM build i386. Revert to -Werror in CPU config. Remove verilator CPP objects from cmake config for tsim and put them as include into vta module.cc to avoid Verilator compilation warnings

* [FIX] Bug fix for a floormod rewrite simplify rule (apache#8852)

* Update rewrite_simplify.cc

* Update test_arith_rewrite_simplify.py

* Update test_arith_rewrite_simplify.py

* Update test_arith_rewrite_simplify.py

* move rust lint script (apache#8726)

* [AMP] Disallow fp16 conversion for summation-like ops (apache#8810)

* [AMP] Disallow fp16 conversion for summation-like ops

* test only structural equality

* [TOPI] [Relay] Sparse Conv2d Implementation for 3x3 kernels (apache#8605)

* [topi] add spconv2d_3x3 nhwc

* [relay] sparse_conv2d: add kernel_size attr

* [relay] add strategy for spconv2d_3x3 nhwc

* [relay] pass to convert spconv2d with const args

* [relay] convert sparse conv2d pass fixes

* use array for sparse conv2d attr

* fixup 1x1 tests; new 3x3 tests

* extend repeat_interleave op for relay.Expr (apache#8839)

Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>

* Change AOT from ExprVisitor to MixedModeVisitor (apache#8856)

This should allow better scale-ability for AOT when targeting larger networks.

* Add a PaddlePaddle Frontend (apache#8645)

* fix some problems for matmul

* fix some problems for matmul

* add alpha parameter for matmul

* remove unnecessary condition

* add TranslatedLayer which support model loaded by jit.load

* add mul operator support

* Add padding mode support for conv/pool2d

* support 4 two-tuples

* add paddle test case

* add paddle conv2d  case

* update test_forward.py

* fix paddle convert_matmul

* add paddle multiply and matmul op test case

* add test case and fix bug

* delete import pandas

* add paddlepaddle tests

* modify the variable name of convert_reshape

* formatting

* formatting

* use black to format python code

* pylint check

* Remove fluid api

* black format

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: wjj19950828 <wjjisloser@163.com>
Co-authored-by: heliqi <1101791222@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>

* [Runtime] add set_output_zero_copy (apache#8497)

* Update graph_executor.h

* Update graph_executor.cc

* modify zero copy UT add set input zero copy

* modify C style

* add runtime test

* realy build  generatr the json

Co-authored-by: hwstaff <hwstaff@hwstaffdeMacBook-Pro.local>

* [Hexagon] Change declaration order of unique_ptr objects to fix crash (apache#8859)

A crash occurs when automatically deleting an instance of
CodeGenHexagon because the LLVMContext object has already been
freed. Objects of both types are created using unique_ptr, but
the object managed by the LLVMContext unique_ptr is passed to
CodeGenHexagon object (not as a unique_ptr).

This crash is fixed by moving the declaration of the LLVMContext
object before the CodeGenHexagon object. I'm not sure if this
is the best way to fix this, but it does fix the crash. Also,
in other files, the LLVMContext object is always created first.

Co-authored-by: Cahoon, Brendon <bcahoon@quicinc.com>

* [Graph Executor, VM] Add end to end benchmarking of models (apache#8858)

Add benchmarking that includes ovearhead of transfering inputs and
outputs to and from the device. This should give an accurate measurement
of the runtime a user would see when using the model. This is
accomplished by adding functions that run from inputs to return values
into the graph executor and the VM.

* [UnitTests] Expose TVM pytest helpers as plugin (apache#8532)

* [UnitTests] Expose TVM pytest helpers as plugin

Previously, pytest helper utilities such as automatic parametrization
of `target`/`dev`, or `tvm.testing.parameter` were only available for
tests within the `${TVM_HOME}/tests` directory.  This PR extracts the
helper utilities into an importable plugin, which can be used in
external tests (e.g. one-off debugging).

* [UnitTests] Refactor the plugin-specific logic out into plugin.py.

* [UnitTests] Moved marker definition out to global variable.

* Remove AOT Executor header from Arduino project (apache#8857)

* [Community] @mdw-octoml -> Reviewer (apache#8868)

* [TIR] Fix opaque access in buffer locator pass and match_buffer in region detector (apache#8855)

* init

* fix

* Update src/tir/transforms/plan_update_buffer_allocation_location.cc

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* Update src/tir/transforms/plan_update_buffer_allocation_location.cc

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* address

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* [Autoscheduler] Configurable workload keys (apache#8862)

* change workload keys

* remove binary string comparison

* append the tuple not every integer

* clean up

* lint

* dump workload keys to dags

* fix things

* change some strings

* misc fixes, add tests

* jostle ci

* [Tutorial][Executor] Fix the usage of executors in tutorials (apache#8586)

* fix: executor usage for keras tutorial

* fix: executor usage for onnx tutorial

* [Tutorial][Executor] Fix executors in tutorials

* [Frontend][Onnx] Simplify onnx input since name accesses are not reliable. (apache#8867)

* Simplify onnx input since name accesses are no longer supported.

* move Celu importer.

* [TIR] GetBlockReadWriteRegion (apache#8875)

* [TIR] GetBlockReadWriteRegion

* Fix black issue

* Use constant reference for the interface

* Fix lint issue

* [RISCV] Add support for llvm parameter -mabi (-target-abi) (apache#8860)

* [Community] @manupa-arm -> Committer (apache#8870)

* adding Manupa to the contributors list

* re-trigger CI

* [RPC] Fix ios_rpc build (apache#8864)

* [Vulkan][Target] Added the driver name to the vulkan target string. (apache#8882)

Driver name (e.g. "NVIDIA", "radv", "AMD open-source driver") is read
from the `driverName` property in
[VkPhysicalDeviceDriverProperties](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPhysicalDeviceDriverProperties.html),
or is left as `"unknown_driver_name"` if the driver does not support
querying the driver name.

* [ONNX][TOPI] Support select_last_index for argmin/max (apache#8816)

* support select_last_index for argmin/max

* reverse conditions which made on accident

* forward args in reduce.py

* make proper nodes for reduction ops

* remove complicated nested lambdas

* fix lambda capture for conversion

* forward more arguments

* forward more args

* enable onnx tests

* wrapping casts to remove ambiguity

* revert changes extraneous

* correct incorrect attrs being used for ops

* change attributes

* remove old impl

* register new attribute node

* clean up test

* reformat

* reformat

* coolio

* stable comparison

* casts to avoid ambiguity

* casting more

* correct arg passing

* support select_last_index for argmin/max

* reverse conditions which made on accident

* forward args in reduce.py

* make proper nodes for reduction ops

* remove complicated nested lambdas

* fix lambda capture for conversion

* forward more arguments

* forward more args

* enable onnx tests

* wrapping casts to remove ambiguity

* revert changes extraneous

* correct incorrect attrs being used for ops

* change attributes

* remove old impl

* register new attribute node

* clean up test

* reformat

* reformat

* coolio

* stable comparison

* casts to avoid ambiguity

* casting more

* correct arg passing

* fix broken input

* OneElementReduceAttrs-->ArgReduceAttrs"

* reduce boilerplate

* change names

* remove log statement

* jostle ci

Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>

* refactor optimize GEMM on CPU tutorial (apache#8825)

* refactor optimize GEMM on CPU tutorial

* fix lint errors

* fix more lint errors

* fix typo

* fix problem with redefinition of `k`
add TODO and comments around loop unrolling
clarify note on the array packing figure

* reword general description of array packing

* grap kaxis from compute definition

* remove duplicate comments on unrolling

* Change target string to Target object in the TE compiler and interpreter (apache#8835)

* # This is a combination of 2 commits.
# This is the 1st commit message:

Initial changes

# This is the commit message apache#2:

Ftarget string -> Target object works!

* Fix remaining target strings

* fix bad rebase

* Fix typo

* 1 more bad rebase fix

* Lint

* typo

* Forgot to commit this

* Add TargetStrHash and Map<Target... to std::unordered_map<Target... conversion fn

* Passing most tests, yay

* remove some comments

* lint

* target-str-to-target-object

* Respond to change requests

Co-authored-by: Jared Roesch <roeschinc@gmail.com>

* [TensorIR][M2a] CacheRead/Write (apache#8863)

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>

* [CI] make pre-commit hooks to run on every push instead of every commit (apache#8888)

* [TVMScript] Fix printing ForNode annotations (apache#8891)

* [1/10] CMSIS-NN graph partitioner for softmax (apache#8653)

* cmsis graph partitioner for softmax

Change-Id: I80ecd7bc5351f241b4674ef53b36e4398c8adb83

* Updated docstring in the partioning function

Change-Id: Ieb4b623e5929cfdb6aa0235db64c825fac8d7055

* [microTVM][RVM] Add Arduino RVM (apache#8748)

* Functioning Arduino Vagrant VM

Begin building Arduino Vagrant VM

Mostly working Vagrant VM

Changes for debugging

Add ignored json file

Fix venv path

* Generalize parts of RVM for multiple platforms

cwd hack

Add unit tests from apps directory to task_python_microtvm.sh

Generalize parts of RVM for multiple platforms

* Add Vagrantfile lint exceptions

* Address PR comments

Address Mehrdad's PR comments

More PR comments

Documentation tweaks

Add dialout group to user

* Rerun tests

* Spresense fix

* Rerun CI tests

* Rerun tests

* sce loss example

* add comments, remove other tests

* lint

* lint

* jostle

* lint up

* jostle

* uncomment some tests

* proper return

* clean up

* lint

* minor merge errors

Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
Co-authored-by: Mehrdad Hessar <mhessar@octoml.ai>
Co-authored-by: Jiawei Liu <jaway.liu@gmail.com>
Co-authored-by: Tristan Konolige <tkonolige@octoml.ai>
Co-authored-by: Christopher Sidebottom <chris.sidebottom@arm.com>
Co-authored-by: Anastasia Stulova <38433336+AnastasiaStulova@users.noreply.github.com>
Co-authored-by: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com>
Co-authored-by: Krzysztof Parzyszek <kparzysz@quicinc.com>
Co-authored-by: Elen Kalda <elen.kalda@arm.com>
Co-authored-by: Anton Sorokin <anton.a.sorokin@intel.com>
Co-authored-by: Chenfan <jcf94@outlook.com>
Co-authored-by: masahi <masahi129@gmail.com>
Co-authored-by: Tantalus13A98B5F <jsl_713@live.com>
Co-authored-by: Valery Chernov <black.chervi@gmail.com>
Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
Co-authored-by: Jason <928090362@qq.com>
Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: wjj19950828 <wjjisloser@163.com>
Co-authored-by: heliqi <1101791222@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Swift.Sun <sunjiwei@yeah.net>
Co-authored-by: hwstaff <hwstaff@hwstaffdeMacBook-Pro.local>
Co-authored-by: Cahoon, Brendon <bcahoon@quicinc.com>
Co-authored-by: Lunderberg <Lunderberg@users.noreply.github.com>
Co-authored-by: Yizhi Liu <liuyizhi@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@vip.qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Josh Fromm <jwfromm@octoml.ai>
Co-authored-by: Alexander Pivovarov <pivovaa@amazon.com>
Co-authored-by: Thierry Moreau <tmoreau@octoml.ai>
Co-authored-by: Egor Churaev <egor.churaev@gmail.com>
Co-authored-by: Adam Straw <astraw@octoml.ai>
Co-authored-by: Lily Orth-Smith <lilyorthsmith@gmail.com>
Co-authored-by: Jared Roesch <roeschinc@gmail.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Michalis Papadimitriou <mikepapadim@users.noreply.github.com>
Co-authored-by: Gavin Uberti <guberti@users.noreply.github.com>
  • Loading branch information
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 60 deletions.
107 changes: 81 additions & 26 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""ONNX: Open Neural Network Exchange frontend for Relay."""
import copy
import warnings
from typing import Optional

import numpy as np
import tvm
Expand Down Expand Up @@ -1926,18 +1927,22 @@ def _impl_v13(cls, inputs, attr, params):
class LogSoftmax(OnnxOpConverter):
"""Operator converter for Softmax."""

@classmethod
def run_calculation(cls, x, axes):
"""Run the calculation for Log Softmax calculation."""
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
s = _op.sum(e, axes, keepdims=True)
return x - m - _op.log(s)

@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axis", 1)
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
axes = list(range(axis, ndim))
x = inputs[0]
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
s = _op.sum(e, axes, keepdims=True)
return x - m - _op.log(s)
return cls.run_calculation(inputs[0], axes)

@classmethod
def _impl_v13(cls, inputs, attr, params):
Expand All @@ -1946,11 +1951,7 @@ def _impl_v13(cls, inputs, attr, params):
if axis < 0:
axis += ndim
axes = [axis]
x = inputs[0]
m = _op.max(x, axes, keepdims=True)
e = _op.exp(x - m)
s = _op.sum(e, axes, keepdims=True)
return x - m - _op.log(s)
return cls.run_calculation(inputs[0], axes)


class Hardmax(OnnxOpConverter):
Expand Down Expand Up @@ -3611,33 +3612,30 @@ def _impl_v1(cls, inputs, attr, params):


class NegativeLogLikelihoodLoss(OnnxOpConverter):
"""Operator converter for random_uniform"""
"""Operator converter for NegativeLogLikehoodLoss"""

VALID_REDUCTIONS = {"mean", "sum", "none"}

@classmethod
def _impl_v13(cls, inputs, attr, params):
ignore_index = attr.get("ignore_index", None)
reduction = attr.get("reduction", b"mean").decode("utf-8")

if reduction not in cls.VALID_REDUCTIONS:
raise ValueError(
f"Unknown reduction type {reduction}, choices are {cls.VALID_REDUCTIONS}"
)

input_tensor, target_tensor = inputs[0], inputs[1]

def run_calculation(
cls: "NegativeLogLikelihoodLoss",
input_tensor: relay.Expr,
target_tensor: relay.Expr,
weight_tensor: Optional[relay.Expr],
ignore_index: int,
):
"""Run calculation for NegativeLogLikelihood, returning output tensor and
weight tensor used for mean-style reductions.
"""
# Convert negative indices --> positive indices for gather ops, note we have to
# use the original target tensor to interact with ignore_index to have proper behavior.
normalized_target_tensor = normalize_gather_indices(input_tensor, target_tensor, 1)

if len(inputs) == 3:
weight_tensor = inputs[2]
else:
if weight_tensor is None:
channels = infer_shape(input_tensor)[1]
weight_tensor = relay.ones(
[channels],
dtype=input_tensor.type_annotation.dtype,
dtype=infer_type(input_tensor).checked_type.dtype,
)

loss = -relay.gather(
Expand Down Expand Up @@ -3670,7 +3668,30 @@ def _impl_v13(cls, inputs, attr, params):
select_weights *= relay.cast_like(mask_tensor, select_weights)

weight_total = relay.sum(select_weights)
return loss, weight_total

@classmethod
def _impl_v13(cls, inputs, attr, params):
ignore_index = attr.get("ignore_index", None)
reduction = attr.get("reduction", b"mean").decode("utf-8")

if reduction not in cls.VALID_REDUCTIONS:
raise ValueError(
f"Unknown reduction type {reduction}, choices are {cls.VALID_REDUCTIONS}"
)

input_tensor, target_tensor = inputs[0], inputs[1]
if len(inputs) == 3:
weight_tensor = inputs[2]
else:
weight_tensor = None

loss, weight_total = cls.run_calculation(
input_tensor,
target_tensor,
weight_tensor=weight_tensor,
ignore_index=ignore_index,
)
if reduction == "mean":
return relay.sum(loss) / weight_total
if reduction == "sum":
Expand All @@ -3679,6 +3700,39 @@ def _impl_v13(cls, inputs, attr, params):
return loss


class SoftmaxCrossEntropyLoss(OnnxOpConverter):
"""Operator converter for SCE_loss"""

@classmethod
def _impl_v13(cls, inputs, attr, params):
ignore_index = attr.get("ignore_index", None)
reduction = attr.get("reduction", b"mean").decode("utf-8")
input_tensor, target_tensor = inputs[0], inputs[1]
if len(inputs) == 3:
weight_tensor = inputs[2]
else:
weight_tensor = None

get_log_prob = attr["tvm_custom"]["num_outputs"] == 2
log_softmax_tensor = LogSoftmax.run_calculation(input_tensor, axes=[1])

loss, weight_total = NegativeLogLikelihoodLoss.run_calculation(
log_softmax_tensor,
target_tensor,
weight_tensor,
ignore_index=ignore_index,
)

if reduction == "mean":
loss = relay.sum(loss) / weight_total
elif reduction == "sum":
loss = relay.sum(loss)

if get_log_prob:
return relay.TupleWrapper(relay.Tuple((loss, log_softmax_tensor)), 2)
return loss


class Adagrad(OnnxOpConverter):
"""Operator converter for adagrad op."""

Expand Down Expand Up @@ -4037,6 +4091,7 @@ def _get_convert_map(opset):
"RandomUniform": RandomUniform.get_converter(opset),
# Loss functions / training
"NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset),
"SoftmaxCrossEntropyLoss": SoftmaxCrossEntropyLoss.get_converter(opset),
"Adagrad": Adagrad.get_converter(opset),
"Adam": Adam.get_converter(opset),
"Momentum": Momentum.get_converter(opset),
Expand Down
41 changes: 7 additions & 34 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4944,73 +4944,40 @@ def verify_eyelike(indata):
"test_round",
"test_scan9_sum",
"test_scan_sum",
"test_sce_NCd1_mean_weight_negative_ii",
# With reduce_sum supported fully, these expanded tests should pass
"test_sce_NCd1_mean_weight_negative_ii_expanded",
"test_sce_NCd1_mean_weight_negative_ii_log_prob",
"test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded",
"test_sce_NCd1d2d3_none_no_weight_negative_ii",
"test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded",
"test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob",
"test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded",
"test_sce_NCd1d2d3_sum_weight_high_ii",
"test_sce_NCd1d2d3_sum_weight_high_ii_expanded",
"test_sce_NCd1d2d3_sum_weight_high_ii_log_prob",
"test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded",
"test_sce_NCd1d2d3d4d5_mean_weight",
"test_sce_NCd1d2d3d4d5_mean_weight_expanded",
"test_sce_NCd1d2d3d4d5_mean_weight_log_prob",
"test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded",
"test_sce_NCd1d2d3d4d5_none_no_weight",
"test_sce_NCd1d2d3d4d5_none_no_weight_expanded",
"test_sce_NCd1d2d3d4d5_none_no_weight_log_prob",
"test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded",
"test_sce_mean",
"test_sce_mean_3d",
"test_sce_mean_3d_expanded",
"test_sce_mean_3d_log_prob",
"test_sce_mean_3d_log_prob_expanded",
"test_sce_mean_expanded",
"test_sce_mean_log_prob",
"test_sce_mean_log_prob_expanded",
"test_sce_mean_no_weight_ii",
"test_sce_mean_no_weight_ii_3d",
"test_sce_mean_no_weight_ii_3d_expanded",
"test_sce_mean_no_weight_ii_3d_log_prob",
"test_sce_mean_no_weight_ii_3d_log_prob_expanded",
"test_sce_mean_no_weight_ii_4d",
"test_sce_mean_no_weight_ii_4d_expanded",
"test_sce_mean_no_weight_ii_4d_log_prob",
"test_sce_mean_no_weight_ii_4d_log_prob_expanded",
"test_sce_mean_no_weight_ii_expanded",
"test_sce_mean_no_weight_ii_log_prob",
"test_sce_mean_no_weight_ii_log_prob_expanded",
"test_sce_mean_weight",
"test_sce_mean_weight_expanded",
"test_sce_mean_weight_ii",
"test_sce_mean_weight_ii_3d",
"test_sce_mean_weight_ii_3d_expanded",
"test_sce_mean_weight_ii_3d_log_prob",
"test_sce_mean_weight_ii_3d_log_prob_expanded",
"test_sce_mean_weight_ii_4d",
"test_sce_mean_weight_ii_4d_expanded",
"test_sce_mean_weight_ii_4d_log_prob",
"test_sce_mean_weight_ii_4d_log_prob_expanded",
"test_sce_mean_weight_ii_expanded",
"test_sce_mean_weight_ii_log_prob",
"test_sce_mean_weight_ii_log_prob_expanded",
"test_sce_mean_weight_log_prob",
"test_sce_mean_weight_log_prob_expanded",
"test_sce_none",
"test_sce_none_expanded",
"test_sce_none_log_prob",
"test_sce_none_log_prob_expanded",
"test_sce_none_weights",
"test_sce_none_weights_expanded",
"test_sce_none_weights_log_prob",
"test_sce_none_weights_log_prob_expanded",
"test_sce_sum",
"test_sce_sum_expanded",
"test_sce_sum_log_prob",
"test_sce_sum_log_prob_expanded",
"test_sequence_insert_at_back",
"test_sequence_insert_at_front",
Expand Down Expand Up @@ -5093,6 +5060,12 @@ def test_onnx_nodes(target, dev, onnx_test):
# for some reason the ONNX test crops the
# roialign results to 4 decimal places
atol = 1e-4

if "_sce_" in test_dir:
# complicated loss functions like SoftmaxCrossEntropy can have minor variations
# in accuracy depending on implementation
atol = 1e-4

onnx_model = onnx.load(test_dir + "/model.onnx")
inputs = []
outputs = []
Expand Down

0 comments on commit d590349

Please sign in to comment.