From 018364902047864767a9f0deae0c09cebdf714f2 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Sat, 2 Apr 2022 19:06:03 +0800 Subject: [PATCH] Add module information in error message when type mismatch (#7918) * add module name in error message * fix bug for pyobject without __module__ attr * add unittest, refine code * format code * remove module name in error message * remove unittest for oneflow input * move test_error_msg.py to test/exceptions from test/misc Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- oneflow/api/python/functional/py_function.cpp | 7 ++-- .../oneflow/test/exceptions/test_error_msg.py | 39 +++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 python/oneflow/test/exceptions/test_error_msg.py diff --git a/oneflow/api/python/functional/py_function.cpp b/oneflow/api/python/functional/py_function.cpp index 7df18262e8d..4c64448d930 100644 --- a/oneflow/api/python/functional/py_function.cpp +++ b/oneflow/api/python/functional/py_function.cpp @@ -95,9 +95,10 @@ bool ParseArgs(const py::args& args, const py::kwargs& kwargs, std::vectorat(i) = std::move(arg); } else { if (raise_exception) { - THROW(TypeError) << function.name << "(): argument '" << param.name << "' must be " - << ValueTypeName(param.type).GetOrThrow() << ", not " - << Py_TYPE(obj.ptr())->tp_name; + THROW(TypeError) + << function.name << "(): argument '" << param.name << "' must be " + << ValueTypeName(param.type).GetOrThrow() << ", not " + << PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(obj.ptr()))).GetOrThrow(); } return false; } diff --git a/python/oneflow/test/exceptions/test_error_msg.py b/python/oneflow/test/exceptions/test_error_msg.py new file mode 100644 index 00000000000..3e9cbff99c4 --- /dev/null +++ b/python/oneflow/test/exceptions/test_error_msg.py @@ -0,0 +1,39 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import oneflow as flow +import oneflow.unittest +import oneflow.nn.functional as F +import torch + + +@flow.unittest.skip_unless_1n1d() +class TestErrorMsg(flow.unittest.TestCase): + def test_torch_error_msg(test_case): + with test_case.assertRaises(flow._oneflow_internal.exception.Exception) as exp: + F.pad(torch.randn(2, 2)) + test_case.assertTrue("torch.Tensor" in str(exp.exception)) + + def test_numpy_error_msg(test_case): + import numpy as np + + with test_case.assertRaises(flow._oneflow_internal.exception.Exception) as exp: + F.pad(np.random.randn(2, 2)) + test_case.assertTrue("numpy" in str(exp.exception)) + + +if __name__ == "__main__": + unittest.main()