Skip to content

Commit

Permalink
Merge pull request #26 from Nyamort/main
Browse files Browse the repository at this point in the history
new: mistral integration
  • Loading branch information
evilsocket authored Dec 4, 2024
2 parents e023ac4 + 066605c commit 950d5db
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/agent/generator/mistral.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use anyhow::Result;
use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};

pub struct MistralClient {
client: OpenAIClient,
}

#[async_trait]
impl Client for MistralClient {
fn new(_: &str, _: u16, model_name: &str, _: u32) -> anyhow::Result<Self>
where
Self: Sized,
{
let client = OpenAIClient::custom(model_name, "MISTRAL_API_KEY", "https://api.mistral.ai/v1/")?;

Ok(Self { client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
self.client.check_native_tools_support().await
}

async fn chat(
&self,
state: SharedState,
options: &ChatOptions,
) -> anyhow::Result<ChatResponse> {
let response = self.client.chat(state.clone(), options).await;

if let Err(error) = &response {
if self.check_rate_limit(&error.to_string()).await {
return self.chat(state, options).await;
}
}

response
}

async fn check_rate_limit(&self, error: &str) -> bool {
// if message contains "Requests rate limit exceeded" return true
if error.contains("Requests rate limit exceeded") {
let retry_time = std::time::Duration::from_secs(5);
log::warn!(
"rate limit reached for this model, retrying in {:?} ...",
&retry_time,
);

tokio::time::sleep(retry_time).await;

return true;
}

false
}
}

#[async_trait]
impl mini_rag::Embedder for MistralClient {
async fn embed(&self, text: &str) -> Result<mini_rag::Embeddings> {
self.client.embed(text).await
}
}
7 changes: 7 additions & 0 deletions src/agent/generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod ollama;
mod openai;
mod openai_compatible;
mod xai;
mod mistral;

mod options;

Expand Down Expand Up @@ -208,6 +209,12 @@ macro_rules! factory_body {
$model_name,
$context_window,
)?)),
"mistral" => Ok(Box::new(mistral::MistralClient::new(
$url,
$port,
$model_name,
$context_window,
)?)),
"http" => Ok(Box::new(openai_compatible::OpenAiCompatibleClient::new(
$url,
$port,
Expand Down

0 comments on commit 950d5db

Please sign in to comment.