Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Reduce a copy for rowsparse parameter during parameter.save #12039

Merged
merged 1 commit into from
Aug 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _reduce(self):
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx)
self._trainer._row_sparse_pull(self, data, all_row_ids)
self._trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True)
return data

def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
Expand Down
11 changes: 9 additions & 2 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,21 @@ def set_learning_rate(self, lr):
else:
self._optimizer.set_learning_rate(lr)

def _row_sparse_pull(self, parameter, out, row_id):
def _row_sparse_pull(self, parameter, out, row_id, full_idx=False):
"""Internal method to invoke pull operations on KVStore. If `full_idx` is set to True,
`kv.pull` is preferred instead of `kv.row_sparse_pull`.
"""
# initialize kv and params if not already
if not self._kv_initialized:
self._init_kvstore()
if self._params_to_init:
self._init_params()
idx = self._param2idx[parameter.name]
self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)
if full_idx and 'dist' not in self._kvstore.type:
assert row_id.size == out.shape[0]
self._kvstore.pull(idx, out=out, priority=-idx, ignore_sparse=False)
else:
self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)

def step(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update. Should be called after
Expand Down
35 changes: 18 additions & 17 deletions tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ def test_trainer_save_load():
# check if parameter dict is correctly associated with optimizer after load_state
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2

@with_seed()
def test_trainer_sparse_save_load():
x = gluon.Parameter('x', shape=(10, 1), lr_mult=1.0, stype='row_sparse')
x.initialize(ctx=[mx.cpu(0)], init='zeros')
trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0))
with mx.autograd.record():
for w in x.list_row_sparse_data(all_rows):
y = w * 1
y.backward()
trainer.step(1)
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
trainer.save_states('test_trainer_sparse_save_load.states')
trainer.load_states('test_trainer_sparse_save_load.states')
x.lr_mult = 2.0
# check if parameter dict is correctly associated with optimizer after load_state
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2

@with_seed()
def test_trainer_multi_layer_init():
class Net(gluon.Block):
Expand Down Expand Up @@ -158,23 +176,6 @@ def check_init(ctxes):
check_init([mx.cpu(1), mx.cpu(2)])
check_init([mx.cpu(1)])

@with_seed()
def test_trainer_save_load():
x = gluon.Parameter('x', shape=(10,), lr_mult=1.0)
x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
trainer.save_states('test_trainer_save_load.states')
trainer.load_states('test_trainer_save_load.states')
x.lr_mult = 2.0
# check if parameter dict is correctly associated with optimizer after load_state
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2

@with_seed()
def test_trainer_reset_kv():
def check_trainer_reset_kv(kv):
Expand Down