Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Gluon] Don't serialize shared parameters twice #16582

Merged
merged 7 commits into from
Oct 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import copy
import warnings
import re
from collections import OrderedDict
from collections import OrderedDict, defaultdict

from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer
Expand Down Expand Up @@ -413,7 +413,7 @@ def _collect_params_with_prefix(self, prefix=''):
ret.update(child._collect_params_with_prefix(prefix + name))
return ret

def save_parameters(self, filename):
def save_parameters(self, filename, deduplicate=False):
"""Save parameters to file.

Saved parameters can only be loaded with `load_parameters`. Note that this
Expand All @@ -424,14 +424,28 @@ def save_parameters(self, filename):
----------
filename : str
Path to file.
deduplicate : bool, default False
If True, save shared parameters only once. Otherwise, if a Block
contains multiple sub-blocks that share parameters, each of the
shared parameters will be separately saved for every sub-block.

References
----------
`Saving and Loading Gluon Models \
<https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
"""
params = self._collect_params_with_prefix()
arg_dict = {key : val._reduce() for key, val in params.items()}

if deduplicate:
# Shared parameters are stored only a single time as of MXNet 1.6.
# Shared parameters are registered under multiple prefixes returned by
# _collect_params_with_prefix. We select a single one and only store
# it. In load_parameters it is sufficient for a shared parameter to
# only set it for a single prefix.
reverse_params = {v: k for k, v in params.items()}
params = {v: k for k, v in reverse_params.items()}

arg_dict = {key: val._reduce() for key, val in params.items()}
save_fn = _mx_npx.save if is_np_array() else ndarray.save
save_fn(filename, arg_dict)

Expand Down Expand Up @@ -510,15 +524,24 @@ def load_parameters(self, filename, ctx=None, allow_missing=False,

if not any('.' in i for i in loaded.keys()):
# legacy loading
del loaded
loaded = None # This should be changed to `del loaded` when dropping Python 2
self.collect_params().load(
filename, ctx, allow_missing, ignore_extra, self.prefix,
cast_dtype=cast_dtype, dtype_source=dtype_source)
return

if not allow_missing:
for name in params.keys():
assert name in loaded, \
# Shared parameters are stored only a single time as of MXNet 1.6.
# We thus retrieve all prefixes (through _collect_params_with_prefix)
# that a shared parameter is used with. Check that there are no
# missing parameters that were not yet already loaded from the
# shared version.
params_inv = defaultdict(list)
for k, v in params.items():
params_inv[v].append(k)

for name, param in params.items():
assert any(p in loaded for p in params_inv[param]), \
"Parameter '%s' is missing in file '%s', which contains parameters: %s. " \
"Set allow_missing=True to ignore missing parameters."%(
name, filename, _brief_print_list(loaded.keys()))
Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,46 @@ def forward(self, x):
net2 = Network()
net2.load_parameters('tmp.params')

@with_seed()
def test_save_load_deduplicate_with_shared_params():
class B(mx.gluon.Block):
def __init__(self, params=None):
super(B, self).__init__(params=params)

with self.name_scope():
self.weight = self.params.get('weight', shape=(10, 10))

class C(mx.gluon.Block):
def __init__(self, b1, b2):
super(C, self).__init__()
self.b1 = b1
self.b2 = b2

b1 = B()
b2 = B(b1.collect_params())
c = C(b1, b2)
c.initialize()
c.save_parameters('tmp.params', deduplicate=True)

params = mx.nd.load('tmp.params')
assert len(params) == 1 # Only a single copy of the shared parameter is saved

b1 = B()
b2 = B(b1.collect_params())
c = C(b1, b2)
c.load_parameters('tmp.params')

# Test default behavior
c.save_parameters('tmp2.params', deduplicate=False)

params = mx.nd.load('tmp2.params')
assert len(params) == 2 # Only a single copy of the shared parameter is saved

b1 = B()
b2 = B(b1.collect_params())
c = C(b1, b2)
c.load_parameters('tmp2.params')

@with_seed()
def test_symbol_block_save_load():
class Net(gluon.HybridBlock):
Expand Down