Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM and test refactor #623

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from codemodder.codemods.api import BaseCodemod
from codemodder.codemods.semgrep import SemgrepRuleDetector
from codemodder.codetf import CodeTF
from codemodder.context import CodemodExecutionContext, MisconfiguredAIClient
from codemodder.context import CodemodExecutionContext
from codemodder.dependency import Dependency
from codemodder.llm import MisconfiguredAIClient
from codemodder.logging import configure_logger, log_list, log_section, logger
from codemodder.project_analysis.file_parsers.package_store import PackageStore
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
Expand Down
1 change: 1 addition & 0 deletions src/codemodder/codemods/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
BaseDjangoCodemodTest,
BaseSASTCodemodTest,
BaseSemgrepCodemodTest,
DiffError,
)
30 changes: 24 additions & 6 deletions src/codemodder/codemods/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
from codemodder.semgrep import run as semgrep_run


class DiffError(Exception):
"""Custom exception to raise when output code != expected output code."""

def __init__(self, expected, actual):
self.expected = expected
self.actual = actual

def __str__(self):
return (
f"\nExpected:\n\n{self.expected}\n does NOT match actual:\n\n{self.actual}"
)


class BaseCodemodTest:
codemod: ClassVar = NotImplemented

Expand Down Expand Up @@ -74,20 +87,25 @@ def run_and_assert(
)

def assert_changes(self, root, file_path, input_code, expected, changes):
assert os.path.relpath(file_path, root) == changes.path
assert all(change.description for change in changes.changes)

expected_diff = create_diff(
dedent(input_code).splitlines(keepends=True),
dedent(expected).splitlines(keepends=True),
)

assert expected_diff == changes.diff
assert os.path.relpath(file_path, root) == changes.path
try:
assert expected_diff == changes.diff
except AssertionError:
raise DiffError(expected_diff, changes.diff)

with open(file_path, "r", encoding="utf-8") as tmp_file:
output_code = tmp_file.read()

assert output_code == dedent(expected)
# All changes must have non-empty descriptions
assert all(change.description for change in changes.changes)
try:
assert output_code == (format_expected := dedent(expected))
except AssertionError:
raise DiffError(format_expected, output_code)

def run_and_assert_filepath(
self,
Expand Down
58 changes: 2 additions & 56 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import itertools
import logging
import os
from pathlib import Path
from textwrap import indent
from typing import TYPE_CHECKING, Iterator, List
Expand All @@ -16,33 +15,19 @@
build_failed_dependency_notification,
)
from codemodder.file_context import FileContext
from codemodder.llm import setup_llm_client
from codemodder.logging import log_list, logger
from codemodder.project_analysis.file_parsers.package_store import PackageStore
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
from codemodder.registry import CodemodRegistry
from codemodder.utils.timer import Timer

try:
from openai import AzureOpenAI, OpenAI
except ImportError:
OpenAI = None
AzureOpenAI = None


if TYPE_CHECKING:
from openai import OpenAI

from codemodder.codemods.base_codemod import BaseCodemod


class MisconfiguredAIClient(ValueError):
pass


MODELS = ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"]
DEFAULT_AZURE_OPENAI_API_VERSION = "2024-02-01"


class CodemodExecutionContext:
_failures_by_codemod: dict[str, list[Path]] = {}
_dependency_update_by_codemod: dict[str, PackageStore | None] = {}
Expand Down Expand Up @@ -87,41 +72,7 @@ def __init__(
self.path_exclude = path_exclude
self.max_workers = max_workers
self.tool_result_files_map = tool_result_files_map or {}
self.llm_client = self._setup_llm_client()

def _setup_llm_client(self) -> OpenAI | None:
if not AzureOpenAI:
logger.info("Azure OpenAI API client not available")
return None

azure_openapi_key = os.getenv("CODEMODDER_AZURE_OPENAI_API_KEY")
azure_openapi_endpoint = os.getenv("CODEMODDER_AZURE_OPENAI_ENDPOINT")
if bool(azure_openapi_key) ^ bool(azure_openapi_endpoint):
raise MisconfiguredAIClient(
"Azure OpenAI API key and endpoint must both be set or unset"
)

if azure_openapi_key and azure_openapi_endpoint:
logger.info("Using Azure OpenAI API client")
return AzureOpenAI(
api_key=azure_openapi_key,
api_version=os.getenv(
"CODEMODDER_AZURE_OPENAI_API_VERSION",
DEFAULT_AZURE_OPENAI_API_VERSION,
),
azure_endpoint=azure_openapi_endpoint,
)

if not OpenAI:
logger.info("OpenAI API client not available")
return None

if not (api_key := os.getenv("CODEMODDER_OPENAI_API_KEY")):
logger.info("OpenAI API key not found")
return None

logger.info("Using OpenAI API client")
return OpenAI(api_key=api_key)
self.llm_client = setup_llm_client()

def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)
Expand Down Expand Up @@ -244,8 +195,3 @@ def log_changes(self, codemod_id: str):
for change in changes:
logger.info(" - %s", change.path)
logger.debug(" diff:\n%s", indent(change.diff, " " * 6))

def __getattribute__(self, attr: str):
if (name := attr.replace("_", "-")) in MODELS:
return os.getenv(f"CODEMODDER_AZURE_OPENAI_{name.upper()}_DEPLOYMENT", name)
return super().__getattribute__(attr)
83 changes: 83 additions & 0 deletions src/codemodder/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
from typing import TYPE_CHECKING

try:
from openai import AzureOpenAI, OpenAI
except ImportError:
OpenAI = None
AzureOpenAI = None


if TYPE_CHECKING:
from openai import OpenAI

from codemodder.logging import logger

__all__ = [
"MODELS",
"setup_llm_client",
"MisconfiguredAIClient",
]

models = ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"]
DEFAULT_AZURE_OPENAI_API_VERSION = "2024-02-01"


class ModelRegistry(dict):
def __init__(self, models):
super().__init__()
self.models = models
for model in models:
attribute_name = model.replace("-", "_")
self[attribute_name] = model

def __getattr__(self, name):
if name in self:
return os.getenv(
f"CODEMODDER_AZURE_OPENAI_{self[name].upper()}_DEPLOYMENT", self[name]
)
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)


MODELS = ModelRegistry(models)


def setup_llm_client() -> OpenAI | None:
if not AzureOpenAI:
logger.info("Azure OpenAI API client not available")
return None

azure_openapi_key = os.getenv("CODEMODDER_AZURE_OPENAI_API_KEY")
azure_openapi_endpoint = os.getenv("CODEMODDER_AZURE_OPENAI_ENDPOINT")
if bool(azure_openapi_key) ^ bool(azure_openapi_endpoint):
raise MisconfiguredAIClient(
"Azure OpenAI API key and endpoint must both be set or unset"
)

if azure_openapi_key and azure_openapi_endpoint:
logger.info("Using Azure OpenAI API client")
return AzureOpenAI(
api_key=azure_openapi_key,
api_version=os.getenv(
"CODEMODDER_AZURE_OPENAI_API_VERSION",
DEFAULT_AZURE_OPENAI_API_VERSION,
),
azure_endpoint=azure_openapi_endpoint,
)

if not OpenAI:
logger.info("OpenAI API client not available")
return None

if not (api_key := os.getenv("CODEMODDER_OPENAI_API_KEY")):
logger.info("OpenAI API key not found")
return None

logger.info("Using OpenAI API client")
return OpenAI(api_key=api_key)


class MisconfiguredAIClient(ValueError):
pass
35 changes: 1 addition & 34 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import pytest
from openai import AzureOpenAI, OpenAI

from codemodder.context import DEFAULT_AZURE_OPENAI_API_VERSION
from codemodder.context import CodemodExecutionContext as Context
from codemodder.context import MisconfiguredAIClient
from codemodder.dependency import Security
from codemodder.llm import DEFAULT_AZURE_OPENAI_API_VERSION, MisconfiguredAIClient
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
from codemodder.registry import load_registered_codemods

Expand Down Expand Up @@ -146,38 +145,6 @@ def test_setup_azure_llm_client_missing_one(self, mocker, env_var):
[],
)

def test_get_model_name(self, mocker):
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert context.gpt_4_turbo_2024_04_09 == "gpt-4-turbo-2024-04-09"

@pytest.mark.parametrize("model", ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"])
def test_model_get_name_from_env(self, mocker, model):
name = "my-awesome-deployment"
mocker.patch.dict(
os.environ,
{
f"CODEMODDER_AZURE_OPENAI_{model.upper()}_DEPLOYMENT": name,
},
)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert getattr(context, model.replace("-", "_")) == name

def test_get_api_version_from_env(self, mocker):
version = "fake-version"
mocker.patch.dict(
Expand Down
21 changes: 21 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os

import pytest

from codemodder.llm import MODELS


class TestModels:
def test_get_model_name(self):
assert MODELS.gpt_4_turbo_2024_04_09 == "gpt-4-turbo-2024-04-09"

@pytest.mark.parametrize("model", ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13"])
def test_model_get_name_from_env(self, mocker, model):
name = "my-awesome-deployment"
mocker.patch.dict(
os.environ,
{
f"CODEMODDER_AZURE_OPENAI_{model.upper()}_DEPLOYMENT": name,
},
)
assert getattr(MODELS, model.replace("-", "_")) == name
Loading