Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: CohereGenerator #6395

Merged
merged 32 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0cbf196
added CohereGenerator with unit tests
sunilkumardash9 Oct 12, 2023
3576a4a
Merge branch 'main' into add-cohere-generator
sunilkumardash9 Oct 12, 2023
2f48f60
Merge branch 'main' into add-cohere-generator
masci Oct 13, 2023
bdaf1c2
1. added releasenote
sunilkumardash9 Oct 13, 2023
b518339
Merge branch 'main' into add-cohere-generator
sunilkumardash9 Oct 13, 2023
e2547aa
Merge remote-tracking branch 'origin/add-cohere-generator' into add-c…
sunilkumardash9 Oct 13, 2023
127e2c7
1. move client creation to __init__
sunilkumardash9 Oct 13, 2023
bcd855e
few fixes
sunilkumardash9 Oct 13, 2023
3fa2f0d
add cohere to git workflows
sunilkumardash9 Oct 13, 2023
f58e602
1. CohereGenerator as top level import in generators
sunilkumardash9 Oct 13, 2023
e074a5e
1. corrected git workflow files for cohere import
sunilkumardash9 Oct 16, 2023
44a3f14
Merge branch 'main' into add-cohere-generator
ZanSara Oct 16, 2023
106e7fc
added cohere in missed out workflow installs
sunilkumardash9 Oct 16, 2023
3b4574d
Merge branch 'main' into add-cohere-generator
ZanSara Oct 16, 2023
4ece817
1. Removed default_streaming_callback from cohere.py and added in test.
sunilkumardash9 Oct 19, 2023
bb63179
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
35a1304
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
24359a0
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
6449bd2
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
b43c698
Update haystack/preview/components/generators/cohere/cohere.py
sunilkumardash9 Oct 20, 2023
282ebab
Merge branch 'main' into add-cohere-generator
ZanSara Nov 23, 2023
e51d787
move out of folder
ZanSara Nov 23, 2023
2378802
Merge branch 'main' into add-cohere-generator
ZanSara Nov 23, 2023
1d8c849
black
ZanSara Nov 23, 2023
b789855
fix tests
ZanSara Nov 23, 2023
99f9a78
feedback
ZanSara Nov 23, 2023
4146edb
black
ZanSara Nov 23, 2023
36838ad
remove api key from tests
ZanSara Nov 23, 2023
f0b0596
read api key from env var if missing
ZanSara Nov 23, 2023
9d1a46a
typo
ZanSara Nov 23, 2023
40f2e20
black
ZanSara Nov 23, 2023
2e9821a
missing import
ZanSara Nov 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/linting_preview.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Mypy
if: steps.files.outputs.any_changed == 'true'
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:

- name: Install Haystack
run: |
pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere
pip install ./haystack-linter

- name: Pylint
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/tests_preview.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Run
run: pytest -m "not integration" test/preview
Expand Down Expand Up @@ -174,7 +174,7 @@ jobs:
sudo apt install ffmpeg # for local Whisper tests

- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Run
run: pytest --maxfail=5 -m "integration" test/preview
Expand Down Expand Up @@ -230,7 +230,7 @@ jobs:
colima start

- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Run Tika
run: docker run -d -p 9998:9998 apache/tika:2.9.0.0
Expand Down Expand Up @@ -281,7 +281,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Run
run: pytest --maxfail=5 -m "integration" test/preview -k 'not tika'
Expand Down
3 changes: 2 additions & 1 deletion haystack/preview/components/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from haystack.preview.components.generators.cohere import CohereGenerator
from haystack.preview.components.generators.hugging_face_local import HuggingFaceLocalGenerator
from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
from haystack.preview.components.generators.openai import GPTGenerator

__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"]
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator", "CohereGenerator"]
154 changes: 154 additions & 0 deletions haystack/preview/components/generators/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import logging
import os
import sys
from typing import Any, Callable, Dict, List, Optional

from haystack.lazy_imports import LazyImport
from haystack.preview import DeserializationError, component, default_from_dict, default_to_dict

with LazyImport(message="Run 'pip install cohere'") as cohere_import:
from cohere import Client, COHERE_API_URL

logger = logging.getLogger(__name__)


@component
class CohereGenerator:
"""LLM Generator compatible with Cohere's generate endpoint.

Queries the LLM using Cohere's API. Invocations are made using 'cohere' package.
See [Cohere API](https://docs.cohere.com/reference/generate) for more details.

Example usage:

```python
from haystack.preview.generators import CohereGenerator
generator = CohereGenerator(api_key="test-api-key")
generator.run(prompt="What's the capital of France?")
```
"""

def __init__(
self,
api_key: Optional[str] = None,
model: str = "command",
streaming_callback: Optional[Callable] = None,
api_base_url: str = COHERE_API_URL,
**kwargs,
):
"""
Instantiates a `CohereGenerator` component.
:param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var.
:param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly, command-nightly-light]. Defaults to "command".
:param streaming_callback: A callback function to be called with the streaming response. Defaults to None.
:param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai".
:param kwargs: Additional model parameters. These will be used during generation. Refer to https://docs.cohere.com/reference/generate for more details.
Some of the parameters are:
- 'max_tokens': The maximum number of tokens to be generated. Defaults to 1024.
- 'truncate': One of NONE|START|END to specify how the API will handle inputs longer than the maximum token length. Defaults to END.
- 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations.
- 'preset': Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the playground.
- 'end_sequences': The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
- 'stop_sequences': The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
- 'k': Defaults to 0, min value of 0.01, max value of 0.99.
- 'p': Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
- 'frequency_penalty': Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens,
proportional to how many times they have already appeared in the prompt or prior generation.'
- 'presence_penalty': Defaults to 0.0, min value of 0.0, max value of 1.0. Can be used to reduce repetitiveness of generated tokens.
Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
- 'return_likelihoods': One of GENERATION|ALL|NONE to specify how and if the token likelihoods are returned with the response. Defaults to NONE.
- 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens.
The format is {token_id: bias} where bias is a float between -10 and 10.

ZanSara marked this conversation as resolved.
Show resolved Hide resolved
"""
if not api_key:
api_key = os.environ.get("COHERE_API_KEY")
if not api_key:
raise ValueError(
"CohereGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY."
)

self.api_key = api_key
self.model = model
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.model_parameters = kwargs
self.client = Client(api_key=self.api_key, api_url=self.api_base_url)

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
if self.streaming_callback:
module = self.streaming_callback.__module__
if module == "builtins":
callback_name = self.streaming_callback.__name__
else:
callback_name = f"{module}.{self.streaming_callback.__name__}"
else:
callback_name = None

return default_to_dict(
self,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.model_parameters,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
streaming_callback = None
if "streaming_callback" in init_params and init_params["streaming_callback"]:
parts = init_params["streaming_callback"].split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
data["init_parameters"]["streaming_callback"] = streaming_callback
return default_from_dict(cls, data)

@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
def run(self, prompt: str):
"""
Queries the LLM with the prompts to produce replies.
:param prompt: The prompt to be sent to the generative model.
"""
response = self.client.generate(
model=self.model, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters
)
if self.streaming_callback:
metadata_dict: Dict[str, Any] = {}
for chunk in response:
self.streaming_callback(chunk)
metadata_dict["index"] = chunk.index
replies = response.texts
metadata_dict["finish_reason"] = response.finish_reason
metadata = [metadata_dict]
self._check_truncated_answers(metadata)
return {"replies": replies, "metadata": metadata}

metadata = [{"finish_reason": resp.finish_reason} for resp in response]
replies = [resp.text for resp in response]
self._check_truncated_answers(metadata)
return {"replies": replies, "metadata": metadata}

def _check_truncated_answers(self, metadata: List[Dict[str, Any]]):
"""
Check the `finish_reason` returned with the Cohere response.
If the `finish_reason` is `MAX_TOKEN`, log a warning to the user.
:param metadata: The metadata returned by the Cohere API.
"""
if metadata[0]["finish_reason"] == "MAX_TOKENS":
logger.warning(
"Responses have been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions."
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Add CohereGenerator compatible with Cohere generate endpoint
168 changes: 168 additions & 0 deletions test/preview/components/generators/test_cohere_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import os

import pytest
import cohere

from haystack.preview.components.generators import CohereGenerator


def default_streaming_callback(chunk):
"""
Default callback function for streaming responses from Cohere API.
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged.
"""
print(chunk.text, flush=True, end="")


class TestGPTGenerator:
def test_init_default(self):
component = CohereGenerator(api_key="test-api-key")
assert component.api_key == "test-api-key"
assert component.model == "command"
assert component.streaming_callback is None
assert component.api_base_url == cohere.COHERE_API_URL
assert component.model_parameters == {}

def test_init_with_parameters(self):
callback = lambda x: x
component = CohereGenerator(
api_key="test-api-key",
model="command-light",
max_tokens=10,
some_test_param="test-params",
streaming_callback=callback,
api_base_url="test-base-url",
)
assert component.api_key == "test-api-key"
assert component.model == "command-light"
assert component.streaming_callback == callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}

def test_to_dict_default(self):
component = CohereGenerator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {"model": "command", "streaming_callback": None, "api_base_url": cohere.COHERE_API_URL},
}

def test_to_dict_with_parameters(self):
component = CohereGenerator(
api_key="test-api-key",
model="command-light",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {
"model": "command-light",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "test_cohere_generators.default_streaming_callback",
},
}

def test_to_dict_with_lambda_streaming_callback(self):
component = CohereGenerator(
api_key="test-api-key",
model="command",
max_tokens=10,
some_test_param="test-params",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {
"model": "command",
"streaming_callback": "test_cohere_generators.<lambda>",
"api_base_url": "test-base-url",
"max_tokens": 10,
"some_test_param": "test-params",
},
}

def test_from_dict(self, monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "test-key")
data = {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {
"model": "command",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "test_cohere_generators.default_streaming_callback",
},
}
component = CohereGenerator.from_dict(data)
assert component.api_key == "test-key"
assert component.model == "command"
assert component.streaming_callback == default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}

def test_check_truncated_answers(self, caplog):
component = CohereGenerator(api_key="test-api-key")
metadata = [{"finish_reason": "MAX_TOKENS"}]
component._check_truncated_answers(metadata)
assert caplog.records[0].message == (
"Responses have been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions."
)

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None),
reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_cohere_generator_run(self):
component = CohereGenerator(api_key=os.environ.get("COHERE_API_KEY"))
results = component.run(prompt="What's the capital of France?")
assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert results["metadata"][0]["finish_reason"] == "COMPLETE"

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None),
reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_cohere_generator_run_wrong_model_name(self):
component = CohereGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY"))
with pytest.raises(
cohere.CohereAPIError,
match="model not found, make sure the correct model ID was used and that you have access to the model.",
):
component.run(prompt="What's the capital of France?")

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None),
reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_cohere_generator_run_streaming(self):
class Callback:
def __init__(self):
self.responses = ""

def __call__(self, chunk):
self.responses += chunk.text
return chunk

callback = Callback()
component = CohereGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback)
results = component.run(prompt="What's the capital of France?")

assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert results["metadata"][0]["finish_reason"] == "COMPLETE"
assert callback.responses == results["replies"][0]