diff --git a/language/mixtral-8x7b/dataset.py b/language/mixtral-8x7b/dataset.py index d2cafac63..4bc12d84a 100644 --- a/language/mixtral-8x7b/dataset.py +++ b/language/mixtral-8x7b/dataset.py @@ -90,7 +90,14 @@ def postProcess(self, out_tokens, input_seq_lens=None, # Everything is padded to max_len (1024), so prune the input and parse # to numpy output_seq = out_tokens[:, 1024:].cpu().numpy() + aux_seq = [] assert len(query_id_list) == output_seq.shape[0] + for i in range(len(output_seq)): + aux = output_seq[i] + while(len(output_seq[i]) <= 1): + aux = np.append(aux, self.tokenizer.eos_token_id) + aux_seq.append(aux) + output_seq = np.stack(aux_seq) # Save outputs if not os.path.exists("run_outputs"):