Skip to content

Commit 140ebec

Browse files
committed
Refactor of responses models dumping
1 parent f7b927a commit 140ebec

7 files changed

Lines changed: 202 additions & 33 deletions

File tree

src/app/endpoints/responses.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -598,12 +598,7 @@ async def responses_endpoint_handler(
598598
original_request.input, inline_rag_context.context_text
599599
)
600600

601-
api_params = ResponsesApiParams.model_validate(
602-
{
603-
**updated_request.model_dump(exclude={"tools"}),
604-
"tools": updated_request.tools,
605-
}
606-
)
601+
api_params = ResponsesApiParams.model_validate(updated_request.model_dump())
607602
context = ResponsesContext(
608603
client=client,
609604
auth=auth,

src/models/api/requests/responses_openai.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from constants import RESPONSES_REQUEST_MAX_SIZE
2424
from models.common.query import SolrVectorSearchRequest
2525
from models.common.responses.types import IncludeParameter, ResponseInput
26+
from models.utils import add_mcp_authorizations
2627
from utils import suid
2728

2829

@@ -176,3 +177,14 @@ def check_previous_response_id(cls, value: Optional[str]) -> Optional[str]:
176177
if value is not None and value.startswith("modr"):
177178
raise ValueError("You cannot provide context by moderation response.")
178179
return value
180+
181+
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
182+
"""Serialize to a request body dict.
183+
184+
Returns:
185+
Serializable dict with MCP authorizations preserved.
186+
"""
187+
result = super().model_dump(*args, **kwargs)
188+
if result.get("tools") is not None and self.tools is not None:
189+
result["tools"] = add_mcp_authorizations(result["tools"], self.tools)
190+
return result

src/models/common/responses/responses_api_params.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pydantic import BaseModel, Field
2525

2626
from models.common.responses.types import IncludeParameter, ResponseInput
27+
from models.utils import add_mcp_authorizations
2728
from utils.tool_formatter import translate_vector_store_ids_to_user_facing
2829

2930
# Attribute names that are echoed back in the response.
@@ -126,28 +127,19 @@ class ResponsesApiParams(BaseModel):
126127
)
127128

128129
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
129-
"""Serialize params, re-injecting MCP authorization stripped by exclude=True.
130+
"""Serialize to a request body dict.
130131
131-
llama-stack-api marks ``InputToolMCP.authorization`` with
132-
``Field(exclude=True)`` to prevent token leakage in API responses.
133-
The base ``model_dump()`` therefore strips the field, but we need it
134-
in the request payload so llama-stack server can authenticate with
135-
MCP servers. See LCORE-1414 / GitHub issue #1269.
132+
Omits conversation when previous_response_id is set; restores MCP
133+
authorization on dumped tool rows.
134+
135+
Returns:
136+
Serializable dict for the Responses API request body.
136137
"""
137138
result = super().model_dump(*args, **kwargs)
138-
# Only one context option is allowed, previous_response_id has priority
139-
# Turn is added to conversation manually if previous_response_id is used
140139
if self.previous_response_id:
141140
result.pop("conversation", None)
142-
dumped_tools = result.get("tools")
143-
if not self.tools or not isinstance(dumped_tools, list):
144-
return result
145-
if len(dumped_tools) != len(self.tools):
146-
return result
147-
for tool, dumped_tool in zip(self.tools, dumped_tools):
148-
authorization = getattr(tool, "authorization", None)
149-
if authorization is not None and isinstance(dumped_tool, dict):
150-
dumped_tool["authorization"] = authorization
141+
if self.tools is not None and result.get("tools") is not None:
142+
result["tools"] = add_mcp_authorizations(result["tools"], self.tools)
151143
return result
152144

153145
def echoed_params(self, rag_id_mapping: Mapping[str, str]) -> dict[str, Any]:

src/models/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Utility functions for models."""
2+
3+
from typing import Any
4+
5+
from llama_stack_api.openai_responses import OpenAIResponseInputTool as InputTool
6+
7+
8+
def add_mcp_authorizations(
9+
dumped_tools: list[dict[str, Any]],
10+
tools: list[InputTool],
11+
) -> list[dict[str, Any]]:
12+
"""Merge MCP authorization into serialized tool dicts keyed by server_label.
13+
14+
Args:
15+
dumped_tools: Serialized tools.
16+
tools: Live tool models. MCP entries with authorization are mapped by
17+
server_label.
18+
19+
Returns:
20+
A new list of dicts. For MCP rows, authorization is set only when a
21+
matching non-None token exists.
22+
"""
23+
authorizations = {
24+
tool.server_label: tool.authorization
25+
for tool in tools
26+
if tool.type == "mcp" and tool.authorization is not None
27+
} # server_labels are unique by design
28+
result: list[dict[str, Any]] = []
29+
for dumped in dumped_tools:
30+
row = dict(dumped)
31+
if (
32+
row.get("type") == "mcp"
33+
and (label := row.get("server_label")) is not None
34+
and (token := authorizations.get(label)) is not None
35+
):
36+
row["authorization"] = token
37+
38+
result.append(row)
39+
return result

tests/unit/app/endpoints/test_responses.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
from llama_stack_api.openai_responses import (
1313
OpenAIResponseInputToolChoiceMode as ToolChoiceMode,
1414
)
15-
from llama_stack_api.openai_responses import OpenAIResponseMessage
15+
from llama_stack_api.openai_responses import (
16+
OpenAIResponseInputToolMCP as InputToolMCP,
17+
)
18+
from llama_stack_api.openai_responses import (
19+
OpenAIResponseMessage,
20+
)
1621
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
1722
from pytest_mock import MockerFixture
1823

@@ -71,12 +76,7 @@ def build_api_params_and_context( # pylint: disable=too-many-arguments
7176
user_agent: Optional[str] = None,
7277
) -> tuple[ResponsesApiParams, ResponsesContext]:
7378
"""Build api_params/context for direct helper invocation tests."""
74-
api_params = ResponsesApiParams.model_validate(
75-
{
76-
**updated_request.model_dump(exclude={"tools"}),
77-
"tools": updated_request.tools,
78-
}
79-
)
79+
api_params = ResponsesApiParams.model_validate(updated_request.model_dump())
8080
context = ResponsesContext.model_construct(
8181
client=client,
8282
auth=auth,
@@ -94,6 +94,26 @@ def build_api_params_and_context( # pylint: disable=too-many-arguments
9494
return api_params, context
9595

9696

97+
def test_responses_api_params_preserves_mcp_authorization() -> None:
98+
"""After model_validate, MCP tool authorization from model_dump is kept on api_params.tools."""
99+
token = "secret-token"
100+
req = ResponsesRequest(
101+
input="x",
102+
model=MODEL,
103+
conversation=VALID_CONV_ID,
104+
tools=[
105+
InputToolMCP(
106+
server_label="alpha",
107+
server_url="http://alpha",
108+
require_approval="never",
109+
authorization=token,
110+
)
111+
],
112+
)
113+
api = ResponsesApiParams.model_validate(req.model_dump())
114+
assert api.tools is not None and api.tools[0].authorization == token
115+
116+
97117
def _patch_base(mocker: MockerFixture, config: AppConfig) -> None:
98118
"""Patch configuration and mandatory checks for responses endpoint."""
99119
mocker.patch(f"{MODULE}.configuration", config)

tests/unit/models/test_utils.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Unit tests for models.utils (mirrors src/models/utils.py)."""
2+
3+
from llama_stack_api.openai_responses import (
4+
OpenAIResponseInputToolFileSearch as InputToolFileSearch,
5+
)
6+
from llama_stack_api.openai_responses import (
7+
OpenAIResponseInputToolMCP as InputToolMCP,
8+
)
9+
10+
from models.utils import add_mcp_authorizations
11+
12+
13+
class TestAddMcpAuthorizations:
14+
"""Tests for add_mcp_authorizations with realistic MCP tool rows.
15+
16+
Assumes server_label is present on MCP dicts and unique across configured
17+
servers; see InputToolMCP in llama-stack-api.
18+
"""
19+
20+
def test_merges_authorization_by_server_label(self) -> None:
21+
"""MCP model_dump omits authorization; the helper restores it by server_label."""
22+
live = InputToolMCP(
23+
server_label="alpha",
24+
server_url="http://alpha",
25+
require_approval="never",
26+
authorization="secret-token",
27+
)
28+
dumped = [live.model_dump()]
29+
assert "authorization" not in dumped[0]
30+
31+
out = add_mcp_authorizations(dumped, [live])
32+
assert len(out) == 1
33+
assert out[0]["authorization"] == "secret-token"
34+
assert out[0]["server_label"] == "alpha"
35+
36+
def test_two_mcp_servers_distinct_tokens(self) -> None:
37+
"""Each server_label receives its own authorization."""
38+
a = InputToolMCP(
39+
server_label="srv-a",
40+
server_url="http://a",
41+
require_approval="never",
42+
authorization="token-a",
43+
)
44+
b = InputToolMCP(
45+
server_label="srv-b",
46+
server_url="http://b",
47+
require_approval="never",
48+
authorization="token-b",
49+
)
50+
dumped = [a.model_dump(), b.model_dump()]
51+
assert "authorization" not in dumped[0]
52+
assert "authorization" not in dumped[1]
53+
54+
out = add_mcp_authorizations(dumped, [a, b])
55+
assert out[0]["authorization"] == "token-a"
56+
assert out[1]["authorization"] == "token-b"
57+
58+
def test_file_search_row_unchanged_no_authorization_merge(self) -> None:
59+
"""Non-MCP rows are copied; MCP row still gets auth from live list."""
60+
mcp = InputToolMCP(
61+
server_label="m",
62+
server_url="http://m",
63+
require_approval="never",
64+
authorization="mcp-secret",
65+
)
66+
fs = InputToolFileSearch(type="file_search", vector_store_ids=["vs-1"])
67+
dumped = [fs.model_dump(), mcp.model_dump()]
68+
assert "authorization" not in dumped[1]
69+
70+
out = add_mcp_authorizations(dumped, [fs, mcp])
71+
assert out[0]["type"] == "file_search"
72+
assert "authorization" not in out[0]
73+
assert out[1]["authorization"] == "mcp-secret"
74+
75+
def test_subset_dumped_rows_still_match_live_by_label(self) -> None:
76+
"""When only some MCP tools appear in dumped_tools, labels still align."""
77+
first = InputToolMCP(
78+
server_label="one",
79+
server_url="http://one",
80+
require_approval="never",
81+
authorization="tok-one",
82+
)
83+
second = InputToolMCP(
84+
server_label="two",
85+
server_url="http://two",
86+
require_approval="never",
87+
authorization="tok-two",
88+
)
89+
dumped = [second.model_dump()]
90+
assert "authorization" not in dumped[0]
91+
92+
out = add_mcp_authorizations(dumped, [first, second])
93+
assert len(out) == 1
94+
assert out[0]["authorization"] == "tok-two"
95+
96+
def test_does_not_mutate_input_list_or_dicts(self) -> None:
97+
"""Output is new containers; inputs stay as provided."""
98+
live = InputToolMCP(
99+
server_label="s",
100+
server_url="http://s",
101+
require_approval="never",
102+
authorization="t",
103+
)
104+
dumped = [live.model_dump()]
105+
row = dumped[0]
106+
assert "authorization" not in row
107+
108+
out = add_mcp_authorizations(dumped, [live])
109+
assert out is not dumped
110+
assert out[0] is not row
111+
assert "authorization" not in row

tests/unit/utils/test_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ def test_multiple_mcp_tools_each_preserves_authorization(self) -> None:
240240
assert dumped["tools"][0]["authorization"] == "token-a"
241241
assert dumped["tools"][1]["authorization"] == "token-b"
242242

243-
def test_exclude_changing_tool_list_shape_skips_reinjection(self) -> None:
244-
"""Test that exclude removing tool indices does not mis-assign authorization."""
243+
def test_partial_tool_dump_reinjects_auth_by_server_label(self) -> None:
244+
"""When exclude drops some tools, remaining MCP rows still get auth by label."""
245245
tool_a = InputToolMCP(
246246
server_label="server-a",
247247
server_url="http://a:3000",
@@ -258,7 +258,7 @@ def test_exclude_changing_tool_list_shape_skips_reinjection(self) -> None:
258258
dumped = params.model_dump(exclude={"tools": {0}})
259259
assert len(dumped["tools"]) == 1
260260
assert dumped["tools"][0]["server_label"] == "server-b"
261-
assert "authorization" not in dumped["tools"][0]
261+
assert dumped["tools"][0]["authorization"] == "token-b"
262262

263263
def test_no_tools_does_not_error(self) -> None:
264264
"""Test that model_dump() works when tools is None."""

0 commit comments

Comments
 (0)