feat: add prompt observability code

This commit is contained in:
Eduardo Chiarotti
2025-02-04 11:47:58 -03:00
parent 23b9e10323
commit d263540325
9 changed files with 783 additions and 18 deletions

View File

@@ -38,6 +38,7 @@ from crewai.tasks.task_output import TaskOutput
from crewai.telemetry import Telemetry
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import Tool
from crewai.traces.unified_trace_controller import init_crew_main_trace
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINING_DATA_FILE
@@ -515,6 +516,7 @@ class Crew(BaseModel):
CrewTrainingHandler(filename).clear()
raise
@init_crew_main_trace
def kickoff(
self,
inputs: Optional[Dict[str, Any]] = None,

View File

@@ -29,6 +29,10 @@ from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.utils import get_possible_return_constants
from crewai.telemetry import Telemetry
from crewai.traces.unified_trace_controller import (
init_flow_main_trace,
trace_flow_step,
)
from crewai.utilities.printer import Printer
logger = logging.getLogger(__name__)
@@ -394,7 +398,6 @@ class FlowMeta(type):
or hasattr(attr_value, "__trigger_methods__")
or hasattr(attr_value, "__is_router__")
):
# Register start methods
if hasattr(attr_value, "__is_start_method__"):
start_methods.append(attr_name)
@@ -600,7 +603,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
```
"""
try:
if not hasattr(self, '_state'):
if not hasattr(self, "_state"):
return ""
if isinstance(self._state, dict):
@@ -706,26 +709,31 @@ class Flow(Generic[T], metaclass=FlowMeta):
inputs: Optional dictionary containing input values and potentially a state ID to restore
"""
# Handle state restoration if ID is provided in inputs
if inputs and 'id' in inputs and self._persistence is not None:
restore_uuid = inputs['id']
if inputs and "id" in inputs and self._persistence is not None:
restore_uuid = inputs["id"]
stored_state = self._persistence.load_state(restore_uuid)
# Override the id in the state if it exists in inputs
if 'id' in inputs:
if "id" in inputs:
if isinstance(self._state, dict):
self._state['id'] = inputs['id']
self._state["id"] = inputs["id"]
elif isinstance(self._state, BaseModel):
setattr(self._state, 'id', inputs['id'])
setattr(self._state, "id", inputs["id"])
if stored_state:
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow")
self._log_flow_event(
f"Loading flow state from memory for UUID: {restore_uuid}",
color="yellow",
)
# Restore the state
self._restore_state(stored_state)
else:
self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red")
self._log_flow_event(
f"No flow state found for UUID: {restore_uuid}", color="red"
)
# Apply any additional inputs after restoration
filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'}
filtered_inputs = {k: v for k, v in inputs.items() if k != "id"}
if filtered_inputs:
self._initialize_state(filtered_inputs)
@@ -737,13 +745,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
flow_name=self.__class__.__name__,
),
)
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="bold_magenta")
self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
)
if inputs is not None and 'id' not in inputs:
if inputs is not None and "id" not in inputs:
self._initialize_state(inputs)
return asyncio.run(self.kickoff_async())
@init_flow_main_trace
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
if not self._start_methods:
raise ValueError("No start method defined")
@@ -793,6 +804,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
await self._execute_listeners(start_method_name, result)
@trace_flow_step
async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
) -> Any:
@@ -984,7 +996,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
traceback.print_exc()
def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info"
) -> None:
"""Centralized logging method for flow events.
This method provides a consistent interface for logging flow-related events,

View File

@@ -1,3 +1,4 @@
import inspect
import json
import logging
import os
@@ -5,7 +6,17 @@ import sys
import threading
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union, cast
from typing import (
Any,
Dict,
List,
Optional,
Protocol,
Tuple,
Union,
cast,
runtime_checkable,
)
from dotenv import load_dotenv
@@ -16,6 +27,7 @@ with warnings.catch_warnings():
from litellm.types.utils import ModelResponse
from crewai.traces.unified_trace_controller import trace_llm_call
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
@@ -114,6 +126,17 @@ def suppress_warnings():
sys.stderr = old_stderr
@runtime_checkable
class AgentExecutorProtocol(Protocol):
"""Protocol defining the expected interface for an agent executor."""
@property
def agent(self) -> Any: ...
@property
def task(self) -> Any: ...
class LLM:
def __init__(
self,
@@ -160,6 +183,7 @@ class LLM:
self.callbacks = callbacks
self.context_window_size = 0
self.additional_params = kwargs
self._message_history = []
litellm.drop_params = True
@@ -174,6 +198,12 @@ class LLM:
self.set_callbacks(callbacks)
self.set_env_callbacks()
@trace_llm_call
def _call_llm(self, params: Dict[str, Any]) -> Any:
with suppress_warnings():
response = litellm.completion(**params)
return response
def call(
self,
messages: Union[str, List[Dict[str, str]]],
@@ -249,7 +279,7 @@ class LLM:
params = {k: v for k, v in params.items() if v is not None}
# --- 2) Make the completion call
response = litellm.completion(**params)
response = self._call_llm(params)
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
@@ -394,3 +424,88 @@ class LLM:
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
def _get_execution_context(self) -> Tuple[Optional[Any], Optional[Any]]:
"""Get the agent and task from the execution context.
Returns:
tuple: (agent, task) from any AgentExecutor context, or (None, None) if not found
"""
frame = inspect.currentframe()
caller_frame = frame.f_back if frame else None
agent = None
task = None
while caller_frame:
if "self" in caller_frame.f_locals:
caller_self = caller_frame.f_locals["self"]
if isinstance(caller_self, AgentExecutorProtocol):
agent = caller_self.agent
task = caller_self.task
break
caller_frame = caller_frame.f_back
return agent, task
def _get_new_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Get only the new messages that haven't been processed before."""
if not hasattr(self, "_message_history"):
self._message_history = []
new_messages = []
for message in messages:
message_key = (message["role"], message["content"])
if message_key not in [
(m["role"], m["content"]) for m in self._message_history
]:
new_messages.append(message)
self._message_history.append(message)
return new_messages
def _get_new_tool_results(self, agent) -> List[Dict]:
"""Get only the new tool results that haven't been processed before."""
if not agent or not agent.tools_results:
return []
if not hasattr(self, "_tool_results_history"):
self._tool_results_history = []
new_tool_results = []
for result in agent.tools_results:
# Process tool arguments to extract actual values
processed_args = {}
if isinstance(result["tool_args"], dict):
for key, value in result["tool_args"].items():
if isinstance(value, dict) and "type" in value:
# Skip metadata and just store the actual value
continue
processed_args[key] = value
# Create a clean result with processed arguments
clean_result = {
"tool_name": result["tool_name"],
"tool_args": processed_args,
"result": result["result"],
"content": result.get("content", ""),
"start_time": result.get("start_time", ""),
}
# Check if this exact tool execution exists in history
is_duplicate = False
for history_result in self._tool_results_history:
if (
clean_result["tool_name"] == history_result["tool_name"]
and str(clean_result["tool_args"])
== str(history_result["tool_args"])
and str(clean_result["result"]) == str(history_result["result"])
and clean_result["content"] == history_result.get("content", "")
):
is_duplicate = True
break
if not is_duplicate:
new_tool_results.append(clean_result)
self._tool_results_history.append(clean_result)
return new_tool_results

View File

@@ -2,6 +2,7 @@ import ast
import datetime
import json
import time
from datetime import UTC
from difflib import SequenceMatcher
from json import JSONDecodeError
from textwrap import dedent
@@ -116,7 +117,10 @@ class ToolUsage:
self._printer.print(content=f"\n\n{error}\n", color="red")
return error
if isinstance(tool, CrewStructuredTool) and tool.name == self._i18n.tools("add_image")["name"]: # type: ignore
if (
isinstance(tool, CrewStructuredTool)
and tool.name == self._i18n.tools("add_image")["name"]
): # type: ignore
try:
result = self._use(tool_string=tool_string, tool=tool, calling=calling)
return result
@@ -153,7 +157,7 @@ class ToolUsage:
except Exception:
self.task.increment_tools_errors()
started_at = time.time()
started_at = datetime.datetime.now(UTC)
from_cache = False
result = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
@@ -181,7 +185,9 @@ class ToolUsage:
if calling.arguments:
try:
acceptable_args = tool.args_schema.model_json_schema()["properties"].keys() # type: ignore
acceptable_args = tool.args_schema.model_json_schema()[
"properties"
].keys() # type: ignore
arguments = {
k: v
for k, v in calling.arguments.items()
@@ -244,6 +250,7 @@ class ToolUsage:
"result": result,
"tool_name": tool.name,
"tool_args": calling.arguments,
"start_time": started_at,
}
self.on_tool_use_finished(

View File

View File

@@ -0,0 +1,39 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Generator
class TraceContext:
"""Maintains the current trace context throughout the execution stack.
This class provides a context manager for tracking trace execution across
async and sync code paths using ContextVars.
"""
_context: ContextVar = ContextVar("trace_context", default=None)
@classmethod
def get_current(cls):
"""Get the current trace context.
Returns:
Optional[UnifiedTraceController]: The current trace controller or None if not set.
"""
return cls._context.get()
@classmethod
@contextmanager
def set_current(cls, trace):
"""Set the current trace context within a context manager.
Args:
trace: The trace controller to set as current.
Yields:
UnifiedTraceController: The current trace controller.
"""
token = cls._context.set(trace)
try:
yield trace
finally:
cls._context.reset(token)

View File

@@ -0,0 +1,19 @@
from enum import Enum
class TraceType(Enum):
LLM_CALL = "llm_call"
TOOL_CALL = "tool_call"
FLOW_STEP = "flow_step"
START_CALL = "start_call"
class RunType(Enum):
KICKOFF = "kickoff"
TRAIN = "train"
TEST = "test"
class CrewType(Enum):
CREW = "crew"
FLOW = "flow"

View File

@@ -0,0 +1,89 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class ToolCall(BaseModel):
"""Model representing a tool call during execution"""
name: str
arguments: Dict[str, Any]
output: str
start_time: datetime
end_time: Optional[datetime] = None
latency_ms: Optional[int] = None
error: Optional[str] = None
class LLMRequest(BaseModel):
"""Model representing the LLM request details"""
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = None
max_tokens: Optional[int] = None
stop_sequences: Optional[List[str]] = None
additional_params: Dict[str, Any] = Field(default_factory=dict)
class LLMResponse(BaseModel):
"""Model representing the LLM response details"""
content: str
finish_reason: Optional[str] = None
class FlowStepIO(BaseModel):
"""Model representing flow step input/output details"""
function_name: str
inputs: Dict[str, Any] = Field(default_factory=dict)
outputs: Any
metadata: Dict[str, Any] = Field(default_factory=dict)
class CrewTrace(BaseModel):
"""Model for tracking detailed information about LLM interactions and Flow steps"""
deployment_instance_id: Optional[str] = Field(
description="ID of the deployment instance"
)
trace_id: str = Field(description="Unique identifier for this trace")
run_id: str = Field(description="Identifier for the execution run")
agent_role: Optional[str] = Field(description="Role of the agent")
task_id: Optional[str] = Field(description="ID of the current task being executed")
task_name: Optional[str] = Field(description="Name of the current task")
task_description: Optional[str] = Field(
description="Description of the current task"
)
trace_type: str = Field(description="Type of the trace")
crew_type: str = Field(description="Type of the crew")
run_type: str = Field(description="Type of the run")
# Timing information
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
latency_ms: Optional[int] = None
# Request/Response for LLM calls
request: Optional[LLMRequest] = None
response: Optional[LLMResponse] = None
# Input/Output for Flow steps
flow_step: Optional[FlowStepIO] = None
# Tool usage
tool_calls: List[ToolCall] = Field(default_factory=list)
# Metrics
tokens_used: Optional[int] = None
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
cost: Optional[float] = None
# Additional metadata
status: str = "running" # running, completed, error
error: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
tags: List[str] = Field(default_factory=list)

View File

@@ -0,0 +1,480 @@
import inspect
import os
from datetime import UTC, datetime
from enum import Enum
from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Generator, List, Optional
from uuid import uuid4
from crewai.traces.context import TraceContext
from crewai.traces.enums import CrewType, RunType, TraceType
from crewai.traces.models import (
CrewTrace,
FlowStepIO,
LLMRequest,
LLMResponse,
ToolCall,
)
class UnifiedTraceController:
"""Controls and manages trace execution and recording.
This class handles the lifecycle of traces including creation, execution tracking,
and recording of results for various types of operations (LLM calls, tool calls, flow steps).
"""
def __init__(
self,
trace_type: TraceType,
run_type: RunType,
crew_type: CrewType,
run_id: str,
deployment_instance_id: str = os.environ.get(
"CREWAI_DEPLOYMENT_INSTANCE_ID", ""
),
parent_trace_id: Optional[str] = None,
agent_role: Optional[str] = None,
task_name: Optional[str] = None,
task_description: Optional[str] = None,
task_id: Optional[str] = None,
flow_step: Dict[str, Any] = {},
tool_calls: List[ToolCall] = [],
**context: Any,
) -> None:
"""Initialize a new trace controller.
Args:
trace_type: Type of trace being recorded.
run_type: Type of run being executed.
crew_type: Type of crew executing the trace.
run_id: Unique identifier for the run.
deployment_instance_id: Optional deployment instance identifier.
parent_trace_id: Optional parent trace identifier for nested traces.
agent_role: Role of the agent executing the trace.
task_name: Optional name of the task being executed.
task_description: Optional description of the task.
task_id: Optional unique identifier for the task.
flow_step: Optional flow step information.
tool_calls: Optional list of tool calls made during execution.
**context: Additional context parameters.
"""
self.trace_id = str(uuid4())
self.run_id = run_id
self.parent_trace_id = parent_trace_id
self.trace_type = trace_type
self.run_type = run_type
self.crew_type = crew_type
self.context = context
self.agent_role = agent_role
self.task_name = task_name
self.task_description = task_description
self.task_id = task_id
self.deployment_instance_id = deployment_instance_id
self.children: List[Dict[str, Any]] = []
self.start_time: Optional[datetime] = None
self.end_time: Optional[datetime] = None
self.error: Optional[str] = None
self.tool_calls = tool_calls
self.flow_step = flow_step
self.status: str = "running"
def start_trace(self) -> "UnifiedTraceController":
"""Start the trace execution.
Returns:
UnifiedTraceController: Self for method chaining.
"""
self.start_time = datetime.now(UTC)
return self
def end_trace(self, result: Any = None, error: Optional[str] = None) -> None:
"""End the trace execution and record results.
Args:
result: Optional result from the trace execution.
error: Optional error message if the trace failed.
"""
self.end_time = datetime.now(UTC)
self.status = "error" if error else "completed"
self.error = error
self._record_trace(result)
def add_child_trace(self, child_trace: Dict[str, Any]) -> None:
"""Add a child trace to this trace's execution history.
Args:
child_trace: The child trace information to add.
"""
self.children.append(child_trace)
def to_crew_trace(self) -> CrewTrace:
"""Convert to CrewTrace format for storage.
Returns:
CrewTrace: The trace data in CrewTrace format.
"""
latency_ms = None
if self.tool_calls and hasattr(self.tool_calls[0], "start_time"):
self.start_time = self.tool_calls[0].start_time
if self.start_time and self.end_time:
latency_ms = int((self.end_time - self.start_time).total_seconds() * 1000)
request = None
response = None
flow_step_obj = None
if self.trace_type in [TraceType.LLM_CALL, TraceType.TOOL_CALL]:
request = LLMRequest(
model=self.context.get("model", "unknown"),
messages=self.context.get("messages", []),
temperature=self.context.get("temperature"),
max_tokens=self.context.get("max_tokens"),
stop_sequences=self.context.get("stop_sequences"),
)
if "response" in self.context:
response = LLMResponse(
content=self.context["response"].get("content", ""),
finish_reason=self.context["response"].get("finish_reason"),
)
elif self.trace_type == TraceType.FLOW_STEP:
flow_step_obj = FlowStepIO(
function_name=self.flow_step.get("function_name", "unknown"),
inputs=self.flow_step.get("inputs", {}),
outputs={"result": self.context.get("response")},
metadata=self.flow_step.get("metadata", {}),
)
return CrewTrace(
deployment_instance_id=self.deployment_instance_id,
trace_id=self.trace_id,
task_id=self.task_id,
run_id=self.run_id,
agent_role=self.agent_role,
task_name=self.task_name,
task_description=self.task_description,
trace_type=self.trace_type.value,
crew_type=self.crew_type.value,
run_type=self.run_type.value,
start_time=self.start_time,
end_time=self.end_time,
latency_ms=latency_ms,
request=request,
response=response,
flow_step=flow_step_obj,
tool_calls=self.tool_calls,
tokens_used=self.context.get("tokens_used"),
prompt_tokens=self.context.get("prompt_tokens"),
completion_tokens=self.context.get("completion_tokens"),
status=self.status,
error=self.error,
)
def _record_trace(self, result: Any = None) -> None:
"""Record the trace using PlusClient.
Args:
result: Optional result to include in the trace.
"""
if result:
self.context["response"] = result
# TODO: Add trace to record_task_finished
def should_trace() -> bool:
"""Check if tracing is enabled via environment variable."""
return os.getenv("CREWAI_ENABLE_TRACING", "false").lower() == "true"
def trace_llm_call(func: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator to trace LLM calls.
Args:
func: The function to trace.
Returns:
Wrapped function that creates and manages the LLM call trace context.
"""
@wraps(func)
def wrapper(self: Any, params: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
if not should_trace():
return func(self, params, *args, **kwargs)
current_trace = TraceContext.get_current()
agent, task = self._get_execution_context()
# Get new messages and tool results
new_messages = self._get_new_messages(params.get("messages", []))
new_tool_results = self._get_new_tool_results(agent)
# Create trace context
trace = UnifiedTraceController(
trace_type=TraceType.TOOL_CALL if new_tool_results else TraceType.LLM_CALL,
crew_type=current_trace.crew_type if current_trace else CrewType.CREW,
run_type=current_trace.run_type if current_trace else RunType.KICKOFF,
run_id=current_trace.run_id if current_trace else str(uuid4()),
parent_trace_id=current_trace.trace_id if current_trace else None,
agent_role=agent.role if agent else "unknown",
task_id=str(task.id) if task else None,
task_name=task.name if task else None,
task_description=task.description if task else None,
model=self.model,
messages=new_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
stop_sequences=self.stop,
tool_calls=[
ToolCall(
name=result["tool_name"],
arguments=result["tool_args"],
output=str(result["result"]),
start_time=result.get("start_time", ""),
end_time=datetime.now(UTC),
)
for result in new_tool_results
],
)
with TraceContext.set_current(trace):
trace.start_trace()
try:
response = func(self, params, *args, **kwargs)
# Extract relevant data from response
trace_response = {
"content": response["choices"][0]["message"]["content"],
"finish_reason": response["choices"][0].get("finish_reason"),
}
# Add usage metrics to context
if "usage" in response:
trace.context["tokens_used"] = response["usage"].get(
"total_tokens", 0
)
trace.context["prompt_tokens"] = response["usage"].get(
"prompt_tokens", 0
)
trace.context["completion_tokens"] = response["usage"].get(
"completion_tokens", 0
)
trace.end_trace(trace_response)
return response
except Exception as e:
trace.end_trace(error=str(e))
raise
return wrapper
# Crew main trace
def init_crew_main_trace(func: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator to initialize and track the main crew execution trace.
This decorator sets up the trace context for the main crew execution,
handling both synchronous and asynchronous crew operations.
Args:
func: The crew function to be traced.
Returns:
Wrapped function that creates and manages the main crew trace context.
"""
@wraps(func)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if not should_trace():
return func(self, *args, **kwargs)
trace = build_crew_main_trace(self)
with TraceContext.set_current(trace):
try:
return func(self, *args, **kwargs)
except Exception as e:
trace.end_trace(error=str(e))
raise
return wrapper
def build_crew_main_trace(self: Any) -> "UnifiedTraceController":
"""Build the main trace controller for a crew execution.
This function creates a trace controller configured for the main crew execution,
handling different run types (kickoff, test, train) and maintaining context.
Args:
self: The crew instance.
Returns:
UnifiedTraceController: The configured trace controller for the crew.
"""
run_type = RunType.KICKOFF
if hasattr(self, "_test") and self._test:
run_type = RunType.TEST
elif hasattr(self, "_train") and self._train:
run_type = RunType.TRAIN
current_trace = TraceContext.get_current()
trace = UnifiedTraceController(
trace_type=TraceType.LLM_CALL,
run_type=run_type,
crew_type=current_trace.crew_type if current_trace else CrewType.CREW,
run_id=current_trace.run_id if current_trace else str(self.id),
parent_trace_id=current_trace.trace_id if current_trace else None,
)
return trace
# Flow main trace
def init_flow_main_trace(
func: Callable[..., Awaitable[Any]],
) -> Callable[..., Awaitable[Any]]:
"""Decorator to initialize and track the main flow execution trace.
Args:
func: The async flow function to be traced.
Returns:
Wrapped async function that creates and manages the main flow trace context.
"""
@wraps(func)
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
if not should_trace():
return await func(self, *args, **kwargs)
trace = build_flow_main_trace(self, *args, **kwargs)
with TraceContext.set_current(trace):
try:
return await func(self, *args, **kwargs)
except Exception as e:
trace.end_trace(error=str(e))
raise
return wrapper
def build_flow_main_trace(
self: Any, *args: Any, **kwargs: Any
) -> "UnifiedTraceController":
"""Build the main trace controller for a flow execution.
Args:
self: The flow instance.
*args: Variable positional arguments.
**kwargs: Variable keyword arguments.
Returns:
UnifiedTraceController: The configured trace controller for the flow.
"""
current_trace = TraceContext.get_current()
trace = UnifiedTraceController(
trace_type=TraceType.FLOW_STEP,
run_id=current_trace.run_id if current_trace else str(self.flow_id),
parent_trace_id=current_trace.trace_id if current_trace else None,
crew_type=CrewType.FLOW,
run_type=RunType.KICKOFF,
context={
"crew_name": self.__class__.__name__,
"inputs": kwargs.get("inputs", {}),
"agents": [],
"tasks": [],
},
)
return trace
# Flow step trace
def trace_flow_step(
func: Callable[..., Awaitable[Any]],
) -> Callable[..., Awaitable[Any]]:
"""Decorator to trace individual flow step executions.
Args:
func: The async flow step function to be traced.
Returns:
Wrapped async function that creates and manages the flow step trace context.
"""
@wraps(func)
async def wrapper(
self: Any,
method_name: str,
method: Callable[..., Any],
*args: Any,
**kwargs: Any,
) -> Any:
if not should_trace():
return await func(self, method_name, method, *args, **kwargs)
trace = build_flow_step_trace(self, method_name, method, *args, **kwargs)
with TraceContext.set_current(trace):
trace.start_trace()
try:
result = await func(self, method_name, method, *args, **kwargs)
trace.end_trace(result=result)
return result
except Exception as e:
trace.end_trace(error=str(e))
raise
return wrapper
def build_flow_step_trace(
self: Any, method_name: str, method: Callable[..., Any], *args: Any, **kwargs: Any
) -> "UnifiedTraceController":
"""Build a trace controller for an individual flow step.
Args:
self: The flow instance.
method_name: Name of the method being executed.
method: The actual method being executed.
*args: Variable positional arguments.
**kwargs: Variable keyword arguments.
Returns:
UnifiedTraceController: The configured trace controller for the flow step.
"""
current_trace = TraceContext.get_current()
# Get method signature
sig = inspect.signature(method)
params = list(sig.parameters.values())
# Create inputs dictionary mapping parameter names to values
method_params = [p for p in params if p.name != "self"]
inputs: Dict[str, Any] = {}
# Map positional args to their parameter names
for i, param in enumerate(method_params):
if i < len(args):
inputs[param.name] = args[i]
# Add keyword arguments
inputs.update(kwargs)
trace = UnifiedTraceController(
trace_type=TraceType.FLOW_STEP,
run_type=current_trace.run_type if current_trace else RunType.KICKOFF,
crew_type=current_trace.crew_type if current_trace else CrewType.FLOW,
run_id=current_trace.run_id if current_trace else str(self.flow_id),
parent_trace_id=current_trace.trace_id if current_trace else None,
flow_step={
"function_name": method_name,
"inputs": inputs,
"metadata": {
"crew_name": self.__class__.__name__,
},
},
)
return trace