Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[DEBUG] enable custom error type (#17128)
Browse files Browse the repository at this point in the history
* enable custom error type

* fix test

* fix lint

* fix error for cython

* handle windows path in error message normalize
  • Loading branch information
szha authored Dec 31, 2019
1 parent e9c482e commit c020f37
Show file tree
Hide file tree
Showing 8 changed files with 437 additions and 56 deletions.
32 changes: 31 additions & 1 deletion include/mxnet/c_api_error.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
* \file c_api_error.h
* \brief Error handling for C API.
*/
#include <string>

#ifndef MXNET_C_API_ERROR_H_
#define MXNET_C_API_ERROR_H_

Expand Down Expand Up @@ -51,6 +53,34 @@
} \
on_exit_api(); \
return 0; // NOLINT(*)

//--------------------------------------------------------
// Error handling mechanism
// -------------------------------------------------------
// Standard error message format, {} means optional
//--------------------------------------------------------
// {error_type:} {message0}
// {message1}
// {message2}
// {Stack trace:} // stack traces follow by this line
// {trace 0} // two spaces in the begining.
// {trace 1}
// {trace 2}
//--------------------------------------------------------
/*!
* \brief Normalize error message
*
* Parse them header generated by by LOG(FATAL) and CHECK
* and reformat the message into the standard format.
*
* This function will also merge all the stack traces into
* one trace and trim them.
*
* \param err_msg The error message.
* \return normalized message.
*/
std::string NormalizeError(std::string err_msg);

/*!
* \brief Set the last error message needed by C API
* \param msg The error message to set.
Expand All @@ -62,7 +92,7 @@ void MXAPISetLastError(const char* msg);
* \return the return value of API after exception is handled
*/
inline int MXAPIHandleException(const std::exception &e) {
MXAPISetLastError(e.what());
MXAPISetLastError(NormalizeError(e.what()).c_str());
return -1;
}

Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import

from .context import Context, current_context, cpu, gpu, cpu_pinned
from . import engine
from . import engine, error
from .base import MXNetError
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from .util import is_np_array, np_array, use_np_array, use_np
Expand Down
162 changes: 144 additions & 18 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
integer_types = (int, long, _np.int32, _np.int64)
numeric_types = (float, int, long, _np.generic)
string_types = basestring,
error_types = {}

if sys.version_info[0] > 2:
# this function is needed for python3
Expand Down Expand Up @@ -109,11 +110,121 @@ def __repr__(self):
_Null = _NullType()


class MXNetError(Exception):
"""Error that will be thrown by all mxnet functions."""
pass
class MXNetError(RuntimeError):
"""Default error thrown by MXNet functions.
MXNetError will be raised if you do not give any error type specification,
"""

def register_error(func_name=None, cls=None):
"""Register an error class so it can be recognized by the ffi error handler.
Parameters
----------
func_name : str or function or class
The name of the error function.
cls : function
The function to create the class
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
.. code-block:: python
@mxnet.error.register_error
class MyError(RuntimeError):
pass
err_inst = mxnet.error.create_ffi_error("MyError: xyz")
assert isinstance(err_inst, MyError)
"""
if callable(func_name):
cls = func_name
func_name = cls.__name__

def register(mycls):
"""internal register function"""
err_name = func_name if isinstance(func_name, str) else mycls.__name__
error_types[err_name] = mycls
return mycls
if cls is None:
return register
return register(cls)


def _valid_error_name(name):
"""Check whether name is a valid error name."""
return all(x.isalnum() or x in "_." for x in name)


def _find_error_type(line):
"""Find the error name given the first line of the error message.
Parameters
----------
line : str
The first line of error message.
Returns
-------
name : str The error name
"""
end_pos = line.find(":")
if end_pos == -1:
return None
err_name = line[:end_pos]
if _valid_error_name(err_name):
return err_name
return None


def c2pyerror(err_msg):
"""Translate C API error message to python style.
Parameters
----------
err_msg : str
The error message.
Returns
-------
new_msg : str
Translated message.
err_type : str
Detected error type.
"""
arr = err_msg.split("\n")
if arr[-1] == "":
arr.pop()
err_type = _find_error_type(arr[0])
trace_mode = False
stack_trace = []
message = []
for line in arr:
if trace_mode:
if line.startswith(" "):
stack_trace.append(line)
else:
trace_mode = False
if not trace_mode:
if line.startswith("Stack trace"):
trace_mode = True
else:
message.append(line)
out_msg = ""
if stack_trace:
out_msg += "Traceback (most recent call last):\n"
out_msg += "\n".join(reversed(stack_trace)) + "\n"
out_msg += "\n".join(message)
return out_msg, err_type

@register_error
class NotImplementedForSymbol(MXNetError):
"""Error: Not implemented for symbol"""
def __init__(self, function, alias, *args):
Expand All @@ -132,6 +243,36 @@ def __str__(self):
return msg


def get_last_ffi_error():
"""Create error object given result of MXGetLastError.
Returns
-------
err : object
The error object based on the err_msg
"""
c_err_msg = py_str(_LIB.MXGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg)
if err_type is not None and err_type.startswith("mxnet.error."):
err_type = err_type[10:]
return error_types.get(err_type, MXNetError)(py_err_msg)


def check_call(ret):
"""Check the return value of C API call.
This function will raise an exception when an error occurs.
Wrap every API call with this function.
Parameters
----------
ret : int
return value from API calls.
"""
if ret != 0:
raise get_last_ffi_error()


class NotSupportedForSparseNDArray(MXNetError):
"""Error: Not supported for SparseNDArray"""
def __init__(self, function, alias, *args):
Expand Down Expand Up @@ -263,21 +404,6 @@ def _load_lib():
#----------------------------
# helper function definition
#----------------------------
def check_call(ret):
"""Check the return value of C API call.
This function will raise an exception when an error occurs.
Wrap every API call with this function.
Parameters
----------
ret : int
return value from API calls.
"""
if ret != 0:
raise MXNetError(py_str(_LIB.MXGetLastError()))


if sys.version_info[0] < 3:
def c_str(string):
"""Create ctypes char * from a Python string.
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/cython/base.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..base import MXNetError
from ..base import get_last_ffi_error

from libcpp.vector cimport vector
from libcpp.string cimport string
Expand Down Expand Up @@ -38,7 +38,7 @@ cdef c_str(pystr):

cdef CALL(int ret):
if ret != 0:
raise MXNetError(NNGetLastError())
raise get_last_ffi_error()


cdef const char** CBeginPtr(vector[const char*]& vec):
Expand Down
57 changes: 57 additions & 0 deletions python/mxnet/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Structured error classes in MXNet.
Each error class takes an error message as its input.
See the example sections for for suggested message conventions.
To make the code more readable, we recommended developers to
copy the examples and raise errors with the same message convention.
"""
from .base import MXNetError, register_error

__all__ = ['MXNetError', 'register']

register = register_error

@register_error
class InternalError(MXNetError):
"""Internal error in the system.
Examples
--------
.. code :: c++
// Example code C++
LOG(FATAL) << "InternalError: internal error detail.";
.. code :: python
# Example code in python
raise InternalError("internal error detail")
"""
def __init__(self, msg):
# Patch up additional hint message.
if "MXNet hint:" not in msg:
msg += ("\nMXNet hint: You hit an internal error. Please open an issue in "
"https://github.com/apache/incubator-mxnet/issues/new/choose"
" to report it.")
super(InternalError, self).__init__(msg)


register_error("ValueError", ValueError)
register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError)
Loading

0 comments on commit c020f37

Please sign in to comment.