Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for mcp tools with rig agents + add anthropic prompt caching #213

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
1,152 changes: 643 additions & 509 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ thiserror = "1.0.61"
rig-derive = { version = "0.1.0", path = "./rig-core-derive", optional = true }
glob = "0.3.1"
lopdf = { version = "0.34.0", optional = true }
rayon = { version = "1.10.0", optional = true }
rayon = { version = "1.10.0", optional = true}
mcp_client_rs = { git = "https://github.com/edisontim/mcp_client_rust", branch = "ref/cleanup", default-features = false }
worker = { version = "0.5", optional = true }
bytes = "1.9.0"
async-stream = "0.3.6"

[dev-dependencies]
anyhow = "1.0.75"
assert_fs = "1.1.2"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"
tokio = { version = "1.34.0", features = ["full"] }
tokio-test = "0.4.4"

[features]
Expand Down
108 changes: 97 additions & 11 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,21 @@
//! let response = agent.prompt("What does \"glarb-glarb\" mean?").await
//! .expect("Failed to prompt the agent");
//! ```
use std::collections::HashMap;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};

use futures::{stream, StreamExt, TryStreamExt};
use mcp_client_rs::{client::Client, CallToolResult, MessageContent};

use crate::{
completion::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder,
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError,
CompletionResponse, Document, Message, ModelChoice, Prompt, PromptError, ToolDefinition,
},
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
},
tool::{Tool, ToolSet},
tool::{Tool, ToolDyn, ToolError, ToolSet},
vector_store::{VectorStoreError, VectorStoreIndexDyn},
};

Expand All @@ -146,8 +147,10 @@ use crate::{
pub struct Agent<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// Cached preamble
cached_preamble: Option<Vec<String>>,
/// System prompt
preamble: String,
preamble: Vec<String>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (identified by their name)
Expand Down Expand Up @@ -239,6 +242,7 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
Ok(self
.model
.completion_request(prompt)
.cached_preamble(self.cached_preamble.clone())
.preamble(self.preamble.clone())
.messages(chat_history)
.documents([self.static_context.clone(), dynamic_context].concat())
Expand Down Expand Up @@ -300,8 +304,10 @@ impl<M: CompletionModel> Chat for Agent<M> {
pub struct AgentBuilder<M: CompletionModel> {
/// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
model: M,
/// Cached preamble
cached_preamble: Option<Vec<String>>,
/// System prompt
preamble: Option<String>,
preamble: Option<Vec<String>>,
/// Context documents always available to the agent
static_context: Vec<Document>,
/// Tools that are always available to the agent (by name)
Expand All @@ -324,6 +330,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
cached_preamble: None,
preamble: None,
static_context: vec![],
static_tools: vec![],
Expand All @@ -336,19 +343,25 @@ impl<M: CompletionModel> AgentBuilder<M> {
}
}

pub fn cached_preamble(mut self, cached_preamble: Vec<String>) -> Self {
self.cached_preamble = Some(cached_preamble);
self
}

/// Set the system prompt
pub fn preamble(mut self, preamble: &str) -> Self {
self.preamble = Some(preamble.into());
self.preamble = Some(vec![preamble.into()]);
self
}

/// Append to the preamble of the agent
pub fn append_preamble(mut self, doc: &str) -> Self {
self.preamble = Some(format!(
"{}\n{}",
self.preamble.unwrap_or_else(|| "".into()),
doc
));
self.preamble = if let Some(preamble) = self.preamble.as_mut() {
preamble.push(doc.into());
Some(preamble.to_vec())
} else {
Some(vec![doc.into()])
};
self
}

Expand Down Expand Up @@ -395,6 +408,13 @@ impl<M: CompletionModel> AgentBuilder<M> {
self
}

pub fn mcp_tool(mut self, tool: mcp_client_rs::Tool, client: Arc<Client>) -> Self {
let toolname = tool.name.clone();
self.tools.add_tool(MCPTool::from_mcp_server(tool, client));
self.static_tools.push(toolname);
self
}

/// Set the temperature of the model
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
Expand All @@ -417,6 +437,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
pub fn build(self) -> Agent<M> {
Agent {
model: self.model,
cached_preamble: self.cached_preamble,
preamble: self.preamble.unwrap_or_default(),
static_context: self.static_context,
static_tools: self.static_tools,
Expand Down Expand Up @@ -460,3 +481,68 @@ impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
.await
}
}
pub struct MCPTool {
client: Arc<Client>,
definition: mcp_client_rs::Tool,
}

impl MCPTool {
pub fn from_mcp_server(definition: mcp_client_rs::Tool, client: Arc<Client>) -> Self {
Self { client, definition }
}
}

#[derive(Debug, thiserror::Error)]
#[error("MCP tool error")]
pub struct MCPToolError(String);

impl ToolDyn for MCPTool {
fn name(&self) -> String {
self.definition.name.clone()
}

fn definition(
&self,
_prompt: String,
) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
Box::pin(async move {
ToolDefinition {
name: self.definition.name.clone(),
description: self.definition.description.clone(),
parameters: self.definition.input_schema.clone(),
}
})
}

fn call(
&self,
args: String,
) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
let client = self.client.clone();
let name = self.definition.name.clone();
let args_clone = args.clone();
let args: serde_json::Value = serde_json::from_str(&args_clone).unwrap_or_default();
Box::pin(async move {
let result: CallToolResult = client.call_tool(&name, args).await.map_err(|e| {
ToolError::ToolCallError(Box::new(MCPToolError(format!(
"Tool returned an error: {}",
e
))))
})?;
if result.is_error {
return Err(ToolError::ToolCallError(Box::new(MCPToolError(
"Tool returned an error".to_string(),
))));
}
Ok(result
.content
.into_iter()
.map(|c| match c {
MessageContent::Text { text } => text,
_ => "".to_string(),
})
.collect::<Vec<_>>()
.join(""))
})
}
}
20 changes: 16 additions & 4 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,10 @@ pub trait CompletionModel: Clone + Send + Sync {
pub struct CompletionRequest {
/// The prompt to be sent to the completion model provider
pub prompt: String,
/// The preamble to be sent to the completion model provider
pub preamble: Option<String>,
/// The preambles to be sent to the completion model provider
pub preamble: Option<Vec<String>>,
/// The cached preamble to be sent to the completion model provider
pub cached_preamble: Option<Vec<String>>,
/// The chat history to be sent to the completion model provider
pub chat_history: Vec<Message>,
/// The documents to be sent to the completion model provider
Expand Down Expand Up @@ -331,7 +333,8 @@ impl CompletionRequest {
pub struct CompletionRequestBuilder<M: CompletionModel> {
model: M,
prompt: String,
preamble: Option<String>,
preamble: Option<Vec<String>>,
cached_preamble: Option<Vec<String>>,
chat_history: Vec<Message>,
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
Expand All @@ -346,6 +349,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
model,
prompt,
preamble: None,
cached_preamble: None,
chat_history: Vec::new(),
documents: Vec::new(),
tools: Vec::new(),
Expand All @@ -356,7 +360,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
}

/// Sets the preamble for the completion request.
pub fn preamble(mut self, preamble: String) -> Self {
pub fn preamble(mut self, preamble: Vec<String>) -> Self {
self.preamble = Some(preamble);
self
}
Expand Down Expand Up @@ -453,11 +457,18 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
self
}

/// Sets the cached preamble for the completion request.
pub fn cached_preamble(mut self, cached_preamble: Option<Vec<String>>) -> Self {
self.cached_preamble = cached_preamble;
self
}

/// Builds the completion request.
pub fn build(self) -> CompletionRequest {
CompletionRequest {
prompt: self.prompt,
preamble: self.preamble,
cached_preamble: self.cached_preamble,
chat_history: self.chat_history,
documents: self.documents,
tools: self.tools,
Expand Down Expand Up @@ -536,6 +547,7 @@ mod tests {
let request = CompletionRequest {
prompt: "What is the capital of France?".to_string(),
preamble: None,
cached_preamble: None,
chat_history: Vec::new(),
documents: vec![doc1, doc2],
tools: Vec::new(),
Expand Down
69 changes: 56 additions & 13 deletions rig-core/src/providers/anthropic/completion.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
//! Anthropic completion api implementation

use std::iter;

use crate::{
completion::{self, CompletionError},
json_utils,
Expand Down Expand Up @@ -34,7 +32,7 @@ pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;

#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Clone)]
pub struct CompletionResponse {
pub content: Vec<Content>,
pub id: String,
Expand All @@ -45,7 +43,7 @@ pub struct CompletionResponse {
pub usage: Usage,
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(untagged)]
pub enum Content {
String(String),
Expand All @@ -61,7 +59,7 @@ pub enum Content {
},
}

#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Usage {
pub input_tokens: u64,
pub cache_read_input_tokens: Option<u64>,
Expand Down Expand Up @@ -221,20 +219,50 @@ impl completion::CompletionModel for CompletionModel {
"`max_tokens` must be set for Anthropic".into(),
));
};
let mut system = Vec::new();

if let Some(cached_preamble) = completion_request.cached_preamble {
for (index, preamble) in cached_preamble.iter().enumerate() {
let mut preamble_json = json!({
"type": "text",
"text": preamble,
});
if index == cached_preamble.len() - 1 {
json_utils::merge_inplace(
&mut preamble_json,
json!({
"cache_control": {
"type": "ephemeral"
}
}),
);
}
system.push(preamble_json);
}
}

if let Some(preamble) = completion_request.preamble {
for preamble in preamble {
system.push(json!({
"type": "text",
"text": preamble
}));
}
}

let mut request = json!({
"model": self.model,
"messages": completion_request
.chat_history
.into_iter()
.map(Message::from)
.chain(iter::once(Message {
.chain((!prompt_with_context.is_empty()).then(|| Message {
role: "user".to_owned(),
content: prompt_with_context,
}))
.collect::<Vec<_>>(),
"max_tokens": max_tokens,
"system": completion_request.preamble.unwrap_or("".to_string()),
"system": system,
});

if let Some(temperature) = completion_request.temperature {
Expand All @@ -247,13 +275,28 @@ impl completion::CompletionModel for CompletionModel {
json!({
"tools": completion_request
.tools
.clone()
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
})
.collect::<Vec<_>>(),
.enumerate()
.map(|(index, tool)| {
let mut tool_json = json!({
"name": tool.name,
"description": tool.description,
"input_schema": tool.parameters,
});
if index == completion_request.tools.len() - 1 {
json_utils::merge_inplace(
&mut tool_json,
json!({
"cache_control": {
"type": "ephemeral"
}
}),
);
}
tool_json
})
.collect::<Vec<_>>(),
"tool_choice": ToolChoice::Auto,
}),
);
Expand Down
Loading
Loading