diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 1f8fd0f4ec39..a1a21103a0e0 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1060,6 +1060,18 @@ def _build_cache(self, *args): for name in out.list_auxiliary_states()} # Partition the graph. out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts) + # BFS to delete the delete args/aux from the block Params and its children's Params + input_names = out.list_inputs() + queue = [self] + while len(queue) > 0: + curr_block = queue.pop(0) + curr_params_names = list(curr_block.params._params.keys()) + for k in curr_params_names: + if k not in input_names: + curr_block.params._params.pop(k) + + queue.extend(curr_block._children.values()) + #update cached graph with partitioned graph self._cached_graph = data, out diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 5ffa9ac9d010..4daa45a25934 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1549,7 +1549,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): arg_names = self.list_arguments() new_arg_names = new_sym.list_arguments() deleted_arg_names = set([item for item in arg_names - if item not in set(new_arg_names)]) + if item not in set(new_arg_names)]) if len(deleted_arg_names) > 0: if args is not None: @@ -1562,7 +1562,7 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs): aux_names = self.list_auxiliary_states() new_aux_names = new_sym.list_auxiliary_states() deleted_aux_names = set([item for item in aux_names - if item not in set(new_aux_names)]) + if item not in set(new_aux_names)]) if len(deleted_aux_names) > 0: if aux is not None: for a_n in deleted_aux_names: