Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WaveNetの損失関数を計算する時の出力のシフト方向 #21

Closed
zzxiang opened this issue Sep 25, 2021 · 5 comments
Closed
Labels

Comments

@zzxiang
Copy link
Contributor

zzxiang commented Sep 25, 2021

お世話になっております!

WaveNetの損失関数に関して、一つ間違っているかと思うところがあって、ご確認いただきたいです。

7.7節の最後により、WaveNetの損失関数を計算する時に、

自己回帰モデルとしての制約を満たすために、出力を時間方向に一つシフトしていることに注意します。シフトしないまま損失を計算すると、WaveNetは時刻 t までの入力を元に時刻 t の音声サンプルを予測するという、本来の目的に沿わない動作をしてしまいます。時刻 t までの入力を元に、時刻 t + 1 の音声サンプルを予測することが、学習の目的であることに注意します。この問題は、WaveNetのみならず、teacher forcingを利用するその他の自己回帰モデルにも共通するため、実装の際に十分に注意する必要があります。

該当のソースコードcode 7.16とcode 7.17は下記です。

"# 自己回帰性を保つため、出力を時間方向に1つシフトする\n",
"nll = nn.NLLLoss()(log_prob[:, :, 1:], x[:, :-1])"

"ce_loss = nn.CrossEntropyLoss()(x_hat[:, :, 1:], x[:, :-1])\n",

しかし、第8章のcode 8.11に、出力のシフト方向は正反対です。

" # 負の対数尤度の計算\n",
" loss = nn.CrossEntropyLoss()(x_hat[:, :, :-1], x[:, 1:]).mean()\n",

レシピソースコードのシフト方向も同じ正反対です。

# 損失 (負の対数尤度) の計算
# y_hat: (B, C, T)
# x: (B, T)
loss = criterion(y_hat[:, :, :-1], x[:, 1:]).mean()

もしかして片方が間違っているかと思っています。

自分の認識として、teacher forcingの場合、x_hat[:, :, t]は因果的な畳み込みで、x[:, t]までの音声サンプルから予測され、x[:, t + 1]と比較するのが正しいかと思います。なので、第8章以降の方(x_hat[:, :, :-1], x[:, 1:])が正しいように思います。

しかしcode 8.11を編集して実際に実行してみた結果、x_hat[:, :, :-1], x[:, 1:]の損失値はx_hat[:, :, 1:], x[:, :-1]より大きかったです。

前者の損失値は

5.5439348220825195
5.494748115539551
5.402365684509277
5.309176921844482
5.262940883636475
...

で、後者の損失値は

5.043774604797363
4.923819541931152
4.949016094207764
4.854518413543701
4.862161636352539
...

です。

なので、どちらが正しいかはよくわからなくなります。ご確認いただけないでしょうか?

@zzxiang
Copy link
Contributor Author

zzxiang commented Sep 25, 2021

申し訳ございません!間違えて内容を書く前にissueをサブミットしました。内容を補足します!

@zzxiang
Copy link
Contributor Author

zzxiang commented Sep 25, 2021

内容を補足させていただきました!
お手数をおかけしますが、
ご確認のほど、よろしくお願いいたします!

@r9y9
Copy link
Owner

r9y9 commented Sep 25, 2021

ご指摘ありがとうございます。おっしゃるとおり、code 7.16とcode 7.17 が誤りであり、一方でcode 8.11は正しいです。

自分の認識として、teacher forcingの場合、x_hat[:, :, t]は因果的な畳み込みで、x[:, t]までの音声サンプルから予測され、x[:, t + 1]と比較するのが正しいかと思います。なので、第8章以降の方(x_hat[:, :, :-1], x[:, 1:])が正しいように思います。

そのとおりです。

損失関数について、誤った計算において値が大きくなってしまうのは、計算が誤っているかどうかとは関係ないと思います。気にしなくて問題ありません。

ご指摘いただいた誤りについては、修正し、正誤表を更新します。修正が完了しましたら、このissueで報告します。

r9y9 added a commit that referenced this issue Sep 25, 2021
@r9y9
Copy link
Owner

r9y9 commented Sep 25, 2021

ソースコードを修正し、また正誤表をアップデートしました https://docs.google.com/spreadsheets/d/185pTXTzCI3l4kkJTXVa4fsu6yhAwd8aury2PnLol55Q/edit?usp=sharing

@zzxiang
Copy link
Contributor Author

zzxiang commented Sep 25, 2021

ありがとうございました!

@zzxiang zzxiang closed this as completed Sep 25, 2021
@r9y9 r9y9 added the 誤植 label Sep 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants