Skip to content

Commit

Permalink
[ENH] - add support for Cohere assistants (#307)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
smokestacklightnin and pmeier authored Feb 9, 2024
1 parent 8aa6bd0 commit dc39e01
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
11 changes: 11 additions & 0 deletions docs/references/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,14 @@
```bash
export GOOGLE_API_KEY="XXXXX"
```

### [Cohere](https://cohere.com/)

1. > To use the API, you need an API key. You can create a key in the Cohere dashboard
> using your Cohere account.
>
> ~ [Cohere Dashboard](https://dashboard.cohere.com/api-keys)
2. Set the `COHERE_API_KEY` environment variable with your Cohere API key:
```bash
export COHERE_API_KEY="XXXXX"
```
3 changes: 3 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__all__ = [
"Claude",
"ClaudeInstant",
"Command",
"CommandLight",
"GeminiPro",
"GeminiUltra",
"Gpt35Turbo16k",
Expand All @@ -11,6 +13,7 @@
]

from ._anthropic import Claude, ClaudeInstant
from ._cohere import Command, CommandLight
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._mosaicml import Mpt7bInstruct, Mpt30bInstruct
Expand Down
88 changes: 88 additions & 0 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source

from ._api import ApiAssistant


class CohereApiAssistant(ApiAssistant):
_API_KEY_ENV_VAR = "COHERE_API_KEY"
_MODEL: str
_CONTEXT_SIZE: int = 4_000
# See https://docs.cohere.com/docs/models#command

@classmethod
def display_name(cls) -> str:
return f"Cohere/{cls._MODEL}"

@property
def max_input_size(self) -> int:
return self._CONTEXT_SIZE

def _make_preamble(self) -> str:
return (
"You are a helpful assistant that answers user questions given the included context. "
"If you don't know the answer, just say so. Don't try to make up an answer. "
"Only use the included documents below to generate the answer."
)

def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]:
return [{"title": source.id, "snippet": source.content} for source in sources]

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> AsyncIterator[str]:
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
async with self._client.stream(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {self._api_key}",
},
json={
"preamble_override": self._make_preamble(),
"message": prompt,
"model": self._MODEL,
"stream": True,
"temperature": 0.0,
"max_tokens": max_new_tokens,
"documents": self._make_source_documents(sources),
},
) as response:
if response.is_error:
raise RagnaException(status_code=response.status_code)
async for chunk in response.aiter_lines():
event = json.loads(chunk)
if event["event_type"] == "stream-end":
break
if "text" in event:
yield cast(str, event["text"])


class Command(CohereApiAssistant):
"""
[Cohere Command](https://docs.cohere.com/docs/models#command)
!!! info "Required environment variables"
- `COHERE_API_KEY`
"""

_MODEL = "command"


class CommandLight(CohereApiAssistant):
"""
[Cohere Command-Light](https://docs.cohere.com/docs/models#command)
!!! info "Required environment variables"
- `COHERE_API_KEY`
"""

_MODEL = "command-light"

0 comments on commit dc39e01

Please sign in to comment.