- SDK: client with BatchTransport, trace decorator/context manager, log_decision, thread-local context stack, nested trace→span support - API: POST /api/traces (batch ingest), GET /api/traces (paginated list), GET /api/traces/[id] (full trace with relations), GET /api/health - Tests: 8 unit tests for SDK (all passing) - Transport: thread-safe buffer with background flush thread
192 lines
5.4 KiB
Python
192 lines
5.4 KiB
Python
"""Trace decorator and context manager for instrumenting agent functions."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import threading
|
|
import time
|
|
from functools import wraps
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
from agentlens.models import (
|
|
TraceData,
|
|
Span,
|
|
SpanType,
|
|
TraceStatus,
|
|
SpanStatus,
|
|
Event,
|
|
EventType,
|
|
_now_iso,
|
|
)
|
|
|
|
logger = logging.getLogger("agentlens")
|
|
|
|
_context = threading.local()
|
|
|
|
|
|
def _get_context_stack() -> List[Union[TraceData, Span]]:
|
|
if not hasattr(_context, "stack"):
|
|
_context.stack = []
|
|
return _context.stack
|
|
|
|
|
|
def get_current_trace() -> Optional[TraceData]:
|
|
stack = _get_context_stack()
|
|
if not stack:
|
|
return None
|
|
for item in stack:
|
|
if isinstance(item, TraceData):
|
|
return item
|
|
return None
|
|
|
|
|
|
def get_current_span_id() -> Optional[str]:
|
|
stack = _get_context_stack()
|
|
if not stack:
|
|
return None
|
|
for item in reversed(stack):
|
|
if isinstance(item, Span):
|
|
return item.id
|
|
return None
|
|
|
|
|
|
class TraceContext:
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
tags: Optional[List[str]] = None,
|
|
session_id: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
self.name = name or "trace"
|
|
self.tags = tags or []
|
|
self.session_id = session_id
|
|
self.metadata = metadata
|
|
self._trace_data: Optional[TraceData] = None
|
|
self._span: Optional[Span] = None
|
|
self._start_time: float = 0
|
|
self._is_nested: bool = False
|
|
|
|
def __enter__(self) -> "TraceContext":
|
|
self._start_time = time.time()
|
|
stack = _get_context_stack()
|
|
|
|
if stack:
|
|
self._is_nested = True
|
|
parent_trace = get_current_trace()
|
|
parent_span_id = get_current_span_id()
|
|
self._span = Span(
|
|
name=self.name,
|
|
type=SpanType.AGENT.value,
|
|
parent_span_id=parent_span_id,
|
|
started_at=_now_iso(),
|
|
)
|
|
if parent_trace:
|
|
parent_trace.spans.append(self._span)
|
|
stack.append(self._span)
|
|
else:
|
|
self._trace_data = TraceData(
|
|
name=self.name,
|
|
tags=self.tags,
|
|
session_id=self.session_id,
|
|
metadata=self.metadata,
|
|
status=TraceStatus.RUNNING.value,
|
|
started_at=_now_iso(),
|
|
)
|
|
stack.append(self._trace_data)
|
|
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: Optional[type],
|
|
exc_val: Optional[BaseException],
|
|
exc_tb: Optional[Any],
|
|
) -> None:
|
|
from agentlens.client import get_client
|
|
|
|
client = get_client()
|
|
end_time = time.time()
|
|
total_duration = int((end_time - self._start_time) * 1000)
|
|
|
|
stack = _get_context_stack()
|
|
|
|
if self._is_nested and self._span:
|
|
self._span.status = (
|
|
SpanStatus.COMPLETED.value
|
|
if exc_type is None
|
|
else SpanStatus.ERROR.value
|
|
)
|
|
self._span.duration_ms = total_duration
|
|
self._span.ended_at = _now_iso()
|
|
stack.pop()
|
|
elif self._trace_data:
|
|
if exc_type is not None:
|
|
self._trace_data.status = TraceStatus.ERROR.value
|
|
error_event = Event(
|
|
type=EventType.ERROR.value,
|
|
name=str(exc_val) if exc_val else "Unknown error",
|
|
)
|
|
self._trace_data.events.append(error_event)
|
|
else:
|
|
self._trace_data.status = TraceStatus.COMPLETED.value
|
|
|
|
self._trace_data.total_duration = total_duration
|
|
self._trace_data.ended_at = _now_iso()
|
|
stack.pop()
|
|
|
|
client = get_client()
|
|
if client:
|
|
client.send_trace(self._trace_data)
|
|
|
|
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
|
|
if asyncio.iscoroutinefunction(func):
|
|
|
|
@wraps(func)
|
|
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
with TraceContext(
|
|
name=self.name,
|
|
tags=self.tags,
|
|
session_id=self.session_id,
|
|
metadata=self.metadata,
|
|
):
|
|
return await func(*args, **kwargs)
|
|
|
|
return async_wrapper
|
|
else:
|
|
|
|
@wraps(func)
|
|
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
with TraceContext(
|
|
name=self.name,
|
|
tags=self.tags,
|
|
session_id=self.session_id,
|
|
metadata=self.metadata,
|
|
):
|
|
return func(*args, **kwargs)
|
|
|
|
return sync_wrapper
|
|
|
|
@property
|
|
def trace_id(self) -> Optional[str]:
|
|
if self._trace_data:
|
|
return self._trace_data.id
|
|
return None
|
|
|
|
|
|
def trace(
|
|
name: Union[Callable[..., Any], str, None] = None,
|
|
tags: Optional[List[str]] = None,
|
|
session_id: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> Union[TraceContext, Callable[..., Any]]:
|
|
if callable(name):
|
|
func = name
|
|
ctx = TraceContext(
|
|
name=func.__name__, tags=tags, session_id=session_id, metadata=metadata
|
|
)
|
|
return ctx(func)
|
|
|
|
return TraceContext(
|
|
name=name or "trace", tags=tags, session_id=session_id, metadata=metadata
|
|
)
|