diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 2fda08067e0f..1f9cd43dd2ee 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1040,41 +1040,69 @@ def _build_cache(self, *args): warnings.warn("Parameter %s is not used by any computation. " "Is this intended?"%unused, stacklevel=4) - data_indices = [] - param_indices = [] - self._cached_op_args = [] - for i, name in enumerate(input_names): - if name in data_names: - data_indices.append(i) - self._cached_op_args.append((True, data_names[name])) - else: - param_indices.append(i) - self._cached_op_args.append((False, params[name])) - flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ - self._flags - args, _ = _flatten(args, "input") try: - for is_arg, i in self._cached_op_args: - if not is_arg: - i.data() + for name in input_names: + if name in params: + params[name].data() except DeferredInitializationError: self._deferred_infer_shape(*args) - for is_arg, i in self._cached_op_args: - if not is_arg: - i._finish_deferred_init() + for name in input_names: + if name in params: + params[name]._finish_deferred_init() + arg_dict, aux_dict = dict(), dict() if self._backend: ctx = args[0].context # get list of params in the order of out.list_arguments - arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_arguments()} - aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data() - for name in out.list_auxiliary_states()} + arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_arguments()}) + aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data() + for name in out.list_auxiliary_states()}) # Partition the graph. out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) + #update cached graph with partitioned graph self._cached_graph = data, out + + input_names = out.list_inputs() + data_indices = [] + param_indices = [] + + # In the default case, _cached_ops_args contains all the parameters from params (the sets are identical) + # In the case of Partition API optimized graph _cached_ops_args might contain some parameters from params, + # might contain some new parameters created during optimization and added to `arg_dict/aux_dict`, + # and might not contain some parameters that were deleted during optimization. + self._cached_op_args = [] + for i, name in enumerate(input_names): + pair = None + if name in data_names: + data_indices.append(i) + pair = (True, data_names[name]) + else: + param_indices.append(i) + if name in params: + param = params[name] + else: + # The param is missing from the original params dictionary, which means the param must have + # been added by the Partition API backend + if name in arg_dict or name: + param_data = arg_dict[name] + elif name in aux_dict: + param_data = aux_dict[name] + else: + raise RuntimeError('A parameter was added to the graph during optimization but it was not ' + 'added to the parameter dicts.\n' + 'Please check the backend.') + + param = Parameter(name) + param._load_init(param_data, args[0].context) + pair = (False, param) + + self._cached_op_args.append(pair) + + flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ + self._flags self._cached_op = ndarray.CachedOp(out, flags) @@ -1321,12 +1349,14 @@ def export(self, path, epoch=0, remove_amp_cast=True): arg_names = set(sym.list_arguments()) aux_names = set(sym.list_auxiliary_states()) arg_dict = {} - for param in self.collect_params().values(): - if param.name in arg_names: - arg_dict['arg:%s'%param.name] = param._reduce() - else: - assert param.name in aux_names - arg_dict['aux:%s'%param.name] = param._reduce() + for is_arg, param in self._cached_op_args: + if not is_arg: + name = param.name + if name in arg_names: + arg_dict['arg:{}'.format(name)] = param._reduce() + else: + assert name in aux_names + arg_dict['aux:{}'.format(name)] = param._reduce() save_fn = _mx_npx.save if is_np_array() else ndarray.save params_filename = '%s-%04d.params'%(path, epoch) save_fn(params_filename, arg_dict) @@ -1437,6 +1467,23 @@ def hybrid_forward(self, F, x, *args, **kwargs): # pylint: disable= invalid-name raise NotImplementedError + def reset_ctx(self, ctx): + """Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args. + + Parameters + ---------- + ctx : Context or list of Context, default :py:meth:`context.current_context()`. + Assign Parameter to given context. If ctx is a list of Context, a + copy will be made for each context. + """ + params = self.collect_params() + if self._cached_op: + for p in self._cached_op_args: + # resetting parameters creating by the partitioning backend + if p.name not in params: + p.reset_ctx(ctx) + for p in params.values(): + p.reset_ctx(ctx) class SymbolBlock(HybridBlock): """Construct block from symbol. This is useful for using pre-trained models diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 39b8799ce155..89ff6bfbd181 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1544,8 +1544,36 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' + 'Provide a dictionary to the aux argument to optimize_for') - # return modified symbol - return Symbol(out) + new_sym = Symbol(out) + + arg_names = self.list_arguments() + new_arg_names = new_sym.list_arguments() + deleted_arg_names = set([item for item in arg_names + if item not in set(new_arg_names)]) + + if len(deleted_arg_names) > 0: + if args is not None: + for a_n in deleted_arg_names: + if a_n in args: + args.pop(a_n) + else: + warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' + + 'Please ensure that your model weights match the newly optimized model.') + + aux_names = self.list_auxiliary_states() + new_aux_names = new_sym.list_auxiliary_states() + deleted_aux_names = set([item for item in aux_names + if item not in set(new_aux_names)]) + if len(deleted_aux_names) > 0: + if aux is not None: + for a_n in deleted_aux_names: + if a_n in aux: + aux.pop(a_n) + else: + warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' + + 'Please ensure that your model weights match the newly optimized model.') + + return new_sym # pylint: disable=too-many-locals def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,