Skip to content

Commit 8e64caf

Browse files
phernandezclaude
andcommitted
fix: list_workspaces bypasses factory pattern on cloud MCP server (#636)
Add set_workspace_provider() injection point so the cloud MCP server can list workspaces by querying its own database directly, instead of making an HTTP round-trip to the control-plane API with credentials it doesn't have. 🔧 Mirrors the existing set_client_factory() pattern in async_client.py 🧪 Adds 3 tests for provider injection, fallback, and context caching 🩹 Updates build_context test assertions for v0.18 backward compat fields Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: phernandez <paul@basicmachines.co>
1 parent ccb5740 commit 8e64caf

3 files changed

Lines changed: 126 additions & 16 deletions

File tree

src/basic_memory/mcp/project_context.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from contextlib import asynccontextmanager
12-
from typing import AsyncIterator, Optional, List, Tuple
12+
from typing import AsyncIterator, Awaitable, Callable, Optional, List, Tuple
1313

1414
from httpx import AsyncClient
1515
from httpx._types import (
@@ -27,6 +27,18 @@
2727
from basic_memory.schemas.memory import memory_url_path
2828
from basic_memory.utils import generate_permalink, normalize_project_reference
2929

30+
# --- Workspace provider injection ---
31+
# Mirrors the set_client_factory() pattern in async_client.py.
32+
# The cloud MCP server sets a provider that queries its own database directly,
33+
# avoiding the control-plane HTTP round-trip that requires local credentials.
34+
_workspace_provider: Optional[Callable[[], Awaitable[list[WorkspaceInfo]]]] = None
35+
36+
37+
def set_workspace_provider(provider: Callable[[], Awaitable[list[WorkspaceInfo]]]) -> None:
38+
"""Override workspace discovery (for cloud app, testing, etc)."""
39+
global _workspace_provider
40+
_workspace_provider = provider
41+
3042

3143
async def resolve_project_parameter(
3244
project: Optional[str] = None,
@@ -103,6 +115,19 @@ async def get_available_workspaces(context: Optional[Context] = None) -> list[Wo
103115
if isinstance(cached_raw, list):
104116
return [WorkspaceInfo.model_validate(item) for item in cached_raw]
105117

118+
# Trigger: workspace provider was injected (e.g., by cloud MCP server)
119+
# Why: the cloud server IS the cloud — it can query its own database
120+
# directly instead of making an HTTP round-trip that requires local credentials
121+
# Outcome: use provider result, cache in context, skip control-plane client
122+
if _workspace_provider is not None:
123+
workspaces = await _workspace_provider()
124+
if context:
125+
await context.set_state(
126+
"available_workspaces",
127+
[ws.model_dump() for ws in workspaces],
128+
)
129+
return workspaces
130+
106131
from basic_memory.mcp.async_client import get_cloud_control_plane_client
107132
from basic_memory.mcp.tools.utils import call_get
108133

tests/mcp/test_tool_build_context.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@pytest.mark.asyncio
1111
async def test_get_basic_discussion_context(client, test_graph, test_project):
12-
"""Test getting basic discussion context returns JSON dict with excluded fields removed."""
12+
"""Test getting basic discussion context returns JSON dict with expected fields."""
1313
result = await build_context(project=test_project.name, url="memory://test/root")
1414

1515
assert isinstance(result, dict)
@@ -19,29 +19,27 @@ async def test_get_basic_discussion_context(client, test_graph, test_project):
1919
assert primary["permalink"] == f"{test_project.name}/test/root"
2020
assert len(result["results"][0]["related_results"]) > 0
2121

22-
# Verify metadata — excluded fields should be absent
22+
# Verify metadata fields
2323
meta = result["metadata"]
2424
assert meta["uri"] == f"{test_project.name}/test/root"
2525
assert meta["depth"] == 1 # default depth
2626
assert meta["timeframe"] is not None
2727
assert meta["primary_count"] == 1
28-
assert "generated_at" not in meta
29-
assert "total_results" not in meta
28+
# COMPAT(v0.18): generated_at and total_results restored for old clients
29+
assert "generated_at" in meta
30+
assert "total_results" in meta
3031

31-
# Entity: entity_id excluded, created_at kept (needed for related results)
32-
assert "entity_id" not in primary
32+
# Entity fields present
33+
assert "entity_id" in primary
3334
assert "created_at" in primary
3435

35-
# Verify observation-level fields: internal IDs excluded, file_path/created_at kept
36+
# Verify observation-level fields
3637
if result["results"][0]["observations"]:
3738
obs = result["results"][0]["observations"][0]
38-
assert "observation_id" not in obs
39-
assert "entity_id" not in obs
40-
assert "title" not in obs
41-
# file_path and created_at kept (needed when observation is primary_result)
39+
assert "observation_id" in obs
40+
assert "entity_id" in obs
4241
assert "file_path" in obs
4342
assert "created_at" in obs
44-
# Other kept fields
4543
assert "permalink" in obs
4644
assert "category" in obs
4745
assert "content" in obs
@@ -53,14 +51,14 @@ async def test_get_basic_discussion_context(client, test_graph, test_project):
5351
assert "title" in related
5452
assert "file_path" in related
5553
assert "created_at" in related
56-
assert "entity_id" not in related # excluded
54+
assert "entity_id" in related
5755
elif item_type == "relation":
5856
assert "relation_type" in related
5957
assert "title" in related
6058
assert "file_path" in related
6159
assert "created_at" in related
62-
assert "relation_id" not in related # excluded
63-
assert "entity_id" not in related # excluded
60+
assert "relation_id" in related
61+
assert "entity_id" in related
6462

6563

6664
@pytest.mark.asyncio

tests/mcp/test_tool_workspace_management.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
from basic_memory.mcp.project_context import get_available_workspaces, set_workspace_provider
56
from basic_memory.mcp.tools.workspaces import list_workspaces
67
from basic_memory.schemas.cloud import WorkspaceInfo
78

@@ -105,3 +106,89 @@ async def fake_get_available_workspaces(context=None):
105106
assert "# Available Workspaces (1)" in first
106107
assert "# Available Workspaces (1)" in second
107108
assert call_count["fetches"] == 1
109+
110+
111+
# --- Workspace provider injection tests ---
112+
113+
114+
@pytest.fixture
115+
def _reset_workspace_provider(monkeypatch):
116+
"""Ensure _workspace_provider is reset after each test."""
117+
import basic_memory.mcp.project_context as _mod
118+
119+
monkeypatch.setattr(_mod, "_workspace_provider", None)
120+
121+
122+
@pytest.mark.asyncio
123+
@pytest.mark.usefixtures("_reset_workspace_provider")
124+
async def test_get_available_workspaces_uses_provider_when_set():
125+
"""When a workspace provider is injected, it is called instead of the control-plane client."""
126+
expected = [
127+
WorkspaceInfo(
128+
tenant_id="aaaa-bbbb",
129+
workspace_type="personal",
130+
name="Injected",
131+
role="owner",
132+
),
133+
]
134+
135+
async def fake_provider() -> list[WorkspaceInfo]:
136+
return expected
137+
138+
set_workspace_provider(fake_provider)
139+
140+
result = await get_available_workspaces()
141+
assert len(result) == 1
142+
assert result[0].tenant_id == "aaaa-bbbb"
143+
assert result[0].name == "Injected"
144+
145+
146+
@pytest.mark.asyncio
147+
@pytest.mark.usefixtures("_reset_workspace_provider")
148+
async def test_get_available_workspaces_falls_back_without_provider(monkeypatch):
149+
"""Without a provider, get_available_workspaces uses the control-plane client (existing path)."""
150+
called = {"control_plane": False}
151+
152+
async def fake_control_plane_path(context=None):
153+
called["control_plane"] = True
154+
return []
155+
156+
# Patch the entire function to avoid needing real credentials
157+
monkeypatch.setattr(
158+
"basic_memory.mcp.tools.workspaces.get_available_workspaces",
159+
fake_control_plane_path,
160+
)
161+
162+
result = await list_workspaces()
163+
assert called["control_plane"]
164+
assert "# No Workspaces Available" in result
165+
166+
167+
@pytest.mark.asyncio
168+
@pytest.mark.usefixtures("_reset_workspace_provider")
169+
async def test_get_available_workspaces_provider_caches_in_context():
170+
"""Provider results are cached in the MCP context for subsequent calls."""
171+
call_count = {"provider": 0}
172+
workspace = WorkspaceInfo(
173+
tenant_id="cccc-dddd",
174+
workspace_type="organization",
175+
name="Cached Provider",
176+
role="editor",
177+
)
178+
179+
async def counting_provider() -> list[WorkspaceInfo]:
180+
call_count["provider"] += 1
181+
return [workspace]
182+
183+
set_workspace_provider(counting_provider)
184+
context = _ContextState()
185+
186+
# First call: provider is invoked, result cached
187+
first = await get_available_workspaces(context=context)
188+
assert len(first) == 1
189+
assert call_count["provider"] == 1
190+
191+
# Second call: served from context cache, provider not called again
192+
second = await get_available_workspaces(context=context)
193+
assert len(second) == 1
194+
assert call_count["provider"] == 1

0 commit comments

Comments
 (0)