Skip to content

Commit a76e7db

Browse files
committed
fix: #2873 preserve computer driver compatibility for modifier keys
1 parent 86739b1 commit a76e7db

4 files changed

Lines changed: 324 additions & 43 deletions

File tree

examples/tools/computer_use.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import asyncio
66
import base64
77
import sys
8+
from collections.abc import AsyncIterator
9+
from contextlib import asynccontextmanager
810
from typing import Any, Literal, Union
911

1012
from playwright.async_api import Browser, Page, Playwright, async_playwright
@@ -118,46 +120,77 @@ async def screenshot(self) -> str:
118120
png_bytes = await self.page.screenshot(full_page=False)
119121
return base64.b64encode(png_bytes).decode("utf-8")
120122

121-
async def click(self, x: int, y: int, button: Button = "left") -> None:
123+
def _normalize_keys(self, keys: list[str] | None) -> list[str]:
124+
if not keys:
125+
return []
126+
return [CUA_KEY_TO_PLAYWRIGHT_KEY.get(key.lower(), key) for key in keys]
127+
128+
@asynccontextmanager
129+
async def _hold_keys(self, keys: list[str] | None) -> AsyncIterator[None]:
130+
mapped_keys = self._normalize_keys(keys)
131+
try:
132+
for key in mapped_keys:
133+
await self.page.keyboard.down(key)
134+
yield
135+
finally:
136+
for key in reversed(mapped_keys):
137+
await self.page.keyboard.up(key)
138+
139+
async def click(
140+
self, x: int, y: int, button: Button = "left", *, keys: list[str] | None = None
141+
) -> None:
122142
playwright_button: Literal["left", "middle", "right"] = "left"
123143

124144
# Playwright only supports left, middle, right buttons
125145
if button in ("left", "right", "middle"):
126146
playwright_button = button # type: ignore
127147

128-
await self.page.mouse.click(x, y, button=playwright_button)
148+
async with self._hold_keys(keys):
149+
await self.page.mouse.click(x, y, button=playwright_button)
129150

130-
async def double_click(self, x: int, y: int) -> None:
131-
await self.page.mouse.dblclick(x, y)
151+
async def double_click(self, x: int, y: int, *, keys: list[str] | None = None) -> None:
152+
async with self._hold_keys(keys):
153+
await self.page.mouse.dblclick(x, y)
132154

133-
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
134-
await self.page.mouse.move(x, y)
135-
await self.page.evaluate(f"window.scrollBy({scroll_x}, {scroll_y})")
155+
async def scroll(
156+
self,
157+
x: int,
158+
y: int,
159+
scroll_x: int,
160+
scroll_y: int,
161+
*,
162+
keys: list[str] | None = None,
163+
) -> None:
164+
async with self._hold_keys(keys):
165+
await self.page.mouse.move(x, y)
166+
await self.page.evaluate(f"window.scrollBy({scroll_x}, {scroll_y})")
136167

137168
async def type(self, text: str) -> None:
138169
await self.page.keyboard.type(text)
139170

140171
async def wait(self) -> None:
141172
await asyncio.sleep(1)
142173

143-
async def move(self, x: int, y: int) -> None:
144-
await self.page.mouse.move(x, y)
174+
async def move(self, x: int, y: int, *, keys: list[str] | None = None) -> None:
175+
async with self._hold_keys(keys):
176+
await self.page.mouse.move(x, y)
145177

146178
async def keypress(self, keys: list[str]) -> None:
147-
mapped_keys = [CUA_KEY_TO_PLAYWRIGHT_KEY.get(key.lower(), key) for key in keys]
179+
mapped_keys = self._normalize_keys(keys)
148180
for key in mapped_keys:
149181
await self.page.keyboard.down(key)
150182
for key in reversed(mapped_keys):
151183
await self.page.keyboard.up(key)
152184

153-
async def drag(self, path: list[tuple[int, int]]) -> None:
185+
async def drag(self, path: list[tuple[int, int]], *, keys: list[str] | None = None) -> None:
154186
if not path:
155187
return
156-
await self.page.mouse.move(path[0][0], path[0][1])
157-
await self.page.mouse.down()
158-
for px, py in path[1:]:
159-
await self.page.mouse.move(px, py)
160-
await self.page.mouse.up()
188+
async with self._hold_keys(keys):
189+
await self.page.mouse.move(path[0][0], path[0][1])
190+
await self.page.mouse.down()
191+
for px, py in path[1:]:
192+
await self.page.mouse.move(px, py)
193+
await self.page.mouse.up()
161194

162195

163196
async def run_agent(

src/agents/computer.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66

77

88
class Computer(abc.ABC):
9-
"""A computer implemented with sync operations. The Computer interface abstracts the
10-
operations needed to control a computer or browser."""
9+
"""A computer implemented with sync operations.
10+
11+
Subclasses provide the local runtime behind `ComputerTool`. Mouse action methods may
12+
also accept a keyword-only `keys` argument to receive held modifier keys when the
13+
driver supports them.
14+
"""
1115

1216
@property
1317
def environment(self) -> Environment | None:
@@ -21,44 +25,57 @@ def dimensions(self) -> tuple[int, int] | None:
2125

2226
@abc.abstractmethod
2327
def screenshot(self) -> str:
28+
"""Return a base64-encoded PNG screenshot of the current display."""
2429
pass
2530

2631
@abc.abstractmethod
2732
def click(self, x: int, y: int, button: Button) -> None:
33+
"""Click `button` at the given `(x, y)` screen coordinates."""
2834
pass
2935

3036
@abc.abstractmethod
3137
def double_click(self, x: int, y: int) -> None:
38+
"""Double-click at the given `(x, y)` screen coordinates."""
3239
pass
3340

3441
@abc.abstractmethod
3542
def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
43+
"""Scroll at `(x, y)` by `(scroll_x, scroll_y)` units."""
3644
pass
3745

3846
@abc.abstractmethod
3947
def type(self, text: str) -> None:
48+
"""Type `text` into the currently focused target."""
4049
pass
4150

4251
@abc.abstractmethod
4352
def wait(self) -> None:
53+
"""Wait until the computer is ready for the next action."""
4454
pass
4555

4656
@abc.abstractmethod
4757
def move(self, x: int, y: int) -> None:
58+
"""Move the mouse cursor to the given `(x, y)` screen coordinates."""
4859
pass
4960

5061
@abc.abstractmethod
5162
def keypress(self, keys: list[str]) -> None:
63+
"""Press the provided keys, such as `["ctrl", "c"]`."""
5264
pass
5365

5466
@abc.abstractmethod
5567
def drag(self, path: list[tuple[int, int]]) -> None:
68+
"""Click-and-drag the mouse along the given sequence of `(x, y)` waypoints."""
5669
pass
5770

5871

5972
class AsyncComputer(abc.ABC):
60-
"""A computer implemented with async operations. The Computer interface abstracts the
61-
operations needed to control a computer or browser."""
73+
"""A computer implemented with async operations.
74+
75+
Subclasses provide the local runtime behind `ComputerTool`. Mouse action methods may
76+
also accept a keyword-only `keys` argument to receive held modifier keys when the
77+
driver supports them.
78+
"""
6279

6380
@property
6481
def environment(self) -> Environment | None:
@@ -72,36 +89,45 @@ def dimensions(self) -> tuple[int, int] | None:
7289

7390
@abc.abstractmethod
7491
async def screenshot(self) -> str:
92+
"""Return a base64-encoded PNG screenshot of the current display."""
7593
pass
7694

7795
@abc.abstractmethod
7896
async def click(self, x: int, y: int, button: Button) -> None:
97+
"""Click `button` at the given `(x, y)` screen coordinates."""
7998
pass
8099

81100
@abc.abstractmethod
82101
async def double_click(self, x: int, y: int) -> None:
102+
"""Double-click at the given `(x, y)` screen coordinates."""
83103
pass
84104

85105
@abc.abstractmethod
86106
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
107+
"""Scroll at `(x, y)` by `(scroll_x, scroll_y)` units."""
87108
pass
88109

89110
@abc.abstractmethod
90111
async def type(self, text: str) -> None:
112+
"""Type `text` into the currently focused target."""
91113
pass
92114

93115
@abc.abstractmethod
94116
async def wait(self) -> None:
117+
"""Wait until the computer is ready for the next action."""
95118
pass
96119

97120
@abc.abstractmethod
98121
async def move(self, x: int, y: int) -> None:
122+
"""Move the mouse cursor to the given `(x, y)` screen coordinates."""
99123
pass
100124

101125
@abc.abstractmethod
102126
async def keypress(self, keys: list[str]) -> None:
127+
"""Press the provided keys, such as `["ctrl", "c"]`."""
103128
pass
104129

105130
@abc.abstractmethod
106131
async def drag(self, path: list[tuple[int, int]]) -> None:
132+
"""Click-and-drag the mouse along the given sequence of `(x, y)` waypoints."""
107133
pass

src/agents/run_internal/tool_actions.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,30 +185,38 @@ async def _execute_action_and_capture(
185185
) -> str:
186186
"""Execute computer actions (sync or async drivers) and return the final screenshot."""
187187

188-
async def maybe_call(method_name: str, *args: Any) -> Any:
188+
async def maybe_call(method_name: str, *args: Any, **kwargs: Any) -> Any:
189189
method = getattr(computer, method_name, None)
190190
if method is None or not callable(method):
191191
raise ModelBehaviorError(f"Computer driver missing method {method_name}")
192-
result = method(*args)
192+
filtered_kwargs = cls._filter_supported_kwargs(
193+
method_name=method_name,
194+
method=method,
195+
kwargs=kwargs,
196+
)
197+
result = method(*args, **filtered_kwargs)
193198
return await result if inspect.isawaitable(result) else result
194199

195200
last_action_was_screenshot = False
196201
last_screenshot_result: Any = None
197202
for action in cls._iter_actions(tool_call):
198203
action_type = get_mapping_or_attr(action, "type")
204+
action_keys = cls._normalize_modifier_keys(get_mapping_or_attr(action, "keys"))
199205
last_action_was_screenshot = False
200206
if action_type == "click":
201207
await maybe_call(
202208
"click",
203209
get_mapping_or_attr(action, "x"),
204210
get_mapping_or_attr(action, "y"),
205211
get_mapping_or_attr(action, "button"),
212+
keys=action_keys,
206213
)
207214
elif action_type == "double_click":
208215
await maybe_call(
209216
"double_click",
210217
get_mapping_or_attr(action, "x"),
211218
get_mapping_or_attr(action, "y"),
219+
keys=action_keys,
212220
)
213221
elif action_type == "drag":
214222
path = get_mapping_or_attr(action, "path") or []
@@ -221,6 +229,7 @@ async def maybe_call(method_name: str, *args: Any) -> Any:
221229
)
222230
for point in path
223231
],
232+
keys=action_keys,
224233
)
225234
elif action_type == "keypress":
226235
await maybe_call("keypress", get_mapping_or_attr(action, "keys"))
@@ -229,6 +238,7 @@ async def maybe_call(method_name: str, *args: Any) -> Any:
229238
"move",
230239
get_mapping_or_attr(action, "x"),
231240
get_mapping_or_attr(action, "y"),
241+
keys=action_keys,
232242
)
233243
elif action_type == "screenshot":
234244
last_screenshot_result = await maybe_call("screenshot")
@@ -240,6 +250,7 @@ async def maybe_call(method_name: str, *args: Any) -> Any:
240250
get_mapping_or_attr(action, "y"),
241251
get_mapping_or_attr(action, "scroll_x"),
242252
get_mapping_or_attr(action, "scroll_y"),
253+
keys=action_keys,
243254
)
244255
elif action_type == "type":
245256
await maybe_call("type", get_mapping_or_attr(action, "text"))
@@ -285,6 +296,61 @@ def _serialize_action_payload(action: Any) -> Any:
285296
return dataclasses.asdict(action)
286297
return action
287298

299+
@staticmethod
300+
def _normalize_modifier_keys(keys: Any) -> list[str] | None:
301+
if not keys:
302+
return None
303+
return cast(list[str], keys)
304+
305+
@classmethod
306+
def _filter_supported_kwargs(
307+
cls,
308+
*,
309+
method_name: str,
310+
method: Any,
311+
kwargs: dict[str, Any],
312+
) -> dict[str, Any]:
313+
filtered_kwargs = {key: value for key, value in kwargs.items() if value is not None}
314+
if not filtered_kwargs:
315+
return {}
316+
317+
supported_kwargs = cls._supported_keyword_arguments(method)
318+
unsupported_kwargs = [
319+
key
320+
for key in filtered_kwargs
321+
if key not in supported_kwargs and None not in supported_kwargs
322+
]
323+
if unsupported_kwargs:
324+
logger.warning(
325+
"Computer driver method %r does not accept keyword argument(s) %s; "
326+
"dropping them and continuing.",
327+
method_name,
328+
", ".join(sorted(unsupported_kwargs)),
329+
)
330+
for key in unsupported_kwargs:
331+
filtered_kwargs.pop(key, None)
332+
333+
return filtered_kwargs
334+
335+
@staticmethod
336+
def _supported_keyword_arguments(method: Any) -> set[str | None]:
337+
signature = inspect.signature(method)
338+
supported: set[str | None] = {
339+
parameter.name
340+
for parameter in signature.parameters.values()
341+
if parameter.kind
342+
in {
343+
inspect.Parameter.KEYWORD_ONLY,
344+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
345+
}
346+
}
347+
if any(
348+
parameter.kind == inspect.Parameter.VAR_KEYWORD
349+
for parameter in signature.parameters.values()
350+
):
351+
supported.add(None)
352+
return supported
353+
288354

289355
class LocalShellAction:
290356
"""Execute local shell commands via the LocalShellTool with lifecycle hooks."""

0 commit comments

Comments
 (0)