From b794ddd3e6cd4fc523e775b0cf03b97a8e78c76b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 5 May 2026 06:05:49 +0000 Subject: [PATCH 1/2] feat: implement MCP progress notifications in MCPClient - Add `progress_callback: ProgressFnT | None` parameter to `MCPClient.__init__` - Thread callback through `call_tool_sync`, `call_tool_async`, and `_create_call_tool_coroutine` to `ClientSession.call_tool()` - Support per-call override: per-call callback takes precedence over instance callback - Export `ProgressFnT` from `strands.tools.mcp` package - Update existing tests to include `progress_callback` in call assertions - Add new tests for instance-level callback, per-call override, and default None behavior Agent-Logs-Url: https://github.com/joshwand/strands-sdk-python/sessions/7ab253b3-c748-48ed-9a8a-ae5c8508e938 Co-authored-by: joshwand <22531+joshwand@users.noreply.github.com> --- src/strands/tools/mcp/__init__.py | 4 +- src/strands/tools/mcp/mcp_client.py | 26 +++++++- tests/strands/tools/mcp/test_mcp_client.py | 70 ++++++++++++++++++---- 3 files changed, 86 insertions(+), 14 deletions(-) diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index 8d2c1daa2..8bd621a22 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -6,9 +6,11 @@ - Docs: https://www.anthropic.com/news/model-context-protocol """ +from mcp.shared.session import ProgressFnT + from .mcp_agent_tool import MCPAgentTool from .mcp_client import MCPClient, ToolFilters from .mcp_tasks import TasksConfig from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "TasksConfig", "ToolFilters"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ProgressFnT", "TasksConfig", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 1884ce9bc..2d64ae4e4 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -27,6 +27,7 @@ from mcp import ClientSession, ListToolsResult from mcp.client.session import ElicitationFnT from mcp.shared.exceptions import McpError +from mcp.shared.session import ProgressFnT from mcp.types import ( BlobResourceContents, ElicitationRequiredErrorData, @@ -121,6 +122,7 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, + progress_callback: ProgressFnT | None = None, tasks_config: TasksConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -132,6 +134,9 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. + progress_callback: Optional callback to receive progress notifications during tool execution. + Called with ``(progress, total, message)`` as the server reports progress. The ``total`` + and ``message`` parameters may be ``None`` if the server does not provide them. tasks_config: Configuration for MCP task-augmented execution for long-running tools. If provided (not None), enables task-augmented execution for tools that support it. See TasksConfig for details. This feature is experimental and subject to change. @@ -140,6 +145,7 @@ def __init__( self._tool_filters = tool_filters self._prefix = prefix self._elicitation_callback = elicitation_callback + self._progress_callback = progress_callback mcp_instrumentation() self._session_id = uuid.uuid4() @@ -579,6 +585,7 @@ def _create_call_tool_coroutine( arguments: dict[str, Any] | None, read_timeout_seconds: timedelta | None, meta: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, ) -> Coroutine[Any, Any, MCPCallToolResult]: """Create the appropriate coroutine for calling a tool. @@ -590,11 +597,14 @@ def _create_call_tool_coroutine( arguments: Optional arguments to pass to the tool. read_timeout_seconds: Optional timeout for the tool call. meta: Optional metadata to pass to the tool call per MCP spec (_meta). + progress_callback: Optional callback to receive progress notifications. + If None, falls back to the instance-level callback set at construction time. Returns: A coroutine that will execute the tool call. """ use_task = self._should_use_task(name) + effective_callback = progress_callback if progress_callback is not None else self._progress_callback if use_task: self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) @@ -612,7 +622,7 @@ async def _call_as_task() -> MCPCallToolResult: async def _call_tool_direct() -> MCPCallToolResult: return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds, meta=meta + name, arguments, read_timeout_seconds, progress_callback=effective_callback, meta=meta ) return _call_tool_direct() @@ -624,6 +634,7 @@ def call_tool_sync( arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, meta: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. @@ -636,6 +647,8 @@ def call_tool_sync( arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call meta: Optional metadata to pass to the tool call per MCP spec (_meta) + progress_callback: Optional callback to receive progress notifications for this + call. Overrides the instance-level callback set at construction time. Returns: MCPToolResult: The result of the tool call @@ -645,7 +658,9 @@ def call_tool_sync( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + coro = self._create_call_tool_coroutine( + name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback + ) call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -659,6 +674,7 @@ async def call_tool_async( arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, meta: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. @@ -671,6 +687,8 @@ async def call_tool_async( arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call meta: Optional metadata to pass to the tool call per MCP spec (_meta) + progress_callback: Optional callback to receive progress notifications for this + call. Overrides the instance-level callback set at construction time. Returns: MCPToolResult: The result of the tool call @@ -680,7 +698,9 @@ async def call_tool_async( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + coro = self._create_call_tool_coroutine( + name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback + ) future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f270fa6fc..fd24fd510 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -124,7 +124,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None) assert result["status"] == expected_status assert result["toolUseId"] == "test-123" @@ -155,7 +155,7 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert result["toolUseId"] == "test-123" @@ -193,10 +193,60 @@ def test_call_tool_sync_forwards_meta(mock_transport, mock_session): tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta ) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=meta) assert result["status"] == "success" +def test_call_tool_sync_forwards_instance_progress_callback(mock_transport, mock_session): + """Test that call_tool_sync uses the instance-level progress callback when no per-call callback is given.""" + from unittest.mock import AsyncMock + + mock_content = MCPTextContent(type="text", text="done") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + cb = AsyncMock() + + with MCPClient(mock_transport["transport_callable"], progress_callback=cb) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + + mock_session.call_tool.assert_called_once_with( + "test_tool", {}, None, progress_callback=cb, meta=None + ) + assert result["status"] == "success" + + +def test_call_tool_sync_per_call_progress_callback_overrides_instance(mock_transport, mock_session): + """Test that a per-call progress callback overrides the instance-level one.""" + from unittest.mock import AsyncMock + + mock_content = MCPTextContent(type="text", text="done") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + instance_cb = AsyncMock() + per_call_cb = AsyncMock() + + with MCPClient(mock_transport["transport_callable"], progress_callback=instance_cb) as client: + result = client.call_tool_sync( + tool_use_id="test-123", name="test_tool", arguments={}, progress_callback=per_call_cb + ) + + mock_session.call_tool.assert_called_once_with( + "test_tool", {}, None, progress_callback=per_call_cb, meta=None + ) + assert result["status"] == "success" + + +def test_call_tool_sync_no_progress_callback_by_default(mock_transport, mock_session): + """Test that progress_callback defaults to None when not set on instance or per-call.""" + mock_content = MCPTextContent(type="text", text="done") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + + with MCPClient(mock_transport["transport_callable"]) as client: + client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + + mock_session.call_tool.assert_called_once_with( + "test_tool", {}, None, progress_callback=None, meta=None + ) + + @pytest.mark.asyncio async def test_call_tool_async_forwards_meta(mock_transport, mock_session): """Test that call_tool_async forwards meta to ClientSession.call_tool.""" @@ -672,7 +722,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "inner text" @@ -697,7 +747,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == '{"k":"v"}' @@ -723,7 +773,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "image" in result["content"][0] @@ -748,7 +798,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Content should be dropped @@ -771,7 +821,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "key: value" in result["content"][0]["text"] @@ -798,7 +848,7 @@ def __init__(self): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Unknown resource type should be dropped @@ -850,7 +900,7 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert result["toolUseId"] == "test-123" From 924240e78333d2cd808d6aea9226fa35a5664348 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 5 May 2026 06:07:44 +0000 Subject: [PATCH 2/2] fix: move AsyncMock imports to top of test file Agent-Logs-Url: https://github.com/joshwand/strands-sdk-python/sessions/7ab253b3-c748-48ed-9a8a-ae5c8508e938 Co-authored-by: joshwand <22531+joshwand@users.noreply.github.com> --- tests/strands/tools/mcp/test_mcp_client.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index fd24fd510..e6d6032e9 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,6 +1,6 @@ import base64 import time -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from mcp import ListToolsResult @@ -199,8 +199,6 @@ def test_call_tool_sync_forwards_meta(mock_transport, mock_session): def test_call_tool_sync_forwards_instance_progress_callback(mock_transport, mock_session): """Test that call_tool_sync uses the instance-level progress callback when no per-call callback is given.""" - from unittest.mock import AsyncMock - mock_content = MCPTextContent(type="text", text="done") mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) cb = AsyncMock() @@ -216,8 +214,6 @@ def test_call_tool_sync_forwards_instance_progress_callback(mock_transport, mock def test_call_tool_sync_per_call_progress_callback_overrides_instance(mock_transport, mock_session): """Test that a per-call progress callback overrides the instance-level one.""" - from unittest.mock import AsyncMock - mock_content = MCPTextContent(type="text", text="done") mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) instance_cb = AsyncMock()