-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathtest_stdio.py
More file actions
127 lines (103 loc) · 5.04 KB
/
test_stdio.py
File metadata and controls
127 lines (103 loc) · 5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import io
import os
import sys
import tempfile
from io import TextIOWrapper
import anyio
import pytest
from mcp.server.stdio import stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
@pytest.mark.anyio
async def test_stdio_server():
stdin = io.StringIO()
stdout = io.StringIO()
messages = [
JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"),
JSONRPCResponse(jsonrpc="2.0", id=2, result={}),
]
for message in messages:
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
stdin.seek(0)
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
read_stream,
write_stream,
):
received_messages: list[JSONRPCMessage] = []
async with read_stream:
async for message in read_stream:
if isinstance(message, Exception): # pragma: no cover
raise message
received_messages.append(message.message)
if len(received_messages) == 2:
break
# Verify received messages
assert len(received_messages) == 2
assert received_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
assert received_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={})
# Test sending responses from the server
responses = [
JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"),
JSONRPCResponse(jsonrpc="2.0", id=4, result={}),
]
async with write_stream:
for response in responses:
session_message = SessionMessage(response)
await write_stream.send(session_message)
stdout.seek(0)
output_lines = stdout.readlines()
assert len(output_lines) == 2
received_responses = [jsonrpc_message_adapter.validate_json(line.strip()) for line in output_lines]
assert len(received_responses) == 2
assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={})
@pytest.mark.anyio
async def test_stdio_server_does_not_close_real_stdio(monkeypatch: pytest.MonkeyPatch):
"""stdio_server() must not close the real sys.stdin/sys.stdout after exiting.
Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1933.
"""
# Substitute sys.stdin/stdout with real file-backed streams so that fileno()
# works and we can verify the fds remain open after the server exits.
with tempfile.TemporaryFile() as tf_in, tempfile.TemporaryFile() as tf_out:
fake_stdin = TextIOWrapper(tf_in, encoding="utf-8")
fake_stdout = TextIOWrapper(tf_out, encoding="utf-8")
monkeypatch.setattr(sys, "stdin", fake_stdin)
monkeypatch.setattr(sys, "stdout", fake_stdout)
real_stdin_fd = sys.stdin.fileno()
real_stdout_fd = sys.stdout.fileno()
with anyio.fail_after(5):
async with stdio_server() as (read_stream, write_stream):
await write_stream.aclose()
await read_stream.aclose()
# os.fstat() raises OSError if the fd has been closed; successful calls
# prove stdio_server() did not close the real process descriptors.
os.fstat(real_stdin_fd)
os.fstat(real_stdout_fd)
# The Python wrappers we set as sys.stdin/stdout must not be closed either.
assert not sys.stdin.closed
assert not sys.stdout.closed
@pytest.mark.anyio
async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
"""Non-UTF-8 bytes on stdin must not crash the server.
Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and
is delivered as an in-stream exception. Subsequent valid messages must
still be processed.
"""
# \xff\xfe are invalid UTF-8 start bytes.
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
# Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that
# stdio_server()'s default path wraps it with errors='replace'.
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8"))
with anyio.fail_after(5):
async with stdio_server() as (read_stream, write_stream):
await write_stream.aclose()
async with read_stream: # pragma: no branch
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
first = await read_stream.receive()
assert isinstance(first, Exception)
# Second line: valid message still comes through
second = await read_stream.receive()
assert isinstance(second, SessionMessage)
assert second.message == valid