Skip to content

Commit

Permalink
[fix] extract json from cohere response
Browse files Browse the repository at this point in the history
  • Loading branch information
KyrianC committed Nov 21, 2023
1 parent 6940eae commit f042f4a
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 38 deletions.
38 changes: 19 additions & 19 deletions edenai_apis/apis/cohere/cohere_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, List, Dict, Sequence, Literal

import requests
from edenai_apis.apis.cohere.helpers import extract_json_text

from edenai_apis.features import ProviderInterface, TextInterface
from edenai_apis.features.text.custom_classification import (
Expand Down Expand Up @@ -78,19 +79,21 @@ def _format_custom_ner_examples(example: Dict):
Text: {text}
Answer: [{', '.join([f'{{"entity":"{entity["entity"]}", "category":"{entity["category"]}"}}' for entity in extracted_entities])}]
Answer: ```json[{', '.join([f'{{"entity":"{entity["entity"]}", "category":"{entity["category"]}"}}' for entity in extracted_entities])}]```
"""

@staticmethod
def _format_spell_check_prompt(text: str) -> str:
return f"""
Given a text with spelling errors, identify the misspelled words and correct them.
Return the results as a list of dictionaries and only the json result, where each dictionary contains two keys: "word" and "correction".
The "word" key should contain the misspelled word, and the "correction" key should contain the corrected version of the word.
Given a text with spelling errors, identify the misspelled words and correct them.
Return the results as a json list of objects, where each object contains two keys: "word" and "correction".
The "word" key should contain the misspelled word, and the "correction" key should contain the corrected version of the word.
Return the json response between ```json and ```.
For example, if the misspelled word is 'halo', the corresponding dictionary should be: {{"word": "halo", "correction": "hello"}}.
Text: {text}
Examples of entry Text with misspelling: "Hallo my friend hw are you"
Examples of response: [{{"word": "Hallo", "correction": "hello"}}, {{"word": "hw", "correction": "how"}}]
Examples of response: ```json[{{"word": "Hallo", "correction": "hello"}}, {{"word": "hw", "correction": "how"}}]```
List of corrected words:
"""

Expand Down Expand Up @@ -237,7 +240,7 @@ def text__custom_named_entity_recognition(
prompt = f"""You act as a named entities recognition model.
Extract an exhaustive list of Entities from the given Text according to the specified Categories and return the list as a valid JSON.
ONLY return a valid JSON. DO NOT return any other form of text. The keys of each objects in the list are `entity` and `category`.
return the json response between ```json and ```. The keys of each objects in the list are `entity` and `category`.
`entity` value must be the extracted entity from the text, `category` value must be the category of the extracted entity.
The JSON MUST be valid and conform to the given description.
Be correct and concise. If no entities are found, return an empty list.
Expand Down Expand Up @@ -265,11 +268,9 @@ def text__custom_named_entity_recognition(
data = original_response.get("text")

try:
items = json.loads(data)
items = extract_json_text(data)
except json.JSONDecodeError as exc:
raise ProviderException(
"Cohere didn't return valid JSON object"
) from exc
raise ProviderException("Cohere didn't return valid JSON object") from exc

standardized_response = CustomNamedEntityRecognitionDataClass(items=items)

Expand Down Expand Up @@ -300,14 +301,13 @@ def text__spell_check(
)

try:
data = original_response["text"]
corrected_items = json.loads(data)
data = extract_json_text(original_response["text"])
except json.JSONDecodeError as exc:
raise ProviderException(
"An error occurred while parsing the response."
) from exc

corrections = construct_word_list(text, corrected_items)
corrections = construct_word_list(text, data)
items: List[SpellCheckItem] = []
for item in corrections:
items.append(
Expand All @@ -330,16 +330,16 @@ def text__embeddings(
self, texts: List[str], model: str
) -> ResponseType[EmbeddingsDataClass]:
url = f"{self.base_url}embed"
model = model.split("__")
payload = {"texts": texts, "model": model[1]}
model = model.split("__")[1]
payload = {"texts": texts, "model": model}
response = requests.post(url, json=payload, headers=self.headers)
original_response = response.json()
if "message" in original_response:
raise ProviderException(
original_response["message"], code=response.status_code
)

items: Sequence[EmbeddingsDataClass] = []
items: Sequence[EmbeddingDataClass] = []
for prediction in original_response["embeddings"]:
items.append(EmbeddingDataClass(embedding=prediction))

Expand All @@ -356,7 +356,7 @@ def text__search(
similarity_metric: Literal[
"cosine", "hamming", "manhattan", "euclidean"
] = "cosine",
model: str = None,
model: Optional[str] = None,
) -> ResponseType[SearchDataClass]:
if model is None:
model = "768__embed-multilingual-v2.0"
Expand All @@ -365,14 +365,14 @@ def text__search(

# Embed the texts & query
texts_embed_response = self.text__embeddings(
texts=texts, model=model
texts=texts, model=model
).original_response
query_embed_response = self.text__embeddings(
texts=[query], model=model
).original_response

# Extracts embeddings from texts & query
texts_embed = [item for item in texts_embed_response["embeddings"]]
texts_embed = list(texts_embed_response["embeddings"])
query_embed = query_embed_response["embeddings"][0]

items = []
Expand Down
14 changes: 14 additions & 0 deletions edenai_apis/apis/cohere/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import json
import re
from typing import Optional, Union


def extract_json_text(input_string: str) -> Optional[Union[dict, list]]:
pattern = r"```json(.*?)```"
match = re.search(pattern, input_string, re.DOTALL)

if match:
json_text = match.group(1).strip()
return json.loads(json_text)
else:
return None
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
{
"original_response": {
"id": "6325db70-ebb6-4b67-be39-b1f5b1f91631",
"generations": [
{
"id": "7d26632a-a5fd-4f6b-bc81-f1d0df6702f1",
"text": "[\n{\n \"entity\": \"John Smith\",\n \"category\": \"person\"\n},\n{\n \"entity\": \"Starbucks\",\n \"category\": \"location\"\n},\n{\n \"entity\": \"New York City\",\n \"category\": \"location\"\n},\n{\n \"entity\": \"IBM\",\n \"category\": \"organization\"\n}\n]"
}
],
"prompt": " Extract the specified entities (person,location,organization) from the text enclosed in hash symbols (#) and return a JSON List of dictionaries with two keys: \"entity\" and \"category\". The \"entity\" key represents the detected entity and the \"category\" key represents the category of the entity.\n\n If no entities are found, return an empty list.\n \n Text: #Yesterday, I met John Smith at Starbucks in New York City. He works for IBM.#\n Answer:\n",
"response_id": "1413f4b4-0dfc-4fd1-af7d-722cc1575643",
"text": "Sure, here is a JSON object that contains the extracted entities along with their categories from the given text:\n\n```json\n[\n {\n \"entity\": \"John Smith\",\n \"category\": \"person\"\n },\n {\n \"entity\": \"Starbucks\",\n \"category\": \"location\"\n },\n {\n \"entity\": \"New York City\",\n \"category\": \"location\"\n },\n {\n \"entity\": \"IBM\",\n \"category\": \"organization\"\n }\n]\n```\n\nWould you like me to extract entities from another text?",
"generation_id": "4fac970e-5a40-409a-8f9c-89d1676aad91",
"token_count": {
"prompt_tokens": 342,
"response_tokens": 120,
"total_tokens": 462,
"billed_tokens": 453
},
"meta": {
"api_version": {
"version": "2022-12-06"
},
"warnings": [
"Your text contains a trailing whitespace, which has been trimmed to ensure high quality generations."
]
"billed_units": {
"input_tokens": 333,
"output_tokens": 120
}
}
},
"standardized_response": {
Expand Down
20 changes: 12 additions & 8 deletions edenai_apis/apis/cohere/outputs/text/spell_check_output.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
{
"original_response": {
"response_id": "7044f995-27f0-44b2-a27b-c64a7d31684a",
"text": "[{\"word\": \"Hollo\", \"correction\": \"hello\"}, {\"word\": \"wrld\", \"correction\": \"world\"}, {\"word\": \"yu\", \"correction\": \"you\"}]",
"generation_id": "a2e46520-1f79-46a6-a3ea-0097b032624a",
"response_id": "72554fe7-7f63-4649-8aaf-886ea82aa0e7",
"text": "```json\n[\n {\n \"word\": \"Hollo\",\n \"correction\": \"Hello\"\n },\n {\n \"word\": \"wrld\",\n \"correction\": \"world\"\n },\n {\n \"word\": \"yu\",\n \"correction\": \"you\"\n }\n]\n```",
"generation_id": "efbfcdac-8915-4bda-875a-38992cf4875d",
"token_count": {
"prompt_tokens": 250,
"response_tokens": 42,
"total_tokens": 292,
"billed_tokens": 283
"prompt_tokens": 257,
"response_tokens": 67,
"total_tokens": 324,
"billed_tokens": 315
},
"meta": {
"api_version": {
"version": "2022-12-06"
},
"billed_units": {
"input_tokens": 248,
"output_tokens": 67
}
}
},
Expand All @@ -25,7 +29,7 @@
"length": 5,
"suggestions": [
{
"suggestion": "hello",
"suggestion": "Hello",
"score": 1.0
}
]
Expand Down

0 comments on commit f042f4a

Please sign in to comment.