Skip to content

Commit

Permalink
refactor: extract json stream handling (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Apr 10, 2024
1 parent 5915bc2 commit a0bd6e1
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 101 deletions.
60 changes: 9 additions & 51 deletions src/client/cohere.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use super::{
message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model, PromptType,
SendData, TokensCountFactors,
json_stream, message::*, patch_system_message, Client, CohereClient, ExtraConfig, Model,
PromptType, SendData, TokensCountFactors,
};

use crate::{render::ReplyHandler, utils::PromptKind};

use anyhow::{bail, Result};
use async_trait::async_trait;
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
Expand Down Expand Up @@ -107,55 +106,14 @@ pub(crate) async fn send_message_streaming(
let data: Value = res.json().await?;
check_error(&data)?;
} else {
let mut buffer = vec![];
let mut cursor = 0;
let mut start = 0;
let mut balances = vec![];
let mut quoting = false;
let mut stream = res.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk = std::str::from_utf8(&chunk)?;
buffer.extend(chunk.chars());
for i in cursor..buffer.len() {
let ch = buffer[i];
if quoting {
if ch == '"' && buffer[i - 1] != '\\' {
quoting = false;
}
continue;
}
match ch {
'"' => quoting = true,
'{' => {
if balances.is_empty() {
start = i;
}
balances.push(ch);
}
'[' => {
if start != 0 {
balances.push(ch);
}
}
'}' => {
balances.pop();
if balances.is_empty() {
let value: String = buffer[start..=i].iter().collect();
let value: Value = serde_json::from_str(&value)?;
if let Some("text-generation") = value["event_type"].as_str() {
handler.text(extract_text(&value)?)?;
}
}
}
']' => {
balances.pop();
}
_ => {}
}
let handle = |value: &str| -> Result<()> {
let value: Value = serde_json::from_str(value)?;
if let Some("text-generation") = value["event_type"].as_str() {
handler.text(extract_text(&value)?)?;
}
cursor = buffer.len();
}
Ok(())
};
json_stream(res.bytes_stream(), handle).await?;
}
Ok(())
}
Expand Down
63 changes: 63 additions & 0 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::{

use anyhow::{Context, Result};
use async_trait::async_trait;
use futures_util::{Stream, StreamExt};
use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
Expand Down Expand Up @@ -365,6 +366,68 @@ pub fn patch_system_message(messages: &mut Vec<Message>) {
}
}

pub async fn json_stream<S, F>(mut stream: S, mut handle: F) -> Result<()>
where
S: Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
F: FnMut(&str) -> Result<()>,
{
let mut buffer = vec![];
let mut cursor = 0;
let mut start = 0;
let mut balances = vec![];
let mut quoting = false;
let mut escape = false;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk = std::str::from_utf8(&chunk)?;
buffer.extend(chunk.chars());
for i in cursor..buffer.len() {
let ch = buffer[i];
if quoting {
if ch == '\\' {
escape = !escape;
} else {
if !escape && ch == '"' {
quoting = false;
}
escape = false;
}
continue;
}
match ch {
'"' => {
quoting = true;
escape = false;
}
'{' => {
if balances.is_empty() {
start = i;
}
balances.push(ch);
}
'[' => {
if start != 0 {
balances.push(ch);
}
}
'}' => {
balances.pop();
if balances.is_empty() {
let value: String = buffer[start..=i].iter().collect();
handle(&value)?;
}
}
']' => {
balances.pop();
}
_ => {}
}
}
cursor = buffer.len();
}
Ok(())
}

fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) {
let segs: Vec<&str> = path.split('.').collect();
match segs.as_slice() {
Expand Down
58 changes: 8 additions & 50 deletions src/client/vertexai.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use super::{
message::*, patch_system_message, Client, ExtraConfig, Model, PromptType, SendData,
TokensCountFactors, VertexAIClient,
json_stream, message::*, patch_system_message, Client, ExtraConfig, Model, PromptType,
SendData, TokensCountFactors, VertexAIClient,
};

use crate::{render::ReplyHandler, utils::PromptKind};

use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use chrono::{Duration, Utc};
use futures_util::StreamExt;
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::{json, Value};
Expand Down Expand Up @@ -136,53 +135,12 @@ pub(crate) async fn send_message_streaming(
let data: Value = res.json().await?;
check_error(&data)?;
} else {
let mut buffer = vec![];
let mut cursor = 0;
let mut start = 0;
let mut balances = vec![];
let mut quoting = false;
let mut stream = res.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk = std::str::from_utf8(&chunk)?;
buffer.extend(chunk.chars());
for i in cursor..buffer.len() {
let ch = buffer[i];
if quoting {
if ch == '"' && buffer[i - 1] != '\\' {
quoting = false;
}
continue;
}
match ch {
'"' => quoting = true,
'{' => {
if balances.is_empty() {
start = i;
}
balances.push(ch);
}
'[' => {
if start != 0 {
balances.push(ch);
}
}
'}' => {
balances.pop();
if balances.is_empty() {
let value: String = buffer[start..=i].iter().collect();
let value: Value = serde_json::from_str(&value)?;
handler.text(extract_text(&value)?)?;
}
}
']' => {
balances.pop();
}
_ => {}
}
}
cursor = buffer.len();
}
let handle = |value: &str| -> Result<()> {
let value: Value = serde_json::from_str(value)?;
handler.text(extract_text(&value)?)?;
Ok(())
};
json_stream(res.bytes_stream(), handle).await?;
}
Ok(())
}
Expand Down

0 comments on commit a0bd6e1

Please sign in to comment.