Skip to content

Commit d68a961

Browse files
committed
IActionRunner.run_action_iter and its implementation. Sending partial results using yield in run action and using explicit partial result sender
1 parent 4850991 commit d68a961

6 files changed

Lines changed: 152 additions & 15 deletions

File tree

finecode_builtin_handlers/src/finecode_builtin_handlers/format.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ async def run(
2929
self,
3030
payload: format_action.FormatRunPayload,
3131
run_context: format_action.FormatRunContext,
32-
) -> format_action.FormatRunResult:
32+
):
3333
run_meta = run_context.meta
3434
file_uris: list[ResourceUri]
3535

@@ -66,14 +66,14 @@ async def run(
6666
format_files_action_instance = self.action_runner.get_action_by_source(
6767
format_files_action.FormatFilesAction
6868
)
69-
format_result = await self.action_runner.run_action(
69+
async for partial in self.action_runner.run_action_iter(
7070
action=format_files_action_instance,
7171
payload=format_files_action.FormatFilesRunPayload(
7272
file_paths=file_uris,
7373
save=payload.save,
7474
),
7575
meta=run_meta,
76-
)
77-
return format_action.FormatRunResult(
78-
result_by_file_path=format_result.result_by_file_path
79-
)
76+
):
77+
yield format_action.FormatRunResult(
78+
result_by_file_path=partial.result_by_file_path
79+
)

finecode_builtin_handlers/src/finecode_builtin_handlers/lint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def run(
3131
self,
3232
payload: lint_action.LintRunPayload,
3333
run_context: lint_action.LintRunContext,
34-
) -> lint_action.LintRunResult:
34+
):
3535
run_meta = run_context.meta
3636
file_uris: list[ResourceUri]
3737

@@ -68,9 +68,9 @@ async def run(
6868
lint_files_action_instance = self.action_runner.get_action_by_source(
6969
lint_files_action.LintFilesAction
7070
)
71-
lint_result = await self.action_runner.run_action(
71+
async for partial in self.action_runner.run_action_iter(
7272
action=lint_files_action_instance,
7373
payload=lint_files_action.LintFilesRunPayload(file_paths=file_uris),
7474
meta=run_meta,
75-
)
76-
return lint_action.LintRunResult(messages=lint_result.messages)
75+
):
76+
yield lint_action.LintRunResult(messages=partial.messages)

finecode_extension_api/src/finecode_extension_api/code_action.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ def update(self, result: RunActionResult) -> None:
8888
self._current_result.update(result)
8989

9090

91+
class PartialResultSender(typing.Protocol):
92+
"""Handler-facing interface for sending partial results to the client."""
93+
94+
async def send(self, result: RunActionResult) -> None: ...
95+
96+
97+
class _NoOpPartialResultSender:
98+
async def send(self, result: RunActionResult) -> None:
99+
pass
100+
101+
102+
_NOOP_SENDER = _NoOpPartialResultSender()
103+
104+
91105
class RunActionContext(typing.Generic[RunPayloadType]):
92106
# data object to save data between action steps(only during one run, after run data
93107
# is removed). Keep it simple, without business logic, just data storage, but you
@@ -101,12 +115,14 @@ def __init__(
101115
initial_payload: RunPayloadType,
102116
meta: RunActionMeta,
103117
info_provider: RunContextInfoProvider,
118+
partial_result_sender: PartialResultSender = _NOOP_SENDER,
104119
) -> None:
105120
self.run_id = run_id
106121
self.initial_payload = initial_payload
107122
self.meta = meta
108123
self.exit_stack = contextlib.AsyncExitStack()
109124
self._info_provider = info_provider
125+
self.partial_result_sender = partial_result_sender
110126

111127
@property
112128
def current_result(self) -> RunActionResult | None:
@@ -148,12 +164,14 @@ def __init__(
148164
initial_payload: RunPayloadType,
149165
meta: RunActionMeta,
150166
info_provider: RunContextInfoProvider,
167+
partial_result_sender: PartialResultSender = _NOOP_SENDER,
151168
) -> None:
152169
super().__init__(
153170
run_id=run_id,
154171
initial_payload=initial_payload,
155172
meta=meta,
156173
info_provider=info_provider,
174+
partial_result_sender=partial_result_sender,
157175
)
158176
self.partial_result_scheduler = partialresultscheduler.PartialResultScheduler()
159177

finecode_extension_api/src/finecode_extension_api/interfaces/iactionrunner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections.abc
12
import typing
23

34
from finecode_extension_api import code_action, service
@@ -36,6 +37,13 @@ async def run_action(
3637
meta: code_action.RunActionMeta,
3738
) -> ResultT: ...
3839

40+
def run_action_iter(
41+
self,
42+
action: ActionDeclaration[code_action.Action[PayloadT, typing.Any, ResultT]],
43+
payload: PayloadT,
44+
meta: code_action.RunActionMeta,
45+
) -> collections.abc.AsyncIterator[ResultT]: ...
46+
3947
def get_actions_names(self) -> list[str]: ...
4048

4149

finecode_extension_runner/src/finecode_extension_runner/_services/run_action.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ def __init__(self, response: schemas.RunActionResponse) -> None:
4040
self.response = response
4141

4242

43+
class _TrackingPartialResultSender:
44+
"""Wraps partial_result_sender.schedule_sending with state tracking."""
45+
46+
def __init__(
47+
self,
48+
token: int | str,
49+
send_func: collections.abc.Callable[
50+
[int | str, code_action.RunActionResult], collections.abc.Awaitable[None]
51+
],
52+
) -> None:
53+
self._token = token
54+
self._send_func = send_func
55+
self.has_sent = False
56+
57+
async def send(self, result: code_action.RunActionResult) -> None:
58+
self.has_sent = True
59+
await self._send_func(self._token, result)
60+
61+
4362
def set_partial_result_sender(send_func: typing.Callable) -> None:
4463
global partial_result_sender
4564
partial_result_sender = partial_result_sender_module.PartialResultSender(
@@ -60,6 +79,7 @@ async def run_action(
6079
meta: code_action.RunActionMeta,
6180
partial_result_token: int | str | None = None,
6281
run_id: int | None = None,
82+
partial_result_queue: asyncio.Queue | None = None,
6383
) -> code_action.RunActionResult | None:
6484
# design decisions:
6585
# - keep payload unchanged between all subaction runs.
@@ -105,14 +125,22 @@ async def run_action(
105125

106126
run_context: code_action.RunActionContext | AsyncPlaceholderContext
107127
run_context_info = code_action.RunContextInfoProvider(is_concurrent_execution=execute_handlers_concurrently)
128+
if partial_result_token is not None:
129+
tracking_sender = _TrackingPartialResultSender(
130+
token=partial_result_token,
131+
send_func=partial_result_sender.schedule_sending,
132+
)
133+
else:
134+
tracking_sender = None
108135
if action_exec_info.run_context_type is not None:
109136
constructor_args = await resolve_func_args_with_di(
110137
action_exec_info.run_context_type.__init__,
111138
known_args={
112139
"run_id": lambda _: run_id,
113140
"initial_payload": lambda _: payload,
114141
"meta": lambda _: meta,
115-
"info_provider": lambda _: run_context_info
142+
"info_provider": lambda _: run_context_info,
143+
"partial_result_sender": lambda _: tracking_sender or code_action._NOOP_SENDER,
116144
},
117145
params_to_ignore=["self"],
118146
)
@@ -163,6 +191,8 @@ async def run_action(
163191
action_cache=action_cache,
164192
action_exec_info=action_exec_info,
165193
runner_context=runner_context,
194+
partial_result_token=partial_result_token,
195+
tracking_sender=tracking_sender,
166196
)
167197

168198
parts = [part async for part in payload]
@@ -181,6 +211,12 @@ async def run_action(
181211
try:
182212
async with asyncio.TaskGroup() as tg:
183213
for part in parts:
214+
if part not in run_context.partial_result_scheduler.coroutines_by_key:
215+
logger.warning(
216+
f"R{run_id} | No coroutines scheduled for part {part} "
217+
f"of action '{action_def.name}', skipping"
218+
)
219+
continue
184220
part_coros = (
185221
run_context.partial_result_scheduler.coroutines_by_key[part]
186222
)
@@ -193,6 +229,7 @@ async def run_action(
193229
partial_result_sender,
194230
action_def.name,
195231
run_id,
232+
partial_result_queue=partial_result_queue,
196233
)
197234
else:
198235
coro = run_subresult_coros_sequentially(
@@ -202,6 +239,7 @@ async def run_action(
202239
partial_result_sender,
203240
action_def.name,
204241
run_id,
242+
partial_result_queue=partial_result_queue,
205243
)
206244
subresult_task = tg.create_task(coro)
207245
subresults_tasks.append(subresult_task)
@@ -246,6 +284,8 @@ async def run_action(
246284
action_cache=action_cache,
247285
action_exec_info=action_exec_info,
248286
runner_context=runner_context,
287+
partial_result_token=partial_result_token,
288+
tracking_sender=tracking_sender,
249289
)
250290
)
251291
handlers_tasks.append(handler_task)
@@ -276,6 +316,8 @@ async def run_action(
276316
action_cache=action_cache,
277317
action_exec_info=action_exec_info,
278318
runner_context=runner_context,
319+
partial_result_token=partial_result_token,
320+
tracking_sender=tracking_sender,
279321
)
280322
except ActionFailedException as exception:
281323
raise exception
@@ -314,6 +356,10 @@ async def run_action(
314356
f"Unexpected result type: {type(action_result).__name__}"
315357
)
316358

359+
if partial_result_queue is not None and action_result is not None:
360+
await partial_result_queue.put(action_result)
361+
return None
362+
317363
return action_result
318364

319365

@@ -592,6 +638,8 @@ async def execute_action_handler(
592638
action_exec_info: domain.ActionExecInfo,
593639
action_cache: domain.ActionCache,
594640
runner_context: context.RunnerContext,
641+
partial_result_token: int | str | None = None,
642+
tracking_sender: _TrackingPartialResultSender | None = None,
595643
) -> code_action.RunActionResult:
596644
logger.trace(f"R{run_id} | Run {handler.name} on {str(payload)[:100]}...")
597645
if handler.name in action_cache.handler_cache_by_name:
@@ -647,8 +695,30 @@ def get_run_context(param_type):
647695
# there is also `inspect.iscoroutinefunction` but it cannot recognize coroutine
648696
# functions which are class methods. Use `isawaitable` on result instead.
649697
call_result = handler_run_func(**args)
650-
if inspect.isawaitable(call_result):
651-
execution_result = await call_result
698+
if inspect.isasyncgen(call_result):
699+
execution_result = None
700+
async for partial_result in call_result:
701+
if partial_result_token is not None:
702+
await partial_result_sender.schedule_sending(
703+
partial_result_token, partial_result
704+
)
705+
if execution_result is None:
706+
result_type_pydantic = pydantic_dataclass(type(partial_result))
707+
execution_result = result_type_pydantic(
708+
**dataclasses.asdict(partial_result)
709+
)
710+
else:
711+
execution_result.update(partial_result)
712+
if partial_result_token is not None:
713+
await partial_result_sender.send_all_immediately()
714+
execution_result = None # partials already sent
715+
elif inspect.isawaitable(call_result):
716+
handler_result = await call_result
717+
if tracking_sender is not None and tracking_sender.has_sent:
718+
await partial_result_sender.send_all_immediately()
719+
execution_result = None
720+
else:
721+
execution_result = handler_result
652722
else:
653723
execution_result = call_result
654724
except Exception as exception:
@@ -684,6 +754,7 @@ async def run_subresult_coros_concurrently(
684754
partial_result_sender: partial_result_sender_module.PartialResultSender,
685755
action_name: str,
686756
run_id: int,
757+
partial_result_queue: asyncio.Queue | None = None,
687758
) -> code_action.RunActionResult | None:
688759
coros_tasks: list[asyncio.Task] = []
689760
try:
@@ -725,7 +796,10 @@ async def run_subresult_coros_concurrently(
725796
else:
726797
action_subresult.update(coro_result)
727798

728-
if send_partial_results:
799+
if partial_result_queue is not None:
800+
await partial_result_queue.put(action_subresult)
801+
return None
802+
elif send_partial_results:
729803
await partial_result_sender.schedule_sending(
730804
partial_result_token, action_subresult
731805
)
@@ -741,6 +815,7 @@ async def run_subresult_coros_sequentially(
741815
partial_result_sender: partial_result_sender_module.PartialResultSender,
742816
action_name: str,
743817
run_id: int,
818+
partial_result_queue: asyncio.Queue | None = None,
744819
) -> code_action.RunActionResult | None:
745820
action_subresult: code_action.RunActionResult | None = None
746821
for coro in coros:
@@ -761,7 +836,10 @@ async def run_subresult_coros_sequentially(
761836
else:
762837
action_subresult.update(coro_result)
763838

764-
if send_partial_results:
839+
if partial_result_queue is not None:
840+
await partial_result_queue.put(action_subresult)
841+
return None
842+
elif send_partial_results:
765843
await partial_result_sender.schedule_sending(
766844
partial_result_token, action_subresult
767845
)

finecode_extension_runner/src/finecode_extension_runner/impls/action_runner.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import asyncio
12
import collections.abc
23
import typing
34
from finecode_extension_api import code_action
45
from finecode_extension_api.interfaces import iactionrunner
56

67
from finecode_extension_runner import domain, run_utils
78

9+
_SENTINEL = object()
10+
811

912
class ActionRunner(iactionrunner.IActionRunner):
1013
def __init__(self, run_action_func: typing.Callable[[domain.ActionDeclaration, code_action.RunActionPayload, code_action.RunActionMeta], collections.abc.Coroutine[None, None, code_action.RunActionResult]],
@@ -30,6 +33,36 @@ async def run_action(
3033
except Exception as exception:
3134
raise iactionrunner.ActionRunFailed(str(exception)) from exception
3235

36+
@typing.override
37+
async def run_action_iter(
38+
self,
39+
action: iactionrunner.ActionDeclaration[iactionrunner.ActionT],
40+
payload: code_action.RunActionPayload,
41+
meta: code_action.RunActionMeta,
42+
) -> collections.abc.AsyncIterator[code_action.RunActionResult]:
43+
queue: asyncio.Queue = asyncio.Queue()
44+
45+
async def producer():
46+
try:
47+
await self._run_action_func(action, payload, meta, partial_result_queue=queue)
48+
finally:
49+
await queue.put(_SENTINEL)
50+
51+
task = asyncio.create_task(producer())
52+
try:
53+
while True:
54+
item = await queue.get()
55+
if item is _SENTINEL:
56+
break
57+
yield item
58+
finally:
59+
if not task.done():
60+
task.cancel()
61+
try:
62+
await task
63+
except (asyncio.CancelledError, Exception):
64+
pass
65+
3366
@typing.override
3467
def get_actions_names(self) -> list[str]:
3568
return list(self._actions_getter().keys())

0 commit comments

Comments
 (0)