Skip to content

Commit

Permalink
🎨 fix defaults for is_path and force_new_index
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Sep 15, 2023
1 parent 850c43d commit b951b02
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
28 changes: 16 additions & 12 deletions slack_bot/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ async def main():
"Whether or not the model_name passed is a path to the model "
"(ignored if not using llama-index-llama-cpp)"
),
action="store_true",
action=argparse.BooleanOptionalAction,
default=None,
)
parser.add_argument(
"--max-input-size",
Expand Down Expand Up @@ -95,7 +96,8 @@ async def main():
"--force-new-index",
"-f",
help="Recreate the index vector store or not",
action="store_true",
action=argparse.BooleanOptionalAction,
default=None,
)
parser.add_argument(
"--data-dir",
Expand All @@ -116,7 +118,7 @@ async def main():
"files in the data directory, 'handbook' will "
"only use 'handbook.csv' file."
),
default=os.environ.get("LLAMA_WHICH_INDEX") or "all_data",
default=os.environ.get("LLAMA_INDEX_WHICH_INDEX") or "all_data",
choices=["all_data", "public", "handbook"],
)

Expand All @@ -137,17 +139,17 @@ async def main():
os.environ.get("LLAMA_INDEX_FORCE_NEW_INDEX").lower() == "true"
)
# if force_new_index is provided via command line, override env var
if args.force_new_index:
force_new_index = True
if args.force_new_index is not None:
force_new_index = args.force_new_index

# Set is_path bool (by default, False)
is_path = False
# try to obtain is_path from env var
if os.environ.get("LLAMA_INDEX_PATH_BOOL"):
is_path = os.environ.get("LLAMA_INDEX_PATH_BOOL").lower() == "true"
if os.environ.get("LLAMA_INDEX_IS_PATH"):
is_path = os.environ.get("LLAMA_INDEX_IS_PATH").lower() == "true"
# if is_path bool is provided via command line, override env var
if args.is_path:
is_path = True
if args.is_path is not None:
is_path = args.is_path

# Initialise a new Slack bot with the requested model
try:
Expand All @@ -160,10 +162,11 @@ async def main():
logging.info(f"Initialising bot with model: {args.model}")

# Set up any model args that are required
if model == "llama-index-llama-cpp":
if args.model == "llama-index-llama-cpp":
# try to obtain model name from env var
# if model name is provided via command line, override env var
model_name = args.model_name or os.environ.get("LLAMA_INDEX_MODEL_NAME")

# if no model name is provided by command line or env var,
# default to DEFAULT_LLAMA_CPP_GGUF_MODEL
if model_name is None:
Expand All @@ -175,17 +178,18 @@ async def main():
"n_gpu_layers": args.n_gpu_layers,
"max_input_size": args.max_input_size,
}
elif model == "llama-index-hf":
elif args.model == "llama-index-hf":
# try to obtain model name from env var
# if model name is provided via command line, override env var
model_name = args.model_name or os.environ.get("LLAMA_INDEX_MODEL_NAME")

# if no model name is provided by command line or env var,
# default to DEFAULT_HF_MODEL
if model_name is None:
model_name = DEFAULT_HF_MODEL

model_args = {
"model_name": args.model_name,
"model_name": model_name,
"device": args.device,
"max_input_size": args.max_input_size,
}
Expand Down
21 changes: 13 additions & 8 deletions slack_bot/slack_bot/models/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
The type of engine to use when interacting with the data, options of "chat" or "query".
Default is "chat".
k : int, optional
`similarity_top_k` to use in query engine, by default 3
`similarity_top_k` to use in char or query engine, by default 3
chunk_overlap_ratio : float, optional
Chunk overlap as a ratio of chunk size, by default 0.1
force_new_index : bool, optional
Expand All @@ -79,6 +79,14 @@ def __init__(
"""
super().__init__(emoji="llama")
logging.info("Setting up Huggingface backend.")
if mode == "chat":
logging.info("Setting up chat engine.")
elif mode == "query":
logging.info("Setting up query engine.")
else:
logging.error("Mode must either be 'query' or 'chat'.")
sys.exit(1)

self.max_input_size = max_input_size
self.model_name = model_name
self.num_output = num_output
Expand Down Expand Up @@ -138,17 +146,14 @@ def __init__(
storage_context=storage_context, service_context=service_context
)

if self.mode == "query":
self.query_engine = self.index.as_query_engine(similarity_top_k=k)
logging.info("Done setting up Huggingface backend for query engine.")
elif self.mode == "chat":
if self.mode == "chat":
self.chat_engine = self.index.as_chat_engine(
chat_mode="context", similarity_top_k=k
)
logging.info("Done setting up Huggingface backend for chat engine.")
else:
logging.error("Mode must either be 'query' or 'chat'.")
sys.exit(1)
elif self.mode == "query":
self.query_engine = self.index.as_query_engine(similarity_top_k=k)
logging.info("Done setting up Huggingface backend for query engine.")

self.error_response_template = (
"Oh no! When I tried to get a response to your prompt, "
Expand Down

0 comments on commit b951b02

Please sign in to comment.