From c8037cd048d222b9728e45925431031521a8350c Mon Sep 17 00:00:00 2001 From: vinhtran Date: Thu, 23 Feb 2023 10:16:57 +0700 Subject: [PATCH] fix len greedy decode --- AnnotatedTransformer.ipynb | 6 ++++-- the_annotated_transformer.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/AnnotatedTransformer.ipynb b/AnnotatedTransformer.ipynb index 0f7da7d..f517d14 100644 --- a/AnnotatedTransformer.ipynb +++ b/AnnotatedTransformer.ipynb @@ -2497,7 +2497,7 @@ "def greedy_decode(model, src, src_mask, max_len, start_symbol):\n", " memory = model.encode(src, src_mask)\n", " ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)\n", - " for i in range(max_len - 1):\n", + " for i in range(max_len):\n", " out = model.decode(\n", " memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)\n", " )\n", @@ -2507,7 +2507,9 @@ " ys = torch.cat(\n", " [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1\n", " )\n", - " return ys" + " \n", + " # Return the target sequence without the start symbol\n", + " return ys[:, 1:]" ] }, { diff --git a/the_annotated_transformer.py b/the_annotated_transformer.py index 4aa1d46..f31cd3f 100644 --- a/the_annotated_transformer.py +++ b/the_annotated_transformer.py @@ -1313,7 +1313,7 @@ def __call__(self, x, y, norm): def greedy_decode(model, src, src_mask, max_len, start_symbol): memory = model.encode(src, src_mask) ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data) - for i in range(max_len - 1): + for i in range(max_len): out = model.decode( memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data) ) @@ -1323,7 +1323,9 @@ def greedy_decode(model, src, src_mask, max_len, start_symbol): ys = torch.cat( [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1 ) - return ys + + # Return the target sequence without the start symbol + return ys[:, 1:] # %% id="qgIZ2yEtdYwe" tags=[]