diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 20b2aa2d5c9b..27a134cb5050 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -111,6 +111,11 @@ typedef void (*EngineFuncParamDeleter)(void*); typedef void (*ExecutorMonitorCallback)(const char*, NDArrayHandle, void*); +/*! \brief Monitor callback called at operator level for cached op */ +typedef void (*CachedOpMonitorCallback)(const char*, + const char*, + NDArrayHandle); + struct NativeOpInfo { void (*forward)(int, float**, int*, unsigned**, int*, void*); @@ -1222,6 +1227,13 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle, NDArrayHandle **outputs, const int** out_stypes); +/*! + * \brief cached op set monitor callback + */ +MXNET_DLL int MXCachedOpRegisterOpHook(NDArrayHandle handle, + CachedOpMonitorCallback callback, + bool monitor_all); + //-------------------------------------------- // Part 3: symbolic configuration generation //-------------------------------------------- diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index b1a38c1d2621..0d5dade2f163 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -29,6 +29,13 @@ from ..base import check_call +def _monitor_callback_wrapper(callback): + """A wrapper for the user-defined handle.""" + def callback_handle(name, opr_name, array, _): + """ ctypes function """ + callback(name, opr_name, array) + return callback_handle + class NDArrayBase(object): """Base data structure for ndarray""" __slots__ = ["handle", "writable"] @@ -112,10 +119,11 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op): class CachedOp(object): """Cached operator handle.""" - __slots__ = ["handle", "is_np_sym"] + __slots__ = ["handle", "is_np_sym", "_monitor_callback"] def __init__(self, sym, flags=()): self.handle = CachedOpHandle() + self._monitor_callback = None from ..symbol.numpy._symbol import _Symbol self.is_np_sym = bool(isinstance(sym, _Symbol)) @@ -170,3 +178,21 @@ def __call__(self, *args, **kwargs): else: return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i]) for i in range(num_output.value)] + + def _register_op_hook(self, callback, monitor_all=False): + """Install callback for monitor. + + Parameters + ---------- + callback : function + Takes a string for node_name, string for op_name and a NDArrayHandle. + monitor_all : bool, default False + If true, monitor both input _imperative_invoked output, otherwise monitor output only. + """ + cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p) + if callback: + self._monitor_callback = cb_type(_monitor_callback_wrapper(callback)) + check_call(_LIB.MXCachedOpRegisterOpHook( + self.handle, + self._monitor_callback, + ctypes.c_int(monitor_all))) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 97e6e8b68453..c1bdc370be82 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -590,6 +590,19 @@ def forward(self, *args): # pylint: disable= invalid-name raise NotImplementedError + def register_op_hook(self, callback, monitor_all=False): + """Install callback monitor. + + Parameters + ---------- + callback : function + Takes a string and a NDArrayHandle. + monitor_all : bool, default False + If true, monitor both input and output, otherwise monitor output only. + """ + for cld in self._children.values(): + cld.register_op_hook(callback, monitor_all) + def summary(self, *inputs): """Print the summary of the model's output and parameters. @@ -754,6 +767,8 @@ def __init__(self, prefix=None, params=None): self._in_format = None self._active = False self._flags = [] + self._callback = None + self._monitor_all = False def __setattr__(self, name, value): """Registers parameters.""" @@ -833,6 +848,12 @@ def _deferred_infer_shape(self, *args): def _call_cached_op(self, *args): if self._cached_op is None: self._build_cache(*args) + assert self._cached_op, "cached op is not None" + if self._callback: + self._cached_op._register_op_hook(self._callback, self._monitor_all) + if len(self._flags) >= 2 and self._flags[1]: + warnings.warn("Callback is not supported when static_shape=True " + " and is likely to not work correctly") args, fmt = _flatten(args, "input") assert fmt == self._in_format, "Invalid input format" @@ -938,6 +959,22 @@ def export(self, path, epoch=0, remove_amp_cast=True): save_fn = _mx_npx.save if is_np_array() else ndarray.save save_fn('%s-%04d.params'%(path, epoch), arg_dict) + def register_op_hook(self, callback, monitor_all=False): + """Install op hook for block recursively. + + Parameters + ---------- + callback : function + Takes a string and a NDArrayHandle. + monitor_all : bool, default False + If true, monitor both input and output, otherwise monitor output only. + """ + self._callback = callback + self._monitor_all = monitor_all + for cld in self._children.values(): + cld._callback = callback + cld._monitor_all = monitor_all + def forward(self, x, *args): """Defines the forward computation. Arguments can be either :py:class:`NDArray` or :py:class:`Symbol`.""" diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index c9c6000e2f6f..51301d70cbff 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -378,3 +378,23 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) { *out = reinterpret_cast(sym); API_END(); } + +int MXCachedOpRegisterOpHook(NDArrayHandle handle, + CachedOpMonitorCallback callback, + bool monitor_all) { + API_BEGIN(); + CachedOpMonitorCallback callback_temp = nullptr; + std::function clbk; + if (callback) { + callback_temp = callback; + clbk = [callback_temp](const char *name, const char *opr_name, + void *handle) { + callback_temp(name, opr_name, handle); + }; + } else { + clbk = nullptr; + } + CachedOpPtr op = *static_cast(handle); + op->RegisterOpHook(clbk, monitor_all); + API_END(); +} diff --git a/src/common/utils.cc b/src/common/utils.cc index 9fe46d94d036..032a324c96b0 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -51,5 +51,62 @@ void CastStorageDispatch(const OpContext& ctx, mxnet::op::CastStorageComputeImpl(ctx, input, output); } +void ExecuteMonInputCallback( + const nnvm::IndexedGraph &idx, const std::vector &state_arrays, + size_t nid, const std::function + &monitor_callback) { + static const auto &flist_inputs = + nnvm::Op::GetAttr("FListInputNames"); + std::vector input_names; + const nnvm::IndexedGraph::Node &inode = idx[nid]; + const nnvm::Node *node = inode.source; + if (flist_inputs.count(node->op())) { + input_names = flist_inputs[node->op()](node->attrs); + } else { + for (size_t i = 0; i < node->num_inputs(); ++i) { + input_names.emplace_back("input" + std::to_string(i)); + } + } + + for (size_t i = 0; i < node->num_inputs(); ++i) { + const nnvm::NodeEntry &input = node->inputs[i]; + if (state_arrays[idx.entry_id(input)]->is_none()) { + continue; + } + NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(input)]); + std::string name = inode.source->attrs.name + "_" + input_names[i]; + monitor_callback(name.c_str(), inode.source->op()->name.c_str(), + reinterpret_cast(cpy)); + } +} + +void ExecuteMonOutputCallback( + const nnvm::IndexedGraph &idx, const std::vector &state_arrays, + size_t nid, const std::function + &monitor_callback) { + static const auto &flist_outputs = + nnvm::Op::GetAttr("FListOutputNames"); + std::vector output_names; + const nnvm::IndexedGraph::Node &inode = idx[nid]; + const nnvm::Node *node = inode.source; + if (flist_outputs.count(node->op())) { + output_names = flist_outputs[node->op()](node->attrs); + } else { + for (size_t i = 0; i < node->num_outputs(); ++i) { + output_names.emplace_back(std::to_string(i)); + } + } + + for (size_t i = 0; i < node->num_outputs(); ++i) { + if (state_arrays[idx.entry_id(nid, i)]->is_none()) { + continue; + } + NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(nid, i)]); + std::string name = inode.source->attrs.name + "_" + output_names[i]; + monitor_callback(name.c_str(), inode.source->op()->name.c_str(), + reinterpret_cast(cpy)); + } +} + } // namespace common } // namespace mxnet diff --git a/src/common/utils.h b/src/common/utils.h index 251a8fe3c190..3b0d49729c38 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -791,6 +791,15 @@ inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) { ConvertToLegacyShape(&(shapes->at(i))); } } +void ExecuteMonInputCallback( + const nnvm::IndexedGraph &idx, const std::vector &state_arrays, + size_t nid, const std::function + &monitor_callback); + +void ExecuteMonOutputCallback( + const nnvm::IndexedGraph &idx, const std::vector &state_arrays, + size_t nid, const std::function + &monitor_callback); } // namespace common } // namespace mxnet diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index efe38019cfda..6818d757ab79 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -697,6 +697,9 @@ void CachedOp::StaticRunOps( ndinputs.emplace_back(state_arrays[idx.entry_id(j)]); CHECK(!ndinputs.back()->is_none()); } + if (monitor_callback_ && monitor_all_) { + mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_); + } ndoutputs.clear(); ndoutputs.reserve(num_outputs); req.clear(); @@ -708,6 +711,7 @@ void CachedOp::StaticRunOps( CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none()); } const DispatchMode dispatch_mode = dispatch_modes[i]; + if (createop.count(node.source->op())) { arg_shapes.clear(); arg_dtypes.clear(); @@ -735,6 +739,9 @@ void CachedOp::StaticRunOps( default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); } + if (monitor_callback_) { + mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_); + } } } } @@ -883,12 +890,12 @@ OpStatePtr CachedOp::DynamicForward( // So if it's not the inline mode, we disable recording. RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, - recording && inlining_); + recording && inlining_, nullptr, monitor_callback_, monitor_all_); } else { mxnet::ShapeVector shapes = g.GetAttr("shape"); NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, - dispatch_modes, recording && inlining_, &shapes); + dispatch_modes, recording && inlining_, &shapes, monitor_callback_, monitor_all_); { auto state_ptr = GetCachedOpState(default_ctx); auto& state = state_ptr.get_state(); @@ -1028,7 +1035,7 @@ void CachedOp::DynamicBackward( RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, - Imperative::Get()->is_recording()); + Imperative::Get()->is_recording(), nullptr, monitor_callback_); if (retain_graph) { buff.resize(num_forward_entries); @@ -1295,6 +1302,16 @@ void CachedOpBackward(const OpStatePtr& state_ptr, CopyFromTo(out_bufs[i], outputs[i]); } +/* + * Register the callback to be called when the operator is executed + */ +void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, + bool monitor_all) { + CHECK(callback) << "invalid callback"; + monitor_callback_ = callback; + monitor_all_ = monitor_all; +} + OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, Context ctx, const mxnet::ShapeVector& in_shapes, diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index c45f137b2d63..db049d59ed80 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -74,6 +74,9 @@ struct CachedOpConfig : public dmlc::Parameter { }; class CachedOp { + using CachedOpMonCallback = + std::function; + public: CachedOp( const nnvm::Symbol& sym, @@ -134,6 +137,8 @@ class CachedOp { sym.outputs = fwd_graph_.outputs; return sym; } + void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, + bool monitor_all = false); private: struct GraphInfo; @@ -203,6 +208,9 @@ class CachedOp { std::vector save_inputs_, save_outputs_; std::vector bwd_output_reqs_; + std::function monitor_callback_{nullptr}; + bool monitor_all_{false}; + std::mutex mutex_; std::unordered_map > cached_op_states_; }; diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index 568d39fc8043..5491457b188f 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -137,7 +137,9 @@ void RunGraph( std::vector *p_states, const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector *shapes) { + mxnet::ShapeVector *shapes, + const imperative::CachedOpMonCallback& callback, + const bool monitor_all) { CHECK(shapes == nullptr); for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; @@ -148,6 +150,9 @@ void RunGraph( std::vector ndoutputs = NodeOutputs(idx, i, arrays); std::vector req = NodeReq(idx, i, array_reqs); Context ctx = ndoutputs[0]->ctx(); + if (callback && monitor_all) { + mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback); + } auto invoke = [&](const OpStatePtr &state) { const nnvm::IndexedGraph::Node& node = idx[i]; DispatchMode dispatch_mode = dispatch_modes[i]; @@ -159,6 +164,9 @@ void RunGraph( }; InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke); + if (callback) { + mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); + } } } @@ -173,7 +181,9 @@ void NaiveRunGraph( std::vector *p_states, const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector *shapes) { + mxnet::ShapeVector *shapes, + const imperative::CachedOpMonCallback& callback, + const bool monitor_all) { for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; if (node.source->op() == nullptr) { @@ -183,6 +193,9 @@ void NaiveRunGraph( std::vector ndoutputs = NodeOutputs(idx, i, arrays); std::vector req; Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs, default_ctx); + if (callback && monitor_all) { + mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback); + } auto invoke = [&](const OpStatePtr &state) { const nnvm::IndexedGraph::Node& node = idx[i]; DispatchMode dispatch_mode = DispatchMode::kUndefined; @@ -205,6 +218,9 @@ void NaiveRunGraph( }; InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke); + if (callback) { + mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); + } } } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 477139fd84b8..7ae54e55da7c 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -59,6 +59,7 @@ struct EngineOprSeg { }; using MemoryPlanVector = std::vector; +using CachedOpMonCallback = std::function; inline Context GetContext(const nnvm::NodeAttrs& attrs, const std::vector& inputs, @@ -1056,7 +1057,9 @@ void RunGraph(const bool retain_graph, std::vector *p_states, const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector *shapes = nullptr); + mxnet::ShapeVector *shapes = nullptr, + const CachedOpMonCallback& callback = nullptr, + const bool monitor_all_ = false); void NaiveRunGraph(const bool retain_graph, const Context& default_ctx, @@ -1068,7 +1071,9 @@ void NaiveRunGraph(const bool retain_graph, std::vector *p_states, const DispatchModeVector &dispatch_modes, bool recording, - mxnet::ShapeVector *shapes); + mxnet::ShapeVector *shapes, + const CachedOpMonCallback& callback = nullptr, + const bool monitor_all_ = false); } // namespace imperative } // namespace mxnet diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index af30980b10ea..203721f6bd8f 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -21,6 +21,7 @@ import mxnet as mx from mxnet import gluon from mxnet.gluon import nn +from mxnet.base import py_str from mxnet.test_utils import assert_almost_equal from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from common import (setup_module, with_seed, assertRaises, teardown, @@ -1503,6 +1504,70 @@ def call_pre_hook(block, x): assert hook_call_count == 1 assert pre_hook_call_count == 2 +@with_seed() +def test_op_hook_output_names(): + def check_name(block, expected_names, inputs=None, expected_opr_names=None, monitor_all=False): + opr_names = [] + output_names = [] + + def mon_callback(node_name, opr_name, arr): + output_names.append(py_str(node_name)) + opr_names.append(py_str(opr_name)) + + block.register_op_hook(mon_callback, monitor_all) + if not inputs: + block(mx.nd.ones((2, 3, 4))) + else: + block(inputs) + + for output_name, expected_name in zip(output_names, expected_names): + print(output_name) + assert output_name == expected_name + + if expected_opr_names: + for opr_name, expected_opr_name in zip(opr_names, expected_opr_names): + assert opr_name == expected_opr_name + + # Test with Dense layer + model = mx.gluon.nn.HybridSequential() + model.add(mx.gluon.nn.Dense(2)) + model.initialize() + model.hybridize() + check_name(model, ["dense0_fwd_output"]) + + # Test with Activation, FListInputNames not registered, input name will have _input appended + model = mx.gluon.nn.HybridSequential() + model.add(mx.gluon.nn.Activation("relu")) + model.initialize() + model.hybridize() + check_name(model, ["relu0_fwd_output"]) + + # Test with Pooling, monitor_all is set to True + model = mx.gluon.nn.HybridSequential() + model.add(mx.gluon.nn.AvgPool1D()) + model.initialize() + model.hybridize() + check_name(model, ['pool0_fwd_data', 'pool0_fwd_output'], expected_opr_names=["Pooling"], + monitor_all=True) + + # stack two layers and test + model = mx.gluon.nn.HybridSequential() + model.add(mx.gluon.nn.Dense(2)) + model.add(mx.gluon.nn.Activation("relu")) + model.initialize() + model.hybridize() + check_name(model, + ['dense1_fwd_data', 'dense1_fwd_weight', + 'dense1_fwd_bias', 'dense1_fwd_output', + 'relu1_fwd_input0', 'relu1_fwd_output'], monitor_all=True) + + # check with different hybridize modes + model.hybridize(static_alloc=True) + check_name(model, + ['dense1_fwd_data', 'dense1_fwd_weight', + 'dense1_fwd_bias', 'dense1_fwd_output', + 'relu1_fwd_input0', 'relu1_fwd_output'], monitor_all=True) + @with_seed() def test_apply():