From 6b5a3e213c973bbea4139466bf3b7e831331c87c Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 4 Sep 2024 19:09:33 +0800 Subject: [PATCH] fix: render stream failed due to read cursor position timeout --- src/client/common.rs | 14 +++++++------- src/client/stream.rs | 31 +++++++++++++++---------------- src/render/mod.rs | 5 +++-- src/render/stream.rs | 12 +++++++++++- src/serve.rs | 2 +- 5 files changed, 37 insertions(+), 27 deletions(-) diff --git a/src/client/common.rs b/src/client/common.rs index 9d6480a7..9232a2be 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -3,7 +3,7 @@ use super::*; use crate::{ config::{GlobalConfig, Input}, function::{eval_tool_calls, FunctionDeclaration, ToolCall, ToolResult}, - render::{render_error, render_stream}, + render::render_stream, utils::*, }; @@ -84,11 +84,11 @@ pub trait Client: Sync + Send { let data = input.prepare_completion_data(self.model(), true)?; self.chat_completions_streaming_inner(&client, handler, data).await } => { - handler.done()?; + handler.done(); ret.with_context(|| "Failed to call chat-completions api") } _ = watch_abort_signal(abort_signal) => { - handler.done()?; + handler.done(); Ok(()) }, } @@ -433,9 +433,9 @@ pub async fn call_chat_completions_streaming( client.chat_completions_streaming(input, &mut handler), render_stream(rx, config, abort.clone()), ); - if let Err(err) = render_ret { - render_error(err, config.read().highlight); - } + + render_ret?; + let (text, tool_calls) = handler.take(); match send_ret { Ok(_) => { @@ -465,7 +465,7 @@ where { let text = f(builder).await?; handler.text(&text)?; - handler.done()?; + handler.done(); Ok(()) } diff --git a/src/client/stream.rs b/src/client/stream.rs index 9913f63d..907efee0 100644 --- a/src/client/stream.rs +++ b/src/client/stream.rs @@ -34,19 +34,25 @@ impl SseHandler { let ret = self .sender .send(SseEvent::Text(text.to_string())) - .with_context(|| "Failed to send ReplyEvent:Text"); - self.safe_ret(ret)?; + .with_context(|| "Failed to send SseEvent:Text"); + if let Err(err) = ret { + if self.abort.aborted() { + return Ok(()); + } + return Err(err); + } Ok(()) } - pub fn done(&mut self) -> Result<()> { + pub fn done(&mut self) { // debug!("HandleDone"); - let ret = self - .sender - .send(SseEvent::Done) - .with_context(|| "Failed to send ReplyEvent::Done"); - self.safe_ret(ret)?; - Ok(()) + let ret = self.sender.send(SseEvent::Done); + if ret.is_err() { + if self.abort.aborted() { + return; + } + warn!("Failed to send SseEvent:Done"); + } } pub fn tool_call(&mut self, call: ToolCall) -> Result<()> { @@ -65,13 +71,6 @@ impl SseHandler { } = self; (buffer, tool_calls) } - - fn safe_ret(&self, ret: Result<()>) -> Result<()> { - if self.abort.aborted() { - return Ok(()); - } - ret - } } #[derive(Debug)] diff --git a/src/render/mod.rs b/src/render/mod.rs index ce7191f6..2c87e789 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -15,13 +15,14 @@ pub async fn render_stream( config: &GlobalConfig, abort: AbortSignal, ) -> Result<()> { - if *IS_STDOUT_TERMINAL { + let ret = if *IS_STDOUT_TERMINAL { let render_options = config.read().render_options()?; let mut render = MarkdownRender::init(render_options)?; markdown_stream(rx, &mut render, &abort).await } else { raw_stream(rx, &abort).await - } + }; + ret.map_err(|err| err.context("Failed to reader stream")) } pub fn render_error(err: anyhow::Error, highlight: bool) { diff --git a/src/render/stream.rs b/src/render/stream.rs index 042bfe53..2264c167 100644 --- a/src/render/stream.rs +++ b/src/render/stream.rs @@ -28,6 +28,9 @@ pub async fn markdown_stream( disable_raw_mode()?; + if ret.is_err() { + println!(); + } ret } @@ -78,7 +81,14 @@ async fn markdown_stream_inner( // tab width hacking text = text.replace('\t', " "); - let (col, mut row) = cursor::position()?; + let mut attempts = 0; + let (col, mut row) = loop { + match cursor::position() { + Ok(pos) => break pos, + Err(_) if attempts < 3 => attempts += 1, + Err(e) => return Err(e.into()), + } + }; // Fix unexpected duplicate lines on kitty, see https://github.com/sigoden/aichat/issues/105 if col == 0 && row > 0 && display_width(&buffer) == columns as usize { diff --git a/src/serve.rs b/src/serve.rs index 95b14029..dc16d845 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -358,7 +358,7 @@ impl Server { is_first.store(false, Ordering::SeqCst) } } - let _ = handler.done(); + handler.done(); } tokio::join!( map_event(sse_rx, &tx, is_first.clone()),