Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,28 +602,56 @@ where
// JSON-direct mode: await the single response and return as
// application/json, eliminating SSE framing overhead.
// Allowed by MCP Streamable HTTP spec (2025-06-18).
//
// Tools may emit progress notifications before their
// final response. In JSON-direct mode there is no
// secondary channel for those notifications, so keep
// draining until we receive the terminal response/error
// message that should satisfy the HTTP request.
let cancel = self.config.cancellation_token.child_token();
match tokio::select! {
res = receiver.recv() => res,
_ = cancel.cancelled() => None,
} {
Some(message) => {
tracing::trace!(?message);
let body = serde_json::to_vec(&message).map_err(|e| {
internal_error_response("serialize json response")(e)
})?;
Ok(Response::builder()
.status(http::StatusCode::OK)
.header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
.body(Full::new(Bytes::from(body)).boxed())
.expect("valid response"))
loop {
match tokio::select! {
res = receiver.recv() => res,
_ = cancel.cancelled() => None,
} {
Some(
message @ (crate::model::ServerJsonRpcMessage::Response(_)
| crate::model::ServerJsonRpcMessage::Error(_)),
) => {
tracing::trace!(?message);
let body = serde_json::to_vec(&message).map_err(|e| {
internal_error_response("serialize json response")(e)
})?;
break Ok(Response::builder()
.status(http::StatusCode::OK)
.header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
.body(Full::new(Bytes::from(body)).boxed())
.expect("valid response"));
}
Some(crate::model::ServerJsonRpcMessage::Notification(
notification,
)) => {
tracing::debug!(
?notification,
"dropping server notification while awaiting JSON response"
);
}
Some(crate::model::ServerJsonRpcMessage::Request(request)) => {
tracing::warn!(
?request,
"cannot deliver server request over JSON-direct response"
);
break Err(unexpected_message_response("response or error"));
}
None => {
break Err(internal_error_response("empty response")(
std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"no response message received from handler",
),
));
}
}
None => Err(internal_error_response("empty response")(
std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"no response message received from handler",
),
)),
}
} else {
// SSE mode (default): original behaviour preserved unchanged
Expand Down
131 changes: 129 additions & 2 deletions crates/rmcp/tests/test_streamable_http_json_response.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
#![cfg(not(feature = "local"))]
use rmcp::transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
use std::time::Duration;

use futures::future::BoxFuture;
use rmcp::{
ServerHandler, ServiceExt,
handler::server::{
router::tool::ToolRoute,
tool::{ToolCallContext, ToolRouter, schema_for_type},
},
model::{
CallToolRequestParams, CallToolResult, Content, ProgressNotificationParam,
ServerCapabilities, ServerInfo, Tool,
},
tool_handler,
transport::{
StreamableHttpClientTransport,
streamable_http_client::StreamableHttpClientTransportConfig,
streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
},
};
use tokio_util::sync::CancellationToken;

Expand Down Expand Up @@ -76,6 +95,114 @@ async fn stateless_json_response_returns_application_json() -> anyhow::Result<()
Ok(())
}

#[derive(Debug, Default, serde::Deserialize, schemars::JsonSchema)]
struct EmptyArgs {}

#[derive(Debug, Clone)]
struct ProgressToolServer {
tool_router: ToolRouter<Self>,
}

impl ProgressToolServer {
fn new() -> Self {
Self {
tool_router: ToolRouter::new().with_route(ToolRoute::new_dyn(
Tool::new(
"progress_then_result",
"Emit a progress notification before returning",
schema_for_type::<EmptyArgs>(),
),
|context: ToolCallContext<'_, Self>| -> BoxFuture<'_, _> {
Box::pin(async move {
let Some(progress_token) =
context.request_context.meta.get_progress_token()
else {
return Err(rmcp::ErrorData::invalid_params(
"missing progress token",
None,
));
};

context
.request_context
.peer
.notify_progress(ProgressNotificationParam::new(progress_token, 1.0))
.await
.map_err(|err| {
rmcp::ErrorData::internal_error(
format!("failed to send progress notification: {err}"),
None,
)
})?;

Ok(CallToolResult::success(vec![Content::text("done")]))
})
},
)),
}
}
}

#[tool_handler]
impl ServerHandler for ProgressToolServer {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
}
}

#[tokio::test]
async fn stateless_json_response_waits_for_terminal_tool_response() -> anyhow::Result<()> {
let ct = CancellationToken::new();
let service: StreamableHttpService<ProgressToolServer, LocalSessionManager> =
StreamableHttpService::new(
|| Ok(ProgressToolServer::new()),
Default::default(),
StreamableHttpServerConfig {
stateful_mode: false,
json_response: true,
sse_keep_alive: None,
cancellation_token: ct.child_token(),
..Default::default()
},
);

let router = axum::Router::new().nest_service("/mcp", service);
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = tcp_listener.local_addr()?;

let handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(tcp_listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});

let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
);
let client = ().serve(transport).await?;

let result = tokio::time::timeout(
Duration::from_secs(3),
client.call_tool(CallToolRequestParams::new("progress_then_result")),
)
.await??;

let text = result
.content
.first()
.and_then(|content| content.raw.as_text())
.map(|text| text.text.as_str());
assert_eq!(text, Some("done"));

let _ = client.cancel().await;
ct.cancel();
handle.await?;
Ok(())
}

#[tokio::test]
async fn stateless_sse_mode_default_unchanged() -> anyhow::Result<()> {
let ct = CancellationToken::new();
Expand Down