Skip to content

Commit

Permalink
feature: configurable python (OpenAI) client (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
ibolmo authored Dec 13, 2024
1 parent 40a051f commit a3ca1cf
Show file tree
Hide file tree
Showing 11 changed files with 584 additions and 144 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ VENV_PYTHON_PACKAGES := venv/.python_packages

${VENV_PYTHON_PACKAGES}: ${VENV_INITIALIZED}
bash -c 'source venv/bin/activate && python -m pip install --upgrade pip setuptools build twine openai'
bash -c 'source venv/bin/activate && python -m pip install -e .[dev]'
bash -c 'source venv/bin/activate && python -m pip install -e ".[dev]"'
@touch $@

${VENV_PRE_COMMIT}: ${VENV_PYTHON_PACKAGES}
Expand Down
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,37 @@ print(f"Factuality score: {result.score}")
print(f"Factuality metadata: {result.metadata['rationale']}")
```

#### Custom Client

If you need to use a custom OpenAI client, you can initialize the library with a custom client.

```python
import openai
from autoevals import init
from autoevals.oai import LLMClient

openai_client = openai.OpenAI(base_url="https://api.openai.com/v1/")

class CustomClient(LLMClient):
openai=openai_client # you can also pass in openai module and we will instantiate it for you
embed = openai.embeddings.create
moderation = openai.moderations.create
RateLimitError = openai.RateLimitError

def complete(self, **kwargs):
# make adjustments as needed
return self.openai.chat.completions.create(**kwargs)

# Autoevals will now use your custom client
client = init(client=CustomClient)
```

If you only need to use a custom client for a specific evaluator, you can pass in the client to the evaluator.

```python
evaluator = Factuality(client=CustomClient)
```

### Node.js

```javascript
Expand Down
1 change: 1 addition & 0 deletions py/autoevals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .llm import *
from .moderation import *
from .number import *
from .oai import init
from .ragas import *
from .string import *
from .value import ExactMatch
42 changes: 30 additions & 12 deletions py/autoevals/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import json
import os
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional

Expand All @@ -11,7 +11,7 @@

from autoevals.partial import ScorerWithPartial

from .oai import arun_cached_request, run_cached_request
from .oai import LLMClient, arun_cached_request, run_cached_request

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -79,24 +79,29 @@ def __init__(
self,
api_key=None,
base_url=None,
client: Optional[LLMClient] = None,
):
self.extra_args = {}
if api_key:
self.extra_args["api_key"] = api_key
if base_url:
self.extra_args["base_url"] = base_url

self.client = client


class OpenAILLMScorer(OpenAIScorer):
def __init__(
self,
temperature=None,
api_key=None,
base_url=None,
client: Optional[LLMClient] = None,
):
super().__init__(
api_key=api_key,
base_url=base_url,
client=client,
)
self.extra_args["temperature"] = temperature or 0

Expand All @@ -115,8 +120,10 @@ def __init__(
engine=None,
api_key=None,
base_url=None,
client: Optional[LLMClient] = None,
):
super().__init__(
client=client,
api_key=api_key,
base_url=base_url,
)
Expand Down Expand Up @@ -162,6 +169,7 @@ def _render_messages(self, **kwargs):

def _request_args(self, output, expected, **kwargs):
ret = {
"client": self.client,
**self.extra_args,
**self._build_args(output, expected, **kwargs),
}
Expand Down Expand Up @@ -219,7 +227,7 @@ class LLMClassifier(OpenAILLMClassifier):
An LLM-based classifier that wraps `OpenAILLMClassifier` and provides a standard way to
apply chain of thought, parse the output, and score the result."""

_SPEC_FILE_CONTENTS: Optional[str] = None
_SPEC_FILE_CONTENTS: Dict[str, str] = defaultdict(str)

def __init__(
self,
Expand All @@ -233,6 +241,7 @@ def __init__(
engine=None,
api_key=None,
base_url=None,
client: Optional[LLMClient] = None,
**extra_render_args,
):
choice_strings = list(choice_scores.keys())
Expand All @@ -257,24 +266,33 @@ def __init__(
api_key=api_key,
base_url=base_url,
render_args={"__choices": choice_strings, **extra_render_args},
client=client,
)

@classmethod
def from_spec(cls, name: str, spec: ModelGradedSpec, **kwargs):
return cls(name, spec.prompt, spec.choice_scores, **kwargs)
def from_spec(cls, name: str, spec: ModelGradedSpec, client: Optional[LLMClient] = None, **kwargs):
return cls(name, spec.prompt, spec.choice_scores, client=client, **kwargs)

@classmethod
def from_spec_file(cls, name: str, path: str, **kwargs):
if cls._SPEC_FILE_CONTENTS is None:
def from_spec_file(cls, name: str, path: str, client: Optional[LLMClient] = None, **kwargs):
if cls._SPEC_FILE_CONTENTS[name] == "":
with open(path) as f:
cls._SPEC_FILE_CONTENTS = f.read()
spec = yaml.safe_load(cls._SPEC_FILE_CONTENTS)
return cls.from_spec(name, ModelGradedSpec(**spec), **kwargs)
cls._SPEC_FILE_CONTENTS[name] = f.read()
spec = yaml.safe_load(cls._SPEC_FILE_CONTENTS[name])
return cls.from_spec(name, ModelGradedSpec(**spec), client=client, **kwargs)


class SpecFileClassifier(LLMClassifier):
def __new__(
cls, model=None, engine=None, use_cot=None, max_tokens=None, temperature=None, api_key=None, base_url=None
cls,
model=None,
engine=None,
use_cot=None,
max_tokens=None,
temperature=None,
api_key=None,
base_url=None,
client: Optional[LLMClient] = None,
):
kwargs = {}
if model is not None:
Expand Down Expand Up @@ -302,7 +320,7 @@ def __new__(

extra_render_args = cls._partial_args() if hasattr(cls, "_partial_args") else {}

return LLMClassifier.from_spec_file(cls_name, template_path, **kwargs, **extra_render_args)
return LLMClassifier.from_spec_file(cls_name, template_path, client=client, **kwargs, **extra_render_args)


class Battle(SpecFileClassifier):
Expand Down
24 changes: 19 additions & 5 deletions py/autoevals/moderation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional

from braintrust_core.score import Score

from autoevals.llm import OpenAIScorer

from .oai import arun_cached_request, run_cached_request
from .oai import LLMClient, arun_cached_request, run_cached_request

REQUEST_TYPE = "moderation"

Expand All @@ -15,7 +17,13 @@ class Moderation(OpenAIScorer):
threshold = None
extra_args = {}

def __init__(self, threshold=None, api_key=None, base_url=None):
def __init__(
self,
threshold=None,
api_key=None,
base_url=None,
client: Optional[LLMClient] = None,
):
"""
Create a new Moderation scorer.
Expand All @@ -24,11 +32,13 @@ def __init__(self, threshold=None, api_key=None, base_url=None):
:param api_key: OpenAI key
:param base_url: Base URL to be used to reach OpenAI moderation endpoint.
"""
super().__init__(api_key=api_key, base_url=base_url)
super().__init__(api_key=api_key, base_url=base_url, client=client)
self.threshold = threshold

def _run_eval_sync(self, output, __expected=None):
moderation_response = run_cached_request(REQUEST_TYPE, input=output, **self.extra_args)["results"][0]
moderation_response = run_cached_request(
client=self.client, request_type=REQUEST_TYPE, input=output, **self.extra_args
)["results"][0]
return self.__postprocess_response(moderation_response)

def __postprocess_response(self, moderation_response) -> Score:
Expand All @@ -42,7 +52,9 @@ def __postprocess_response(self, moderation_response) -> Score:
)

async def _run_eval_async(self, output, expected=None, **kwargs) -> Score:
moderation_response = (await arun_cached_request(REQUEST_TYPE, input=output, **self.extra_args))["results"][0]
moderation_response = (
await arun_cached_request(client=self.client, request_type=REQUEST_TYPE, input=output, **self.extra_args)
)["results"][0]
return self.__postprocess_response(moderation_response)

@staticmethod
Expand All @@ -59,3 +71,5 @@ def compute_score(moderation_result, threshold):


__all__ = ["Moderation"]
__all__ = ["Moderation"]
__all__ = ["Moderation"]
Loading

0 comments on commit a3ca1cf

Please sign in to comment.