Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Aug 21, 2023
1 parent 491a9ed commit 606f0fa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
1 change: 0 additions & 1 deletion python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,6 @@ def get_initial_states(
dtype=None,
init_value=0.0,
batch_dim_idx=0,
proj_size=None,
):
r"""
Generate initialized states according to provided shape, data type and
Expand Down
23 changes: 19 additions & 4 deletions test/rnn/rnn_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,14 @@ def forward(self, inputs, hx=None):


class LSTMCell(LayerMixin):
def __init__(self, input_size, hidden_size, bias=True, dtype="float64"):
def __init__(
self,
input_size,
hidden_size,
bias=True,
dtype="float64",
proj_size=None,
):
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
Expand All @@ -154,10 +161,16 @@ def __init__(self, input_size, hidden_size, bias=True, dtype="float64"):
-std, std, (4 * hidden_size, input_size)
).astype(dtype)
self.weight_hh = np.random.uniform(
-std, std, (4 * hidden_size, hidden_size)
-std, std, (4 * hidden_size, proj_size or hidden_size)
).astype(dtype)
self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh
self.proj_size = proj_size
if proj_size:
self.weight_ho = np.random.uniform(
-std, std, (proj_size, hidden_size)
).astype(dtype)
self.parameters['weight_hh'] = self.weight_ho
if bias:
self.bias_ih = np.random.uniform(
-std, std, (4 * hidden_size)
Expand Down Expand Up @@ -195,6 +208,8 @@ def forward(self, inputs, hx=None):
o = 1.0 / (1.0 + np.exp(-chunked_gates[3]))
c = f * pre_cell + i * np.tanh(chunked_gates[2])
h = o * np.tanh(c)
if self.proj_size:
h = np.matmul(h, self.weight_ho.T)

return h, (h, c)

Expand Down Expand Up @@ -401,10 +416,10 @@ def forward(self, inputs, initial_states=None, sequence_length=None):
batch_size = inputs.shape[batch_index]
dtype = inputs.dtype
if initial_states is None:
state_shape = (self.num_layers * self.num_directions, batch_size)
state_shape = (self.wum_layers * self.num_directions, batch_size)
proj_size = self.proj_size if hasattr(self, 'proj_size') else None

dims = ([proj_size or self.hidden_size], [self.hidden_size])
dims = ((proj_size or self.hidden_size), (self.hidden_size))
if self.state_components == 1:
initial_states = np.zeros(state_shape, dtype)
else:
Expand Down

0 comments on commit 606f0fa

Please sign in to comment.