Skip to content

Commit

Permalink
refactor: remove nested if
Browse files Browse the repository at this point in the history
  • Loading branch information
umbertogriffo committed Jun 28, 2024
1 parent aed9ae2 commit ab9a9df
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions chatbot/bot/conversation/ctx_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,34 @@ def generate_response(
"""
cur_response = None
fmt_prompts = []

if not retrieved_contents:
qa_prompt = self.llm.generate_qa_prompt(question=question)
logger.info("--- Generating a single response ... ---")
response = self.llm.start_answer_iterator_streamer(qa_prompt, max_new_tokens=max_new_tokens)
return response, qa_prompt

num_of_contents = len(retrieved_contents)
if num_of_contents > 0:
for idx, node in enumerate(retrieved_contents, start=1):
logger.info(f"--- Generating an answer for the chunk {idx} ... ---")
context = node.page_content
logger.debug(f"--- Context: '{context}' ... ---")
if idx == 0:
fmt_prompt = self.llm.generate_ctx_prompt(question=question, context=context)
else:
fmt_prompt = self.llm.generate_refined_ctx_prompt(
context=context,
question=question,
existing_answer=str(cur_response),
)

if idx == num_of_contents:
cur_response = self.llm.start_answer_iterator_streamer(fmt_prompt, max_new_tokens=max_new_tokens)

else:
cur_response = self.llm.generate_answer(fmt_prompt, max_new_tokens=max_new_tokens)
logger.debug(f"--- Current response: '{cur_response}' ... ---")
fmt_prompts.append(fmt_prompt)
else:
fmt_prompt = self.llm.generate_qa_prompt(question=question)
cur_response = self.llm.start_answer_iterator_streamer(fmt_prompt, max_new_tokens=max_new_tokens)

for idx, node in enumerate(retrieved_contents, start=1):
logger.info(f"--- Generating an answer for the chunk {idx} ... ---")
context = node.page_content
logger.debug(f"--- Context: '{context}' ... ---")
if idx == 0:
fmt_prompt = self.llm.generate_ctx_prompt(question=question, context=context)
else:
fmt_prompt = self.llm.generate_refined_ctx_prompt(
context=context,
question=question,
existing_answer=str(cur_response),
)

if idx == num_of_contents:
cur_response = self.llm.start_answer_iterator_streamer(fmt_prompt, max_new_tokens=max_new_tokens)

else:
cur_response = self.llm.generate_answer(fmt_prompt, max_new_tokens=max_new_tokens)
logger.debug(f"--- Current response: '{cur_response}' ... ---")
fmt_prompts.append(fmt_prompt)

return cur_response, fmt_prompts
Expand Down

0 comments on commit ab9a9df

Please sign in to comment.