Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add register_op_hook for gluon (#15839)
Browse files Browse the repository at this point in the history
* Set monitor callback basic support

* Trigger CI

* Add base.pyi and ndarray.pyx

* Change not supported to experimental and check for both static_shape and static_alloc
  • Loading branch information
anirudh2290 committed Sep 18, 2019
1 parent 3dacabe commit b777a69
Show file tree
Hide file tree
Showing 13 changed files with 307 additions and 9 deletions.
12 changes: 12 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ typedef void (*EngineFuncParamDeleter)(void*);
typedef void (*ExecutorMonitorCallback)(const char*,
NDArrayHandle,
void*);
/*! \brief Monitor callback called at operator level for cached op */
typedef void (*CachedOpMonitorCallback)(const char*,
const char*,
NDArrayHandle);


struct NativeOpInfo {
void (*forward)(int, float**, int*, unsigned**, int*, void*);
Expand Down Expand Up @@ -1284,6 +1289,13 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
NDArrayHandle **outputs,
const int** out_stypes);

/*!
* \brief cached op set monitor callback
*/
MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all);

//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down
28 changes: 27 additions & 1 deletion python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
from ..base import check_call


def _monitor_callback_wrapper(callback):
"""A wrapper for the user-defined handle."""
def callback_handle(name, opr_name, array, _):
""" ctypes function """
callback(name, opr_name, array)
return callback_handle

class NDArrayBase(object):
"""Base data structure for ndarray"""
__slots__ = ["handle", "writable"]
Expand Down Expand Up @@ -112,10 +119,11 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):

class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle", "is_np_sym"]
__slots__ = ["handle", "is_np_sym", "_monitor_callback"]

def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
self._monitor_callback = None

from ..symbol.numpy._symbol import _Symbol
self.is_np_sym = bool(isinstance(sym, _Symbol))
Expand Down Expand Up @@ -170,3 +178,21 @@ def __call__(self, *args, **kwargs):
else:
return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
stype=out_stypes[i]) for i in range(num_output.value)]

def _register_op_hook(self, callback, monitor_all=False):
"""Install callback for monitor.
Parameters
----------
callback : function
Takes a string for node_name, string for op_name and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input _imperative_invoked output, otherwise monitor output only.
"""
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p)
if callback:
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
check_call(_LIB.MXCachedOpRegisterOpHook(
self.handle,
self._monitor_callback,
ctypes.c_int(monitor_all)))
8 changes: 8 additions & 0 deletions python/mxnet/cython/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ from ..base import MXNetError

from libcpp.vector cimport vector
from libcpp.string cimport string
from libcpp cimport bool as _bool
from cpython.version cimport PY_MAJOR_VERSION

ctypedef void* SymbolHandle
ctypedef void* NDArrayHandle
ctypedef void* OpHandle
ctypedef void* CachedOpHandle
ctypedef void* MonitorCallbackHandle
ctypedef unsigned nn_uint
ctypedef void (*CachedOpMonitorCallback)(const char*,
const char*,
NDArrayHandle)

cdef py_str(const char* x):
if PY_MAJOR_VERSION < 3:
Expand Down Expand Up @@ -112,3 +117,6 @@ cdef extern from "mxnet/c_api.h":
int *num_outputs,
NDArrayHandle **outputs,
const int **out_stypes);
int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
_bool monitor_all);
16 changes: 15 additions & 1 deletion python/mxnet/cython/ndarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import ctypes as _ctypes
import numpy as np
from ..ndarray_doc import _build_doc
from libc.stdint cimport uint32_t, int64_t
from ..base import _LIB

include "./base.pyi"

Expand All @@ -47,7 +48,6 @@ cdef class NDArrayBase:
return _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)

property writable:
def __get__(self):
return bool(self.cwritable)
Expand Down Expand Up @@ -75,6 +75,10 @@ def _set_np_ndarray_class(cls):
global _np_ndarray_cls
_np_ndarray_cls = cls

def _monitor_callback_wrapper(callback):
def callback_handle(name, opr_name, arr, _):
callback(name, opr_name, arr)
return callback_handle

cdef NewArray(NDArrayHandle handle, int stype=-1, int is_np_array=0):
"""Create a new array given handle"""
Expand Down Expand Up @@ -103,6 +107,7 @@ cdef class CachedOp:
self._set_handle(value)

cdef int is_np_sym
cdef readonly object mhandle

def __init__(self, sym, flags=()):
cdef vector[string] s_flag_keys
Expand Down Expand Up @@ -169,6 +174,15 @@ cdef class CachedOp:
else:
return [NewArray(p_output_vars[i], p_output_stypes[i], self.is_np_sym) for i in range(num_output)]

def _register_op_hook(self, callback, monitor_all=False):
cb_type = _ctypes.CFUNCTYPE(None, _ctypes.c_char_p, _ctypes.c_char_p, _ctypes.c_void_p, _ctypes.c_void_p)
if callback:
self.mhandle = cb_type(_monitor_callback_wrapper(callback))
chandle = _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p)
CALL(_LIB.MXCachedOpRegisterOpHook(chandle,
self.mhandle,
_ctypes.c_int(monitor_all)))


def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0):
"""cython implementation of imperative invoke wrapper"""
Expand Down
37 changes: 37 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,19 @@ def forward(self, *args):
# pylint: disable= invalid-name
raise NotImplementedError

def register_op_hook(self, callback, monitor_all=False):
"""Install callback monitor.
Parameters
----------
callback : function
Takes a string and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input and output, otherwise monitor output only.
"""
for cld in self._children.values():
cld.register_op_hook(callback, monitor_all)

def summary(self, *inputs):
"""Print the summary of the model's output and parameters.
Expand Down Expand Up @@ -754,6 +767,8 @@ def __init__(self, prefix=None, params=None):
self._in_format = None
self._active = False
self._flags = []
self._callback = None
self._monitor_all = False

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -833,6 +848,12 @@ def _deferred_infer_shape(self, *args):
def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)
assert self._cached_op, "cached op is not None"
if self._callback:
self._cached_op._register_op_hook(self._callback, self._monitor_all)
if len(self._flags) >= 2 and (self._flags[1] or self._flags[0]):
warnings.warn("register_op_hook is experimental when static_alloc=True / static_shape=True "
" and may not work correctly")

args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
Expand Down Expand Up @@ -938,6 +959,22 @@ def export(self, path, epoch=0, remove_amp_cast=True):
save_fn = _mx_npx.save if is_np_array() else ndarray.save
save_fn('%s-%04d.params'%(path, epoch), arg_dict)

def register_op_hook(self, callback, monitor_all=False):
"""Install op hook for block recursively.
Parameters
----------
callback : function
Takes a string and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input and output, otherwise monitor output only.
"""
self._callback = callback
self._monitor_all = monitor_all
for cld in self._children.values():
cld._callback = callback
cld._monitor_all = monitor_all

def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
Expand Down
20 changes: 20 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,23 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) {
*out = reinterpret_cast<SymbolHandle>(sym);
API_END();
}

int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all) {
API_BEGIN();
CachedOpMonitorCallback callback_temp = nullptr;
std::function<void(const char *, const char *, void*)> clbk;
if (callback) {
callback_temp = callback;
clbk = [callback_temp](const char *name, const char *opr_name,
void *handle) {
callback_temp(name, opr_name, handle);
};
} else {
clbk = nullptr;
}
CachedOpPtr op = *static_cast<CachedOpPtr *>(handle);
op->RegisterOpHook(clbk, monitor_all);
API_END();
}
57 changes: 57 additions & 0 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,62 @@ void CastStorageDispatch<cpu>(const OpContext& ctx,
mxnet::op::CastStorageComputeImpl<cpu>(ctx, input, output);
}

void ExecuteMonInputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback) {
static const auto &flist_inputs =
nnvm::Op::GetAttr<nnvm::FListInputNames>("FListInputNames");
std::vector<std::string> input_names;
const nnvm::IndexedGraph::Node &inode = idx[nid];
const nnvm::Node *node = inode.source;
if (flist_inputs.count(node->op())) {
input_names = flist_inputs[node->op()](node->attrs);
} else {
for (size_t i = 0; i < node->num_inputs(); ++i) {
input_names.emplace_back("input" + std::to_string(i));
}
}

for (size_t i = 0; i < node->num_inputs(); ++i) {
const nnvm::NodeEntry &input = node->inputs[i];
if (state_arrays[idx.entry_id(input)]->is_none()) {
continue;
}
NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(input)]);
std::string name = inode.source->attrs.name + "_" + input_names[i];
monitor_callback(name.c_str(), inode.source->op()->name.c_str(),
reinterpret_cast<void *>(cpy));
}
}

void ExecuteMonOutputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback) {
static const auto &flist_outputs =
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
std::vector<std::string> output_names;
const nnvm::IndexedGraph::Node &inode = idx[nid];
const nnvm::Node *node = inode.source;
if (flist_outputs.count(node->op())) {
output_names = flist_outputs[node->op()](node->attrs);
} else {
for (size_t i = 0; i < node->num_outputs(); ++i) {
output_names.emplace_back(std::to_string(i));
}
}

for (size_t i = 0; i < node->num_outputs(); ++i) {
if (state_arrays[idx.entry_id(nid, i)]->is_none()) {
continue;
}
NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(nid, i)]);
std::string name = inode.source->attrs.name + "_" + output_names[i];
monitor_callback(name.c_str(), inode.source->op()->name.c_str(),
reinterpret_cast<void *>(cpy));
}
}

} // namespace common
} // namespace mxnet
9 changes: 9 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,15 @@ inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) {
ConvertToLegacyShape(&(shapes->at(i)));
}
}
void ExecuteMonInputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback);

void ExecuteMonOutputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback);

/*!
* \brief This is function can return the output names of a NodeEntry.
Expand Down
23 changes: 20 additions & 3 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ void CachedOp::StaticRunOps(
ndinputs.emplace_back(state_arrays[idx.entry_id(j)]);
CHECK(!ndinputs.back()->is_none());
}
if (monitor_callback_ && monitor_all_) {
mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_);
}
ndoutputs.clear();
ndoutputs.reserve(num_outputs);
req.clear();
Expand All @@ -708,6 +711,7 @@ void CachedOp::StaticRunOps(
CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
}
const DispatchMode dispatch_mode = dispatch_modes[i];

if (createop.count(node.source->op())) {
arg_shapes.clear();
arg_dtypes.clear();
Expand Down Expand Up @@ -735,6 +739,9 @@ void CachedOp::StaticRunOps(
default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode);
}
if (monitor_callback_) {
mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_);
}
}
}
}
Expand Down Expand Up @@ -883,12 +890,12 @@ OpStatePtr CachedOp::DynamicForward(
// So if it's not the inline mode, we disable recording.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes,
recording && inlining_);
recording && inlining_, nullptr, monitor_callback_, monitor_all_);
} else {
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states,
dispatch_modes, recording && inlining_, &shapes);
dispatch_modes, recording && inlining_, &shapes, monitor_callback_, monitor_all_);
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
Expand Down Expand Up @@ -1028,7 +1035,7 @@ void CachedOp::DynamicBackward(

RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
Imperative::Get()->is_recording());
Imperative::Get()->is_recording(), nullptr, monitor_callback_);

if (retain_graph) {
buff.resize(num_forward_entries);
Expand Down Expand Up @@ -1295,6 +1302,16 @@ void CachedOpBackward(const OpStatePtr& state_ptr,
CopyFromTo(out_bufs[i], outputs[i]);
}

/*
* Register the callback to be called when the operator is executed
*/
void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback,
bool monitor_all) {
CHECK(callback) << "invalid callback";
monitor_callback_ = callback;
monitor_all_ = monitor_all;
}

OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
Context ctx,
const mxnet::ShapeVector& in_shapes,
Expand Down
Loading

0 comments on commit b777a69

Please sign in to comment.