diff --git a/crates/openai/src/serve/proxy/toapi/model.rs b/crates/openai/src/serve/proxy/toapi/model.rs index 80b10167d..87a57ffb9 100644 --- a/crates/openai/src/serve/proxy/toapi/model.rs +++ b/crates/openai/src/serve/proxy/toapi/model.rs @@ -64,6 +64,13 @@ pub struct Delta<'a> { #[derive(Deserialize, Default , Clone)] pub struct WSStreamData { + pub data: Option, + #[serde(rename = "type")] + pub msg_type: String, +} + +#[derive(Deserialize, Default , Clone)] +pub struct WSStreamDataBody { pub body: String, pub conversation_id: String, pub more_body: bool, diff --git a/crates/openai/src/serve/proxy/toapi/stream.rs b/crates/openai/src/serve/proxy/toapi/stream.rs index 09216df4d..814e508d5 100644 --- a/crates/openai/src/serve/proxy/toapi/stream.rs +++ b/crates/openai/src/serve/proxy/toapi/stream.rs @@ -15,6 +15,8 @@ use crate::serve::error::{ProxyError, ResponseError}; use crate::serve::ProxyResult; use crate::warn; use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}; use super::model; @@ -115,16 +117,20 @@ fn from_tungstenite(message: Message) -> String { match message { Message::Text(text) => { let data = serde_json::from_str::(&text).unwrap(); - let body = data.body; - let decoded = general_purpose::STANDARD.decode(&body).unwrap(); - let result_data = String::from_utf8(decoded).unwrap() ; - if result_data.starts_with("data: ") { - let data_index = result_data.find("data: ").unwrap() + 6; - let data_end_index = result_data.find("\n\n").unwrap(); - let data_str = result_data[data_index..data_end_index].to_string(); - return data_str ; + if data.msg_type.eq("message"){ + let body = data.data.unwrap().body; + let decoded = general_purpose::STANDARD.decode(&body).unwrap(); + let result_data = String::from_utf8(decoded).unwrap() ; + if result_data.starts_with("data: ") { + let data_index = result_data.find("data: ").unwrap() + 6; + let data_end_index = result_data.find("\n\n").unwrap(); + let data_str = result_data[data_index..data_end_index].to_string(); + return data_str ; + } + return result_data ; + } - return result_data ; + return "".to_owned() }, Message::Binary(_binary) => "".to_owned(), @@ -149,7 +155,12 @@ pub(super) async fn ws_stream_handler( ) -> Result>, ResponseError> { let id = super::generate_id(29); let timestamp = super::current_timestamp()?; - let (ws_stream, _) = connect_async(socket_url.clone()).await.expect( format!("Failed to connect to {}", socket_url.clone()).as_str()); + + let mut request = socket_url.into_client_request().unwrap(); + request.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("json.reliable.webpubsub.azure.v1")); // Or other modifications + let (ws_stream, _) = connect_async(request) + .await + .expect("Failed to connect"); let (mut _write, mut read) = ws_stream.split(); diff --git a/crates/openai/src/serve/router/chat/mod.rs b/crates/openai/src/serve/router/chat/mod.rs index 827d17213..5ab98f2f3 100644 --- a/crates/openai/src/serve/router/chat/mod.rs +++ b/crates/openai/src/serve/router/chat/mod.rs @@ -29,9 +29,10 @@ use axum_extra::extract::CookieJar; use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message , CloseFrame}; use futures_util::{stream::StreamExt, sink::SinkExt}; -use std::net::SocketAddr; use tokio_tungstenite::connect_async; - +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use std::net::SocketAddr; +use tokio_tungstenite::tungstenite::http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}; use serde_json::{json, Value}; use std::collections::HashMap; @@ -173,7 +174,12 @@ async fn proxy_ws( } async fn handle_socket(socket: WebSocket , host:String, access_token: String) { let base_url = format!("wss://{}/client/hubs/conversations?access_token={}" ,host, access_token) ; - let (target_ws, _) = connect_async(base_url.clone()).await.expect( format!("Failed to connect to {}", base_url.clone().as_str() ).as_str()); + + let mut request = base_url.into_client_request().unwrap(); + request.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("json.reliable.webpubsub.azure.v1")); // Or other modifications + let (target_ws, _) = connect_async(request) + .await + .expect("Failed to connect"); let (mut client_sender, mut client_receiver) = socket.split(); let (mut server_sender, mut server_receiver) = target_ws.split(); let server_to_client = async move { @@ -207,7 +213,8 @@ fn into_tungstenite(msg:Message) -> ts::Message { fn from_tungstenite(message: ts::Message) -> Option { match message { - ts::Message::Text(text) => Some(Message::Text( r#"{"type":"message","from":"server","dataType":"json","data":"#.to_string() + text.as_str() + "}")), + //ts::Message::Text(text) => Some(Message::Text( r#"{"type":"message","from":"server","dataType":"json","data":"#.to_string() + text.as_str() + "}")), + ts::Message::Text(text) => Some(Message::Text(text)), ts::Message::Binary(binary) => Some(Message::Binary(binary)), ts::Message::Ping(ping) => Some(Message::Ping(ping)), ts::Message::Pong(pong) => Some(Message::Pong(pong)),