11import logging
22import os
33import sys
4+ from contextvars import ContextVar
45from typing import Optional , TextIO , Union , cast
56
7+ # Context variable to track current execution_id
8+ current_execution_id : ContextVar [Optional [str ]] = ContextVar (
9+ "current_execution_id" , default = None
10+ )
11+
612
713class PersistentLogsHandler (logging .FileHandler ):
814 """A simple log handler that always writes to a single file without rotation."""
@@ -20,6 +26,30 @@ def __init__(self, file: str):
2026 self .setFormatter (self .formatter )
2127
2228
29+ class ExecutionContextFilter (logging .Filter ):
30+ """Filter that only allows logs from a specific execution context."""
31+
32+ def __init__ (self , execution_id : str ):
33+ super ().__init__ ()
34+ self .execution_id = execution_id
35+
36+ def filter (self , record : logging .LogRecord ) -> bool :
37+ """Allow logs that have matching execution_id attribute or context."""
38+ # First check if record has execution_id attribute
39+ record_execution_id = getattr (record , "execution_id" , None )
40+ if record_execution_id == self .execution_id :
41+ return True
42+
43+ # Fall back to context variable
44+ ctx_execution_id = current_execution_id .get ()
45+ if ctx_execution_id == self .execution_id :
46+ # Inject execution_id into record for downstream handlers
47+ record .execution_id = self .execution_id
48+ return True
49+
50+ return False
51+
52+
2353class LogsInterceptor :
2454 """Intercepts all logging and stdout/stderr, routing to either persistent log files or stdout based on whether it's running as a job or not."""
2555
@@ -31,6 +61,7 @@ def __init__(
3161 job_id : Optional [str ] = None ,
3262 is_debug_run : bool = False ,
3363 log_handler : Optional [logging .Handler ] = None ,
64+ execution_id : Optional [str ] = None ,
3465 ):
3566 """Initialize the log interceptor.
3667
@@ -41,9 +72,11 @@ def __init__(
4172 job_id (str, optional): If provided, logs go to file; otherwise, to stdout.
4273 is_debug_run (bool, optional): If True, log the output to stdout/stderr.
4374 log_handler (logging.Handler, optional): Custom log handler to use.
75+ execution_id (str, optional): Unique identifier for this execution context.
4476 """
4577 min_level = min_level or "INFO"
4678 self .job_id = job_id
79+ self .execution_id = execution_id
4780
4881 # Convert to numeric level for consistent comparison
4982 self .numeric_min_level = getattr (logging , min_level .upper (), logging .INFO )
@@ -81,6 +114,12 @@ def __init__(
81114 self .log_handler = PersistentLogsHandler (file = log_file )
82115
83116 self .log_handler .setLevel (self .numeric_min_level )
117+
118+ # Add execution context filter if execution_id provided
119+ if execution_id :
120+ self .execution_filter = ExecutionContextFilter (execution_id )
121+ self .log_handler .addFilter (self .execution_filter )
122+
84123 self .logger = logging .getLogger ("runtime" )
85124 self .patched_loggers : set [str ] = set ()
86125
@@ -95,22 +134,37 @@ def _clean_all_handlers(self, logger: logging.Logger) -> None:
95134
96135 def setup (self ) -> None :
97136 """Configure logging to use our persistent handler."""
98- # Use global disable to prevent all logging below our minimum level
99- if self .numeric_min_level > logging .NOTSET :
137+ # Set the context variable for this execution
138+ if self .execution_id :
139+ current_execution_id .set (self .execution_id )
140+
141+ # Only use global disable if we're not in a parallel execution context
142+ if not self .execution_id and self .numeric_min_level > logging .NOTSET :
100143 logging .disable (self .numeric_min_level - 1 )
101144
102145 # Set root logger level
103146 self .root_logger .setLevel (self .numeric_min_level )
104147
105- # Remove ALL handlers from root logger and add only ours
106- self ._clean_all_handlers (self .root_logger )
148+ if self .execution_id :
149+ # Parallel execution mode: add our handler without removing others
150+ if self .log_handler not in self .root_logger .handlers :
151+ self .root_logger .addHandler (self .log_handler )
152+
153+ # Set up propagation for all existing loggers
154+ for logger_name in logging .root .manager .loggerDict :
155+ logger = logging .getLogger (logger_name )
156+ # Keep propagation enabled so logs flow to all handlers
157+ self .patched_loggers .add (logger_name )
158+ else :
159+ # Single execution mode: remove all handlers and add only ours
160+ self ._clean_all_handlers (self .root_logger )
107161
108- # Set up propagation for all existing loggers
109- for logger_name in logging .root .manager .loggerDict :
110- logger = logging .getLogger (logger_name )
111- logger .propagate = False # Prevent double-logging
112- self ._clean_all_handlers (logger )
113- self .patched_loggers .add (logger_name )
162+ # Set up propagation for all existing loggers
163+ for logger_name in logging .root .manager .loggerDict :
164+ logger = logging .getLogger (logger_name )
165+ logger .propagate = False # Prevent double-logging
166+ self ._clean_all_handlers (logger )
167+ self .patched_loggers .add (logger_name )
114168
115169 # Set up stdout/stderr redirection
116170 self ._redirect_stdout_stderr ()
@@ -130,15 +184,15 @@ def __init__(
130184 self .level = level
131185 self .min_level = min_level
132186 self .buffer = ""
133- self .sys_file = sys_file # Store reference to system stdout/stderr
187+ self .sys_file = sys_file
134188
135189 def write (self , message : str ) -> None :
136190 self .buffer += message
137191 while "\n " in self .buffer :
138192 line , self .buffer = self .buffer .split ("\n " , 1 )
139193 # Only log if the message is not empty and the level is sufficient
140194 if line and self .level >= self .min_level :
141- # Use _log to avoid potential recursive logging if logging methods are overridden
195+ # The context variable is automatically available here
142196 self .logger ._log (self .level , line , ())
143197
144198 def flush (self ) -> None :
@@ -160,14 +214,21 @@ def isatty(self) -> bool:
160214 def writable (self ) -> bool :
161215 return True
162216
163- # Set up stdout and stderr loggers with propagate=False
217+ # Set up stdout and stderr loggers
164218 stdout_logger = logging .getLogger ("stdout" )
165- stdout_logger .propagate = False
166- self ._clean_all_handlers (stdout_logger )
167-
168219 stderr_logger = logging .getLogger ("stderr" )
220+
221+ stdout_logger .propagate = False
169222 stderr_logger .propagate = False
170- self ._clean_all_handlers (stderr_logger )
223+
224+ if self .execution_id :
225+ if self .log_handler not in stdout_logger .handlers :
226+ stdout_logger .addHandler (self .log_handler )
227+ if self .log_handler not in stderr_logger .handlers :
228+ stderr_logger .addHandler (self .log_handler )
229+ else :
230+ self ._clean_all_handlers (stdout_logger )
231+ self ._clean_all_handlers (stderr_logger )
171232
172233 # Use the min_level in the LoggerWriter to filter messages
173234 sys .stdout = LoggerWriter (
@@ -179,21 +240,41 @@ def writable(self) -> bool:
179240
180241 def teardown (self ) -> None :
181242 """Restore original logging configuration."""
182- # Restore the original disable level
183- logging .disable (self .original_disable_level )
184-
185- if self .log_handler in self .root_logger .handlers :
186- self .root_logger .removeHandler (self .log_handler )
243+ # Clear the context variable
244+ if self .execution_id :
245+ current_execution_id .set (None )
187246
188- for logger_name in self .patched_loggers :
189- logger = logging .getLogger (logger_name )
190- if self .log_handler in logger .handlers :
191- logger .removeHandler (self .log_handler )
192-
193- self .root_logger .setLevel (self .original_level )
194- for handler in self .original_handlers :
195- if handler not in self .root_logger .handlers :
196- self .root_logger .addHandler (handler )
247+ # Restore the original disable level
248+ if not self .execution_id :
249+ logging .disable (self .original_disable_level )
250+
251+ # Remove our handler and filter
252+ if self .execution_id :
253+ if hasattr (self , "execution_filter" ):
254+ self .log_handler .removeFilter (self .execution_filter )
255+ if self .log_handler in self .root_logger .handlers :
256+ self .root_logger .removeHandler (self .log_handler )
257+
258+ # Remove from stdout/stderr loggers too
259+ stdout_logger = logging .getLogger ("stdout" )
260+ stderr_logger = logging .getLogger ("stderr" )
261+ if self .log_handler in stdout_logger .handlers :
262+ stdout_logger .removeHandler (self .log_handler )
263+ if self .log_handler in stderr_logger .handlers :
264+ stderr_logger .removeHandler (self .log_handler )
265+ else :
266+ if self .log_handler in self .root_logger .handlers :
267+ self .root_logger .removeHandler (self .log_handler )
268+
269+ for logger_name in self .patched_loggers :
270+ logger = logging .getLogger (logger_name )
271+ if self .log_handler in logger .handlers :
272+ logger .removeHandler (self .log_handler )
273+
274+ self .root_logger .setLevel (self .original_level )
275+ for handler in self .original_handlers :
276+ if handler not in self .root_logger .handlers :
277+ self .root_logger .addHandler (handler )
197278
198279 self .log_handler .close ()
199280
0 commit comments