Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix unidirectional model's parameter format
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Aug 6, 2018
1 parent 70efd32 commit 45f79f9
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# pylint: disable=too-many-lines, arguments-differ
"""Definition of various recurrent neural network layers."""
from __future__ import print_function
import re
__all__ = ['RNN', 'LSTM', 'GRU']

from ... import ndarray, symbol
Expand Down Expand Up @@ -92,10 +93,17 @@ def __repr__(self):
def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
def convert_key(key): # for compatibility with old parameter format
key = key.split('_')
return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0], '_'.join(key[1:]))
ret = {prefix + convert_key(key) : val for key, val in self._reg_params.items()}
pattern = re.compile(r'(l|r)(\d)_(i2h|h2h)_(weight|bias)')
def convert_key(m, bidirectional): # for compatibility with old parameter format
d, l, g, t = [m.group(i) for i in range(1, 5)]
if bidirectional:
return '_unfused.{}.{}_cell.{}_{}'.format(l, d, g, t)
else:
return '_unfused.{}.{}_{}'.format(l, g, t)
bidirectional = any(pattern.fullmatch(k).group(1) == 'r' for k in self._reg_params)

ret = {prefix + convert_key(pattern.match(key), bidirectional) : val
for key, val in self._reg_params.items()}
for name, child in self._children.items():
ret.update(child._collect_params_with_prefix(prefix + name))
return ret
Expand Down

0 comments on commit 45f79f9

Please sign in to comment.