Skip to content

Commit

Permalink
[Fix] fastspeech2 0d (#3951)
Browse files Browse the repository at this point in the history
  • Loading branch information
megemini authored Dec 16, 2024
1 parent 73beb18 commit 8ee3a7e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions paddlespeech/t2s/models/fastspeech2/fastspeech2.py
Original file line number Diff line number Diff line change
@@ -903,14 +903,14 @@ def _reset_parameters(self, init_enc_alpha: float, init_dec_alpha: float):

# initialize alpha in scaled positional encoding
if self.encoder_type == "transformer" and self.use_scaled_pos_enc:
init_enc_alpha = paddle.to_tensor(init_enc_alpha)
init_enc_alpha = paddle.to_tensor(init_enc_alpha).reshape([1])
self.encoder.embed[-1].alpha = paddle.create_parameter(
shape=init_enc_alpha.shape,
dtype=str(init_enc_alpha.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(
init_enc_alpha))
if self.decoder_type == "transformer" and self.use_scaled_pos_enc:
init_dec_alpha = paddle.to_tensor(init_dec_alpha)
init_dec_alpha = paddle.to_tensor(init_dec_alpha).reshape([1])
self.decoder.embed[-1].alpha = paddle.create_parameter(
shape=init_dec_alpha.shape,
dtype=str(init_dec_alpha.numpy().dtype),

0 comments on commit 8ee3a7e

Please sign in to comment.