From 832dd1a5774c8abdc60108395fdfc308e4c3ea91 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 2 Jun 2020 00:17:41 -0700 Subject: [PATCH 1/4] Add deleting of args aux aux to Partition API Signed-off-by: Serge Panev --- python/mxnet/gluon/block.py | 39 ++++++++++++++++++----------------- python/mxnet/symbol/symbol.py | 31 ++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 2fda08067e0f..1f8fd0f4ec39 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1040,29 +1040,16 @@ def _build_cache(self, *args): warnings.warn("Parameter %s is not used by any computation. " "Is this intended?"%unused, stacklevel=4) - data_indices = [] - param_indices = [] - self._cached_op_args = [] - for i, name in enumerate(input_names): - if name in data_names: - data_indices.append(i) - self._cached_op_args.append((True, data_names[name])) - else: - param_indices.append(i) - self._cached_op_args.append((False, params[name])) - flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ - self._flags - args, _ = _flatten(args, "input") try: - for is_arg, i in self._cached_op_args: - if not is_arg: - i.data() + for name in input_names: + if name in params: + params[name].data() except DeferredInitializationError: self._deferred_infer_shape(*args) - for is_arg, i in self._cached_op_args: - if not is_arg: - i._finish_deferred_init() + for name in input_names: + if name in params: + params[name]._finish_deferred_init() if self._backend: ctx = args[0].context @@ -1075,6 +1062,20 @@ def _build_cache(self, *args): out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) #update cached graph with partitioned graph self._cached_graph = data, out + + input_names = out.list_inputs() + data_indices = [] + param_indices = [] + self._cached_op_args = [] + for i, name in enumerate(input_names): + if name in data_names: + data_indices.append(i) + self._cached_op_args.append((True, data_names[name])) + else: + param_indices.append(i) + self._cached_op_args.append((False, params[name])) + flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ + self._flags self._cached_op = ndarray.CachedOp(out, flags) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 39b8799ce155..5ffa9ac9d010 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1544,8 +1544,35 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' + 'Provide a dictionary to the aux argument to optimize_for') - # return modified symbol - return Symbol(out) + new_sym = Symbol(out) + + arg_names = self.list_arguments() + new_arg_names = new_sym.list_arguments() + deleted_arg_names = set([item for item in arg_names + if item not in set(new_arg_names)]) + + if len(deleted_arg_names) > 0: + if args is not None: + for a_n in deleted_arg_names: + if a_n in args: + args.pop(a_n) + else: + warnings.warn('optimize_for deleted some argument. \n' + + 'Provide a dictionary to the arg argument to optimize_for') + aux_names = self.list_auxiliary_states() + new_aux_names = new_sym.list_auxiliary_states() + deleted_aux_names = set([item for item in aux_names + if item not in set(new_aux_names)]) + if len(deleted_aux_names) > 0: + if aux is not None: + for a_n in deleted_aux_names: + if a_n in aux: + aux.pop(a_n) + else: + warnings.warn('optimize_for deleted some aux argument. \n' + + 'Provide a dictionary to the aux argument to optimize_for') + + return new_sym # pylint: disable=too-many-locals def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, From 2b6bab7a67884559afe0824d9b7fc8b3fc4f99d0 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 30 Jun 2020 18:36:41 -0700 Subject: [PATCH 2/4] Delete args from Block.params Signed-off-by: Serge Panev --- python/mxnet/gluon/block.py | 13 +++++++++++++ python/mxnet/symbol/symbol.py | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 1f8fd0f4ec39..89406762a64e 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1060,6 +1060,19 @@ def _build_cache(self, *args): for name in out.list_auxiliary_states()} # Partition the graph. out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) + # BFS to delete the delete args/aux from the block Params and its children's Params + input_names = out.list_inputs() + queue = [self] + while len(queue) > 0: + curr_block = queue.pop(0) + curr_params = curr_block.params if isinstance(curr_block.params, dict) else curr_block.params._params + curr_params_names = list(curr_params.keys()) + for k in curr_params_names: + if k not in input_names: + curr_params.pop(k) + + queue.extend(curr_block._children.values()) + #update cached graph with partitioned graph self._cached_graph = data, out diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 5ffa9ac9d010..4daa45a25934 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1549,7 +1549,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): arg_names = self.list_arguments() new_arg_names = new_sym.list_arguments() deleted_arg_names = set([item for item in arg_names - if item not in set(new_arg_names)]) + if item not in set(new_arg_names)]) if len(deleted_arg_names) > 0: if args is not None: @@ -1562,7 +1562,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): aux_names = self.list_auxiliary_states() new_aux_names = new_sym.list_auxiliary_states() deleted_aux_names = set([item for item in aux_names - if item not in set(new_aux_names)]) + if item not in set(new_aux_names)]) if len(deleted_aux_names) > 0: if aux is not None: for a_n in deleted_aux_names: From 11a7a5c62c8532b510079e770890faa2dfd4a8b6 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Thu, 2 Jul 2020 21:54:20 -0700 Subject: [PATCH 3/4] Fix to use arg/auxdict when optimize_for is called in HybridBlock Signed-off-by: Serge Panev --- python/mxnet/gluon/block.py | 74 +++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 24 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 89406762a64e..fe44a848d582 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1051,27 +1051,16 @@ def _build_cache(self, *args): if name in params: params[name]._finish_deferred_init() + arg_dict, aux_dict = dict(), dict() if self._backend: ctx = args[0].context # get list of params in the order of out.list_arguments - arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_arguments()} - aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_auxiliary_states()} + arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_arguments()}) + aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_auxiliary_states()}) # Partition the graph. out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) - # BFS to delete the delete args/aux from the block Params and its children's Params - input_names = out.list_inputs() - queue = [self] - while len(queue) > 0: - curr_block = queue.pop(0) - curr_params = curr_block.params if isinstance(curr_block.params, dict) else curr_block.params._params - curr_params_names = list(curr_params.keys()) - for k in curr_params_names: - if k not in input_names: - curr_params.pop(k) - - queue.extend(curr_block._children.values()) #update cached graph with partitioned graph self._cached_graph = data, out @@ -1081,12 +1070,30 @@ def _build_cache(self, *args): param_indices = [] self._cached_op_args = [] for i, name in enumerate(input_names): + pair = None if name in data_names: data_indices.append(i) - self._cached_op_args.append((True, data_names[name])) + pair = (True, data_names[name]) else: param_indices.append(i) - self._cached_op_args.append((False, params[name])) + if name in params: + param = params[name] + else: + assert self._backend, "Parameter " + name + " is missing from block params" + if name in arg_dict or name: + param_data = arg_dict[name] + elif name in aux_dict: + param_data = aux_dict[name] + else: + raise RuntimeError('Expected inputs missing from arg and aux after partioning. ' + 'Please check the backend.') + + param = Parameter(name) + param._load_init(param_data, args[0].context) + pair = (False, param) + + self._cached_op_args.append(pair) + flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ self._flags self._cached_op = ndarray.CachedOp(out, flags) @@ -1335,12 +1342,14 @@ def export(self, path, epoch=0, remove_amp_cast=True): arg_names = set(sym.list_arguments()) aux_names = set(sym.list_auxiliary_states()) arg_dict = {} - for param in self.collect_params().values(): - if param.name in arg_names: - arg_dict['arg:%s'%param.name] = param._reduce() - else: - assert param.name in aux_names - arg_dict['aux:%s'%param.name] = param._reduce() + for is_arg, param in self._cached_op_args: + if not is_arg: + name = param.name + if name in arg_names: + arg_dict['arg:{}'.format(name)] = param._reduce() + else: + assert name in aux_names + arg_dict['aux:{}'.format(name)] = param._reduce() save_fn = _mx_npx.save if is_np_array() else ndarray.save params_filename = '%s-%04d.params'%(path, epoch) save_fn(params_filename, arg_dict) @@ -1451,6 +1460,23 @@ def hybrid_forward(self, F, x, *args, **kwargs): # pylint: disable= invalid-name raise NotImplementedError + def reset_ctx(self, ctx): + """Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args. + + Parameters + ---------- + ctx : Context or list of Context, default :py:meth:`context.current_context()`. + Assign Parameter to given context. If ctx is a list of Context, a + copy will be made for each context. + """ + params = self.collect_params() + if self._cached_op: + for p in self._cached_op_args: + # resetting parameters creating by the partitioning backend + if p.name not in params: + p.reset_ctx(ctx) + for p in params.values(): + p.reset_ctx(ctx) class SymbolBlock(HybridBlock): """Construct block from symbol. This is useful for using pre-trained models From 80beb4b345450cd3f653ed32fa094a588db523ad Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Mon, 6 Jul 2020 11:27:13 -0700 Subject: [PATCH 4/4] Address PR comments Signed-off-by: Serge Panev --- python/mxnet/gluon/block.py | 11 +++++++++-- python/mxnet/symbol/symbol.py | 9 +++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index fe44a848d582..1f9cd43dd2ee 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1068,6 +1068,11 @@ def _build_cache(self, *args): input_names = out.list_inputs() data_indices = [] param_indices = [] + + # In the default case, _cached_ops_args contains all the parameters from params (the sets are identical) + # In the case of Partition API optimized graph _cached_ops_args might contain some parameters from params, + # might contain some new parameters created during optimization and added to `arg_dict/aux_dict`, + # and might not contain some parameters that were deleted during optimization. self._cached_op_args = [] for i, name in enumerate(input_names): pair = None @@ -1079,13 +1084,15 @@ def _build_cache(self, *args): if name in params: param = params[name] else: - assert self._backend, "Parameter " + name + " is missing from block params" + # The param is missing from the original params dictionary, which means the param must have + # been added by the Partition API backend if name in arg_dict or name: param_data = arg_dict[name] elif name in aux_dict: param_data = aux_dict[name] else: - raise RuntimeError('Expected inputs missing from arg and aux after partioning. ' + raise RuntimeError('A parameter was added to the graph during optimization but it was not ' + 'added to the parameter dicts.\n' 'Please check the backend.') param = Parameter(name) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 4daa45a25934..89ff6bfbd181 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1557,8 +1557,9 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): if a_n in args: args.pop(a_n) else: - warnings.warn('optimize_for deleted some argument. \n' + - 'Provide a dictionary to the arg argument to optimize_for') + warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' + + 'Please ensure that your model weights match the newly optimized model.') + aux_names = self.list_auxiliary_states() new_aux_names = new_sym.list_auxiliary_states() deleted_aux_names = set([item for item in aux_names @@ -1569,8 +1570,8 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): if a_n in aux: aux.pop(a_n) else: - warnings.warn('optimize_for deleted some aux argument. \n' + - 'Provide a dictionary to the aux argument to optimize_for') + warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' + + 'Please ensure that your model weights match the newly optimized model.') return new_sym