@@ -511,12 +511,14 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
511511 )
512512 except Exception as e :
513513 logger .exception ('Consumer[%s]: Failed' , self ._task_id )
514+ # TODO: Make the task in database as failed.
514515 async with self ._lock :
515516 await self ._mark_task_as_failed (e )
516517 finally :
517518 # The consumer is dead. The ActiveTask is permanently finished.
518519 self ._is_finished .set ()
519520 self ._request_queue .shutdown (immediate = True )
521+ await self ._event_queue_agent .close (immediate = True )
520522
521523 logger .debug ('Consumer[%s]: Finishing' , self ._task_id )
522524 await self ._maybe_cleanup ()
@@ -574,53 +576,42 @@ async def subscribe( # noqa: PLR0912, PLR0915
574576 if self ._exception :
575577 raise self ._exception
576578
577- # Wait for next event or task completion
578- try :
579- dequeued = await asyncio .wait_for (
580- tapped_queue .dequeue_event (), timeout = 0.1
581- )
582- event , updated_task = cast ('Any' , dequeued )
579+ dequeued = await tapped_queue .dequeue_event ()
580+ event , updated_task = cast ('Any' , dequeued )
581+ logger .debug (
582+ 'Subscriber[%s]\n Dequeued event %s\n Updated task %s\n ' ,
583+ self ._task_id ,
584+ event ,
585+ updated_task ,
586+ )
587+ if replace_status_update_with_task and isinstance (
588+ event , TaskStatusUpdateEvent
589+ ):
583590 logger .debug (
584- 'Subscriber[%s]\n Dequeued event %s \n Updated task %s \n ' ,
591+ 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s ' ,
585592 self ._task_id ,
586- event ,
587593 updated_task ,
588594 )
589- if replace_status_update_with_task and isinstance (
590- event , TaskStatusUpdateEvent
595+ event = updated_task
596+ if self ._exception :
597+ raise self ._exception from None
598+ if isinstance (event , _RequestCompleted ):
599+ if (
600+ request_id is not None
601+ and event .request_id == request_id
591602 ):
592603 logger .debug (
593- 'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s ' ,
604+ 'Subscriber[%s]: Request completed ' ,
594605 self ._task_id ,
595- updated_task ,
596606 )
597- event = updated_task
598- if self ._exception :
599- raise self ._exception from None
600- if isinstance (event , _RequestCompleted ):
601- if (
602- request_id is not None
603- and event .request_id == request_id
604- ):
605- logger .debug (
606- 'Subscriber[%s]: Request completed' ,
607- self ._task_id ,
608- )
609- return
610- continue
611- elif isinstance (event , _RequestStarted ):
612- logger .debug (
613- 'Subscriber[%s]: Request started' ,
614- self ._task_id ,
615- )
616- continue
617- except (asyncio .TimeoutError , TimeoutError ):
618- if self ._is_finished .is_set ():
619- if self ._exception :
620- raise self ._exception from None
621- break
607+ return
608+ continue
609+ elif isinstance (event , _RequestStarted ):
610+ logger .debug (
611+ 'Subscriber[%s]: Request started' ,
612+ self ._task_id ,
613+ )
622614 continue
623-
624615 try :
625616 yield event
626617 finally :
@@ -715,17 +706,20 @@ async def _mark_task_as_failed(self, exception: Exception) -> None:
715706 if self ._exception is None :
716707 self ._exception = exception
717708 if self ._task_created .is_set ():
718- task = await self ._task_manager .get_task ()
719- if task is not None :
720- await self ._event_queue_agent .enqueue_event (
721- TaskStatusUpdateEvent (
722- task_id = task .id ,
723- context_id = task .context_id ,
724- status = TaskStatus (
725- state = TaskState .TASK_STATE_FAILED ,
726- ),
709+ try :
710+ task = await self ._task_manager .get_task ()
711+ if task is not None :
712+ await self ._event_queue_agent .enqueue_event (
713+ TaskStatusUpdateEvent (
714+ task_id = task .id ,
715+ context_id = task .context_id ,
716+ status = TaskStatus (
717+ state = TaskState .TASK_STATE_FAILED ,
718+ ),
719+ )
727720 )
728- )
721+ except QueueShutDown :
722+ pass
729723
730724 async def get_task (self ) -> Task :
731725 """Get task from db."""
0 commit comments