Skip to content

Commit

Permalink
Add TogetherCompletionClient (stanford-crfm#2629)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored and weiqipedia committed May 20, 2024
1 parent cb50621 commit 6b3de7f
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 7 deletions.
74 changes: 72 additions & 2 deletions src/helm/clients/test_together_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import pytest
import tempfile

from helm.common.cache import SqliteCacheConfig
from helm.common.cache import BlackHoleCacheConfig, SqliteCacheConfig
from helm.common.request import Request

from .together_client import TogetherClient, TogetherClientError
from .together_client import TogetherClient, TogetherChatClient, TogetherCompletionClient, TogetherClientError


class TestTogetherClient:
Expand Down Expand Up @@ -107,3 +107,73 @@ def test_api_key_error(self):
model_deployment="together/redpajama-incite-base-3b-v1",
)
)


@pytest.mark.models
def test_together_chat_client_make_request():
# Requires setting TOGETHER_API_KEY environment variable.
client = TogetherChatClient(
cache_config=BlackHoleCacheConfig(), api_key=None, together_model="meta-llama/Llama-3-8b-chat-hf"
)
request = Request(
model="meta/llama-3-8b-instruct",
model_deployment="together/llama-3-8b-instruct",
prompt="Elephants are one of the most",
temperature=0.0,
max_tokens=10,
)
result = client.make_request(request)
assert result.success
assert not result.cached
assert result.embedding == []
assert len(result.completions) == 1
assert result.completions[0].text == "...intelligent animals on Earth!assistant"
assert result.completions[0].logprob == 0.0
result_token_strings = [token.text for token in result.completions[0].tokens]
assert result_token_strings == [
"...",
"int",
"elligent",
" animals",
" on",
" Earth",
"!",
"<|eot_id|>",
"<|start_header_id|>",
"assistant",
]


@pytest.mark.models
def test_together_completion_client_make_request():
# Requires setting TOGETHER_API_KEY environment variable.
client = TogetherCompletionClient(
cache_config=BlackHoleCacheConfig(), api_key=None, together_model="meta-llama/Llama-3-8b-hf"
)
request = Request(
model="meta/llama-3-8b",
model_deployment="together/llama-3-8b",
prompt="Elephants are one of the most",
temperature=0.0,
max_tokens=10,
)
result = client.make_request(request)
assert result.success
assert not result.cached
assert result.embedding == []
assert len(result.completions) == 1
assert result.completions[0].text == " popular animals in the world. They are known for"
assert result.completions[0].logprob == 0.0
result_token_strings = [token.text for token in result.completions[0].tokens]
assert result_token_strings == [
" popular",
" animals",
" in",
" the",
" world",
".",
" They",
" are",
" known",
" for",
]
99 changes: 94 additions & 5 deletions src/helm/clients/together_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

try:
from together import Together
from together.types import ChatCompletionResponse
from together.types import ChatCompletionResponse, CompletionResponse
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["together"])

Expand Down Expand Up @@ -295,14 +295,18 @@ class TogetherRawChatRequest(TypedDict):
n: int


def convert_to_raw_chat_request(request: Request) -> TogetherRawChatRequest:
def convert_to_raw_chat_request(request: Request, together_model: Optional[str]) -> TogetherRawChatRequest:
if request.messages:
messages = request.messages
else:
messages = [{"role": "user", "content": request.prompt}]
if together_model is not None:
model = together_model
else:
model = request.model
return {
"messages": messages,
"model": request.model,
"model": model,
"max_tokens": request.max_tokens,
"stop": request.stop_sequences,
"temperature": request.temperature,
Expand All @@ -317,12 +321,13 @@ def convert_to_raw_chat_request(request: Request) -> TogetherRawChatRequest:
class TogetherChatClient(CachingClient):
"""Client that uses the Python Together library for chat models."""

def __init__(self, cache_config: CacheConfig, api_key: str, together_model: Optional[str] = None):
def __init__(self, cache_config: CacheConfig, api_key: Optional[str], together_model: Optional[str] = None):
super().__init__(cache_config=cache_config)
self._client = Together(api_key=api_key)
self._together_model = together_model

def make_request(self, request: Request) -> RequestResult:
raw_request = convert_to_raw_chat_request(request)
raw_request = convert_to_raw_chat_request(request, self._together_model)
cache_key = CachingClient.make_cache_key(raw_request, request)

def do_it() -> Dict[Any, Any]:
Expand Down Expand Up @@ -363,3 +368,87 @@ def do_it() -> Dict[Any, Any]:
completions=generated_outputs,
embedding=[],
)


class TogetherRawCompletionRequest(TypedDict):
prompt: str
model: str
max_tokens: int
stop: List[str]
temperature: float
top_p: float
top_k: int
logprobs: int
echo: bool
n: int


def convert_to_raw_completion_request(request: Request, together_model: Optional[str]) -> TogetherRawCompletionRequest:
if together_model is not None:
model = together_model
else:
model = request.model
return {
"prompt": request.prompt,
"model": model,
"max_tokens": request.max_tokens,
"stop": request.stop_sequences,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k_per_token,
"logprobs": min(request.top_k_per_token, 1),
"echo": request.echo_prompt,
"n": request.num_completions,
}


class TogetherCompletionClient(CachingClient):
"""Client that uses the Python Together library for text completion models."""

def __init__(self, cache_config: CacheConfig, api_key: Optional[str], together_model: Optional[str] = None):
super().__init__(cache_config=cache_config)
self._client = Together(api_key=api_key)
self._together_model = together_model

def make_request(self, request: Request) -> RequestResult:
raw_request = convert_to_raw_completion_request(request, self._together_model)
cache_key = CachingClient.make_cache_key(raw_request, request)

def do_it() -> Dict[Any, Any]:
response = self._client.completions.create(**raw_request)
return response.model_dump(mode="json")

try:
raw_response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
response = CompletionResponse.model_validate(raw_response)
except Exception as error:
return RequestResult(
success=False,
cached=False,
error=str(error),
completions=[],
embedding=[],
)

generated_outputs: List[GeneratedOutput] = []
for choice in response.choices:
# NOTE: Together always returns None for choice.finish_reason
# NOTE: Together does not return logprobs for the whole generated output, only for individual tokens
tokens: List[Token] = []
if choice.logprobs:
for token_text, token_logprob in zip_longest(
choice.logprobs.tokens or [], choice.logprobs.token_logprobs or []
):
if token_text is None:
break
tokens.append(Token(text=token_text, logprob=token_logprob or 0.0))
assert choice.text
generated_outputs.append(GeneratedOutput(text=choice.text, logprob=0.0, tokens=tokens))
return RequestResult(
success=True,
cached=cached,
request_time=raw_response["request_time"],
request_datetime=raw_response["request_datetime"],
completions=generated_outputs,
embedding=[],
)

0 comments on commit 6b3de7f

Please sign in to comment.