diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 0c6aae921352..1f6b86c978c6 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -319,7 +319,7 @@ def _reduce(self): # fetch all rows for 'row_sparse' param all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx) data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx) - self._trainer._row_sparse_pull(self, data, all_row_ids) + self._trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True) return data def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 98a6878b94ba..028e66075100 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -235,14 +235,21 @@ def set_learning_rate(self, lr): else: self._optimizer.set_learning_rate(lr) - def _row_sparse_pull(self, parameter, out, row_id): + def _row_sparse_pull(self, parameter, out, row_id, full_idx=False): + """Internal method to invoke pull operations on KVStore. If `full_idx` is set to True, + `kv.pull` is preferred instead of `kv.row_sparse_pull`. + """ # initialize kv and params if not already if not self._kv_initialized: self._init_kvstore() if self._params_to_init: self._init_params() idx = self._param2idx[parameter.name] - self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx) + if full_idx and 'dist' not in self._kvstore.type: + assert row_id.size == out.shape[0] + self._kvstore.pull(idx, out=out, priority=-idx, ignore_sparse=False) + else: + self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx) def step(self, batch_size, ignore_stale_grad=False): """Makes one step of parameter update. Should be called after diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 2a34400d60ab..72c01acb2652 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -113,6 +113,24 @@ def test_trainer_save_load(): # check if parameter dict is correctly associated with optimizer after load_state assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2 +@with_seed() +def test_trainer_sparse_save_load(): + x = gluon.Parameter('x', shape=(10, 1), lr_mult=1.0, stype='row_sparse') + x.initialize(ctx=[mx.cpu(0)], init='zeros') + trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1}) + all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0)) + with mx.autograd.record(): + for w in x.list_row_sparse_data(all_rows): + y = w * 1 + y.backward() + trainer.step(1) + assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1 + trainer.save_states('test_trainer_sparse_save_load.states') + trainer.load_states('test_trainer_sparse_save_load.states') + x.lr_mult = 2.0 + # check if parameter dict is correctly associated with optimizer after load_state + assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2 + @with_seed() def test_trainer_multi_layer_init(): class Net(gluon.Block): @@ -158,23 +176,6 @@ def check_init(ctxes): check_init([mx.cpu(1), mx.cpu(2)]) check_init([mx.cpu(1)]) -@with_seed() -def test_trainer_save_load(): - x = gluon.Parameter('x', shape=(10,), lr_mult=1.0) - x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') - trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1}) - with mx.autograd.record(): - for w in x.list_data(): - y = w + 1 - y.backward() - trainer.step(1) - assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1 - trainer.save_states('test_trainer_save_load.states') - trainer.load_states('test_trainer_save_load.states') - x.lr_mult = 2.0 - # check if parameter dict is correctly associated with optimizer after load_state - assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2 - @with_seed() def test_trainer_reset_kv(): def check_trainer_reset_kv(kv):