Skip to content

Commit

Permalink
Updated api_keys to be a dict
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiannis128 committed Feb 4, 2025
1 parent 0a5cdaf commit 74d7be4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 32 deletions.
28 changes: 11 additions & 17 deletions esbmc_ai/ai_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Author: Yiannis Charalambous

from abc import abstractmethod
from typing import Any, Iterable, Optional, Union
from typing import Any, Iterable, Union
from enum import Enum
from langchain_core.language_models import BaseChatModel
from pydantic.types import SecretStr
Expand All @@ -17,9 +17,6 @@
)


from esbmc_ai.api_key_collection import APIKeyCollection


class AIModel(object):
"""This base class represents an abstract AI model."""

Expand All @@ -37,7 +34,7 @@ def __init__(
@abstractmethod
def create_llm(
self,
api_keys: APIKeyCollection,
api_keys: dict[str, str],
temperature: float = 1.0,
requests_max_tries: int = 5,
requests_timeout: float = 60,
Expand Down Expand Up @@ -143,15 +140,15 @@ class AIModelOpenAI(AIModel):
@override
def create_llm(
self,
api_keys: APIKeyCollection,
api_keys: dict[str, str],
temperature: float = 1.0,
requests_max_tries: int = 5,
requests_timeout: float = 60,
) -> BaseChatModel:
assert api_keys.openai, "No OpenAI api key has been specified..."
assert "openai" in api_keys, "No OpenAI api key has been specified..."
return ChatOpenAI(
model=self.name,
api_key=SecretStr(api_keys.openai),
api_key=SecretStr(api_keys["openai"]),
max_tokens=None,
temperature=temperature,
max_retries=requests_max_tries,
Expand Down Expand Up @@ -199,7 +196,7 @@ def __init__(self, name: str, tokens: int, url: str) -> None:
@override
def create_llm(
self,
api_keys: APIKeyCollection,
api_keys: dict[str, str],
temperature: float = 1,
requests_max_tries: int = 5,
requests_timeout: float = 60,
Expand All @@ -222,7 +219,6 @@ class _AIModels(Enum):
defined because they are fetched from the API."""

# FALCON_7B = OllamaAIModel(...)
pass


_custom_ai_models: list[AIModel] = []
Expand Down Expand Up @@ -254,25 +250,23 @@ def add_custom_ai_model(ai_model: AIModel) -> None:
_custom_ai_models.append(ai_model)


def download_openai_model_names(api_keys: APIKeyCollection) -> list[str]:
def download_openai_model_names(api_keys: dict[str, str]) -> list[str]:
"""Gets the open AI models from the API service and returns them."""
assert api_keys and api_keys.openai
assert "openai" in api_keys
from openai import Client

"llm_requests.open_ai.model_refresh_seconds"
# Check if needs refreshing
try:
return [
str(model.id)
for model in Client(api_key=api_keys.openai).models.list().data
for model in Client(api_key=api_keys["openai"]).models.list().data
]
except ImportError:
return []


def is_valid_ai_model(
ai_model: Union[str, AIModel], api_keys: Optional[APIKeyCollection] = None
) -> bool:
def is_valid_ai_model(ai_model: Union[str, AIModel]) -> bool:
"""Accepts both the AIModel object and the name as parameter. It checks the
openai servers to see if a model is defined on their servers, if not, then
it checks the internally defined AI models list."""
Expand Down Expand Up @@ -301,4 +295,4 @@ def get_ai_model_by_name(name: str) -> AIModel:
if name == custom_ai.name:
return custom_ai

raise Exception(f'The AI "{name}" was not found...')
raise ValueError(f'The AI "{name}" was not found...')
38 changes: 23 additions & 15 deletions esbmc_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
OllamaAIModel,
download_openai_model_names,
)
from .api_key_collection import APIKeyCollection


@dataclass
Expand Down Expand Up @@ -81,6 +80,14 @@ def __new__(cls):
cls.instance = super(Config, cls).__new__(cls)
return cls.instance

def __init__(self) -> None:
super().__init__()
self._args: argparse.Namespace
self.api_keys: dict[str, str] = {}
self.raw_conversation: bool = False
self.generate_patches: bool
self.output_dir: Optional[Path] = None

# Define some shortcuts for the values here (instead of having to use get_value)

def get_ai_model(self) -> AIModel:
Expand All @@ -92,7 +99,7 @@ def get_llm_requests_max_tries(self) -> int:
return self.get_value("llm_requests.max_tries")

def get_llm_requests_timeout(self) -> float:
""""""
"""Max timeout for a request when prompting the LLM"""
return self.get_value("llm_requests.timeout")

def get_user_chat_initial(self) -> BaseMessage:
Expand Down Expand Up @@ -120,11 +127,7 @@ def init(self, args: Any) -> None:
"""Will load the config from the args, the env file and then from config file.
Call once to initialize."""

self._args: argparse.Namespace = args
self.api_keys: APIKeyCollection
self.raw_conversation: bool = False
self.generate_patches: bool
self.output_dir: Optional[Path] = None
self._args = args

self._load_envs()

Expand All @@ -141,7 +144,7 @@ def init(self, args: Any) -> None:
# Default is to refresh once a day
default_value=self._load_openai_model_names(86400),
validate=lambda v: isinstance(v, int),
on_load=lambda v: self._load_openai_model_names(v),
on_load=self._load_openai_model_names,
error_message="Invalid value, needs to be an int in seconds",
),
# This needs to be processed after ai_custom
Expand All @@ -151,7 +154,7 @@ def init(self, args: Any) -> None:
# Api keys are loaded from system env so they are already
# available
validate=lambda v: isinstance(v, str) and is_valid_ai_model(v),
on_load=lambda v: get_ai_model_by_name(v),
on_load=get_ai_model_by_name,
),
ConfigField(
name="temp_auto_clean",
Expand Down Expand Up @@ -181,6 +184,12 @@ def init(self, args: Any) -> None:
validate=lambda v: isinstance(v, str) and v in ["full", "single"],
error_message="source_code_format can only be 'full' or 'single'",
),
# API Keys is a pseudo-entry, the value is fetched from the class
# itself rather config.
ConfigField(
name="api_keys",
default_value=self.api_keys,
),
ConfigField(
name="solution.filenames",
default_value=[],
Expand Down Expand Up @@ -279,7 +288,8 @@ def init(self, args: Any) -> None:
name="fix_code.message_history",
default_value="normal",
validate=lambda v: v in ["normal", "latest_only", "reverse"],
error_message='fix_code.message_history can only be "normal", "latest_only", "reverse"',
error_message='fix_code.message_history can only be "normal", '
+ '"latest_only", "reverse"',
),
ConfigField(
name="prompt_templates.user_chat.initial",
Expand Down Expand Up @@ -385,9 +395,7 @@ def get_env_vars() -> None:
print(f"Error: No ${key} in environment.")
sys.exit(1)

self.api_keys = APIKeyCollection(
openai=str(os.getenv("OPENAI_API_KEY")),
)
self.api_keys["openai"] = str(os.getenv("OPENAI_API_KEY"))

self.cfg_path: Path = Path(
os.path.expanduser(os.path.expandvars(str(os.getenv(config_env_name))))
Expand All @@ -400,7 +408,7 @@ def _load_args(self) -> None:

# AI Model -m
if args.ai_model != "":
if is_valid_ai_model(args.ai_model, self.api_keys):
if is_valid_ai_model(args.ai_model):
ai_model = get_ai_model_by_name(args.ai_model)
self.set_value("ai_model", ai_model)
else:
Expand Down Expand Up @@ -527,7 +535,7 @@ def write_cache(cache: Path) -> list[str]:
return models_list

duration = timedelta(seconds=refresh_duration_seconds)
if self.api_keys and self.api_keys.openai:
if "openai" in self.api_keys:
print("Loading OpenAI models list")
models_list: list[str] = []
cache: Path = Path(user_cache_dir("esbmc-ai", "Yiannis Charalambous"))
Expand Down

0 comments on commit 74d7be4

Please sign in to comment.