diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 0130467d..53a65c87 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -602,28 +602,56 @@ where // JSON-direct mode: await the single response and return as // application/json, eliminating SSE framing overhead. // Allowed by MCP Streamable HTTP spec (2025-06-18). + // + // Tools may emit progress notifications before their + // final response. In JSON-direct mode there is no + // secondary channel for those notifications, so keep + // draining until we receive the terminal response/error + // message that should satisfy the HTTP request. let cancel = self.config.cancellation_token.child_token(); - match tokio::select! { - res = receiver.recv() => res, - _ = cancel.cancelled() => None, - } { - Some(message) => { - tracing::trace!(?message); - let body = serde_json::to_vec(&message).map_err(|e| { - internal_error_response("serialize json response")(e) - })?; - Ok(Response::builder() - .status(http::StatusCode::OK) - .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE) - .body(Full::new(Bytes::from(body)).boxed()) - .expect("valid response")) + loop { + match tokio::select! { + res = receiver.recv() => res, + _ = cancel.cancelled() => None, + } { + Some( + message @ (crate::model::ServerJsonRpcMessage::Response(_) + | crate::model::ServerJsonRpcMessage::Error(_)), + ) => { + tracing::trace!(?message); + let body = serde_json::to_vec(&message).map_err(|e| { + internal_error_response("serialize json response")(e) + })?; + break Ok(Response::builder() + .status(http::StatusCode::OK) + .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE) + .body(Full::new(Bytes::from(body)).boxed()) + .expect("valid response")); + } + Some(crate::model::ServerJsonRpcMessage::Notification( + notification, + )) => { + tracing::debug!( + ?notification, + "dropping server notification while awaiting JSON response" + ); + } + Some(crate::model::ServerJsonRpcMessage::Request(request)) => { + tracing::warn!( + ?request, + "cannot deliver server request over JSON-direct response" + ); + break Err(unexpected_message_response("response or error")); + } + None => { + break Err(internal_error_response("empty response")( + std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "no response message received from handler", + ), + )); + } } - None => Err(internal_error_response("empty response")( - std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "no response message received from handler", - ), - )), } } else { // SSE mode (default): original behaviour preserved unchanged diff --git a/crates/rmcp/tests/test_streamable_http_json_response.rs b/crates/rmcp/tests/test_streamable_http_json_response.rs index b023acd0..ec5be224 100644 --- a/crates/rmcp/tests/test_streamable_http_json_response.rs +++ b/crates/rmcp/tests/test_streamable_http_json_response.rs @@ -1,6 +1,25 @@ #![cfg(not(feature = "local"))] -use rmcp::transport::streamable_http_server::{ - StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, +use std::time::Duration; + +use futures::future::BoxFuture; +use rmcp::{ + ServerHandler, ServiceExt, + handler::server::{ + router::tool::ToolRoute, + tool::{ToolCallContext, ToolRouter, schema_for_type}, + }, + model::{ + CallToolRequestParams, CallToolResult, Content, ProgressNotificationParam, + ServerCapabilities, ServerInfo, Tool, + }, + tool_handler, + transport::{ + StreamableHttpClientTransport, + streamable_http_client::StreamableHttpClientTransportConfig, + streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }, + }, }; use tokio_util::sync::CancellationToken; @@ -76,6 +95,114 @@ async fn stateless_json_response_returns_application_json() -> anyhow::Result<() Ok(()) } +#[derive(Debug, Default, serde::Deserialize, schemars::JsonSchema)] +struct EmptyArgs {} + +#[derive(Debug, Clone)] +struct ProgressToolServer { + tool_router: ToolRouter, +} + +impl ProgressToolServer { + fn new() -> Self { + Self { + tool_router: ToolRouter::new().with_route(ToolRoute::new_dyn( + Tool::new( + "progress_then_result", + "Emit a progress notification before returning", + schema_for_type::(), + ), + |context: ToolCallContext<'_, Self>| -> BoxFuture<'_, _> { + Box::pin(async move { + let Some(progress_token) = + context.request_context.meta.get_progress_token() + else { + return Err(rmcp::ErrorData::invalid_params( + "missing progress token", + None, + )); + }; + + context + .request_context + .peer + .notify_progress(ProgressNotificationParam::new(progress_token, 1.0)) + .await + .map_err(|err| { + rmcp::ErrorData::internal_error( + format!("failed to send progress notification: {err}"), + None, + ) + })?; + + Ok(CallToolResult::success(vec![Content::text("done")])) + }) + }, + )), + } + } +} + +#[tool_handler] +impl ServerHandler for ProgressToolServer { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) + } +} + +#[tokio::test] +async fn stateless_json_response_waits_for_terminal_tool_response() -> anyhow::Result<()> { + let ct = CancellationToken::new(); + let service: StreamableHttpService = + StreamableHttpService::new( + || Ok(ProgressToolServer::new()), + Default::default(), + StreamableHttpServerConfig { + stateful_mode: false, + json_response: true, + sse_keep_alive: None, + cancellation_token: ct.child_token(), + ..Default::default() + }, + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = tcp_listener.local_addr()?; + + let handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + let transport = StreamableHttpClientTransport::from_config( + StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), + ); + let client = ().serve(transport).await?; + + let result = tokio::time::timeout( + Duration::from_secs(3), + client.call_tool(CallToolRequestParams::new("progress_then_result")), + ) + .await??; + + let text = result + .content + .first() + .and_then(|content| content.raw.as_text()) + .map(|text| text.text.as_str()); + assert_eq!(text, Some("done")); + + let _ = client.cancel().await; + ct.cancel(); + handle.await?; + Ok(()) +} + #[tokio::test] async fn stateless_sse_mode_default_unchanged() -> anyhow::Result<()> { let ct = CancellationToken::new();