@@ -40,6 +40,25 @@ def __init__(self, response: schemas.RunActionResponse) -> None:
4040 self .response = response
4141
4242
43+ class _TrackingPartialResultSender :
44+ """Wraps partial_result_sender.schedule_sending with state tracking."""
45+
46+ def __init__ (
47+ self ,
48+ token : int | str ,
49+ send_func : collections .abc .Callable [
50+ [int | str , code_action .RunActionResult ], collections .abc .Awaitable [None ]
51+ ],
52+ ) -> None :
53+ self ._token = token
54+ self ._send_func = send_func
55+ self .has_sent = False
56+
57+ async def send (self , result : code_action .RunActionResult ) -> None :
58+ self .has_sent = True
59+ await self ._send_func (self ._token , result )
60+
61+
4362def set_partial_result_sender (send_func : typing .Callable ) -> None :
4463 global partial_result_sender
4564 partial_result_sender = partial_result_sender_module .PartialResultSender (
@@ -60,6 +79,7 @@ async def run_action(
6079 meta : code_action .RunActionMeta ,
6180 partial_result_token : int | str | None = None ,
6281 run_id : int | None = None ,
82+ partial_result_queue : asyncio .Queue | None = None ,
6383) -> code_action .RunActionResult | None :
6484 # design decisions:
6585 # - keep payload unchanged between all subaction runs.
@@ -105,14 +125,22 @@ async def run_action(
105125
106126 run_context : code_action .RunActionContext | AsyncPlaceholderContext
107127 run_context_info = code_action .RunContextInfoProvider (is_concurrent_execution = execute_handlers_concurrently )
128+ if partial_result_token is not None :
129+ tracking_sender = _TrackingPartialResultSender (
130+ token = partial_result_token ,
131+ send_func = partial_result_sender .schedule_sending ,
132+ )
133+ else :
134+ tracking_sender = None
108135 if action_exec_info .run_context_type is not None :
109136 constructor_args = await resolve_func_args_with_di (
110137 action_exec_info .run_context_type .__init__ ,
111138 known_args = {
112139 "run_id" : lambda _ : run_id ,
113140 "initial_payload" : lambda _ : payload ,
114141 "meta" : lambda _ : meta ,
115- "info_provider" : lambda _ : run_context_info
142+ "info_provider" : lambda _ : run_context_info ,
143+ "partial_result_sender" : lambda _ : tracking_sender or code_action ._NOOP_SENDER ,
116144 },
117145 params_to_ignore = ["self" ],
118146 )
@@ -163,6 +191,8 @@ async def run_action(
163191 action_cache = action_cache ,
164192 action_exec_info = action_exec_info ,
165193 runner_context = runner_context ,
194+ partial_result_token = partial_result_token ,
195+ tracking_sender = tracking_sender ,
166196 )
167197
168198 parts = [part async for part in payload ]
@@ -181,6 +211,12 @@ async def run_action(
181211 try :
182212 async with asyncio .TaskGroup () as tg :
183213 for part in parts :
214+ if part not in run_context .partial_result_scheduler .coroutines_by_key :
215+ logger .warning (
216+ f"R{ run_id } | No coroutines scheduled for part { part } "
217+ f"of action '{ action_def .name } ', skipping"
218+ )
219+ continue
184220 part_coros = (
185221 run_context .partial_result_scheduler .coroutines_by_key [part ]
186222 )
@@ -193,6 +229,7 @@ async def run_action(
193229 partial_result_sender ,
194230 action_def .name ,
195231 run_id ,
232+ partial_result_queue = partial_result_queue ,
196233 )
197234 else :
198235 coro = run_subresult_coros_sequentially (
@@ -202,6 +239,7 @@ async def run_action(
202239 partial_result_sender ,
203240 action_def .name ,
204241 run_id ,
242+ partial_result_queue = partial_result_queue ,
205243 )
206244 subresult_task = tg .create_task (coro )
207245 subresults_tasks .append (subresult_task )
@@ -246,6 +284,8 @@ async def run_action(
246284 action_cache = action_cache ,
247285 action_exec_info = action_exec_info ,
248286 runner_context = runner_context ,
287+ partial_result_token = partial_result_token ,
288+ tracking_sender = tracking_sender ,
249289 )
250290 )
251291 handlers_tasks .append (handler_task )
@@ -276,6 +316,8 @@ async def run_action(
276316 action_cache = action_cache ,
277317 action_exec_info = action_exec_info ,
278318 runner_context = runner_context ,
319+ partial_result_token = partial_result_token ,
320+ tracking_sender = tracking_sender ,
279321 )
280322 except ActionFailedException as exception :
281323 raise exception
@@ -314,6 +356,10 @@ async def run_action(
314356 f"Unexpected result type: { type (action_result ).__name__ } "
315357 )
316358
359+ if partial_result_queue is not None and action_result is not None :
360+ await partial_result_queue .put (action_result )
361+ return None
362+
317363 return action_result
318364
319365
@@ -592,6 +638,8 @@ async def execute_action_handler(
592638 action_exec_info : domain .ActionExecInfo ,
593639 action_cache : domain .ActionCache ,
594640 runner_context : context .RunnerContext ,
641+ partial_result_token : int | str | None = None ,
642+ tracking_sender : _TrackingPartialResultSender | None = None ,
595643) -> code_action .RunActionResult :
596644 logger .trace (f"R{ run_id } | Run { handler .name } on { str (payload )[:100 ]} ..." )
597645 if handler .name in action_cache .handler_cache_by_name :
@@ -647,8 +695,30 @@ def get_run_context(param_type):
647695 # there is also `inspect.iscoroutinefunction` but it cannot recognize coroutine
648696 # functions which are class methods. Use `isawaitable` on result instead.
649697 call_result = handler_run_func (** args )
650- if inspect .isawaitable (call_result ):
651- execution_result = await call_result
698+ if inspect .isasyncgen (call_result ):
699+ execution_result = None
700+ async for partial_result in call_result :
701+ if partial_result_token is not None :
702+ await partial_result_sender .schedule_sending (
703+ partial_result_token , partial_result
704+ )
705+ if execution_result is None :
706+ result_type_pydantic = pydantic_dataclass (type (partial_result ))
707+ execution_result = result_type_pydantic (
708+ ** dataclasses .asdict (partial_result )
709+ )
710+ else :
711+ execution_result .update (partial_result )
712+ if partial_result_token is not None :
713+ await partial_result_sender .send_all_immediately ()
714+ execution_result = None # partials already sent
715+ elif inspect .isawaitable (call_result ):
716+ handler_result = await call_result
717+ if tracking_sender is not None and tracking_sender .has_sent :
718+ await partial_result_sender .send_all_immediately ()
719+ execution_result = None
720+ else :
721+ execution_result = handler_result
652722 else :
653723 execution_result = call_result
654724 except Exception as exception :
@@ -684,6 +754,7 @@ async def run_subresult_coros_concurrently(
684754 partial_result_sender : partial_result_sender_module .PartialResultSender ,
685755 action_name : str ,
686756 run_id : int ,
757+ partial_result_queue : asyncio .Queue | None = None ,
687758) -> code_action .RunActionResult | None :
688759 coros_tasks : list [asyncio .Task ] = []
689760 try :
@@ -725,7 +796,10 @@ async def run_subresult_coros_concurrently(
725796 else :
726797 action_subresult .update (coro_result )
727798
728- if send_partial_results :
799+ if partial_result_queue is not None :
800+ await partial_result_queue .put (action_subresult )
801+ return None
802+ elif send_partial_results :
729803 await partial_result_sender .schedule_sending (
730804 partial_result_token , action_subresult
731805 )
@@ -741,6 +815,7 @@ async def run_subresult_coros_sequentially(
741815 partial_result_sender : partial_result_sender_module .PartialResultSender ,
742816 action_name : str ,
743817 run_id : int ,
818+ partial_result_queue : asyncio .Queue | None = None ,
744819) -> code_action .RunActionResult | None :
745820 action_subresult : code_action .RunActionResult | None = None
746821 for coro in coros :
@@ -761,7 +836,10 @@ async def run_subresult_coros_sequentially(
761836 else :
762837 action_subresult .update (coro_result )
763838
764- if send_partial_results :
839+ if partial_result_queue is not None :
840+ await partial_result_queue .put (action_subresult )
841+ return None
842+ elif send_partial_results :
765843 await partial_result_sender .schedule_sending (
766844 partial_result_token , action_subresult
767845 )
0 commit comments