Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Command R and Command R+ models #2548

Merged
merged 11 commits into from
May 20, 2024
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ anthropic =
anthropic~=0.17
websocket-client~=1.3.2 # For legacy stanford-online-all-v4-s3

cohere =
cohere~=5.3

mistral =
mistralai~=0.0.11

Expand All @@ -152,6 +155,7 @@ models =
crfm-helm[allenai]
crfm-helm[amazon]
crfm-helm[anthropic]
crfm-helm[cohere]
crfm-helm[google]
crfm-helm[mistral]
crfm-helm[openai]
Expand Down
101 changes: 98 additions & 3 deletions src/helm/clients/cohere_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import requests
from typing import List
from typing import List, Optional, Sequence, TypedDict

from helm.common.cache import CacheConfig
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import (
wrap_request_time,
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
Expand All @@ -11,8 +12,13 @@
GeneratedOutput,
Token,
)
from .client import CachingClient, truncate_sequence
from .cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
from helm.clients.client import CachingClient, truncate_sequence
from helm.clients.cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION

try:
import cohere
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["cohere"])


class CohereClient(CachingClient):
Expand Down Expand Up @@ -152,3 +158,92 @@ def do_it():
completions=completions,
embedding=[],
)


class CohereRawChatRequest(TypedDict):
message: str
model: Optional[str]
preamble: Optional[str]
chat_history: Optional[Sequence[cohere.ChatMessage]]
temperature: Optional[float]
max_tokens: Optional[int]
k: Optional[int]
p: Optional[float]
seed: Optional[float]
stop_sequences: Optional[Sequence[str]]
frequency_penalty: Optional[float]
presence_penalty: Optional[float]


def convert_to_raw_chat_request(request: Request) -> CohereRawChatRequest:
# TODO: Support chat
model = request.model.replace("cohere/", "")
return {
"message": request.prompt,
"model": model,
"preamble": None,
"chat_history": None,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"k": request.top_k_per_token,
"p": request.top_p,
"stop_sequences": request.stop_sequences,
"seed": float(request.random) if request.random is not None else None,
"frequency_penalty": request.frequency_penalty,
"presence_penalty": request.presence_penalty,
}


class CohereChatClient(CachingClient):
"""
Leverages the chat endpoint: https://docs.cohere.com/reference/chat

Cohere models will only support chat soon: https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat
"""

def __init__(self, api_key: str, cache_config: CacheConfig):
super().__init__(cache_config=cache_config)
self.client = cohere.Client(api_key=api_key)

def make_request(self, request: Request) -> RequestResult:
if request.embedding:
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT
# TODO: Support multiple completions
assert request.num_completions == 1, "CohereChatClient only supports num_completions=1"
# TODO: Support messages
assert not request.messages, "CohereChatClient currently does not support the messages API"

raw_request: CohereRawChatRequest = convert_to_raw_chat_request(request)

try:

def do_it():
"""
Send the request to the Cohere Chat API. Responses will be structured like this:
cohere.Chat {
message: What's up?
text: Hey there! How's it going? I'm doing well, thank you for asking 😊.
...
}
"""
raw_response = self.client.chat(**raw_request).dict()
assert "text" in raw_response, f"Response does not contain text: {raw_response}"
return raw_response

response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
except (requests.exceptions.RequestException, AssertionError) as e:
error: str = f"CohereClient error: {e}"
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])

completions: List[GeneratedOutput] = []
completion: GeneratedOutput = GeneratedOutput(text=response["text"], logprob=0.0, tokens=[])
completions.append(completion)

return RequestResult(
success=True,
cached=cached,
request_time=response["request_time"],
request_datetime=response["request_datetime"],
completions=completions,
embedding=[],
)
19 changes: 19 additions & 0 deletions src/helm/config/model_deployments.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,25 @@ model_deployments:
window_service_spec:
class_name: "helm.benchmark.window_services.cohere_window_service.CohereWindowService"

- name: cohere/command-r
model_name: cohere/command-r
tokenizer_name: cohere/c4ai-command-r-v01
max_sequence_length: 128000
max_request_length: 128000
client_spec:
class_name: "helm.clients.cohere_client.CohereChatClient"

- name: cohere/command-r-plus
model_name: cohere/command-r-plus
tokenizer_name: cohere/c4ai-command-r-plus
# "We have a known issue where prompts between 112K - 128K in length
# result in bad generations."
# Source: https://docs.cohere.com/docs/command-r-plus
max_sequence_length: 110000
max_request_length: 110000
client_spec:
class_name: "helm.clients.cohere_client.CohereChatClient"

# Craiyon

- name: craiyon/dalle-mini
Expand Down
20 changes: 19 additions & 1 deletion src/helm/config/model_metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,25 @@ models:
creator_organization_name: Cohere
access: limited
release_date: 2023-09-29
tags: [TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

- name: cohere/command-r
display_name: Cohere Command R
description: Command R is a multilingual 35B parameter model with a context length of 128K that has been trained with conversational tool use capabilities.
creator_organization_name: Cohere
access: open
num_parameters: 35000000000
release_date: 2024-03-11
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

- name: cohere/command-r-plus
display_name: Cohere Command R Plus
description: Command R+ is a multilingual 104B parameter model with a context length of 128K that has been trained with conversational tool use capabilities.
creator_organization_name: Cohere
access: open
num_parameters: 104000000000
release_date: 2024-04-04
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

# Craiyon
- name: craiyon/dalle-mini
Expand Down
16 changes: 16 additions & 0 deletions src/helm/config/tokenizer_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ tokenizer_configs:
end_of_text_token: ""
prefix_token: ":"

- name: cohere/c4ai-command-r-v01
tokenizer_spec:
class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer"
args:
pretrained_model_name_or_path: CohereForAI/c4ai-command-r-v01
end_of_text_token: "<|END_OF_TURN_TOKEN|>"
prefix_token: "<BOS_TOKEN>"

- name: cohere/c4ai-command-r-plus
tokenizer_spec:
class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer"
args:
pretrained_model_name_or_path: CohereForAI/c4ai-command-r-plus
end_of_text_token: "<|END_OF_TURN_TOKEN|>"
prefix_token: "<BOS_TOKEN>"

# Databricks
- name: databricks/dbrx-instruct
tokenizer_spec:
Expand Down