Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add register_op_hook for gluon #15839

Merged
merged 5 commits into from
Sep 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ typedef void (*EngineFuncParamDeleter)(void*);
typedef void (*ExecutorMonitorCallback)(const char*,
NDArrayHandle,
void*);
/*! \brief Monitor callback called at operator level for cached op */
typedef void (*CachedOpMonitorCallback)(const char*,
const char*,
NDArrayHandle);


struct NativeOpInfo {
void (*forward)(int, float**, int*, unsigned**, int*, void*);
Expand Down Expand Up @@ -1286,6 +1291,13 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
NDArrayHandle **outputs,
const int** out_stypes);

/*!
* \brief cached op set monitor callback
*/
MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all);

//--------------------------------------------
// Part 3: symbolic configuration generation
//--------------------------------------------
Expand Down
28 changes: 27 additions & 1 deletion python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
from ..base import check_call


def _monitor_callback_wrapper(callback):
"""A wrapper for the user-defined handle."""
def callback_handle(name, opr_name, array, _):
""" ctypes function """
callback(name, opr_name, array)
return callback_handle

class NDArrayBase(object):
"""Base data structure for ndarray"""
__slots__ = ["handle", "writable"]
Expand Down Expand Up @@ -112,10 +119,11 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):

class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle", "is_np_sym"]
__slots__ = ["handle", "is_np_sym", "_monitor_callback"]

def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
self._monitor_callback = None

from ..symbol.numpy._symbol import _Symbol
self.is_np_sym = bool(isinstance(sym, _Symbol))
Expand Down Expand Up @@ -170,3 +178,21 @@ def __call__(self, *args, **kwargs):
else:
return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
stype=out_stypes[i]) for i in range(num_output.value)]

def _register_op_hook(self, callback, monitor_all=False):
"""Install callback for monitor.

Parameters
----------
callback : function
Takes a string for node_name, string for op_name and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input _imperative_invoked output, otherwise monitor output only.
"""
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p)
if callback:
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
check_call(_LIB.MXCachedOpRegisterOpHook(
self.handle,
self._monitor_callback,
ctypes.c_int(monitor_all)))
8 changes: 8 additions & 0 deletions python/mxnet/cython/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ from ..base import MXNetError

from libcpp.vector cimport vector
from libcpp.string cimport string
from libcpp cimport bool as _bool
from cpython.version cimport PY_MAJOR_VERSION

ctypedef void* SymbolHandle
ctypedef void* NDArrayHandle
ctypedef void* OpHandle
ctypedef void* CachedOpHandle
ctypedef void* MonitorCallbackHandle
ctypedef unsigned nn_uint
ctypedef void (*CachedOpMonitorCallback)(const char*,
const char*,
NDArrayHandle)

cdef py_str(const char* x):
if PY_MAJOR_VERSION < 3:
Expand Down Expand Up @@ -112,3 +117,6 @@ cdef extern from "mxnet/c_api.h":
int *num_outputs,
NDArrayHandle **outputs,
const int **out_stypes);
int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
_bool monitor_all);
16 changes: 15 additions & 1 deletion python/mxnet/cython/ndarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import ctypes as _ctypes
import numpy as np
from ..ndarray_doc import _build_doc
from libc.stdint cimport uint32_t, int64_t
from ..base import _LIB

include "./base.pyi"

Expand All @@ -47,7 +48,6 @@ cdef class NDArrayBase:
return _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)

property writable:
def __get__(self):
return bool(self.cwritable)
Expand Down Expand Up @@ -75,6 +75,10 @@ def _set_np_ndarray_class(cls):
global _np_ndarray_cls
_np_ndarray_cls = cls

def _monitor_callback_wrapper(callback):
def callback_handle(name, opr_name, arr, _):
callback(name, opr_name, arr)
return callback_handle

cdef NewArray(NDArrayHandle handle, int stype=-1, int is_np_array=0):
"""Create a new array given handle"""
Expand Down Expand Up @@ -103,6 +107,7 @@ cdef class CachedOp:
self._set_handle(value)

cdef int is_np_sym
cdef readonly object mhandle

def __init__(self, sym, flags=()):
cdef vector[string] s_flag_keys
Expand Down Expand Up @@ -169,6 +174,15 @@ cdef class CachedOp:
else:
return [NewArray(p_output_vars[i], p_output_stypes[i], self.is_np_sym) for i in range(num_output)]

def _register_op_hook(self, callback, monitor_all=False):
cb_type = _ctypes.CFUNCTYPE(None, _ctypes.c_char_p, _ctypes.c_char_p, _ctypes.c_void_p, _ctypes.c_void_p)
if callback:
self.mhandle = cb_type(_monitor_callback_wrapper(callback))
chandle = _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p)
CALL(_LIB.MXCachedOpRegisterOpHook(chandle,
self.mhandle,
_ctypes.c_int(monitor_all)))


def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op=0):
"""cython implementation of imperative invoke wrapper"""
Expand Down
37 changes: 37 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,19 @@ def forward(self, *args):
# pylint: disable= invalid-name
raise NotImplementedError

def register_op_hook(self, callback, monitor_all=False):
"""Install callback monitor.

Parameters
----------
callback : function
Takes a string and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input and output, otherwise monitor output only.
"""
for cld in self._children.values():
cld.register_op_hook(callback, monitor_all)

def summary(self, *inputs):
"""Print the summary of the model's output and parameters.

Expand Down Expand Up @@ -754,6 +767,8 @@ def __init__(self, prefix=None, params=None):
self._in_format = None
self._active = False
self._flags = []
self._callback = None
self._monitor_all = False

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -833,6 +848,12 @@ def _deferred_infer_shape(self, *args):
def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)
assert self._cached_op, "cached op is not None"
if self._callback:
self._cached_op._register_op_hook(self._callback, self._monitor_all)
if len(self._flags) >= 2 and (self._flags[1] or self._flags[0]):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._flags = list(kwargs.items())
The condition is true even when it is not supposed to be. For example if
self._flags = [('static_alloc', False), ('static_shape', False), ('inline_limit', 2)]

warnings.warn("register_op_hook is experimental when static_alloc=True / static_shape=True "
" and may not work correctly")

args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
Expand Down Expand Up @@ -938,6 +959,22 @@ def export(self, path, epoch=0, remove_amp_cast=True):
save_fn = _mx_npx.save if is_np_array() else ndarray.save
save_fn('%s-%04d.params'%(path, epoch), arg_dict)

def register_op_hook(self, callback, monitor_all=False):
"""Install op hook for block recursively.

Parameters
----------
callback : function
Takes a string and a NDArrayHandle.
monitor_all : bool, default False
If true, monitor both input and output, otherwise monitor output only.
"""
self._callback = callback
self._monitor_all = monitor_all
for cld in self._children.values():
cld._callback = callback
cld._monitor_all = monitor_all

def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
Expand Down
20 changes: 20 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,23 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) {
*out = reinterpret_cast<SymbolHandle>(sym);
API_END();
}

int MXCachedOpRegisterOpHook(NDArrayHandle handle,
CachedOpMonitorCallback callback,
bool monitor_all) {
API_BEGIN();
CachedOpMonitorCallback callback_temp = nullptr;
std::function<void(const char *, const char *, void*)> clbk;
if (callback) {
callback_temp = callback;
clbk = [callback_temp](const char *name, const char *opr_name,
void *handle) {
callback_temp(name, opr_name, handle);
};
} else {
clbk = nullptr;
}
CachedOpPtr op = *static_cast<CachedOpPtr *>(handle);
op->RegisterOpHook(clbk, monitor_all);
API_END();
}
57 changes: 57 additions & 0 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,62 @@ void CastStorageDispatch<cpu>(const OpContext& ctx,
mxnet::op::CastStorageComputeImpl<cpu>(ctx, input, output);
}

void ExecuteMonInputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback) {
static const auto &flist_inputs =
nnvm::Op::GetAttr<nnvm::FListInputNames>("FListInputNames");
std::vector<std::string> input_names;
const nnvm::IndexedGraph::Node &inode = idx[nid];
const nnvm::Node *node = inode.source;
if (flist_inputs.count(node->op())) {
input_names = flist_inputs[node->op()](node->attrs);
} else {
for (size_t i = 0; i < node->num_inputs(); ++i) {
input_names.emplace_back("input" + std::to_string(i));
}
}

for (size_t i = 0; i < node->num_inputs(); ++i) {
const nnvm::NodeEntry &input = node->inputs[i];
if (state_arrays[idx.entry_id(input)]->is_none()) {
continue;
}
NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(input)]);
std::string name = inode.source->attrs.name + "_" + input_names[i];
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
monitor_callback(name.c_str(), inode.source->op()->name.c_str(),
reinterpret_cast<void *>(cpy));
}
}

void ExecuteMonOutputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback) {
static const auto &flist_outputs =
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
std::vector<std::string> output_names;
const nnvm::IndexedGraph::Node &inode = idx[nid];
const nnvm::Node *node = inode.source;
if (flist_outputs.count(node->op())) {
output_names = flist_outputs[node->op()](node->attrs);
} else {
for (size_t i = 0; i < node->num_outputs(); ++i) {
output_names.emplace_back(std::to_string(i));
}
}

for (size_t i = 0; i < node->num_outputs(); ++i) {
if (state_arrays[idx.entry_id(nid, i)]->is_none()) {
continue;
}
NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(nid, i)]);
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
std::string name = inode.source->attrs.name + "_" + output_names[i];
monitor_callback(name.c_str(), inode.source->op()->name.c_str(),
reinterpret_cast<void *>(cpy));
}
}

} // namespace common
} // namespace mxnet
9 changes: 9 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,15 @@ inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) {
ConvertToLegacyShape(&(shapes->at(i)));
}
}
void ExecuteMonInputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback);

void ExecuteMonOutputCallback(
const nnvm::IndexedGraph &idx, const std::vector<NDArray *> &state_arrays,
size_t nid, const std::function<void(const char *, const char *, void *)>
&monitor_callback);

} // namespace common
} // namespace mxnet
Expand Down
23 changes: 20 additions & 3 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ void CachedOp::StaticRunOps(
ndinputs.emplace_back(state_arrays[idx.entry_id(j)]);
CHECK(!ndinputs.back()->is_none());
}
if (monitor_callback_ && monitor_all_) {
mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_);
}
ndoutputs.clear();
ndoutputs.reserve(num_outputs);
req.clear();
Expand All @@ -708,6 +711,7 @@ void CachedOp::StaticRunOps(
CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
}
const DispatchMode dispatch_mode = dispatch_modes[i];

if (createop.count(node.source->op())) {
arg_shapes.clear();
arg_dtypes.clear();
Expand Down Expand Up @@ -735,6 +739,9 @@ void CachedOp::StaticRunOps(
default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode);
}
if (monitor_callback_) {
mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_);
}
}
}
}
Expand Down Expand Up @@ -883,12 +890,12 @@ OpStatePtr CachedOp::DynamicForward(
// So if it's not the inline mode, we disable recording.
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes,
recording && inlining_);
recording && inlining_, nullptr, monitor_callback_, monitor_all_);
} else {
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states,
dispatch_modes, recording && inlining_, &shapes);
dispatch_modes, recording && inlining_, &shapes, monitor_callback_, monitor_all_);
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
Expand Down Expand Up @@ -1028,7 +1035,7 @@ void CachedOp::DynamicBackward(

RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes,
Imperative::Get()->is_recording());
Imperative::Get()->is_recording(), nullptr, monitor_callback_);

if (retain_graph) {
buff.resize(num_forward_entries);
Expand Down Expand Up @@ -1295,6 +1302,16 @@ void CachedOpBackward(const OpStatePtr& state_ptr,
CopyFromTo(out_bufs[i], outputs[i]);
}

/*
* Register the callback to be called when the operator is executed
*/
void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback,
bool monitor_all) {
CHECK(callback) << "invalid callback";
monitor_callback_ = callback;
monitor_all_ = monitor_all;
}

OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
Context ctx,
const mxnet::ShapeVector& in_shapes,
Expand Down
Loading