Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
fix bug on service_context and add init to callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
SeeknnDestroy committed Oct 20, 2023
1 parent 2758c2e commit b25333e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions autollm/auto/service_context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
from typing import Union

from llama_index import OpenAIEmbedding, ServiceContext
from llama_index.callbacks import CallbackManager
from llama_index.embeddings.base import BaseEmbedding
from llama_index.llms.utils import LLMType
from llama_index.prompts import ChatMessage, ChatPromptTemplate, MessageRole
from llama_index.prompts.base import BasePromptTemplate

from autollm.callbacks.cost_calculating import CostCalculatingHandler
from autollm.utils.llm_utils import set_default_prompt_template
Expand All @@ -22,7 +24,7 @@ def from_defaults(
llm: LLMType = "default",
embed_model: BaseEmbedding = None,
system_prompt: str = None,
query_wrapper_prompt: str = None,
query_wrapper_prompt: Union[str, BasePromptTemplate] = None,
enable_cost_calculator: bool = False,
**kwargs) -> ServiceContext:
"""
Expand All @@ -33,7 +35,7 @@ def from_defaults(
llm (LLM): The LLM to use for the query engine. Defaults to gp3-5-turbo.
embed_model (BaseEmbedding): The embedding model to use for the query engine.
system_prompt (str): The system prompt to use for the query engine.
query_wrapper_prompt (str): The query wrapper prompt to use for the query engine.
query_wrapper_prompt (Union[str, BasePromptTemplate]): The query wrapper prompt to use for the query engine.
cost_calculator_verbose (bool): Flag to enable cost calculator logging.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Expand All @@ -44,13 +46,17 @@ def from_defaults(
if not system_prompt or not query_wrapper_prompt:
logger.info('System prompt and query wrapper prompt not provided. Using default prompts.')
system_prompt, query_wrapper_prompt = set_default_prompt_template()
# Convert system_prompt to ChatPromptTemplate if it is a string
elif isinstance(query_wrapper_prompt, str):
query_wrapper_prompt = ChatPromptTemplate([
ChatMessage(
role=MessageRole.USER,
content=query_wrapper_prompt,
),
])
# Use the provided query wrapper prompt as is if it is a BasePromptTemplate
elif isinstance(query_wrapper_prompt, BasePromptTemplate):
pass
else:
raise ValueError(f'Invalid system_prompt type: {type(query_wrapper_prompt)}')

Expand Down
Empty file added autollm/callbacks/__init__.py
Empty file.

0 comments on commit b25333e

Please sign in to comment.