Skip to content

Commit

Permalink
fix unidirectional model's parameter format (apache#12055)
Browse files Browse the repository at this point in the history
* fix unidirectional model's parameter format

* Update rnn_layer.py
  • Loading branch information
szha authored and eric-haibin-lin committed Aug 8, 2018
1 parent 2d2fab9 commit 31992ef
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
# pylint: disable=too-many-lines, arguments-differ
"""Definition of various recurrent neural network layers."""
from __future__ import print_function
import re

__all__ = ['RNN', 'LSTM', 'GRU']

from ... import ndarray, symbol
Expand Down Expand Up @@ -92,10 +94,17 @@ def __repr__(self):
def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
def convert_key(key): # for compatibility with old parameter format
key = key.split('_')
return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0], '_'.join(key[1:]))
ret = {prefix + convert_key(key) : val for key, val in self._reg_params.items()}
pattern = re.compile(r'(l|r)(\d)_(i2h|h2h)_(weight|bias)\Z')
def convert_key(m, bidirectional): # for compatibility with old parameter format
d, l, g, t = [m.group(i) for i in range(1, 5)]
if bidirectional:
return '_unfused.{}.{}_cell.{}_{}'.format(l, d, g, t)
else:
return '_unfused.{}.{}_{}'.format(l, g, t)
bidirectional = any(pattern.match(k).group(1) == 'r' for k in self._reg_params)

ret = {prefix + convert_key(pattern.match(key), bidirectional) : val
for key, val in self._reg_params.items()}
for name, child in self._children.items():
ret.update(child._collect_params_with_prefix(prefix + name))
return ret
Expand Down

0 comments on commit 31992ef

Please sign in to comment.