Skip to content

Commit

Permalink
[fix] add mistral as new provider + new models
Browse files Browse the repository at this point in the history
  • Loading branch information
Daggx committed Nov 8, 2023
1 parent 5c15495 commit 7f19012
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 0 deletions.
6 changes: 6 additions & 0 deletions edenai_apis/api_keys/mistral_settings_template.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"user_id": "",
"app_id": "",
"key": ""
}

1 change: 1 addition & 0 deletions edenai_apis/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,6 @@
from .vernai import VernaiApi
from .readyredact import ReadyRedactApi
from .senseloaf import SenseloafApi
from .mistral import MistralApi

# THIS NEEDS TO BE DONE AUTOMATICALLY
1 change: 1 addition & 0 deletions edenai_apis/apis/mistral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mistral_api import MistralApi
17 changes: 17 additions & 0 deletions edenai_apis/apis/mistral/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from edenai_apis.utils.exception import (
ProviderErrorLists,
ProviderInternalServerError,
ProviderTimeoutError,
)

# NOTE: error messages should be regex patterns
ERRORS: ProviderErrorLists = {
ProviderInternalServerError: [
r"Error calling Clarifai",
r"Failure",
],
ProviderTimeoutError: [
r"<[^<>]+debug_error_string = 'UNKNOWN:Error received from peer ipv4:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d+ {created_time:'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+[\+\-]\d{2}:\d{2}', grpc_status:14, grpc_message:'GOAWAY received'}'>",
r"Model is deploying"
]
}
11 changes: 11 additions & 0 deletions edenai_apis/apis/mistral/info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"text": {
"generation": {
"constraints": {
"models":["mistral-7B-Instruct", "mistral-7B-OpenOrca","openHermes-2-mistral-7B"],
"default_model": "mistral-7B-Instruct"
},
"version": "v1"
}
}
}
87 changes: 87 additions & 0 deletions edenai_apis/apis/mistral/mistral_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Dict
from google.protobuf.json_format import MessageToDict
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api.status import status_code_pb2

from edenai_apis.features import ProviderInterface, TextInterface
from edenai_apis.features.text.generation.generation_dataclass import (
GenerationDataClass,
)
from edenai_apis.loaders.data_loader import ProviderDataEnum
from edenai_apis.loaders.loaders import load_provider
from edenai_apis.utils.exception import ProviderException
from edenai_apis.utils.types import ResponseType



class MistralApi(ProviderInterface, TextInterface):
provider_name = "mistral"

def __init__(self, api_keys: Dict = {}) -> None:
self.api_settings = load_provider(
ProviderDataEnum.KEY, self.provider_name, api_keys=api_keys
)
self.user_id = self.api_settings["user_id"]
self.app_id = self.api_settings["app_id"]
self.key = self.api_settings["key"]

def __chat_markup_tokens(self, model):
if model == "mistral-7B-Instruct":
return "[INST]", "[/INST]"
else:
return "<|im_start|>", "<|im_end|>"

def text__generation(
self, text: str, temperature: float, max_tokens: int, model: str
) -> ResponseType[GenerationDataClass]:
start, end = self.__chat_markup_tokens(model)

text = f"{start} {text} {end}"

channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)

metadata = (("authorization", self.key),)
user_data_object = resources_pb2.UserAppIDSet(
user_id="mistralai", app_id="completion"
)

post_model_outputs_response = stub.PostModelOutputs(
service_pb2.PostModelOutputsRequest(
user_app_id=user_data_object,
model_id=model,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=text))
)
],
),
metadata=metadata,
)

if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
raise ProviderException(
post_model_outputs_response.status.description,
code=post_model_outputs_response.status.code,
)

response = MessageToDict(
post_model_outputs_response, preserving_proto_field_name=True
)

output = response.get("outputs", [])
if len(output) == 0:
raise ProviderException(
"Mistral returned an empty response!",
code=post_model_outputs_response.status.code,
)

original_response = output[0].get("data", {}) or {}

return ResponseType[GenerationDataClass](
original_response=original_response,
standardized_response=GenerationDataClass(
generated_text=(original_response.get("text", {}) or {}).get("raw", "")
),
)
13 changes: 13 additions & 0 deletions edenai_apis/apis/mistral/outputs/text/generation_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"original_response": {
"text": {
"raw": "party\n\nAI assistant: Hi there! I'm an AI language model, designed to assist and engage in conversations. My name is not \"who are you?\" but I'm here to help you with any questions or tasks you may have. How can I assist you today?",
"text_info": {
"encoding": "UnknownTextEnc"
}
}
},
"standardized_response": {
"generated_text": "party\n\nAI assistant: Hi there! I'm an AI language model, designed to assist and engage in conversations. My name is not \"who are you?\" but I'm here to help you with any questions or tasks you may have. How can I assist you today?"
}
}

0 comments on commit 7f19012

Please sign in to comment.