From 4430ae18a377c2563c9f890d6834023c2153f206 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 22 Sep 2015 15:47:07 -0700 Subject: [PATCH] [MODEL] Allow extra params --- python/mxnet/model.py | 17 +++++++++++++++++ python/mxnet/symbol.py | 10 ++++++++++ tests/python/train/test_mlp.py | 12 +++++++++--- tests/python/unittest/test_symbol.py | 4 ++-- 4 files changed, 38 insertions(+), 5 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index df450be4cb86..2da0e98ae8ca 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -420,15 +420,32 @@ class FeedForward(BASE_ESTIMATOR): aux_params : dict of str to NDArray, optional Model parameter, dict of name to NDArray of net's auxiliary states. + allow_extra_params : boolean, optional + Whether allow extra parameters that are not needed by symbol + to be passed by aux_params and arg_params. + If this is True, no error will be thrown when aux_params and arg_params + contain extra parameters than needed. + **kwargs : dict The additional keyword arguments passed to optimizer. """ def __init__(self, symbol, ctx=None, num_round=None, optimizer='sgd', initializer=Xavier(), arg_params=None, aux_params=None, + allow_extra_params=False, **kwargs): # check if symbol contain duplicated names. _check_arguments(symbol) + # rematch parameters to delete useless ones + if allow_extra_params: + if arg_params: + arg_names = set(symbol.list_arguments()) + arg_params = {k : v for k, v in arg_params.items() + if k in arg_names} + if aux_params: + aux_names = set(symbol.list_auxiliary_states()) + aux_params = {k : v for k, v in aux_params.items() + if k in aux_names} # basic configuration self.symbol = symbol if ctx is None: diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index e4f795ec1795..1c8841a74460 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -154,6 +154,16 @@ def _compose(self, *args, **kwargs): self.handle, name, num_args, keys, args)) def __getitem__(self, index): + if isinstance(index, string_types): + idx = None + for i, name in enumerate(self.list_outputs()): + if name == index: + if idx is not None: + raise ValueError('There are multiple outputs with name \"%s\"' % index) + idx = i + if idx is None: + raise ValueError('Cannot find output that matches name \"%s\"' % index) + index = idx if not isinstance(index, int): raise TypeError('Symbol only support integer index to fetch i-th output') handle = SymbolHandle() diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 5ad44fe0350b..ca738128d717 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -67,6 +67,15 @@ def test_mlp(): logging.info('final accuracy = %f', acc1) assert(acc1 > 0.95) + # predict internal featuremaps + internals = softmax.get_internals() + fc2 = internals['fc2_output'] + mfeat = mx.model.FeedForward(symbol=fc2, + arg_params=model.arg_params, + aux_params=model.aux_params, + allow_extra_params=True) + feat = mfeat.predict(val_dataiter) + assert feat.shape == (10000, 64) # pickle the model smodel = pickle.dumps(model) model2 = pickle.loads(smodel) @@ -79,9 +88,6 @@ def test_mlp(): assert np.sum(np.abs(prob - prob3)) == 0 # save model explicitly - - - model.save(prefix, 128) model4 = mx.model.FeedForward.load(prefix, 128) prob4 = model4.predict(val_dataiter) diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 199d3dfaf7cb..0298645f7285 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -39,8 +39,8 @@ def test_symbol_internal(): 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias'] internal = net1.get_internals() - nmap = {x: i for i, x in enumerate(internal.list_outputs())} - fc1 = internal[nmap['fc1_output']] + print internal.list_outputs() + fc1 = internal['fc1_output'] assert fc1.list_arguments() == oldfc.list_arguments()