|
2 | 2 | import inspect |
3 | 3 | import json |
4 | 4 | import logging |
| 5 | +from contextvars import ContextVar |
5 | 6 | from functools import wraps |
6 | 7 | from typing import Any, Callable, List, Optional, Tuple |
7 | 8 |
|
8 | | -from opentelemetry import trace |
| 9 | +from opentelemetry import context, trace |
| 10 | +from opentelemetry.trace import set_span_in_context |
9 | 11 |
|
10 | 12 | from ._utils import _SpanUtils |
11 | 13 |
|
12 | 14 | logger = logging.getLogger(__name__) |
13 | 15 |
|
14 | 16 | _tracer_instance: Optional[trace.Tracer] = None |
| 17 | +# ContextVar to track the currently active span for nesting |
| 18 | +_active_traced_span: ContextVar[Optional[trace.Span]] = ContextVar( |
| 19 | + "_active_traced_span", default=None |
| 20 | +) |
15 | 21 |
|
16 | 22 |
|
17 | 23 | def get_tracer() -> trace.Tracer: |
18 | 24 | """Lazily initializes and returns the tracer instance.""" |
19 | 25 | global _tracer_instance |
20 | 26 | if _tracer_instance is None: |
21 | | - logger.warning( |
22 | | - "Initializing tracer instance. This should only be done once per process." |
23 | | - ) |
24 | 27 | _tracer_instance = trace.get_tracer(__name__) |
25 | 28 | return _tracer_instance |
26 | 29 |
|
@@ -53,24 +56,24 @@ def register_current_span_provider( |
53 | 56 | """ |
54 | 57 | cls._current_span_provider = current_span_provider |
55 | 58 |
|
56 | | - @classmethod |
57 | | - def get_parent_context(cls): |
58 | | - """Get the parent context using the registered current span provider. |
| 59 | + @staticmethod |
| 60 | + def get_parent_context(): |
| 61 | + # Always use the currently active OTel span if valid (recursion / children) |
| 62 | + current_span = trace.get_current_span() |
| 63 | + if current_span is not None and current_span.get_span_context().is_valid: |
| 64 | + return set_span_in_context(current_span) |
59 | 65 |
|
60 | | - Returns: |
61 | | - Context object with the current span set, or None if no provider is registered. |
62 | | - """ |
63 | | - if cls._current_span_provider is not None: |
| 66 | + # Only for the very top-level call, fallback to LangGraph span |
| 67 | + if TracingManager._current_span_provider is not None: |
64 | 68 | try: |
65 | | - current_span = cls._current_span_provider() |
66 | | - if current_span is not None: |
67 | | - from opentelemetry.trace import set_span_in_context |
68 | | - |
69 | | - return set_span_in_context(current_span) |
| 69 | + external_span = TracingManager._current_span_provider() |
| 70 | + if external_span is not None: |
| 71 | + return set_span_in_context(external_span) |
70 | 72 | except Exception as e: |
71 | 73 | logger.warning(f"Error getting current span from provider: {e}") |
72 | | - return None |
73 | | - return None |
| 74 | + |
| 75 | + # Last fallback |
| 76 | + return context.get_current() |
74 | 77 |
|
75 | 78 | @classmethod |
76 | 79 | def register_traced_function(cls, original_func, decorated_func, params): |
@@ -176,176 +179,169 @@ def _opentelemetry_traced( |
176 | 179 | """Default tracer implementation using OpenTelemetry.""" |
177 | 180 |
|
178 | 181 | def decorator(func): |
179 | | - trace_name = name if name is not None else func.__name__ |
| 182 | + trace_name = name or func.__name__ |
| 183 | + |
| 184 | + def get_parent_context(): |
| 185 | + """Return a context object for starting the new span.""" |
| 186 | + current_span = _active_traced_span.get() |
| 187 | + if current_span is not None and current_span.get_span_context().is_valid: |
| 188 | + return set_span_in_context(current_span) |
180 | 189 |
|
| 190 | + if TracingManager._current_span_provider is not None: |
| 191 | + try: |
| 192 | + external_span = TracingManager._current_span_provider() |
| 193 | + if external_span is not None: |
| 194 | + return set_span_in_context(external_span) |
| 195 | + except Exception as e: |
| 196 | + logger.warning(f"Error getting current span from provider: {e}") |
| 197 | + |
| 198 | + return context.get_current() |
| 199 | + |
| 200 | + # --------- Sync wrapper --------- |
181 | 201 | @wraps(func) |
182 | 202 | def sync_wrapper(*args, **kwargs): |
183 | | - context = TracingManager.get_parent_context() |
184 | | - |
185 | | - with get_tracer().start_as_current_span( |
186 | | - trace_name, context=context |
187 | | - ) as span: |
188 | | - default_span_type = "function_call_sync" |
189 | | - span.set_attribute( |
190 | | - "span_type", |
191 | | - span_type if span_type is not None else default_span_type, |
192 | | - ) |
| 203 | + ctx = get_parent_context() |
| 204 | + span_cm = get_tracer().start_as_current_span(trace_name, context=ctx) |
| 205 | + span = span_cm.__enter__() |
| 206 | + token = _active_traced_span.set(span) |
| 207 | + try: |
| 208 | + span.set_attribute("span_type", span_type or "function_call_sync") |
193 | 209 | if run_type is not None: |
194 | 210 | span.set_attribute("run_type", run_type) |
195 | 211 |
|
196 | | - # Format arguments for tracing |
197 | 212 | inputs = _SpanUtils.format_args_for_trace_json( |
198 | 213 | inspect.signature(func), *args, **kwargs |
199 | 214 | ) |
200 | | - # Apply input processor if provided |
201 | | - if input_processor is not None: |
| 215 | + if input_processor: |
202 | 216 | processed_inputs = input_processor(json.loads(inputs)) |
203 | 217 | inputs = json.dumps(processed_inputs, default=str) |
204 | 218 | span.set_attribute("inputs", inputs) |
205 | | - try: |
206 | | - result = func(*args, **kwargs) |
207 | | - # Process output if processor is provided |
208 | | - output = result |
209 | | - if output_processor is not None: |
210 | | - output = output_processor(result) |
211 | | - span.set_attribute("output", json.dumps(output, default=str)) |
212 | | - return result |
213 | | - except Exception as e: |
214 | | - span.record_exception(e) |
215 | | - span.set_status( |
216 | | - trace.status.Status(trace.status.StatusCode.ERROR, str(e)) |
217 | | - ) |
218 | | - raise |
219 | 219 |
|
| 220 | + result = func(*args, **kwargs) |
| 221 | + output = output_processor(result) if output_processor else result |
| 222 | + span.set_attribute("output", json.dumps(output, default=str)) |
| 223 | + return result |
| 224 | + except Exception as e: |
| 225 | + span.record_exception(e) |
| 226 | + span.set_status( |
| 227 | + trace.status.Status(trace.status.StatusCode.ERROR, str(e)) |
| 228 | + ) |
| 229 | + raise |
| 230 | + finally: |
| 231 | + _active_traced_span.reset(token) |
| 232 | + span_cm.__exit__(None, None, None) |
| 233 | + |
| 234 | + # --------- Async wrapper --------- |
220 | 235 | @wraps(func) |
221 | 236 | async def async_wrapper(*args, **kwargs): |
222 | | - context = TracingManager.get_parent_context() |
223 | | - |
224 | | - with get_tracer().start_as_current_span( |
225 | | - trace_name, context=context |
226 | | - ) as span: |
227 | | - default_span_type = "function_call_async" |
228 | | - span.set_attribute( |
229 | | - "span_type", |
230 | | - span_type if span_type is not None else default_span_type, |
231 | | - ) |
| 237 | + ctx = get_parent_context() |
| 238 | + span_cm = get_tracer().start_as_current_span(trace_name, context=ctx) |
| 239 | + span = span_cm.__enter__() |
| 240 | + token = _active_traced_span.set(span) |
| 241 | + try: |
| 242 | + span.set_attribute("span_type", span_type or "function_call_async") |
232 | 243 | if run_type is not None: |
233 | 244 | span.set_attribute("run_type", run_type) |
234 | 245 |
|
235 | | - # Format arguments for tracing |
236 | 246 | inputs = _SpanUtils.format_args_for_trace_json( |
237 | 247 | inspect.signature(func), *args, **kwargs |
238 | 248 | ) |
239 | | - # Apply input processor if provided |
240 | | - if input_processor is not None: |
| 249 | + if input_processor: |
241 | 250 | processed_inputs = input_processor(json.loads(inputs)) |
242 | 251 | inputs = json.dumps(processed_inputs, default=str) |
243 | 252 | span.set_attribute("inputs", inputs) |
244 | | - try: |
245 | | - result = await func(*args, **kwargs) |
246 | | - # Process output if processor is provided |
247 | | - output = result |
248 | | - if output_processor is not None: |
249 | | - output = output_processor(result) |
250 | | - span.set_attribute("output", json.dumps(output, default=str)) |
251 | | - return result |
252 | | - except Exception as e: |
253 | | - span.record_exception(e) |
254 | | - span.set_status( |
255 | | - trace.status.Status(trace.status.StatusCode.ERROR, str(e)) |
256 | | - ) |
257 | | - raise |
258 | 253 |
|
| 254 | + result = await func(*args, **kwargs) |
| 255 | + output = output_processor(result) if output_processor else result |
| 256 | + span.set_attribute("output", json.dumps(output, default=str)) |
| 257 | + return result |
| 258 | + except Exception as e: |
| 259 | + span.record_exception(e) |
| 260 | + span.set_status( |
| 261 | + trace.status.Status(trace.status.StatusCode.ERROR, str(e)) |
| 262 | + ) |
| 263 | + raise |
| 264 | + finally: |
| 265 | + _active_traced_span.reset(token) |
| 266 | + span_cm.__exit__(None, None, None) |
| 267 | + |
| 268 | + # --------- Generator wrapper --------- |
259 | 269 | @wraps(func) |
260 | 270 | def generator_wrapper(*args, **kwargs): |
261 | | - context = TracingManager.get_parent_context() |
262 | | - |
263 | | - with get_tracer().start_as_current_span( |
264 | | - trace_name, context=context |
265 | | - ) as span: |
266 | | - span.get_span_context() |
267 | | - default_span_type = "function_call_generator_sync" |
| 271 | + ctx = get_parent_context() |
| 272 | + span_cm = get_tracer().start_as_current_span(trace_name, context=ctx) |
| 273 | + span = span_cm.__enter__() |
| 274 | + token = _active_traced_span.set(span) |
| 275 | + try: |
268 | 276 | span.set_attribute( |
269 | | - "span_type", |
270 | | - span_type if span_type is not None else default_span_type, |
| 277 | + "span_type", span_type or "function_call_generator_sync" |
271 | 278 | ) |
272 | 279 | if run_type is not None: |
273 | 280 | span.set_attribute("run_type", run_type) |
274 | 281 |
|
275 | | - # Format arguments for tracing |
276 | 282 | inputs = _SpanUtils.format_args_for_trace_json( |
277 | 283 | inspect.signature(func), *args, **kwargs |
278 | 284 | ) |
279 | | - # Apply input processor if provided |
280 | | - if input_processor is not None: |
| 285 | + if input_processor: |
281 | 286 | processed_inputs = input_processor(json.loads(inputs)) |
282 | 287 | inputs = json.dumps(processed_inputs, default=str) |
283 | 288 | span.set_attribute("inputs", inputs) |
| 289 | + |
284 | 290 | outputs = [] |
285 | | - try: |
286 | | - for item in func(*args, **kwargs): |
287 | | - outputs.append(item) |
288 | | - span.add_event(f"Yielded: {item}") # Add event for each yield |
289 | | - yield item |
290 | | - |
291 | | - # Process output if processor is provided |
292 | | - output_to_record = outputs |
293 | | - if output_processor is not None: |
294 | | - output_to_record = output_processor(outputs) |
295 | | - span.set_attribute( |
296 | | - "output", json.dumps(output_to_record, default=str) |
297 | | - ) |
298 | | - except Exception as e: |
299 | | - span.record_exception(e) |
300 | | - span.set_status( |
301 | | - trace.status.Status(trace.status.StatusCode.ERROR, str(e)) |
302 | | - ) |
303 | | - raise |
| 291 | + for item in func(*args, **kwargs): |
| 292 | + outputs.append(item) |
| 293 | + span.add_event(f"Yielded: {item}") |
| 294 | + yield item |
| 295 | + output = output_processor(outputs) if output_processor else outputs |
| 296 | + span.set_attribute("output", json.dumps(output, default=str)) |
| 297 | + except Exception as e: |
| 298 | + span.record_exception(e) |
| 299 | + span.set_status( |
| 300 | + trace.status.Status(trace.status.StatusCode.ERROR, str(e)) |
| 301 | + ) |
| 302 | + raise |
| 303 | + finally: |
| 304 | + _active_traced_span.reset(token) |
| 305 | + span_cm.__exit__(None, None, None) |
304 | 306 |
|
| 307 | + # --------- Async generator wrapper --------- |
305 | 308 | @wraps(func) |
306 | 309 | async def async_generator_wrapper(*args, **kwargs): |
307 | | - context = TracingManager.get_parent_context() |
308 | | - |
309 | | - with get_tracer().start_as_current_span( |
310 | | - trace_name, context=context |
311 | | - ) as span: |
312 | | - default_span_type = "function_call_generator_async" |
| 310 | + ctx = get_parent_context() |
| 311 | + span_cm = get_tracer().start_as_current_span(trace_name, context=ctx) |
| 312 | + span = span_cm.__enter__() |
| 313 | + token = _active_traced_span.set(span) |
| 314 | + try: |
313 | 315 | span.set_attribute( |
314 | | - "span_type", |
315 | | - span_type if span_type is not None else default_span_type, |
| 316 | + "span_type", span_type or "function_call_generator_async" |
316 | 317 | ) |
317 | 318 | if run_type is not None: |
318 | 319 | span.set_attribute("run_type", run_type) |
319 | 320 |
|
320 | | - # Format arguments for tracing |
321 | 321 | inputs = _SpanUtils.format_args_for_trace_json( |
322 | 322 | inspect.signature(func), *args, **kwargs |
323 | 323 | ) |
324 | | - # Apply input processor if provided |
325 | | - if input_processor is not None: |
| 324 | + if input_processor: |
326 | 325 | processed_inputs = input_processor(json.loads(inputs)) |
327 | 326 | inputs = json.dumps(processed_inputs, default=str) |
328 | 327 | span.set_attribute("inputs", inputs) |
| 328 | + |
329 | 329 | outputs = [] |
330 | | - try: |
331 | | - async for item in func(*args, **kwargs): |
332 | | - outputs.append(item) |
333 | | - span.add_event(f"Yielded: {item}") # Add event for each yield |
334 | | - yield item |
335 | | - |
336 | | - # Process output if processor is provided |
337 | | - output_to_record = outputs |
338 | | - if output_processor is not None: |
339 | | - output_to_record = output_processor(outputs) |
340 | | - span.set_attribute( |
341 | | - "output", json.dumps(output_to_record, default=str) |
342 | | - ) |
343 | | - except Exception as e: |
344 | | - span.record_exception(e) |
345 | | - span.set_status( |
346 | | - trace.status.Status(trace.status.StatusCode.ERROR, str(e)) |
347 | | - ) |
348 | | - raise |
| 330 | + async for item in func(*args, **kwargs): |
| 331 | + outputs.append(item) |
| 332 | + span.add_event(f"Yielded: {item}") |
| 333 | + yield item |
| 334 | + output = output_processor(outputs) if output_processor else outputs |
| 335 | + span.set_attribute("output", json.dumps(output, default=str)) |
| 336 | + except Exception as e: |
| 337 | + span.record_exception(e) |
| 338 | + span.set_status( |
| 339 | + trace.status.Status(trace.status.StatusCode.ERROR, str(e)) |
| 340 | + ) |
| 341 | + raise |
| 342 | + finally: |
| 343 | + _active_traced_span.reset(token) |
| 344 | + span_cm.__exit__(None, None, None) |
349 | 345 |
|
350 | 346 | if inspect.iscoroutinefunction(func): |
351 | 347 | return async_wrapper |
|
0 commit comments