Skip to content

Commit

Permalink
fix(ws): 修复websocket的数据协议问题
Browse files Browse the repository at this point in the history
  • Loading branch information
liuhuapiaoyuan committed Feb 19, 2024
1 parent 23bd01a commit eb75473
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
7 changes: 7 additions & 0 deletions crates/openai/src/serve/proxy/toapi/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ pub struct Delta<'a> {

#[derive(Deserialize, Default , Clone)]
pub struct WSStreamData {
pub data: Option<WSStreamDataBody>,
#[serde(rename = "type")]
pub msg_type: String,
}

#[derive(Deserialize, Default , Clone)]
pub struct WSStreamDataBody {
pub body: String,
pub conversation_id: String,
pub more_body: bool,
Expand Down
31 changes: 21 additions & 10 deletions crates/openai/src/serve/proxy/toapi/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use crate::serve::error::{ProxyError, ResponseError};
use crate::serve::ProxyResult;
use crate::warn;
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL};

use super::model;

Expand Down Expand Up @@ -115,16 +117,20 @@ fn from_tungstenite(message: Message) -> String {
match message {
Message::Text(text) => {
let data = serde_json::from_str::<model::WSStreamData>(&text).unwrap();
let body = data.body;
let decoded = general_purpose::STANDARD.decode(&body).unwrap();
let result_data = String::from_utf8(decoded).unwrap() ;
if result_data.starts_with("data: ") {
let data_index = result_data.find("data: ").unwrap() + 6;
let data_end_index = result_data.find("\n\n").unwrap();
let data_str = result_data[data_index..data_end_index].to_string();
return data_str ;
if data.msg_type.eq("message"){
let body = data.data.unwrap().body;
let decoded = general_purpose::STANDARD.decode(&body).unwrap();
let result_data = String::from_utf8(decoded).unwrap() ;
if result_data.starts_with("data: ") {
let data_index = result_data.find("data: ").unwrap() + 6;
let data_end_index = result_data.find("\n\n").unwrap();
let data_str = result_data[data_index..data_end_index].to_string();
return data_str ;
}
return result_data ;

}
return result_data ;
return "".to_owned()

},
Message::Binary(_binary) => "".to_owned(),
Expand All @@ -149,7 +155,12 @@ pub(super) async fn ws_stream_handler(
) -> Result<impl Stream<Item = Result<Event, Infallible>>, ResponseError> {
let id = super::generate_id(29);
let timestamp = super::current_timestamp()?;
let (ws_stream, _) = connect_async(socket_url.clone()).await.expect( format!("Failed to connect to {}", socket_url.clone()).as_str());

let mut request = socket_url.into_client_request().unwrap();
request.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("json.reliable.webpubsub.azure.v1")); // Or other modifications
let (ws_stream, _) = connect_async(request)
.await
.expect("Failed to connect");

let (mut _write, mut read) = ws_stream.split();

Expand Down
15 changes: 11 additions & 4 deletions crates/openai/src/serve/router/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ use axum_extra::extract::CookieJar;

use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message , CloseFrame};
use futures_util::{stream::StreamExt, sink::SinkExt};
use std::net::SocketAddr;
use tokio_tungstenite::connect_async;

use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use std::net::SocketAddr;
use tokio_tungstenite::tungstenite::http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL};

use serde_json::{json, Value};
use std::collections::HashMap;
Expand Down Expand Up @@ -173,7 +174,12 @@ async fn proxy_ws(
}
async fn handle_socket(socket: WebSocket , host:String, access_token: String) {
let base_url = format!("wss://{}/client/hubs/conversations?access_token={}" ,host, access_token) ;
let (target_ws, _) = connect_async(base_url.clone()).await.expect( format!("Failed to connect to {}", base_url.clone().as_str() ).as_str());

let mut request = base_url.into_client_request().unwrap();
request.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("json.reliable.webpubsub.azure.v1")); // Or other modifications
let (target_ws, _) = connect_async(request)
.await
.expect("Failed to connect");
let (mut client_sender, mut client_receiver) = socket.split();
let (mut server_sender, mut server_receiver) = target_ws.split();
let server_to_client = async move {
Expand Down Expand Up @@ -207,7 +213,8 @@ fn into_tungstenite(msg:Message) -> ts::Message {

fn from_tungstenite(message: ts::Message) -> Option<Message> {
match message {
ts::Message::Text(text) => Some(Message::Text( r#"{"type":"message","from":"server","dataType":"json","data":"#.to_string() + text.as_str() + "}")),
//ts::Message::Text(text) => Some(Message::Text( r#"{"type":"message","from":"server","dataType":"json","data":"#.to_string() + text.as_str() + "}")),
ts::Message::Text(text) => Some(Message::Text(text)),
ts::Message::Binary(binary) => Some(Message::Binary(binary)),
ts::Message::Ping(ping) => Some(Message::Ping(ping)),
ts::Message::Pong(pong) => Some(Message::Pong(pong)),
Expand Down

0 comments on commit eb75473

Please sign in to comment.