-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-374] handle row_sparse weight in parameter and trainer #11001
Conversation
python/mxnet/gluon/block.py
Outdated
if stype != 'default': | ||
raise ValueError("Cannot create a HybridBlock with Parameter '%s' " \ | ||
"because its storage type is %s. Please consider " \ | ||
"using a SparseBlock instead."%(param.name, stype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR for sparse block will be created separately after this one is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"please consider using" -> "please use"
|
||
p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)]) | ||
assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)] | ||
|
||
@with_seed() | ||
def test_sparse_parameter(): | ||
p = gluon.Parameter('weight', shape=(10, 10), grad_stype='row_sparse') | ||
p = gluon.Parameter('weight', shape=(10, 10), stype='row_sparse', grad_stype='row_sparse') | ||
p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like constraining the contexts to cpu is causing test failures on GPU, is this a necessary thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
"grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \ | ||
" but got '%s'" % (name, grad_stype) | ||
# sparse related storage type information | ||
valid_stypes = ['default', 'row_sparse', 'csr'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might as well make it a set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only has 3 elements. I don't think this makes any real difference
python/mxnet/gluon/parameter.py
Outdated
""" Set the trainer this parameter is associated with. """ | ||
if self._trainer and self._trainer is not trainer: | ||
raise RuntimeError( | ||
"Failed to set the trainer for Parameter '%s' to %s because it was set to %s. " \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can user detach a parameter's association with a trainer without exiting python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Users can just call _set_trainer(None)
. I don't think this will be used by common users, hence it remains private
python/mxnet/gluon/block.py
Outdated
""" | ||
def __init__(self, prefix=None, params=None): | ||
# check if any parameter is row_sparse | ||
if isinstance(params, ParameterDict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check shouldn't be done here.
Parameters are only added to the current block when self.params.get is called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. Will the checks in param.list_data()
and param.data()
be sufficient?
python/mxnet/gluon/parameter.py
Outdated
raise RuntimeError( | ||
"Failed to set the trainer for Parameter '%s' to %s because it was set to %s. " \ | ||
"More than one trainers for a single Parameter is not supported." %( | ||
self.name, str(trainer), str(self._trainer))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does str(trainer) show? It's likely not meaningful to users
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a breaking change.
Suppose users want to use sgd to train 10 epochs and then switch to ADAM, this would prevent that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now only throws exception for rowsparse param
python/mxnet/gluon/parameter.py
Outdated
""" Get row_sparse data from row_sparse parameters based on row_id. """ | ||
# get row sparse params based on row ids | ||
if not isinstance(row_id, ndarray.NDArray): | ||
raise TypeError("Cannot get 'row_sparse' Parameter %s with %s type. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"row_id must have NDArray type, but %s is given"
python/mxnet/gluon/parameter.py
Outdated
"NDArray type is expected." % (self.name, type(row_id))) | ||
if not self._trainer: | ||
raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \ | ||
"Trainer is created with it."%self.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if user want to train with single device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For single device, we will encourage the user to use normal hybrid blocks with sparse_grad=True. There's no need to use rowsparse weight.
Even if the user choose to use rowsparse weight, a kvstore is created for the rowsparse param and the code still works.
python/mxnet/gluon/parameter.py
Outdated
"""(Re)initializes by loading from data.""" | ||
if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore: | ||
raise RuntimeError("Cannot (Re)initialize Parameter '%s' when its Trainer " \ | ||
"already initialized the parameter on KVStore."%(self.name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
message is cryptic. The reason is multi device training and update_on_kvstore is true.
error message should describe the reason and suggest a solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated message.
python/mxnet/gluon/parameter.py
Outdated
@@ -396,11 +485,25 @@ def data(self, ctx=None): | |||
------- | |||
NDArray on ctx | |||
""" | |||
if self._stype != 'default': | |||
raise ValueError("Cannot return a copy of Parameter '%s' on ctx %s via data() " \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be UserError?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I should change to RuntimeError? There's UserWarning but I am not aware of UserError
python/mxnet/gluon/trainer.py
Outdated
self._params.append(param) | ||
self._params_to_init.append(param) | ||
param._set_trainer(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to set_trainer when stype='default' and update_on_kvstore=False?
python/mxnet/gluon/trainer.py
Outdated
@@ -109,38 +117,54 @@ def _init_optimizer(self, optimizer, optimizer_params): | |||
self._updaters = [opt.get_updater(self._optimizer) \ | |||
for _ in self._contexts] | |||
|
|||
def _init_params(self): | |||
""" Initialize parameters in the KVStore. Parameters whose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong format
python/mxnet/gluon/trainer.py
Outdated
"when KVStore is not initialized." | ||
params_to_init = [] | ||
if self._kvstore: | ||
params = [param for param in self._params_to_init \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to use for loop and if/else here
@@ -191,6 +224,8 @@ def step(self, batch_size, ignore_stale_grad=False): | |||
""" | |||
if not self._kv_initialized: | |||
self._init_kvstore() | |||
if self._params_to_init: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand this. If there are uninitialized parameters, wouldn't step fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the logics of kv.init(param) from _init_kvstore
to _init_params
. _params_to_init
refers to params that are not initialized on kvstore.
857dfd5
to
6038fe9
Compare
…#11001) * + rsp parameter * draft * Fix optimizer pickle * refactor and document * add test for save load with cast_stype * refactor trainer tests * add test * add back test * raise error for load params * add comment * remove print * fix doc * CR comments * CR comments * change error * remove cast stype * fix test * add reset kvstore to trainer * lint * add test to CI * add more checks
…#11001) * + rsp parameter * draft * Fix optimizer pickle * refactor and document * add test for save load with cast_stype * refactor trainer tests * add test * add back test * raise error for load params * add comment * remove print * fix doc * CR comments * CR comments * change error * remove cast stype * fix test * add reset kvstore to trainer * lint * add test to CI * add more checks
Description
@piiswrong @szha @ZiyueHuang @haojin2 @safrooze please review.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments