Skip to content

Commit

Permalink
finish the implementation of stats module
Browse files Browse the repository at this point in the history
  • Loading branch information
unixzii authored and ktiays committed Mar 5, 2023
1 parent 7779fb4 commit 52f3411
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 63 deletions.
95 changes: 38 additions & 57 deletions src/modules/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
mod braille;
mod openai_client;
mod session;
mod session_mgr;

use std::error::Error;
use std::time::Duration;

use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestMessageArgs,
CreateChatCompletionRequestArgs, Role,
};
use async_openai::types::{ChatCompletionRequestMessageArgs, Role};
use async_openai::Client as OpenAIClient;
use teloxide::dispatching::DpHandlerDescription;
use teloxide::prelude::*;
use teloxide::types::{BotCommand, InlineKeyboardButton, InlineKeyboardMarkup, Me};

use crate::module_mgr::Module;
use crate::utils::dptree_ext;
use crate::StatsManager;
use crate::{noop_handler, HandlerResult};
pub(crate) use session::Session;
pub(crate) use session_mgr::SessionManager;
Expand All @@ -29,6 +28,7 @@ async fn handle_chat_message(
text: MessageText,
chat_id: ChatId,
session_mgr: SessionManager,
stats_mgr: StatsManager,
openai_client: OpenAIClient,
) -> bool {
let mut text = text.0;
Expand All @@ -50,8 +50,16 @@ async fn handle_chat_message(
}
text = text.trim().to_owned();

match actually_handle_chat_message(bot, Some(msg), text, chat_id, session_mgr, openai_client)
.await
match actually_handle_chat_message(
bot,
Some(msg),
text,
chat_id,
session_mgr,
stats_mgr,
openai_client,
)
.await
{
Err(err) => {
error!("Failed to handle chat message: {}", err);
Expand All @@ -66,6 +74,7 @@ async fn handle_retry_action(
bot: Bot,
query: CallbackQuery,
session_mgr: SessionManager,
stats_mgr: StatsManager,
openai_client: OpenAIClient,
) -> bool {
if !query.data.map(|data| data == "/retry").unwrap_or(false) {
Expand Down Expand Up @@ -100,6 +109,7 @@ async fn handle_retry_action(
last_message.content,
chat_id,
session_mgr,
stats_mgr,
openai_client,
)
.await
Expand All @@ -119,12 +129,13 @@ async fn actually_handle_chat_message(
content: String,
chat_id: String,
session_mgr: SessionManager,
stats_mgr: StatsManager,
openai_client: OpenAIClient,
) -> HandlerResult {
// Send a progress indicator message first.
let progress_bar = braille::BrailleProgress::new(1, 1, 3, Some("Thinking... 🤔".to_owned()));
let mut send_progress_msg = bot.send_message(chat_id.clone(), progress_bar.current_string());
send_progress_msg.reply_to_message_id = reply_to_msg.map(|m| m.id);
send_progress_msg.reply_to_message_id = reply_to_msg.as_ref().map(|m| m.id);
let sent_progress_msg = send_progress_msg.await?;

// Construct the request messages.
Expand All @@ -150,7 +161,7 @@ async fn actually_handle_chat_message(
).await;
}
} => { unreachable!() },
reply_result = request_chat_model(&openai_client, msgs) => {
reply_result = openai_client::request_chat_model(&openai_client, msgs) => {
reply_result.map_err(|err| anyhow!("API error: {}", err))
},
_ = tokio::time::sleep(Duration::from_secs(30)) => {
Expand All @@ -160,18 +171,29 @@ async fn actually_handle_chat_message(

// Reply to the user and add to history.
let reply_result = match req_result {
Ok(text) => {
Ok(res) => {
let reply_text = res.message.content;
session_mgr.add_message_to_session(chat_id.clone(), user_msg);
session_mgr.add_message_to_session(
chat_id.clone(),
ChatCompletionRequestMessageArgs::default()
.role(Role::Assistant)
.content(text.clone())
.content(reply_text.clone())
.build()
.unwrap(),
);
// TODO: maybe we need to handle the case that `reply_to_msg` is `None`.
if let Some(from_username) = reply_to_msg
.as_ref()
.and_then(|m| m.from())
.and_then(|u| u.username.as_ref())
{
stats_mgr
.add_usage(from_username.to_owned(), res.token_usage as _)
.await;
}
// TODO: add retry for edit failures.
bot.edit_message_text(chat_id, sent_progress_msg.id, text)
bot.edit_message_text(chat_id, sent_progress_msg.id, reply_text)
.await
}
Err(err) => {
Expand Down Expand Up @@ -202,55 +224,12 @@ async fn reset_session(bot: Bot, chat_id: ChatId, session_mgr: SessionManager) -
Ok(())
}

async fn request_chat_model(
client: &OpenAIClient,
msgs: Vec<ChatCompletionRequestMessage>,
) -> Result<String, Box<dyn Error>> {
let req = CreateChatCompletionRequestArgs::default()
.model("gpt-3.5-turbo")
.temperature(0.6)
.messages(msgs)
.build()?;

let resp = client.chat().create(req).await?;
let mut choices = resp.choices;

if choices.is_empty() {
// TODO: use `Err()` to indicate a server error.
return Ok("".to_owned());
}

Ok(choices.remove(0).message.content)
}

fn filter_command(cmd: &str) -> impl Fn(Me, MessageText) -> bool {
let pat = format!("/{}", cmd);
move |me, text| {
if !text.0.starts_with(&pat) {
return false;
}

// When sending commands in a group, a mention suffix may be attached to
// the text. For example: "/reset@xxxx_bot".
let rest = &text.0[pat.len()..];
if rest.len() > 1 {
return me
.username
.as_ref()
.map(|n| n == &rest[1..])
.unwrap_or(false);
}

true
}
}

pub(crate) struct Chat;

impl Module for Chat {
fn register_dependency(&mut self, dep_map: &mut DependencyMap) {
dep_map.insert(SessionManager::new());
dep_map.insert(OpenAIClient::new());
dep_map.insert(openai_client::new_client());
}

fn handler_chain(
Expand All @@ -261,7 +240,9 @@ impl Module for Chat {
Update::filter_message()
.filter_map(|msg: Message| msg.text().map(|text| MessageText(text.to_owned())))
.map(|msg: Message| msg.chat.id)
.branch(dptree::filter(filter_command("reset")).endpoint(reset_session))
.branch(
dptree::filter(dptree_ext::command_filter("reset")).endpoint(reset_session),
)
.branch(dptree::filter_async(handle_chat_message).endpoint(noop_handler)),
)
.branch(
Expand Down
37 changes: 37 additions & 0 deletions src/modules/chat/openai_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use anyhow::Error;
use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionResponseMessage, CreateChatCompletionRequestArgs,
};
use async_openai::Client as OpenAIClient;

pub(crate) struct ChatModelResult {
pub message: ChatCompletionResponseMessage,
pub token_usage: u32,
}

pub(crate) fn new_client() -> OpenAIClient {
OpenAIClient::new()
}

pub(crate) async fn request_chat_model(
client: &OpenAIClient,
msgs: Vec<ChatCompletionRequestMessage>,
) -> Result<ChatModelResult, Error> {
let req = CreateChatCompletionRequestArgs::default()
.model("gpt-3.5-turbo")
.temperature(0.6)
.messages(msgs)
.build()?;

let resp = client.chat().create(req).await?;
let mut choices = resp.choices;

if choices.is_empty() {
return Err(anyhow!("Server responds with empty data"));
}

Ok(ChatModelResult {
message: choices.remove(0).message,
token_usage: resp.usage.map(|u| u.total_tokens).unwrap_or(0),
})
}
26 changes: 25 additions & 1 deletion src/modules/stats/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
pub(crate) mod db_provider;
mod stats_mgr;

use std::fmt::Write;

use teloxide::dispatching::DpHandlerDescription;
use teloxide::prelude::*;
use teloxide::types::BotCommand;

use crate::module_mgr::Module;
use crate::utils::dptree_ext;
use crate::HandlerResult;
pub(crate) use db_provider::DatabaseProvider;
pub(crate) use stats_mgr::StatsManager;
Expand All @@ -22,6 +25,23 @@ impl Stats {
}
}

async fn handle_show_stats(bot: Bot, msg: Message, stats_mgr: StatsManager) -> HandlerResult {
let mut reply_text = String::new();
if let Some(from_username) = msg.from().and_then(|u| u.username.as_ref()) {
let user_usage = stats_mgr.query_usage(Some(from_username.to_owned())).await;
write!(&mut reply_text, "Your token usage: {}\n", user_usage)?;
}
let total_usage = stats_mgr.query_usage(None).await;
write!(&mut reply_text, "Total token usage: {}", total_usage)?;

bot.send_message(msg.chat.id, reply_text)
.reply_to_message_id(msg.id)
.send()
.await?;

Ok(())
}

impl Module for Stats {
fn register_dependency(&mut self, dep_map: &mut DependencyMap) {
dep_map.insert(self.stats_mgr.take().unwrap());
Expand All @@ -30,7 +50,11 @@ impl Module for Stats {
fn handler_chain(
&self,
) -> Handler<'static, DependencyMap, HandlerResult, DpHandlerDescription> {
dptree::entry()
dptree::entry().branch(
Update::filter_message()
.filter(dptree_ext::command_filter("stats"))
.endpoint(handle_show_stats),
)
}

fn commands(&self) -> Vec<BotCommand> {
Expand Down
Loading

0 comments on commit 52f3411

Please sign in to comment.