Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Partition API adding and deleting new params to Block and Symbol #18405

Merged
merged 4 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 76 additions & 29 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,41 +1040,69 @@ def _build_cache(self, *args):
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)

data_indices = []
param_indices = []
self._cached_op_args = []
for i, name in enumerate(input_names):
if name in data_names:
data_indices.append(i)
self._cached_op_args.append((True, data_names[name]))
else:
param_indices.append(i)
self._cached_op_args.append((False, params[name]))
flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags

args, _ = _flatten(args, "input")
try:
for is_arg, i in self._cached_op_args:
if not is_arg:
i.data()
for name in input_names:
if name in params:
params[name].data()
except DeferredInitializationError:
self._deferred_infer_shape(*args)
for is_arg, i in self._cached_op_args:
if not is_arg:
i._finish_deferred_init()
for name in input_names:
if name in params:
params[name]._finish_deferred_init()

arg_dict, aux_dict = dict(), dict()
if self._backend:
ctx = args[0].context
# get list of params in the order of out.list_arguments
arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()}
aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()}
arg_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()})
Copy link
Contributor

@samskalicky samskalicky Jul 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you initialize with dict() and then call update on an empty dictionary instead of just assigning?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arg_dict and aux_dict could otherwise be undefined below the assert in line 1083. This could be a SyntaxError or linter error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, also Python scopes rules are sometimes a bit unsettling. So I thought that this would make it clearer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, thanks!

aux_dict.update({name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()})
# Partition the graph.
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)

#update cached graph with partitioned graph
self._cached_graph = data, out

input_names = out.list_inputs()
data_indices = []
param_indices = []

# In the default case, _cached_ops_args contains all the parameters from params (the sets are identical)
# In the case of Partition API optimized graph _cached_ops_args might contain some parameters from params,
# might contain some new parameters created during optimization and added to `arg_dict/aux_dict`,
# and might not contain some parameters that were deleted during optimization.
self._cached_op_args = []
for i, name in enumerate(input_names):
pair = None
if name in data_names:
data_indices.append(i)
pair = (True, data_names[name])
else:
param_indices.append(i)
if name in params:
param = params[name]
else:
# The param is missing from the original params dictionary, which means the param must have
# been added by the Partition API backend
if name in arg_dict or name:
param_data = arg_dict[name]
elif name in aux_dict:
param_data = aux_dict[name]
else:
raise RuntimeError('A parameter was added to the graph during optimization but it was not '
'added to the parameter dicts.\n'
'Please check the backend.')

param = Parameter(name)
param._load_init(param_data, args[0].context)
pair = (False, param)

self._cached_op_args.append(pair)

flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
self._flags
self._cached_op = ndarray.CachedOp(out, flags)


Expand Down Expand Up @@ -1321,12 +1349,14 @@ def export(self, path, epoch=0, remove_amp_cast=True):
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
arg_dict = {}
for param in self.collect_params().values():
if param.name in arg_names:
arg_dict['arg:%s'%param.name] = param._reduce()
else:
assert param.name in aux_names
arg_dict['aux:%s'%param.name] = param._reduce()
for is_arg, param in self._cached_op_args:
if not is_arg:
name = param.name
if name in arg_names:
arg_dict['arg:{}'.format(name)] = param._reduce()
else:
assert name in aux_names
arg_dict['aux:{}'.format(name)] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
params_filename = '%s-%04d.params'%(path, epoch)
save_fn(params_filename, arg_dict)
Expand Down Expand Up @@ -1437,6 +1467,23 @@ def hybrid_forward(self, F, x, *args, **kwargs):
# pylint: disable= invalid-name
raise NotImplementedError

def reset_ctx(self, ctx):
"""Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args.

Parameters
----------
ctx : Context or list of Context, default :py:meth:`context.current_context()`.
Assign Parameter to given context. If ctx is a list of Context, a
copy will be made for each context.
"""
params = self.collect_params()
if self._cached_op:
for p in self._cached_op_args:
# resetting parameters creating by the partitioning backend
if p.name not in params:
p.reset_ctx(ctx)
for p in params.values():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we need to loop over _cached_op_args and only reset the params not in params, and then loop again over params and then reset them there instead? Is it possible to do the work in the 2nd loop in the first loop?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although i guess if we delete a param, then it will still be in params but not in _cached_op_args. And the context check will fail if we dont do all the params, so i guess this makes sense. Maybe we should have a comment that params and _cached_op_args might contain unique params (ie. there is no superset)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although i guess if we delete a param, then it will still be in params but not in _cached_op_args.
That is the reason why
We don't want to do any additional reset_ctx: reset is costly because it copies NDArrays

I will add some comments to clarify

p.reset_ctx(ctx)

class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
Expand Down
32 changes: 30 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,8 +1544,36 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' +
'Provide a dictionary to the aux argument to optimize_for')

# return modified symbol
return Symbol(out)
new_sym = Symbol(out)

arg_names = self.list_arguments()
new_arg_names = new_sym.list_arguments()
deleted_arg_names = set([item for item in arg_names
if item not in set(new_arg_names)])

if len(deleted_arg_names) > 0:
if args is not None:
for a_n in deleted_arg_names:
if a_n in args:
args.pop(a_n)
else:
warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' +
'Please ensure that your model weights match the newly optimized model.')

aux_names = self.list_auxiliary_states()
new_aux_names = new_sym.list_auxiliary_states()
deleted_aux_names = set([item for item in aux_names
if item not in set(new_aux_names)])
if len(deleted_aux_names) > 0:
if aux is not None:
for a_n in deleted_aux_names:
if a_n in aux:
aux.pop(a_n)
else:
warnings.warn('A param was deleted during optimization, but no args dictionary was provided.\n' +
'Please ensure that your model weights match the newly optimized model.')

return new_sym

# pylint: disable=too-many-locals
def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
Expand Down