From 28144996310f5303a09355d55f238786737dc346 Mon Sep 17 00:00:00 2001 From: Pablo Gonzalez Date: Fri, 19 Jul 2024 11:36:14 -0500 Subject: [PATCH] Add eos padding to postprocessing (#1783) --- language/mixtral-8x7b/dataset.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/language/mixtral-8x7b/dataset.py b/language/mixtral-8x7b/dataset.py index d2cafac63..4bc12d84a 100644 --- a/language/mixtral-8x7b/dataset.py +++ b/language/mixtral-8x7b/dataset.py @@ -90,7 +90,14 @@ def postProcess(self, out_tokens, input_seq_lens=None, # Everything is padded to max_len (1024), so prune the input and parse # to numpy output_seq = out_tokens[:, 1024:].cpu().numpy() + aux_seq = [] assert len(query_id_list) == output_seq.shape[0] + for i in range(len(output_seq)): + aux = output_seq[i] + while(len(output_seq[i]) <= 1): + aux = np.append(aux, self.tokenizer.eos_token_id) + aux_seq.append(aux) + output_seq = np.stack(aux_seq) # Save outputs if not os.path.exists("run_outputs"):