Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Dec 4, 2024
2 parents 1d38e4b + ffa12f3 commit 515d04b
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Nerve features integrations for any model accessible via the following providers
| **Nvidia NIM** | `NIM_API_KEY` | `nim://nvidia/nemotron-4-340b-instruct` |
| **DeepSeek** | `DEEPSEEK_API_KEY` | `deepseek://deepseek-chat` |
| **xAI** | `XAI_API_KEY` | `xai://grok-beta` |
| **Mistral.ai** | `MISTRAL_API_KEY` | `mistral://mistral-large-latest` |
| **Novita** | `NOVITA_API_KEY` | `novita://meta-llama/llama-3.1-70b-instruct` |

¹ Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint.
Expand Down
66 changes: 66 additions & 0 deletions src/agent/generator/mistral.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use anyhow::Result;
use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};

pub struct MistralClient {
client: OpenAIClient,
}

#[async_trait]
impl Client for MistralClient {
fn new(_: &str, _: u16, model_name: &str, _: u32) -> anyhow::Result<Self>
where
Self: Sized,
{
let client = OpenAIClient::custom(model_name, "MISTRAL_API_KEY", "https://api.mistral.ai/v1/")?;

Ok(Self { client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
self.client.check_native_tools_support().await
}

async fn chat(
&self,
state: SharedState,
options: &ChatOptions,
) -> anyhow::Result<ChatResponse> {
let response = self.client.chat(state.clone(), options).await;

if let Err(error) = &response {
if self.check_rate_limit(&error.to_string()).await {
return self.chat(state, options).await;
}
}

response
}

async fn check_rate_limit(&self, error: &str) -> bool {
// if message contains "Requests rate limit exceeded" return true
if error.contains("Requests rate limit exceeded") {
let retry_time = std::time::Duration::from_secs(5);
log::warn!(
"rate limit reached for this model, retrying in {:?} ...",
&retry_time,
);

tokio::time::sleep(retry_time).await;

return true;
}

false
}
}

#[async_trait]
impl mini_rag::Embedder for MistralClient {
async fn embed(&self, text: &str) -> Result<mini_rag::Embeddings> {
self.client.embed(text).await
}
}
7 changes: 7 additions & 0 deletions src/agent/generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod ollama;
mod openai;
mod openai_compatible;
mod xai;
mod mistral;

pub(crate) mod history;
mod options;
Expand Down Expand Up @@ -216,6 +217,12 @@ macro_rules! factory_body {
$model_name,
$context_window,
)?)),
"mistral" => Ok(Box::new(mistral::MistralClient::new(
$url,
$port,
$model_name,
$context_window,
)?)),
"http" => Ok(Box::new(openai_compatible::OpenAiCompatibleClient::new(
$url,
$port,
Expand Down

0 comments on commit 515d04b

Please sign in to comment.