Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
fix AutoQueryEngine bug causing not use of qa_prompt_template its giv…
Browse files Browse the repository at this point in the history
…en (#177)
  • Loading branch information
SeeknnDestroy authored Dec 8, 2023
1 parent 5be3831 commit 1f57b08
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion autollm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and vector databases, along with various utility functions.
"""

__version__ = '0.1.3'
__version__ = '0.1.4'
__author__ = 'safevideo'
__license__ = 'AGPL-3.0'

Expand Down
19 changes: 12 additions & 7 deletions autollm/auto/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from llama_index import Document, ServiceContext, VectorStoreIndex
from llama_index.embeddings.utils import EmbedType
from llama_index.indices.query.base import BaseQueryEngine
from llama_index.prompts.base import PromptTemplate
from llama_index.prompts.base import BasePromptTemplate, PromptTemplate
from llama_index.prompts.prompt_type import PromptType
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.schema import BaseNode
Expand All @@ -24,11 +24,11 @@ def create_query_engine(
llm_api_base: Optional[str] = None,
# service_context_params
system_prompt: str = None,
query_wrapper_prompt: str = None,
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
chunk_size: Optional[int] = 512,
chunk_overlap: Optional[int] = 200,
chunk_overlap: Optional[int] = 100,
context_window: Optional[int] = None,
enable_title_extractor: bool = False,
enable_summary_extractor: bool = False,
Expand Down Expand Up @@ -61,7 +61,7 @@ def create_query_engine(
llm_temperature (float): The temperature to use for the LLM.
llm_api_base (str): The API base to use for the LLM.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
Expand Down Expand Up @@ -133,10 +133,15 @@ def create_query_engine(
refine_prompt_template = PromptTemplate(refine_prompt, prompt_type=PromptType.REFINE)
else:
refine_prompt_template = None

# Convert query_wrapper_prompt to PromptTemplate if it is a string
if isinstance(query_wrapper_prompt, str):
query_wrapper_prompt = PromptTemplate(template=query_wrapper_prompt)
response_synthesizer = get_response_synthesizer(
service_context=service_context,
response_mode=response_mode,
text_qa_template=query_wrapper_prompt,
refine_template=refine_prompt_template,
response_mode=response_mode,
structured_answer_filtering=structured_answer_filtering)

return vector_store_index.as_query_engine(
Expand Down Expand Up @@ -213,7 +218,7 @@ def from_defaults(
llm_temperature: float = 0.1,
# service_context_params
system_prompt: str = None,
query_wrapper_prompt: str = None,
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = True,
embed_model: Union[str, EmbedType] = "default", # ["default", "local"]
chunk_size: Optional[int] = 512,
Expand Down Expand Up @@ -246,7 +251,7 @@ def from_defaults(
llm_temperature (float): The temperature to use for the LLM.
llm_api_base (str): The API base to use for the LLM.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine.
enable_cost_calculator (bool): Flag to enable cost calculator logging.
embed_model (Union[str, EmbedType]): The embedding model to use for generating embeddings. "default" for OpenAI,
"local" for HuggingFace or use full identifier (e.g., local:intfloat/multilingual-e5-large)
Expand Down
5 changes: 4 additions & 1 deletion autollm/auto/service_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ def from_defaults(
"""
if not system_prompt and not query_wrapper_prompt:
system_prompt, query_wrapper_prompt = set_default_prompt_template()
# Convert system_prompt to ChatPromptTemplate if it is a string
# Convert query_wrapper_prompt to PromptTemplate if it is a string
if isinstance(query_wrapper_prompt, str):
query_wrapper_prompt = PromptTemplate(template=query_wrapper_prompt)

callback_manager: CallbackManager = kwargs.get('callback_manager', CallbackManager())
kwargs.pop(
'callback_manager', None) # Make sure callback_manager is not passed to ServiceContext twice

if enable_cost_calculator:
llm_model_name = llm.metadata.model_name if not "default" else "gpt-3.5-turbo"
callback_manager.add_handler(CostCalculatingHandler(model_name=llm_model_name, verbose=True))
Expand Down

0 comments on commit 1f57b08

Please sign in to comment.