Skip to content

Commit

Permalink
✨ add model_name passable for azure models
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Sep 20, 2023
1 parent 2594d85 commit 77d0857
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
24 changes: 24 additions & 0 deletions slack_bot/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"/main/llama-2-13b-chat.Q6_K.gguf"
)
DEFAULT_HF_MODEL = "StabilityAI/stablelm-tuned-alpha-3b"
DEFAULT_LLAMA_INDEX_AZURE_DEPLOYMENT = "reginald-gpt35-turbo"
DEFAULT_CHAT_COMPLETION_AZURE_DEPLOYMENT = "reginald-curie"


async def main():
Expand Down Expand Up @@ -161,6 +163,12 @@ async def main():
# Initialise LLM reponse model
logging.info(f"Initialising bot with model: {args.model}")

logging.info(
f"args.model_name or os.environ.get('LLAMA_INDEX_MODEL_NAME'): {args.model_name or os.environ.get('LLAMA_INDEX_MODEL_NAME')}"
)
logging.info(
f"args.model in ['chat-completion-azure', 'llama-index-gpt-azure']: {args.model in ['chat-completion-azure', 'llama-index-gpt-azure']}"
)
# Set up any model args that are required
if args.model == "llama-index-llama-cpp":
# try to obtain model name from env var
Expand Down Expand Up @@ -193,6 +201,22 @@ async def main():
"device": args.device,
"max_input_size": args.max_input_size,
}
elif args.model in ["chat-completion-azure", "llama-index-gpt-azure"]:
# try to obtain model name from env var
# if model name is provided via command line, override env var
model_name = args.model_name or os.environ.get("LLAMA_INDEX_MODEL_NAME")

# if no model name is provided by command line or env var,
# default to DEFAULT_HF_MODEL
if model_name is None:
if args.model == "chat-completion-azure":
model_name = DEFAULT_CHAT_COMPLETION_AZURE_DEPLOYMENT
elif args.model == "llama-index-gpt-azure":
model_name = DEFAULT_LLAMA_INDEX_AZURE_DEPLOYMENT

model_args = {
"deployment_name": model_name,
}
else:
model_args = {}

Expand Down
8 changes: 6 additions & 2 deletions slack_bot/slack_bot/models/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard library imports
import logging
import os
from typing import Any

Expand All @@ -15,14 +16,17 @@ def __init__(self, *args, **kwargs) -> None:


class ChatCompletionAzure(ChatCompletionBase):
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(
self, deployment_name: str = "reginald-curie", *args: Any, **kwargs: Any
) -> None:
logging.info(f"Setting up AzureOpenAI LLM (model {deployment_name})")
super().__init__(*args, **kwargs)
self.api_base = os.getenv("OPENAI_AZURE_API_BASE")
self.api_key = os.getenv("OPENAI_AZURE_API_KEY")
self.api_type = "azure"
self.api_version = "2023-03-15-preview"
self.best_of = 1
self.engine = "reginald-curie"
self.engine = deployment_name
self.frequency_penalty = 0
self.max_tokens = 100
self.presence_penalty = 0
Expand Down
7 changes: 5 additions & 2 deletions slack_bot/slack_bot/models/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,9 @@ def _prep_llm(self) -> LLM:


class LlamaIndexGPTAzure(LlamaIndex):
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(
self, deployment_name: str = "reginald-gpt35-turbo", *args: Any, **kwargs: Any
) -> None:
"""
`LlamaIndexGPTAzure` is a subclass of `LlamaIndex` that uses Azure's
instance of OpenAI's LLMs to implement the LLM.
Expand All @@ -669,7 +671,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ValueError("You must set OPENAI_AZURE_API_KEY for Azure OpenAI.")

# deployment name can be found in the Azure AI Studio portal
self.deployment_name = "reginald-gpt35-turbo"
self.deployment_name = deployment_name
self.openai_api_base = os.getenv("OPENAI_AZURE_API_BASE")
self.openai_api_key = os.getenv("OPENAI_AZURE_API_KEY")
self.openai_api_version = "2023-03-15-preview"
Expand All @@ -679,6 +681,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
)

def _prep_llm(self) -> LLM:
logging.info(f"Setting up AzureOpenAI LLM (model {self.deployment_name})")
return AzureOpenAI(
model=self.model_name,
engine=self.deployment_name,
Expand Down

0 comments on commit 77d0857

Please sign in to comment.