diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 13eb89d906b6..d6dbbcc2c72c 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -993,7 +993,8 @@ def predict(self, input_texts: list[str], return_tokens=False): output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu() tensor_queue.put(output_tensor) - done_event.wait() + if self.tensor_parallel_rank == 0: + done_event.wait() s_time = time.time() while self.model_inputs["not_need_stop"]: self._infer(self.model_inputs) @@ -1119,7 +1120,8 @@ def predict(self, input_texts: list[str], return_tokens=False): read_res_process.start() output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu() tensor_queue.put(output_tensor) - done_event.wait() + if self.tensor_parallel_rank == 0: + done_event.wait() s_time = time.time() while self.model_inputs["not_need_stop"]: self.predictor.run(list(self.model_inputs.values()))