diff --git a/dbgpt/rag/retriever/rewrite.py b/dbgpt/rag/retriever/rewrite.py index 223b52b61..f18b6c576 100644 --- a/dbgpt/rag/retriever/rewrite.py +++ b/dbgpt/rag/retriever/rewrite.py @@ -89,14 +89,15 @@ async def rewrite( Returns: queries: List[str] """ + from dbgpt.util.chat_util import run_async_tasks + prompt = self._prompt_template.format( context=context, original_query=origin_query, nums=nums ) messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)] request = ModelRequest(model=self._model_name, messages=messages) - # tasks = [self._llm_client.generate(request)] - # queries = await run_async_tasks(tasks=tasks, concurrency_limit=1) - queries = await self._llm_client.generate(request) + tasks = [self._llm_client.generate(request)] + queries = await run_async_tasks(tasks=tasks, concurrency_limit=1) queries = [model_out.text for model_out in queries] queries = list( filter(