diff --git a/examples/README.md b/examples/README.md index 2b358aa10..b1985f147 100644 --- a/examples/README.md +++ b/examples/README.md @@ -70,7 +70,7 @@ see [servers/README.md](servers/README.md) # Integration -- [Rig](rig-integration) A stream chatbot with rig +- [Rig](https://github.com/0xPlaygrounds/rig/blob/main/rig/rig-core/examples/rmcp.rs) A stream chatbot with rig - [Simple Chat Client](simple-chat-client) A simple chat client implementation using the Model Context Protocol (MCP) SDK. # WASI diff --git a/examples/rig-integration/Cargo.toml b/examples/rig-integration/Cargo.toml deleted file mode 100644 index cfed3c4c1..000000000 --- a/examples/rig-integration/Cargo.toml +++ /dev/null @@ -1,34 +0,0 @@ -[package] -name = "rig-integration" -edition = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -repository = { workspace = true } -description = { workspace = true } -keywords = { workspace = true } -homepage = { workspace = true } -categories = { workspace = true } -readme = { workspace = true } -publish = false - -[dependencies] -rig-core = "0.32.0" -tokio = { version = "1", features = ["full"] } -rmcp = { workspace = true, features = [ - "client", - "transport-child-process", - "transport-streamable-http-client-reqwest" -] } -anyhow = "1.0" -serde_json = "1" -serde = { version = "1", features = ["derive"] } -toml = "1.0" -futures = "0.3" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = [ - "env-filter", - "std", - "fmt", -] } -tracing-appender = "0.2" diff --git a/examples/rig-integration/config.toml b/examples/rig-integration/config.toml deleted file mode 100644 index 1affe7aaa..000000000 --- a/examples/rig-integration/config.toml +++ /dev/null @@ -1,10 +0,0 @@ -deepseek_key = "" -cohere_key = "" - -[mcp] - -[[mcp.server]] -name = "git" -protocol = "stdio" -command = "uvx" -args = ["mcp-server-git"] diff --git a/examples/rig-integration/src/chat.rs b/examples/rig-integration/src/chat.rs deleted file mode 100644 index 13d28ab56..000000000 --- a/examples/rig-integration/src/chat.rs +++ /dev/null @@ -1,134 +0,0 @@ -use futures::StreamExt; -use rig::{ - agent::{Agent, MultiTurnStreamItem}, - completion::CompletionModel, - message::{Message, Text}, - streaming::{StreamedAssistantContent, StreamingChat}, -}; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; - -pub async fn cli_chatbot(chatbot: Agent) -> anyhow::Result<()> -where - M: CompletionModel + 'static, - M::StreamingResponse: Send, -{ - let mut chat_log = vec![]; - - let mut output = BufWriter::new(tokio::io::stdout()); - let mut input = BufReader::new(tokio::io::stdin()); - output.write_all(b"Enter :q to quit\n").await?; - loop { - output.write_all(b"\x1b[32muser>\x1b[0m ").await?; - // Flush stdout to ensure the prompt appears before input - output.flush().await?; - let mut input_buf = String::new(); - input.read_line(&mut input_buf).await?; - // Remove the newline character from the input - let input = input_buf.trim(); - // Check for a command to exit - if input == ":q" { - break; - } - - tracing::info!(%input); - chat_log.push(Message::user(input)); - - let mut response = chatbot.stream_chat(input, chat_log.clone()).await; - stream_output_agent_start(&mut output).await?; - let mut message_buf = String::new(); - - while let Some(message) = response.next().await { - match message { - Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text( - Text { text }, - ))) => { - message_buf.push_str(&text); - output_agent(&text, &mut output).await?; - } - Ok(MultiTurnStreamItem::StreamAssistantItem( - StreamedAssistantContent::ToolCall { tool_call, .. }, - )) => { - let name = &tool_call.function.name; - let arguments = &tool_call.function.arguments; - stream_output_toolcall( - format!("Calling tool: {name} with args: {arguments}"), - &mut output, - ) - .await?; - } - Ok(MultiTurnStreamItem::StreamUserItem(user_content)) => { - // Tool results are streamed back as user items - stream_output_toolcall(format!("Tool result: {:?}", user_content), &mut output) - .await?; - } - Ok(MultiTurnStreamItem::FinalResponse(final_response)) => { - tracing::info!("Final response received: {:?}", final_response); - } - Ok(_) => { - // Handle other stream items (reasoning, deltas, etc.) - } - Err(error) => { - output_error(error, &mut output).await?; - } - } - } - - chat_log.push(Message::assistant(message_buf)); - stream_output_agent_finished(&mut output).await?; - } - - Ok(()) -} - -pub async fn output_error( - e: impl std::fmt::Display, - output: &mut BufWriter, -) -> std::io::Result<()> { - output - .write_all(b"\x1b[1;31m\xE2\x9D\x8C ERROR: \x1b[0m") - .await?; - output.write_all(e.to_string().as_bytes()).await?; - output.write_all(b"\n").await?; - output.flush().await?; - Ok(()) -} - -pub async fn output_agent( - content: impl std::fmt::Display, - output: &mut BufWriter, -) -> std::io::Result<()> { - output.write_all(content.to_string().as_bytes()).await?; - output.flush().await?; - Ok(()) -} - -pub async fn stream_output_toolcall( - content: impl std::fmt::Display, - output: &mut BufWriter, -) -> std::io::Result<()> { - output - .write_all(b"\x1b[1;33m\xF0\x9F\x9B\xA0 Tool Call: \x1b[0m") - .await?; - output.write_all(content.to_string().as_bytes()).await?; - output.write_all(b"\n").await?; - output.flush().await?; - Ok(()) -} - -pub async fn stream_output_agent_start( - output: &mut BufWriter, -) -> std::io::Result<()> { - output - .write_all(b"\x1b[1;34m\xF0\x9F\xA4\x96 Agent: \x1b[0m") - .await?; - output.flush().await?; - Ok(()) -} - -pub async fn stream_output_agent_finished( - output: &mut BufWriter, -) -> std::io::Result<()> { - output.write_all(b"\n").await?; - output.flush().await?; - Ok(()) -} diff --git a/examples/rig-integration/src/config.rs b/examples/rig-integration/src/config.rs deleted file mode 100644 index 387a4f686..000000000 --- a/examples/rig-integration/src/config.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::path::Path; - -use serde::{Deserialize, Serialize}; - -pub mod mcp; - -#[derive(Debug, Deserialize, Serialize)] -pub struct Config { - pub mcp: mcp::McpConfig, - pub deepseek_key: Option, - pub cohere_key: Option, -} - -impl Config { - pub async fn retrieve(path: impl AsRef) -> anyhow::Result { - let content = tokio::fs::read_to_string(path).await?; - let config: Self = toml::from_str(&content)?; - Ok(config) - } -} diff --git a/examples/rig-integration/src/config/mcp.rs b/examples/rig-integration/src/config/mcp.rs deleted file mode 100644 index 45e4c23ca..000000000 --- a/examples/rig-integration/src/config/mcp.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::{collections::HashMap, process::Stdio}; - -use rmcp::{RoleClient, ServiceExt, service::RunningService, transport::ConfigureCommandExt}; -use serde::{Deserialize, Serialize}; - -use crate::mcp_adaptor::McpManager; -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct McpServerConfig { - name: String, - #[serde(flatten)] - transport: McpServerTransportConfig, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(tag = "protocol", rename_all = "lowercase")] -pub enum McpServerTransportConfig { - Streamable { - url: String, - }, - Stdio { - command: String, - #[serde(default)] - args: Vec, - #[serde(default)] - envs: HashMap, - }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct McpConfig { - server: Vec, -} - -impl McpConfig { - pub async fn create_manager(&self) -> anyhow::Result { - let mut clients = HashMap::new(); - let mut task_set = tokio::task::JoinSet::>::new(); - for server in &self.server { - let server = server.clone(); - task_set.spawn(async move { - let client = server.transport.start().await?; - anyhow::Result::Ok((server.name.clone(), client)) - }); - } - let start_up_result = task_set.join_all().await; - for result in start_up_result { - match result { - Ok((name, client)) => { - clients.insert(name, client); - } - Err(e) => { - eprintln!("Failed to start server: {:?}", e); - } - } - } - Ok(McpManager { clients }) - } -} - -impl McpServerTransportConfig { - pub async fn start(&self) -> anyhow::Result> { - let client = match self { - McpServerTransportConfig::Streamable { url } => { - let transport = - rmcp::transport::StreamableHttpClientTransport::from_uri(url.to_string()); - ().serve(transport).await? - } - McpServerTransportConfig::Stdio { - command, - args, - envs, - } => { - let transport = rmcp::transport::TokioChildProcess::new( - tokio::process::Command::new(command).configure(|cmd| { - cmd.args(args).envs(envs).stderr(Stdio::null()); - }), - )?; - ().serve(transport).await? - } - }; - Ok(client) - } -} diff --git a/examples/rig-integration/src/main.rs b/examples/rig-integration/src/main.rs deleted file mode 100644 index c9fe81190..000000000 --- a/examples/rig-integration/src/main.rs +++ /dev/null @@ -1,69 +0,0 @@ -use rig::{ - client::{CompletionClient, ProviderClient}, - embeddings::EmbeddingsBuilder, - providers::{cohere, deepseek}, - vector_store::in_memory_store::InMemoryVectorStore, -}; -use tracing_appender::rolling::{RollingFileAppender, Rotation}; -pub mod chat; -pub mod config; -pub mod mcp_adaptor; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let file_appender = RollingFileAppender::new( - Rotation::DAILY, - "logs", - format!("{}.log", env!("CARGO_CRATE_NAME")), - ); - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::from_default_env() - .add_directive(tracing::Level::INFO.into()), - ) - .with_writer(file_appender) - .with_file(false) - .with_ansi(false) - .init(); - - let config = config::Config::retrieve("config.toml").await?; - let deepseek_client = { - if let Some(key) = config.deepseek_key { - deepseek::Client::new(&key)? - } else { - deepseek::Client::from_env() - } - }; - let cohere_client = { - if let Some(key) = config.cohere_key { - cohere::Client::new(&key)? - } else { - cohere::Client::from_env() - } - }; - let mcp_manager = config.mcp.create_manager().await?; - tracing::info!( - "MCP Manager created, {} servers started", - mcp_manager.clients.len() - ); - let tool_set = mcp_manager.get_tool_set().await?; - let embedding_model = - cohere_client.embedding_model(cohere::EMBED_MULTILINGUAL_V3, "search_document"); - let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(tool_set.schemas()?)? - .build() - .await?; - let store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |f| { - tracing::info!("store tool {}", f.name); - f.name.clone() - }); - let index = store.index(embedding_model); - let dpsk = deepseek_client - .agent(deepseek::DEEPSEEK_CHAT) - .dynamic_tools(4, index, tool_set) - .build(); - - chat::cli_chatbot(dpsk).await?; - - Ok(()) -} diff --git a/examples/rig-integration/src/mcp_adaptor.rs b/examples/rig-integration/src/mcp_adaptor.rs deleted file mode 100644 index 41de15768..000000000 --- a/examples/rig-integration/src/mcp_adaptor.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::collections::HashMap; - -use rig::tool::{ToolDyn as RigTool, ToolEmbeddingDyn, ToolSet}; -use rmcp::{ - RoleClient, - model::{CallToolRequestParams, CallToolResult, Tool as McpTool}, - service::{RunningService, ServerSink}, -}; - -pub struct McpToolAdaptor { - tool: McpTool, - server: ServerSink, -} - -impl RigTool for McpToolAdaptor { - fn name(&self) -> String { - self.tool.name.to_string() - } - - fn definition( - &self, - _prompt: String, - ) -> std::pin::Pin + Send + '_>> { - Box::pin(std::future::ready(rig::completion::ToolDefinition { - name: self.name(), - description: self - .tool - .description - .as_deref() - .unwrap_or_default() - .to_string(), - parameters: self.tool.schema_as_json_value(), - })) - } - - fn call( - &self, - args: String, - ) -> std::pin::Pin> + Send + '_>> - { - let server = self.server.clone(); - Box::pin(async move { - let call_mcp_tool_result = server - .call_tool( - CallToolRequestParams::new(self.tool.name.clone()).with_arguments( - serde_json::from_str(&args).map_err(rig::tool::ToolError::JsonError)?, - ), - ) - .await - .inspect(|result| tracing::info!(?result)) - .inspect_err(|error| tracing::error!(%error)) - .map_err(|e| rig::tool::ToolError::ToolCallError(Box::new(e)))?; - - Ok(convert_mcp_call_tool_result_to_string(call_mcp_tool_result)) - }) - } -} - -impl ToolEmbeddingDyn for McpToolAdaptor { - fn context(&self) -> serde_json::Result { - serde_json::to_value(self.tool.clone()) - } - - fn embedding_docs(&self) -> Vec { - vec![ - self.tool - .description - .as_deref() - .unwrap_or_default() - .to_string(), - ] - } -} - -pub struct McpManager { - pub clients: HashMap>, -} - -impl McpManager { - pub async fn get_tool_set(&self) -> anyhow::Result { - let mut tool_set = ToolSet::default(); - let mut task = tokio::task::JoinSet::>::new(); - for client in self.clients.values() { - let server = client.peer().clone(); - task.spawn(get_tool_set(server)); - } - let results = task.join_all().await; - for result in results { - match result { - Err(e) => { - tracing::error!(error = %e, "Failed to get tool set"); - } - Ok(tools) => { - tool_set.add_tools(tools); - } - } - } - Ok(tool_set) - } -} - -pub fn convert_mcp_call_tool_result_to_string(result: CallToolResult) -> String { - serde_json::to_string(&result).unwrap() -} - -pub async fn get_tool_set(server: ServerSink) -> anyhow::Result { - let tools = server.list_all_tools().await?; - let mut tool_builder = ToolSet::builder(); - for tool in tools { - tracing::info!("get tool: {}", tool.name); - let adaptor = McpToolAdaptor { - tool: tool.clone(), - server: server.clone(), - }; - tool_builder = tool_builder.dynamic_tool(adaptor); - } - let tool_set = tool_builder.build(); - Ok(tool_set) -}