diff --git a/Argcfile.sh b/Argcfile.sh index 138f4228..ac8f5da6 100755 --- a/Argcfile.sh +++ b/Argcfile.sh @@ -81,6 +81,7 @@ test-server() { OPENAI_COMPATIBLE_PLATFORMS=( \ openai,gpt-4o-mini,https://api.openai.com/v1 \ ai21,jamba-1.5-mini,https://api.ai21.com/studio/v1 \ + cloudflare,@cf/meta/llama-3.1-8b-instruct, \ deepinfra,meta-llama/Meta-Llama-3.1-8B-Instruct,https://api.deepinfra.com/v1/openai \ deepseek,deepseek-chat,https://api.deepseek.com \ fireworks,accounts/fireworks/models/llama-v3p1-8b-instruct,https://api.fireworks.ai/inference/v1 \ @@ -111,7 +112,7 @@ chat() { fi for platform_config in "${OPENAI_COMPATIBLE_PLATFORMS[@]}"; do if [[ "$argc_platform" == "${platform_config%%,*}" ]]; then - api_base="${platform_config##*,}" + _retrieve_api_base break fi done @@ -141,7 +142,7 @@ chat() { models() { for platform_config in "${OPENAI_COMPATIBLE_PLATFORMS[@]}"; do if [[ "$argc_platform" == "${platform_config%%,*}" ]]; then - api_base="${platform_config##*,}" + _retrieve_api_base break fi done @@ -149,7 +150,7 @@ models() { env_prefix="$(echo "$argc_platform" | tr '[:lower:]' '[:upper:]')" api_key_env="${env_prefix}_API_KEY" api_key="${!api_key_env}" - _openai_models + _retrieve_models else argc models-$argc_platform fi @@ -173,7 +174,7 @@ chat-openai-compatible() { # @option --api-base! $$ # @option --api-key! $$ models-openai-compatible() { - _openai_models + _retrieve_models } # @cmd Chat with azure-openai api @@ -271,19 +272,6 @@ chat-vertexai() { -d "$(_build_body vertexai "$@")" } -# @cmd Chat with cloudflare api -# @env CLOUDFLARE_API_KEY! -# @option -m --model=@cf/meta/llama-3-8b-instruct $CLOUDFLARE_MODEL -# @flag -S --no-stream -# @arg text~ -chat-cloudflare() { - url="https://api.cloudflare.com/client/v4/accounts/$CLOUDFLARE_ACCOUNT_ID/ai/run/$argc_model" - _wrapper curl -i "$url" \ --X POST \ --H "Authorization: Bearer $CLOUDFLARE_API_KEY" \ --d "$(_build_body cloudflare "$@")" -} - # @cmd Chat with replicate api # @env REPLICATE_API_KEY! # @option -m --model=meta/meta-llama-3-8b-instruct $REPLICATE_MODEL @@ -336,7 +324,6 @@ chat-ernie() { -d "$(_build_body ernie "$@")" } - _argc_before() { stream="true" if [[ -n "$argc_no_stream" ]]; then @@ -344,7 +331,7 @@ _argc_before() { fi } -_openai_models() { +_retrieve_models() { api_base="${api_base:-"$argc_api_base"}" api_key="${api_key:-"$argc_api_key"}" _wrapper curl "$api_base/models" \ @@ -352,6 +339,17 @@ _openai_models() { } +_retrieve_api_base() { + api_base="${platform_config##*,}" + if [[ -z "$api_base" ]]; then + key="$(echo $argc_platform | tr '[:lower:]' '[:upper:]')_API_BASE" + api_base="${!key}" + if [[ -z "$api_base" ]]; then + _die "Miss api_base for $argc_platform; please set $key" + fi + fi +} + _choice_model() { aichat --list-models } @@ -436,7 +434,7 @@ _build_body() { "safetySettings":[{"category":"HARM_CATEGORY_HARASSMENT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":"BLOCK_ONLY_HIGH"},{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":"BLOCK_ONLY_HIGH"}] }' ;; - ernie|cloudflare) + ernie) echo '{ "messages": [ { diff --git a/config.example.yaml b/config.example.yaml index f3a550bd..f415d38b 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -223,10 +223,10 @@ clients: region: xxx # See https://developers.cloudflare.com/workers-ai/ - - type: cloudflare - account_id: xxx + - type: openai-compatible + name: cloudflare + api_base: https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/v1 api_key: xxx - api_base: https://api.cloudflare.com/client/v4 # Optional # See https://replicate.com/docs - type: replicate diff --git a/models.yaml b/models.yaml index 8ef8e7b6..3b305838 100644 --- a/models.yaml +++ b/models.yaml @@ -520,6 +520,7 @@ # Links: # - https://developers.cloudflare.com/workers-ai/models/ +# - https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility/ - platform: cloudflare models: - name: '@cf/meta/llama-3.1-8b-instruct' diff --git a/src/client/cloudflare.rs b/src/client/cloudflare.rs deleted file mode 100644 index 3626c730..00000000 --- a/src/client/cloudflare.rs +++ /dev/null @@ -1,181 +0,0 @@ -use super::*; - -use anyhow::{anyhow, Context, Result}; -use reqwest::RequestBuilder; -use serde::Deserialize; -use serde_json::{json, Value}; - -const API_BASE: &str = "https://api.cloudflare.com/client/v4"; - -#[derive(Debug, Clone, Deserialize, Default)] -pub struct CloudflareConfig { - pub name: Option, - pub account_id: Option, - pub api_base: Option, - pub api_key: Option, - #[serde(default)] - pub models: Vec, - pub patch: Option, - pub extra: Option, -} - -impl CloudflareClient { - config_get_fn!(account_id, get_account_id); - config_get_fn!(api_key, get_api_key); - config_get_fn!(api_base, get_api_base); - - pub const PROMPTS: [PromptAction<'static>; 2] = [ - ("account_id", "Account ID:", true, PromptKind::String), - ("api_key", "API Key:", true, PromptKind::String), - ]; -} - -impl_client_trait!( - CloudflareClient, - ( - prepare_chat_completions, - chat_completions, - chat_completions_streaming - ), - (prepare_embeddings, embeddings), - (noop_prepare_rerank, noop_rerank), -); - -fn prepare_chat_completions( - self_: &CloudflareClient, - data: ChatCompletionsData, -) -> Result { - let account_id = self_.get_account_id()?; - let api_key = self_.get_api_key()?; - let api_base = self_ - .get_api_base() - .unwrap_or_else(|_| API_BASE.to_string()); - - let url = format!( - "{}/accounts/{account_id}/ai/run/{}", - api_base.trim_end_matches('/'), - self_.model.name() - ); - - let body = build_chat_completions_body(data, &self_.model)?; - - let mut request_data = RequestData::new(url, body); - - request_data.bearer_auth(api_key); - - Ok(request_data) -} - -fn prepare_embeddings(self_: &CloudflareClient, data: EmbeddingsData) -> Result { - let account_id = self_.get_account_id()?; - let api_key = self_.get_api_key()?; - - let url = format!( - "{API_BASE}/accounts/{account_id}/ai/run/{}", - self_.model.name() - ); - - let body = json!({ - "text": data.texts, - }); - - let mut request_data = RequestData::new(url, body); - - request_data.bearer_auth(api_key); - - Ok(request_data) -} - -async fn chat_completions( - builder: RequestBuilder, - _model: &Model, -) -> Result { - let res = builder.send().await?; - let status = res.status(); - let data: Value = res.json().await?; - if !status.is_success() { - catch_error(&data, status.as_u16())?; - } - - debug!("non-stream-data: {data}"); - extract_chat_completions(&data) -} - -async fn chat_completions_streaming( - builder: RequestBuilder, - handler: &mut SseHandler, - _model: &Model, -) -> Result<()> { - let handle = |message: SseMmessage| -> Result { - if message.data == "[DONE]" { - return Ok(true); - } - let data: Value = serde_json::from_str(&message.data)?; - debug!("stream-data: {data}"); - if let Some(text) = data["response"].as_str() { - handler.text(text)?; - } - Ok(false) - }; - sse_stream(builder, handle).await -} - -async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { - let res = builder.send().await?; - let status = res.status(); - let data: Value = res.json().await?; - if !status.is_success() { - catch_error(&data, status.as_u16())?; - } - let res_body: EmbeddingsResBody = - serde_json::from_value(data).context("Invalid embeddings data")?; - Ok(res_body.result.data) -} - -#[derive(Deserialize)] -struct EmbeddingsResBody { - result: EmbeddingsResBodyResult, -} - -#[derive(Deserialize)] -struct EmbeddingsResBodyResult { - data: Vec>, -} - -fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Result { - let ChatCompletionsData { - messages, - temperature, - top_p, - functions: _, - stream, - } = data; - - let mut body = json!({ - "model": &model.name(), - "messages": messages, - }); - - if let Some(v) = model.max_tokens_param() { - body["max_tokens"] = v.into(); - } - if let Some(v) = temperature { - body["temperature"] = v.into(); - } - if let Some(v) = top_p { - body["top_p"] = v.into(); - } - if stream { - body["stream"] = true.into(); - } - - Ok(body) -} - -fn extract_chat_completions(data: &Value) -> Result { - let text = data["result"]["response"] - .as_str() - .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; - - Ok(ChatCompletionsOutput::new(text)) -} diff --git a/src/client/common.rs b/src/client/common.rs index 684d10e0..9d6480a7 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -364,7 +364,7 @@ pub fn create_config(prompts: &[PromptAction], client: &str) -> Result<(String, pub fn create_openai_compatible_client_config(client: &str) -> Result> { match super::OPENAI_COMPATIBLE_PLATFORMS - .iter() + .into_iter() .find(|(name, _)| client == *name) { None => Ok(None), @@ -372,13 +372,16 @@ pub fn create_openai_compatible_client_config(client: &str) -> Result Result Result Result { +async fn embeddings(builder: RequestBuilder, _model: &Model) -> Result { let res = builder.send().await?; let status = res.status(); let data: Value = res.json().await?; diff --git a/src/client/mod.rs b/src/client/mod.rs index 8ab7c783..ecafd3cb 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -33,13 +33,13 @@ register_client!( ), (vertexai, "vertexai", VertexAIConfig, VertexAIClient), (bedrock, "bedrock", BedrockConfig, BedrockClient), - (cloudflare, "cloudflare", CloudflareConfig, CloudflareClient), (replicate, "replicate", ReplicateConfig, ReplicateClient), (ernie, "ernie", ErnieConfig, ErnieClient), ); -pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 18] = [ +pub const OPENAI_COMPATIBLE_PLATFORMS: [(&str, &str); 19] = [ ("ai21", "https://api.ai21.com/studio/v1"), + ("cloudflare", ""), ("deepinfra", "https://api.deepinfra.com/v1/openai"), ("deepseek", "https://api.deepseek.com"), ("fireworks", "https://api.fireworks.ai/inference/v1"),