From b951b0247f09aa019da7ddbeb6480e8ec086dca6 Mon Sep 17 00:00:00 2001 From: rchan Date: Fri, 15 Sep 2023 18:55:11 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20fix=20defaults=20for=20is=5Fpath?= =?UTF-8?q?=20and=20force=5Fnew=5Findex?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- slack_bot/run.py | 28 +++++++++++++---------- slack_bot/slack_bot/models/llama_index.py | 21 ++++++++++------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/slack_bot/run.py b/slack_bot/run.py index bdb331f7..19d8d623 100755 --- a/slack_bot/run.py +++ b/slack_bot/run.py @@ -59,7 +59,8 @@ async def main(): "Whether or not the model_name passed is a path to the model " "(ignored if not using llama-index-llama-cpp)" ), - action="store_true", + action=argparse.BooleanOptionalAction, + default=None, ) parser.add_argument( "--max-input-size", @@ -95,7 +96,8 @@ async def main(): "--force-new-index", "-f", help="Recreate the index vector store or not", - action="store_true", + action=argparse.BooleanOptionalAction, + default=None, ) parser.add_argument( "--data-dir", @@ -116,7 +118,7 @@ async def main(): "files in the data directory, 'handbook' will " "only use 'handbook.csv' file." ), - default=os.environ.get("LLAMA_WHICH_INDEX") or "all_data", + default=os.environ.get("LLAMA_INDEX_WHICH_INDEX") or "all_data", choices=["all_data", "public", "handbook"], ) @@ -137,17 +139,17 @@ async def main(): os.environ.get("LLAMA_INDEX_FORCE_NEW_INDEX").lower() == "true" ) # if force_new_index is provided via command line, override env var - if args.force_new_index: - force_new_index = True + if args.force_new_index is not None: + force_new_index = args.force_new_index # Set is_path bool (by default, False) is_path = False # try to obtain is_path from env var - if os.environ.get("LLAMA_INDEX_PATH_BOOL"): - is_path = os.environ.get("LLAMA_INDEX_PATH_BOOL").lower() == "true" + if os.environ.get("LLAMA_INDEX_IS_PATH"): + is_path = os.environ.get("LLAMA_INDEX_IS_PATH").lower() == "true" # if is_path bool is provided via command line, override env var - if args.is_path: - is_path = True + if args.is_path is not None: + is_path = args.is_path # Initialise a new Slack bot with the requested model try: @@ -160,10 +162,11 @@ async def main(): logging.info(f"Initialising bot with model: {args.model}") # Set up any model args that are required - if model == "llama-index-llama-cpp": + if args.model == "llama-index-llama-cpp": # try to obtain model name from env var # if model name is provided via command line, override env var model_name = args.model_name or os.environ.get("LLAMA_INDEX_MODEL_NAME") + # if no model name is provided by command line or env var, # default to DEFAULT_LLAMA_CPP_GGUF_MODEL if model_name is None: @@ -175,17 +178,18 @@ async def main(): "n_gpu_layers": args.n_gpu_layers, "max_input_size": args.max_input_size, } - elif model == "llama-index-hf": + elif args.model == "llama-index-hf": # try to obtain model name from env var # if model name is provided via command line, override env var model_name = args.model_name or os.environ.get("LLAMA_INDEX_MODEL_NAME") + # if no model name is provided by command line or env var, # default to DEFAULT_HF_MODEL if model_name is None: model_name = DEFAULT_HF_MODEL model_args = { - "model_name": args.model_name, + "model_name": model_name, "device": args.device, "max_input_size": args.max_input_size, } diff --git a/slack_bot/slack_bot/models/llama_index.py b/slack_bot/slack_bot/models/llama_index.py index ac0edd10..83eae921 100644 --- a/slack_bot/slack_bot/models/llama_index.py +++ b/slack_bot/slack_bot/models/llama_index.py @@ -68,7 +68,7 @@ def __init__( The type of engine to use when interacting with the data, options of "chat" or "query". Default is "chat". k : int, optional - `similarity_top_k` to use in query engine, by default 3 + `similarity_top_k` to use in char or query engine, by default 3 chunk_overlap_ratio : float, optional Chunk overlap as a ratio of chunk size, by default 0.1 force_new_index : bool, optional @@ -79,6 +79,14 @@ def __init__( """ super().__init__(emoji="llama") logging.info("Setting up Huggingface backend.") + if mode == "chat": + logging.info("Setting up chat engine.") + elif mode == "query": + logging.info("Setting up query engine.") + else: + logging.error("Mode must either be 'query' or 'chat'.") + sys.exit(1) + self.max_input_size = max_input_size self.model_name = model_name self.num_output = num_output @@ -138,17 +146,14 @@ def __init__( storage_context=storage_context, service_context=service_context ) - if self.mode == "query": - self.query_engine = self.index.as_query_engine(similarity_top_k=k) - logging.info("Done setting up Huggingface backend for query engine.") - elif self.mode == "chat": + if self.mode == "chat": self.chat_engine = self.index.as_chat_engine( chat_mode="context", similarity_top_k=k ) logging.info("Done setting up Huggingface backend for chat engine.") - else: - logging.error("Mode must either be 'query' or 'chat'.") - sys.exit(1) + elif self.mode == "query": + self.query_engine = self.index.as_query_engine(similarity_top_k=k) + logging.info("Done setting up Huggingface backend for query engine.") self.error_response_template = ( "Oh no! When I tried to get a response to your prompt, "