Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: abandon rag_dedicated client and improve #757

Merged
merged 1 commit into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,13 @@ clients:
api_key: xxx # ENV: {client}_API_KEY

# See https://jina.ai
- type: rag-dedicated
- type: openai-compatible
name: jina
api_base: https://api.jina.ai/v1
api_key: xxx # ENV: {client}_API_KEY

# See https://docs.voyageai.com/docs/introduction
- type: rag-dedicated
- type: openai-compatible
name: voyageai
api_base: https://api.voyageai.ai/v1
api_key: xxx # ENV: {client}_API_KEY
73 changes: 40 additions & 33 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,49 +30,56 @@ impl AzureOpenAIClient {
PromptKind::Integer,
),
];
}

impl_client_trait!(
AzureOpenAIClient,
(
prepare_chat_completions,
openai_chat_completions,
openai_chat_completions_streaming
),
(prepare_embeddings, openai_embeddings),
(noop_prepare_rerank, noop_rerank),
);

fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result<RequestData> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;
fn prepare_chat_completions(
self_: &AzureOpenAIClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_base = self_.get_api_base()?;
let api_key = self_.get_api_key()?;

let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
&api_base,
self.model.name()
);
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
&api_base,
self_.model.name()
);

let body = openai_build_chat_completions_body(data, &self.model);
let body = openai_build_chat_completions_body(data, &self_.model);

let mut request_data = RequestData::new(url, body);
let mut request_data = RequestData::new(url, body);

request_data.header("api-key", api_key);
request_data.header("api-key", api_key);

Ok(request_data)
}
Ok(request_data)
}

fn prepare_embeddings(&self, data: EmbeddingsData) -> Result<RequestData> {
let api_base = self.get_api_base()?;
let api_key = self.get_api_key()?;
fn prepare_embeddings(self_: &AzureOpenAIClient, data: EmbeddingsData) -> Result<RequestData> {
let api_base = self_.get_api_base()?;
let api_key = self_.get_api_key()?;

let url = format!(
"{}/openai/deployments/{}/embeddings?api-version=2024-02-01",
&api_base,
self.model.name()
);
let url = format!(
"{}/openai/deployments/{}/embeddings?api-version=2024-02-01",
&api_base,
self_.model.name()
);

let body = openai_build_embeddings_body(data, &self.model);
let body = openai_build_embeddings_body(data, &self_.model);

let mut request_data = RequestData::new(url, body);
let mut request_data = RequestData::new(url, body);

request_data.header("api-key", api_key);
request_data.header("api-key", api_key);

Ok(request_data)
}
Ok(request_data)
}

impl_client_trait!(
AzureOpenAIClient,
openai_chat_completions,
openai_chat_completions_streaming,
openai_embeddings
);
3 changes: 1 addition & 2 deletions src/client/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use super::*;
use crate::utils::{base64_decode, encode_uri, hex_encode, hmac_sha256, sha256};

use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use aws_smithy_eventstream::frame::{DecodedFrame, MessageFrameDecoder};
use aws_smithy_eventstream::smithy::parse_response_headers;
use bytes::BytesMut;
Expand Down Expand Up @@ -148,7 +147,7 @@ impl BedrockClient {
}
}

#[async_trait]
#[async_trait::async_trait]
impl Client for BedrockClient {
client_common_fns!();

Expand Down
44 changes: 28 additions & 16 deletions src/client/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,41 @@ impl ClaudeClient {

pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
}

fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result<RequestData> {
let api_key = self.get_api_key().ok();
impl_client_trait!(
ClaudeClient,
(
prepare_chat_completions,
claude_chat_completions,
claude_chat_completions_streaming
),
(noop_prepare_embeddings, noop_embeddings),
(noop_prepare_rerank, noop_rerank),
);

let body = claude_build_chat_completions_body(data, &self.model)?;
fn prepare_chat_completions(
self_: &ClaudeClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();

let mut request_data = RequestData::new(API_BASE, body);
let body = claude_build_chat_completions_body(data, &self_.model)?;

request_data.header("anthropic-version", "2023-06-01");
if let Some(api_key) = api_key {
request_data.header("x-api-key", api_key)
}
let mut request_data = RequestData::new(API_BASE, body);

Ok(request_data)
request_data.header("anthropic-version", "2023-06-01");
if let Some(api_key) = api_key {
request_data.header("x-api-key", api_key)
}
}

impl_client_trait!(
ClaudeClient,
claude_chat_completions,
claude_chat_completions_streaming
);
Ok(request_data)
}

pub async fn claude_chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
pub async fn claude_chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
Expand All @@ -59,6 +70,7 @@ pub async fn claude_chat_completions(builder: RequestBuilder) -> Result<ChatComp
pub async fn claude_chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let mut function_name = String::new();
let mut function_arguments = String::new();
Expand Down
81 changes: 46 additions & 35 deletions src/client/cloudflare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,54 +26,64 @@ impl CloudflareClient {
("account_id", "Account ID:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String),
];
}

fn prepare_chat_completions(&self, data: ChatCompletionsData) -> Result<RequestData> {
let account_id = self.get_account_id()?;
let api_key = self.get_api_key()?;
impl_client_trait!(
CloudflareClient,
(
prepare_chat_completions,
chat_completions,
chat_completions_streaming
),
(prepare_embeddings, embeddings),
(noop_prepare_rerank, noop_rerank),
);

let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",
self.model.name()
);
fn prepare_chat_completions(
self_: &CloudflareClient,
data: ChatCompletionsData,
) -> Result<RequestData> {
let account_id = self_.get_account_id()?;
let api_key = self_.get_api_key()?;

let body = build_chat_completions_body(data, &self.model)?;
let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",
self_.model.name()
);

let mut request_data = RequestData::new(url, body);
let body = build_chat_completions_body(data, &self_.model)?;

request_data.bearer_auth(api_key);
let mut request_data = RequestData::new(url, body);

Ok(request_data)
}
request_data.bearer_auth(api_key);

Ok(request_data)
}

fn prepare_embeddings(&self, data: EmbeddingsData) -> Result<RequestData> {
let account_id = self.get_account_id()?;
let api_key = self.get_api_key()?;
fn prepare_embeddings(self_: &CloudflareClient, data: EmbeddingsData) -> Result<RequestData> {
let account_id = self_.get_account_id()?;
let api_key = self_.get_api_key()?;

let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",
self.model.name()
);
let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",
self_.model.name()
);

let body = json!({
"text": data.texts,
});
let body = json!({
"text": data.texts,
});

let mut request_data = RequestData::new(url, body);
let mut request_data = RequestData::new(url, body);

request_data.bearer_auth(api_key);
request_data.bearer_auth(api_key);

Ok(request_data)
}
Ok(request_data)
}

impl_client_trait!(
CloudflareClient,
chat_completions,
chat_completions_streaming,
embeddings
);

async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutput> {
async fn chat_completions(
builder: RequestBuilder,
_model: &Model,
) -> Result<ChatCompletionsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
Expand All @@ -88,6 +98,7 @@ async fn chat_completions(builder: RequestBuilder) -> Result<ChatCompletionsOutp
async fn chat_completions_streaming(
builder: RequestBuilder,
handler: &mut SseHandler,
_model: &Model,
) -> Result<()> {
let handle = |message: SseMmessage| -> Result<bool> {
if message.data == "[DONE]" {
Expand All @@ -103,7 +114,7 @@ async fn chat_completions_streaming(
sse_stream(builder, handle).await
}

async fn embeddings(builder: RequestBuilder) -> Result<EmbeddingsOutput> {
async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result<EmbeddingsOutput> {
let res = builder.send().await?;
let status = res.status();
let data: Value = res.json().await?;
Expand Down
Loading