From 90f1bee6025a2bfdc2f4d7ef5531805c727f090e Mon Sep 17 00:00:00 2001 From: Eduardo Chiarotti Date: Wed, 19 Feb 2025 08:52:30 -0300 Subject: [PATCH] feat: add prompt observability code (#2027) * feat: add prompt observability code * feat: improve logic for llm call * feat: add tests for traces * feat: remove unused improt * feat: add function to clear and add task traces * feat: fix import * feat: chagne time * feat: fix type checking issues * feat: add fixed time to fix test * feat: fix datetime test issue * feat: add add task traces function * feat: add same logic as entp * feat: add start_time as reference for duplication of tool call * feat: add max_depth * feat: add protocols file to properly import on LLM --------- Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com> --- src/crewai/crew.py | 2 + src/crewai/flow/flow.py | 11 +- src/crewai/llm.py | 154 ++++- src/crewai/tools/tool_usage.py | 20 +- src/crewai/traces/__init__.py | 0 src/crewai/traces/context.py | 39 ++ src/crewai/traces/enums.py | 19 + src/crewai/traces/models.py | 89 +++ src/crewai/traces/unified_trace_controller.py | 543 ++++++++++++++++++ src/crewai/utilities/protocols.py | 12 + tests/agent_test.py | 53 +- tests/traces/test_unified_trace_controller.py | 360 ++++++++++++ 12 files changed, 1254 insertions(+), 48 deletions(-) create mode 100644 src/crewai/traces/__init__.py create mode 100644 src/crewai/traces/context.py create mode 100644 src/crewai/traces/enums.py create mode 100644 src/crewai/traces/models.py create mode 100644 src/crewai/traces/unified_trace_controller.py create mode 100644 src/crewai/utilities/protocols.py create mode 100644 tests/traces/test_unified_trace_controller.py diff --git a/src/crewai/crew.py b/src/crewai/crew.py index d331599b5..682d5d60b 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -38,6 +38,7 @@ from crewai.tasks.task_output import TaskOutput from crewai.telemetry import Telemetry from crewai.tools.agent_tools.agent_tools import AgentTools from crewai.tools.base_tool import Tool +from crewai.traces.unified_trace_controller import init_crew_main_trace from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities.constants import TRAINING_DATA_FILE @@ -545,6 +546,7 @@ class Crew(BaseModel): CrewTrainingHandler(filename).clear() raise + @init_crew_main_trace def kickoff( self, inputs: Optional[Dict[str, Any]] = None, diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index f1242a2bf..f0d0b1093 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -30,6 +30,10 @@ from crewai.flow.flow_visualizer import plot_flow from crewai.flow.persistence.base import FlowPersistence from crewai.flow.utils import get_possible_return_constants from crewai.telemetry import Telemetry +from crewai.traces.unified_trace_controller import ( + init_flow_main_trace, + trace_flow_step, +) from crewai.utilities.printer import Printer logger = logging.getLogger(__name__) @@ -753,8 +757,12 @@ class Flow(Generic[T], metaclass=FlowMeta): if inputs is not None and "id" not in inputs: self._initialize_state(inputs) - return asyncio.run(self.kickoff_async()) + async def run_flow(): + return await self.kickoff_async() + return asyncio.run(run_flow()) + + @init_flow_main_trace async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: if not self._start_methods: raise ValueError("No start method defined") @@ -804,6 +812,7 @@ class Flow(Generic[T], metaclass=FlowMeta): ) await self._execute_listeners(start_method_name, result) + @trace_flow_step async def _execute_method( self, method_name: str, method: Callable, *args: Any, **kwargs: Any ) -> Any: diff --git a/src/crewai/llm.py b/src/crewai/llm.py index ada5c9bf3..43391951e 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -1,3 +1,4 @@ +import inspect import json import logging import os @@ -5,7 +6,17 @@ import sys import threading import warnings from contextlib import contextmanager -from typing import Any, Dict, List, Literal, Optional, Type, Union, cast +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + Union, + cast, +) from dotenv import load_dotenv from pydantic import BaseModel @@ -18,9 +29,11 @@ with warnings.catch_warnings(): from litellm.utils import supports_response_schema +from crewai.traces.unified_trace_controller import trace_llm_call from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, ) +from crewai.utilities.protocols import AgentExecutorProtocol load_dotenv() @@ -164,6 +177,7 @@ class LLM: self.context_window_size = 0 self.reasoning_effort = reasoning_effort self.additional_params = kwargs + self._message_history: List[Dict[str, str]] = [] self.is_anthropic = self._is_anthropic_model(model) litellm.drop_params = True @@ -179,16 +193,22 @@ class LLM: self.set_callbacks(callbacks) self.set_env_callbacks() + @trace_llm_call + def _call_llm(self, params: Dict[str, Any]) -> Any: + with suppress_warnings(): + response = litellm.completion(**params) + return response + def _is_anthropic_model(self, model: str) -> bool: """Determine if the model is from Anthropic provider. - + Args: model: The model identifier string. - + Returns: bool: True if the model is from Anthropic, False otherwise. """ - ANTHROPIC_PREFIXES = ('anthropic/', 'claude-', 'claude/') + ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/") return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) def call( @@ -199,7 +219,7 @@ class LLM: available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: """High-level LLM call method. - + Args: messages: Input messages for the LLM. Can be a string or list of message dictionaries. @@ -211,22 +231,22 @@ class LLM: during and after the LLM call. available_functions: Optional dict mapping function names to callables that can be invoked by the LLM. - + Returns: Union[str, Any]: Either a text response from the LLM (str) or the result of a tool function call (Any). - + Raises: TypeError: If messages format is invalid ValueError: If response format is not supported LLMContextLengthExceededException: If input exceeds model's context limit - + Examples: # Example 1: Simple string input >>> response = llm.call("Return the name of a random city.") >>> print(response) "Paris" - + # Example 2: Message list with system and user messages >>> messages = [ ... {"role": "system", "content": "You are a geography expert"}, @@ -288,7 +308,7 @@ class LLM: params = {k: v for k, v in params.items() if v is not None} # --- 2) Make the completion call - response = litellm.completion(**params) + response = self._call_llm(params) response_message = cast(Choices, cast(ModelResponse, response).choices)[ 0 ].message @@ -348,36 +368,40 @@ class LLM: logging.error(f"LiteLLM call failed: {str(e)}") raise - def _format_messages_for_provider(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: + def _format_messages_for_provider( + self, messages: List[Dict[str, str]] + ) -> List[Dict[str, str]]: """Format messages according to provider requirements. - + Args: messages: List of message dictionaries with 'role' and 'content' keys. Can be empty or None. - + Returns: List of formatted messages according to provider requirements. For Anthropic models, ensures first message has 'user' role. - + Raises: TypeError: If messages is None or contains invalid message format. """ if messages is None: raise TypeError("Messages cannot be None") - + # Validate message format first for msg in messages: if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: - raise TypeError("Invalid message format. Each message must be a dict with 'role' and 'content' keys") - + raise TypeError( + "Invalid message format. Each message must be a dict with 'role' and 'content' keys" + ) + if not self.is_anthropic: return messages - + # Anthropic requires messages to start with 'user' role if not messages or messages[0]["role"] == "system": # If first message is system or empty, add a placeholder user message return [{"role": "user", "content": "."}, *messages] - + return messages def _get_custom_llm_provider(self) -> str: @@ -495,3 +519,95 @@ class LLM: litellm.success_callback = success_callbacks litellm.failure_callback = failure_callbacks + + def _get_execution_context(self) -> Tuple[Optional[Any], Optional[Any]]: + """Get the agent and task from the execution context. + + Returns: + tuple: (agent, task) from any AgentExecutor context, or (None, None) if not found + """ + frame = inspect.currentframe() + caller_frame = frame.f_back if frame else None + agent = None + task = None + + # Add a maximum depth to prevent infinite loops + max_depth = 100 # Reasonable limit for call stack depth + current_depth = 0 + + while caller_frame and current_depth < max_depth: + if "self" in caller_frame.f_locals: + caller_self = caller_frame.f_locals["self"] + if isinstance(caller_self, AgentExecutorProtocol): + agent = caller_self.agent + task = caller_self.task + break + caller_frame = caller_frame.f_back + current_depth += 1 + + return agent, task + + def _get_new_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Get only the new messages that haven't been processed before.""" + if not hasattr(self, "_message_history"): + self._message_history = [] + + new_messages = [] + for message in messages: + message_key = (message["role"], message["content"]) + if message_key not in [ + (m["role"], m["content"]) for m in self._message_history + ]: + new_messages.append(message) + self._message_history.append(message) + return new_messages + + def _get_new_tool_results(self, agent) -> List[Dict]: + """Get only the new tool results that haven't been processed before.""" + if not agent or not agent.tools_results: + return [] + + if not hasattr(self, "_tool_results_history"): + self._tool_results_history: List[Dict] = [] + + new_tool_results = [] + + for result in agent.tools_results: + # Process tool arguments to extract actual values + processed_args = {} + if isinstance(result["tool_args"], dict): + for key, value in result["tool_args"].items(): + if isinstance(value, dict) and "type" in value: + # Skip metadata and just store the actual value + continue + processed_args[key] = value + + # Create a clean result with processed arguments + clean_result = { + "tool_name": result["tool_name"], + "tool_args": processed_args, + "result": result["result"], + "content": result.get("content", ""), + "start_time": result.get("start_time", ""), + } + + # Check if this exact tool execution exists in history + is_duplicate = False + for history_result in self._tool_results_history: + if ( + clean_result["tool_name"] == history_result["tool_name"] + and str(clean_result["tool_args"]) + == str(history_result["tool_args"]) + and str(clean_result["result"]) == str(history_result["result"]) + and clean_result["content"] == history_result.get("content", "") + and clean_result["start_time"] + == history_result.get("start_time", "") + ): + is_duplicate = True + break + + if not is_duplicate: + new_tool_results.append(clean_result) + self._tool_results_history.append(clean_result) + + return new_tool_results diff --git a/src/crewai/tools/tool_usage.py b/src/crewai/tools/tool_usage.py index 218410ef7..fa821bebd 100644 --- a/src/crewai/tools/tool_usage.py +++ b/src/crewai/tools/tool_usage.py @@ -2,6 +2,7 @@ import ast import datetime import json import time +from datetime import UTC from difflib import SequenceMatcher from json import JSONDecodeError from textwrap import dedent @@ -116,7 +117,10 @@ class ToolUsage: self._printer.print(content=f"\n\n{error}\n", color="red") return error - if isinstance(tool, CrewStructuredTool) and tool.name == self._i18n.tools("add_image")["name"]: # type: ignore + if ( + isinstance(tool, CrewStructuredTool) + and tool.name == self._i18n.tools("add_image")["name"] # type: ignore + ): try: result = self._use(tool_string=tool_string, tool=tool, calling=calling) return result @@ -154,6 +158,7 @@ class ToolUsage: self.task.increment_tools_errors() started_at = time.time() + started_at_trace = datetime.datetime.now(UTC) from_cache = False result = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str") @@ -181,7 +186,9 @@ class ToolUsage: if calling.arguments: try: - acceptable_args = tool.args_schema.model_json_schema()["properties"].keys() # type: ignore + acceptable_args = tool.args_schema.model_json_schema()[ + "properties" + ].keys() # type: ignore arguments = { k: v for k, v in calling.arguments.items() @@ -202,7 +209,7 @@ class ToolUsage: error=e, tool=tool.name, tool_inputs=tool.description ) error = ToolUsageErrorException( - f'\n{error_message}.\nMoving on then. {self._i18n.slice("format").format(tool_names=self.tools_names)}' + f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}" ).message self.task.increment_tools_errors() if self.agent.verbose: @@ -244,6 +251,7 @@ class ToolUsage: "result": result, "tool_name": tool.name, "tool_args": calling.arguments, + "start_time": started_at_trace, } self.on_tool_use_finished( @@ -368,7 +376,7 @@ class ToolUsage: raise else: return ToolUsageErrorException( - f'{self._i18n.errors("tool_arguments_error")}' + f"{self._i18n.errors('tool_arguments_error')}" ) if not isinstance(arguments, dict): @@ -376,7 +384,7 @@ class ToolUsage: raise else: return ToolUsageErrorException( - f'{self._i18n.errors("tool_arguments_error")}' + f"{self._i18n.errors('tool_arguments_error')}" ) return ToolCalling( @@ -404,7 +412,7 @@ class ToolUsage: if self.agent.verbose: self._printer.print(content=f"\n\n{e}\n", color="red") return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling") - f'{self._i18n.errors("tool_usage_error").format(error=e)}\nMoving on then. {self._i18n.slice("format").format(tool_names=self.tools_names)}' + f"{self._i18n.errors('tool_usage_error').format(error=e)}\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}" ) return self._tool_calling(tool_string) diff --git a/src/crewai/traces/__init__.py b/src/crewai/traces/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/crewai/traces/context.py b/src/crewai/traces/context.py new file mode 100644 index 000000000..dd1cf144e --- /dev/null +++ b/src/crewai/traces/context.py @@ -0,0 +1,39 @@ +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Generator + + +class TraceContext: + """Maintains the current trace context throughout the execution stack. + + This class provides a context manager for tracking trace execution across + async and sync code paths using ContextVars. + """ + + _context: ContextVar = ContextVar("trace_context", default=None) + + @classmethod + def get_current(cls): + """Get the current trace context. + + Returns: + Optional[UnifiedTraceController]: The current trace controller or None if not set. + """ + return cls._context.get() + + @classmethod + @contextmanager + def set_current(cls, trace): + """Set the current trace context within a context manager. + + Args: + trace: The trace controller to set as current. + + Yields: + UnifiedTraceController: The current trace controller. + """ + token = cls._context.set(trace) + try: + yield trace + finally: + cls._context.reset(token) diff --git a/src/crewai/traces/enums.py b/src/crewai/traces/enums.py new file mode 100644 index 000000000..392f46ea4 --- /dev/null +++ b/src/crewai/traces/enums.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class TraceType(Enum): + LLM_CALL = "llm_call" + TOOL_CALL = "tool_call" + FLOW_STEP = "flow_step" + START_CALL = "start_call" + + +class RunType(Enum): + KICKOFF = "kickoff" + TRAIN = "train" + TEST = "test" + + +class CrewType(Enum): + CREW = "crew" + FLOW = "flow" diff --git a/src/crewai/traces/models.py b/src/crewai/traces/models.py new file mode 100644 index 000000000..254da957e --- /dev/null +++ b/src/crewai/traces/models.py @@ -0,0 +1,89 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class ToolCall(BaseModel): + """Model representing a tool call during execution""" + + name: str + arguments: Dict[str, Any] + output: str + start_time: datetime + end_time: Optional[datetime] = None + latency_ms: Optional[int] = None + error: Optional[str] = None + + +class LLMRequest(BaseModel): + """Model representing the LLM request details""" + + model: str + messages: List[Dict[str, str]] + temperature: Optional[float] = None + max_tokens: Optional[int] = None + stop_sequences: Optional[List[str]] = None + additional_params: Dict[str, Any] = Field(default_factory=dict) + + +class LLMResponse(BaseModel): + """Model representing the LLM response details""" + + content: str + finish_reason: Optional[str] = None + + +class FlowStepIO(BaseModel): + """Model representing flow step input/output details""" + + function_name: str + inputs: Dict[str, Any] = Field(default_factory=dict) + outputs: Any + metadata: Dict[str, Any] = Field(default_factory=dict) + + +class CrewTrace(BaseModel): + """Model for tracking detailed information about LLM interactions and Flow steps""" + + deployment_instance_id: Optional[str] = Field( + description="ID of the deployment instance" + ) + trace_id: str = Field(description="Unique identifier for this trace") + run_id: str = Field(description="Identifier for the execution run") + agent_role: Optional[str] = Field(description="Role of the agent") + task_id: Optional[str] = Field(description="ID of the current task being executed") + task_name: Optional[str] = Field(description="Name of the current task") + task_description: Optional[str] = Field( + description="Description of the current task" + ) + trace_type: str = Field(description="Type of the trace") + crew_type: str = Field(description="Type of the crew") + run_type: str = Field(description="Type of the run") + + # Timing information + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + latency_ms: Optional[int] = None + + # Request/Response for LLM calls + request: Optional[LLMRequest] = None + response: Optional[LLMResponse] = None + + # Input/Output for Flow steps + flow_step: Optional[FlowStepIO] = None + + # Tool usage + tool_calls: List[ToolCall] = Field(default_factory=list) + + # Metrics + tokens_used: Optional[int] = None + prompt_tokens: Optional[int] = None + completion_tokens: Optional[int] = None + cost: Optional[float] = None + + # Additional metadata + status: str = "running" # running, completed, error + error: Optional[str] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + tags: List[str] = Field(default_factory=list) diff --git a/src/crewai/traces/unified_trace_controller.py b/src/crewai/traces/unified_trace_controller.py new file mode 100644 index 000000000..986a0a174 --- /dev/null +++ b/src/crewai/traces/unified_trace_controller.py @@ -0,0 +1,543 @@ +import inspect +import os +from datetime import UTC, datetime +from functools import wraps +from typing import Any, Awaitable, Callable, Dict, List, Optional +from uuid import uuid4 + +from crewai.traces.context import TraceContext +from crewai.traces.enums import CrewType, RunType, TraceType +from crewai.traces.models import ( + CrewTrace, + FlowStepIO, + LLMRequest, + LLMResponse, + ToolCall, +) + + +class UnifiedTraceController: + """Controls and manages trace execution and recording. + + This class handles the lifecycle of traces including creation, execution tracking, + and recording of results for various types of operations (LLM calls, tool calls, flow steps). + """ + + _task_traces: Dict[str, List["UnifiedTraceController"]] = {} + + def __init__( + self, + trace_type: TraceType, + run_type: RunType, + crew_type: CrewType, + run_id: str, + deployment_instance_id: str = os.environ.get( + "CREWAI_DEPLOYMENT_INSTANCE_ID", "" + ), + parent_trace_id: Optional[str] = None, + agent_role: Optional[str] = "unknown", + task_name: Optional[str] = None, + task_description: Optional[str] = None, + task_id: Optional[str] = None, + flow_step: Dict[str, Any] = {}, + tool_calls: List[ToolCall] = [], + **context: Any, + ) -> None: + """Initialize a new trace controller. + + Args: + trace_type: Type of trace being recorded. + run_type: Type of run being executed. + crew_type: Type of crew executing the trace. + run_id: Unique identifier for the run. + deployment_instance_id: Optional deployment instance identifier. + parent_trace_id: Optional parent trace identifier for nested traces. + agent_role: Role of the agent executing the trace. + task_name: Optional name of the task being executed. + task_description: Optional description of the task. + task_id: Optional unique identifier for the task. + flow_step: Optional flow step information. + tool_calls: Optional list of tool calls made during execution. + **context: Additional context parameters. + """ + self.trace_id = str(uuid4()) + self.run_id = run_id + self.parent_trace_id = parent_trace_id + self.trace_type = trace_type + self.run_type = run_type + self.crew_type = crew_type + self.context = context + self.agent_role = agent_role + self.task_name = task_name + self.task_description = task_description + self.task_id = task_id + self.deployment_instance_id = deployment_instance_id + self.children: List[Dict[str, Any]] = [] + self.start_time: Optional[datetime] = None + self.end_time: Optional[datetime] = None + self.error: Optional[str] = None + self.tool_calls = tool_calls + self.flow_step = flow_step + self.status: str = "running" + + # Add trace to task's trace collection if task_id is present + if task_id: + self._add_to_task_traces() + + def _add_to_task_traces(self) -> None: + """Add this trace to the task's trace collection.""" + if not hasattr(UnifiedTraceController, "_task_traces"): + UnifiedTraceController._task_traces = {} + + if self.task_id is None: + return + + if self.task_id not in UnifiedTraceController._task_traces: + UnifiedTraceController._task_traces[self.task_id] = [] + + UnifiedTraceController._task_traces[self.task_id].append(self) + + @classmethod + def get_task_traces(cls, task_id: str) -> List["UnifiedTraceController"]: + """Get all traces for a specific task. + + Args: + task_id: The ID of the task to get traces for + + Returns: + List of traces associated with the task + """ + return cls._task_traces.get(task_id, []) + + @classmethod + def clear_task_traces(cls, task_id: str) -> None: + """Clear traces for a specific task. + + Args: + task_id: The ID of the task to clear traces for + """ + if hasattr(cls, "_task_traces") and task_id in cls._task_traces: + del cls._task_traces[task_id] + + def _get_current_trace(self) -> "UnifiedTraceController": + return TraceContext.get_current() + + def start_trace(self) -> "UnifiedTraceController": + """Start the trace execution. + + Returns: + UnifiedTraceController: Self for method chaining. + """ + self.start_time = datetime.now(UTC) + return self + + def end_trace(self, result: Any = None, error: Optional[str] = None) -> None: + """End the trace execution and record results. + + Args: + result: Optional result from the trace execution. + error: Optional error message if the trace failed. + """ + self.end_time = datetime.now(UTC) + self.status = "error" if error else "completed" + self.error = error + self._record_trace(result) + + def add_child_trace(self, child_trace: Dict[str, Any]) -> None: + """Add a child trace to this trace's execution history. + + Args: + child_trace: The child trace information to add. + """ + self.children.append(child_trace) + + def to_crew_trace(self) -> CrewTrace: + """Convert to CrewTrace format for storage. + + Returns: + CrewTrace: The trace data in CrewTrace format. + """ + latency_ms = None + + if self.tool_calls and hasattr(self.tool_calls[0], "start_time"): + self.start_time = self.tool_calls[0].start_time + + if self.start_time and self.end_time: + latency_ms = int((self.end_time - self.start_time).total_seconds() * 1000) + + request = None + response = None + flow_step_obj = None + + if self.trace_type in [TraceType.LLM_CALL, TraceType.TOOL_CALL]: + request = LLMRequest( + model=self.context.get("model", "unknown"), + messages=self.context.get("messages", []), + temperature=self.context.get("temperature"), + max_tokens=self.context.get("max_tokens"), + stop_sequences=self.context.get("stop_sequences"), + ) + if "response" in self.context: + response = LLMResponse( + content=self.context["response"].get("content", ""), + finish_reason=self.context["response"].get("finish_reason"), + ) + + elif self.trace_type == TraceType.FLOW_STEP: + flow_step_obj = FlowStepIO( + function_name=self.flow_step.get("function_name", "unknown"), + inputs=self.flow_step.get("inputs", {}), + outputs={"result": self.context.get("response")}, + metadata=self.flow_step.get("metadata", {}), + ) + + return CrewTrace( + deployment_instance_id=self.deployment_instance_id, + trace_id=self.trace_id, + task_id=self.task_id, + run_id=self.run_id, + agent_role=self.agent_role, + task_name=self.task_name, + task_description=self.task_description, + trace_type=self.trace_type.value, + crew_type=self.crew_type.value, + run_type=self.run_type.value, + start_time=self.start_time, + end_time=self.end_time, + latency_ms=latency_ms, + request=request, + response=response, + flow_step=flow_step_obj, + tool_calls=self.tool_calls, + tokens_used=self.context.get("tokens_used"), + prompt_tokens=self.context.get("prompt_tokens"), + completion_tokens=self.context.get("completion_tokens"), + status=self.status, + error=self.error, + ) + + def _record_trace(self, result: Any = None) -> None: + """Record the trace. + + This method is called when a trace is completed. It ensures the trace + is properly recorded and associated with its task if applicable. + + Args: + result: Optional result to include in the trace + """ + if result: + self.context["response"] = result + + # Add to task traces if this trace belongs to a task + if self.task_id: + self._add_to_task_traces() + + +def should_trace() -> bool: + """Check if tracing is enabled via environment variable.""" + return os.getenv("CREWAI_ENABLE_TRACING", "false").lower() == "true" + + +# Crew main trace +def init_crew_main_trace(func: Callable[..., Any]) -> Callable[..., Any]: + """Decorator to initialize and track the main crew execution trace. + + This decorator sets up the trace context for the main crew execution, + handling both synchronous and asynchronous crew operations. + + Args: + func: The crew function to be traced. + + Returns: + Wrapped function that creates and manages the main crew trace context. + """ + + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if not should_trace(): + return func(self, *args, **kwargs) + + trace = build_crew_main_trace(self) + with TraceContext.set_current(trace): + try: + return func(self, *args, **kwargs) + except Exception as e: + trace.end_trace(error=str(e)) + raise + + return wrapper + + +def build_crew_main_trace(self: Any) -> "UnifiedTraceController": + """Build the main trace controller for a crew execution. + + This function creates a trace controller configured for the main crew execution, + handling different run types (kickoff, test, train) and maintaining context. + + Args: + self: The crew instance. + + Returns: + UnifiedTraceController: The configured trace controller for the crew. + """ + run_type = RunType.KICKOFF + if hasattr(self, "_test") and self._test: + run_type = RunType.TEST + elif hasattr(self, "_train") and self._train: + run_type = RunType.TRAIN + + current_trace = TraceContext.get_current() + + trace = UnifiedTraceController( + trace_type=TraceType.LLM_CALL, + run_type=run_type, + crew_type=current_trace.crew_type if current_trace else CrewType.CREW, + run_id=current_trace.run_id if current_trace else str(self.id), + parent_trace_id=current_trace.trace_id if current_trace else None, + ) + return trace + + +# Flow main trace +def init_flow_main_trace( + func: Callable[..., Awaitable[Any]], +) -> Callable[..., Awaitable[Any]]: + """Decorator to initialize and track the main flow execution trace. + + Args: + func: The async flow function to be traced. + + Returns: + Wrapped async function that creates and manages the main flow trace context. + """ + + @wraps(func) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if not should_trace(): + return await func(self, *args, **kwargs) + + trace = build_flow_main_trace(self, *args, **kwargs) + with TraceContext.set_current(trace): + try: + return await func(self, *args, **kwargs) + except Exception: + raise + + return wrapper + + +def build_flow_main_trace( + self: Any, *args: Any, **kwargs: Any +) -> "UnifiedTraceController": + """Build the main trace controller for a flow execution. + + Args: + self: The flow instance. + *args: Variable positional arguments. + **kwargs: Variable keyword arguments. + + Returns: + UnifiedTraceController: The configured trace controller for the flow. + """ + current_trace = TraceContext.get_current() + trace = UnifiedTraceController( + trace_type=TraceType.FLOW_STEP, + run_id=current_trace.run_id if current_trace else str(self.flow_id), + parent_trace_id=current_trace.trace_id if current_trace else None, + crew_type=CrewType.FLOW, + run_type=RunType.KICKOFF, + context={ + "crew_name": self.__class__.__name__, + "inputs": kwargs.get("inputs", {}), + "agents": [], + "tasks": [], + }, + ) + return trace + + +# Flow step trace +def trace_flow_step( + func: Callable[..., Awaitable[Any]], +) -> Callable[..., Awaitable[Any]]: + """Decorator to trace individual flow step executions. + + Args: + func: The async flow step function to be traced. + + Returns: + Wrapped async function that creates and manages the flow step trace context. + """ + + @wraps(func) + async def wrapper( + self: Any, + method_name: str, + method: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> Any: + if not should_trace(): + return await func(self, method_name, method, *args, **kwargs) + + trace = build_flow_step_trace(self, method_name, method, *args, **kwargs) + with TraceContext.set_current(trace): + trace.start_trace() + try: + result = await func(self, method_name, method, *args, **kwargs) + trace.end_trace(result=result) + return result + except Exception as e: + trace.end_trace(error=str(e)) + raise + + return wrapper + + +def build_flow_step_trace( + self: Any, method_name: str, method: Callable[..., Any], *args: Any, **kwargs: Any +) -> "UnifiedTraceController": + """Build a trace controller for an individual flow step. + + Args: + self: The flow instance. + method_name: Name of the method being executed. + method: The actual method being executed. + *args: Variable positional arguments. + **kwargs: Variable keyword arguments. + + Returns: + UnifiedTraceController: The configured trace controller for the flow step. + """ + current_trace = TraceContext.get_current() + + # Get method signature + sig = inspect.signature(method) + params = list(sig.parameters.values()) + + # Create inputs dictionary mapping parameter names to values + method_params = [p for p in params if p.name != "self"] + inputs: Dict[str, Any] = {} + + # Map positional args to their parameter names + for i, param in enumerate(method_params): + if i < len(args): + inputs[param.name] = args[i] + + # Add keyword arguments + inputs.update(kwargs) + + trace = UnifiedTraceController( + trace_type=TraceType.FLOW_STEP, + run_type=current_trace.run_type if current_trace else RunType.KICKOFF, + crew_type=current_trace.crew_type if current_trace else CrewType.FLOW, + run_id=current_trace.run_id if current_trace else str(self.flow_id), + parent_trace_id=current_trace.trace_id if current_trace else None, + flow_step={ + "function_name": method_name, + "inputs": inputs, + "metadata": { + "crew_name": self.__class__.__name__, + }, + }, + ) + return trace + + +# LLM trace +def trace_llm_call(func: Callable[..., Any]) -> Callable[..., Any]: + """Decorator to trace LLM calls. + + Args: + func: The function to trace. + + Returns: + Wrapped function that creates and manages the LLM call trace context. + """ + + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if not should_trace(): + return func(self, *args, **kwargs) + + trace = build_llm_trace(self, *args, **kwargs) + with TraceContext.set_current(trace): + trace.start_trace() + try: + response = func(self, *args, **kwargs) + # Extract relevant data from response + trace_response = { + "content": response["choices"][0]["message"]["content"], + "finish_reason": response["choices"][0].get("finish_reason"), + } + + # Add usage metrics to context + if "usage" in response: + trace.context["tokens_used"] = response["usage"].get( + "total_tokens", 0 + ) + trace.context["prompt_tokens"] = response["usage"].get( + "prompt_tokens", 0 + ) + trace.context["completion_tokens"] = response["usage"].get( + "completion_tokens", 0 + ) + + trace.end_trace(trace_response) + return response + except Exception as e: + trace.end_trace(error=str(e)) + raise + + return wrapper + + +def build_llm_trace( + self: Any, params: Dict[str, Any], *args: Any, **kwargs: Any +) -> Any: + """Build a trace controller for an LLM call. + + Args: + self: The LLM instance. + params: The parameters for the LLM call. + *args: Variable positional arguments. + **kwargs: Variable keyword arguments. + + Returns: + UnifiedTraceController: The configured trace controller for the LLM call. + """ + current_trace = TraceContext.get_current() + agent, task = self._get_execution_context() + + # Get new messages and tool results + new_messages = self._get_new_messages(params.get("messages", [])) + new_tool_results = self._get_new_tool_results(agent) + + # Create trace context + trace = UnifiedTraceController( + trace_type=TraceType.TOOL_CALL if new_tool_results else TraceType.LLM_CALL, + crew_type=current_trace.crew_type if current_trace else CrewType.CREW, + run_type=current_trace.run_type if current_trace else RunType.KICKOFF, + run_id=current_trace.run_id if current_trace else str(uuid4()), + parent_trace_id=current_trace.trace_id if current_trace else None, + agent_role=agent.role if agent else "unknown", + task_id=str(task.id) if task else None, + task_name=task.name if task else None, + task_description=task.description if task else None, + model=self.model, + messages=new_messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + stop_sequences=self.stop, + tool_calls=[ + ToolCall( + name=result["tool_name"], + arguments=result["tool_args"], + output=str(result["result"]), + start_time=result.get("start_time", ""), + end_time=datetime.now(UTC), + ) + for result in new_tool_results + ], + ) + return trace diff --git a/src/crewai/utilities/protocols.py b/src/crewai/utilities/protocols.py new file mode 100644 index 000000000..83ebf58e9 --- /dev/null +++ b/src/crewai/utilities/protocols.py @@ -0,0 +1,12 @@ +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class AgentExecutorProtocol(Protocol): + """Protocol defining the expected interface for an agent executor.""" + + @property + def agent(self) -> Any: ... + + @property + def task(self) -> Any: ... diff --git a/tests/agent_test.py b/tests/agent_test.py index e67a7454a..d429a3c60 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1,6 +1,7 @@ """Test Agent creation and execution basic functionality.""" import os +from datetime import UTC, datetime, timezone from unittest import mock from unittest.mock import patch @@ -908,6 +909,8 @@ def test_tool_result_as_answer_is_the_final_answer_for_the_agent(): @pytest.mark.vcr(filter_headers=["authorization"]) def test_tool_usage_information_is_appended_to_agent(): + from datetime import UTC, datetime + from crewai.tools import BaseTool class MyCustomTool(BaseTool): @@ -917,30 +920,36 @@ def test_tool_usage_information_is_appended_to_agent(): def _run(self) -> str: return "Howdy!" - agent1 = Agent( - role="Friendly Neighbor", - goal="Make everyone feel welcome", - backstory="You are the friendly neighbor", - tools=[MyCustomTool(result_as_answer=True)], - ) + fixed_datetime = datetime(2025, 2, 10, 12, 0, 0, tzinfo=UTC) + with patch("datetime.datetime") as mock_datetime: + mock_datetime.now.return_value = fixed_datetime + mock_datetime.side_effect = lambda *args, **kw: datetime(*args, **kw) - greeting = Task( - description="Say an appropriate greeting.", - expected_output="The greeting.", - agent=agent1, - ) - tasks = [greeting] - crew = Crew(agents=[agent1], tasks=tasks) + agent1 = Agent( + role="Friendly Neighbor", + goal="Make everyone feel welcome", + backstory="You are the friendly neighbor", + tools=[MyCustomTool(result_as_answer=True)], + ) - crew.kickoff() - assert agent1.tools_results == [ - { - "result": "Howdy!", - "tool_name": "Decide Greetings", - "tool_args": {}, - "result_as_answer": True, - } - ] + greeting = Task( + description="Say an appropriate greeting.", + expected_output="The greeting.", + agent=agent1, + ) + tasks = [greeting] + crew = Crew(agents=[agent1], tasks=tasks) + + crew.kickoff() + assert agent1.tools_results == [ + { + "result": "Howdy!", + "tool_name": "Decide Greetings", + "tool_args": {}, + "result_as_answer": True, + "start_time": fixed_datetime, + } + ] def test_agent_definition_based_on_dict(): diff --git a/tests/traces/test_unified_trace_controller.py b/tests/traces/test_unified_trace_controller.py new file mode 100644 index 000000000..b14fb5d2d --- /dev/null +++ b/tests/traces/test_unified_trace_controller.py @@ -0,0 +1,360 @@ +import os +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch +from uuid import UUID + +import pytest + +from crewai.traces.context import TraceContext +from crewai.traces.enums import CrewType, RunType, TraceType +from crewai.traces.models import ( + CrewTrace, + FlowStepIO, + LLMRequest, + LLMResponse, +) +from crewai.traces.unified_trace_controller import ( + UnifiedTraceController, + init_crew_main_trace, + init_flow_main_trace, + should_trace, + trace_flow_step, + trace_llm_call, +) + + +class TestUnifiedTraceController: + @pytest.fixture + def basic_trace_controller(self): + return UnifiedTraceController( + trace_type=TraceType.LLM_CALL, + run_type=RunType.KICKOFF, + crew_type=CrewType.CREW, + run_id="test-run-id", + agent_role="test-agent", + task_name="test-task", + task_description="test description", + task_id="test-task-id", + ) + + def test_initialization(self, basic_trace_controller): + """Test basic initialization of UnifiedTraceController""" + assert basic_trace_controller.trace_type == TraceType.LLM_CALL + assert basic_trace_controller.run_type == RunType.KICKOFF + assert basic_trace_controller.crew_type == CrewType.CREW + assert basic_trace_controller.run_id == "test-run-id" + assert basic_trace_controller.agent_role == "test-agent" + assert basic_trace_controller.task_name == "test-task" + assert basic_trace_controller.task_description == "test description" + assert basic_trace_controller.task_id == "test-task-id" + assert basic_trace_controller.status == "running" + assert isinstance(UUID(basic_trace_controller.trace_id), UUID) + + def test_start_trace(self, basic_trace_controller): + """Test starting a trace""" + result = basic_trace_controller.start_trace() + assert result == basic_trace_controller + assert basic_trace_controller.start_time is not None + assert isinstance(basic_trace_controller.start_time, datetime) + + def test_end_trace_success(self, basic_trace_controller): + """Test ending a trace successfully""" + basic_trace_controller.start_trace() + basic_trace_controller.end_trace(result={"test": "result"}) + + assert basic_trace_controller.end_time is not None + assert basic_trace_controller.status == "completed" + assert basic_trace_controller.error is None + assert basic_trace_controller.context.get("response") == {"test": "result"} + + def test_end_trace_with_error(self, basic_trace_controller): + """Test ending a trace with an error""" + basic_trace_controller.start_trace() + basic_trace_controller.end_trace(error="Test error occurred") + + assert basic_trace_controller.end_time is not None + assert basic_trace_controller.status == "error" + assert basic_trace_controller.error == "Test error occurred" + + def test_add_child_trace(self, basic_trace_controller): + """Test adding a child trace""" + child_trace = {"id": "child-1", "type": "test"} + basic_trace_controller.add_child_trace(child_trace) + assert len(basic_trace_controller.children) == 1 + assert basic_trace_controller.children[0] == child_trace + + def test_to_crew_trace_llm_call(self): + """Test converting to CrewTrace for LLM call""" + test_messages = [{"role": "user", "content": "test"}] + test_response = { + "content": "test response", + "finish_reason": "stop", + } + + controller = UnifiedTraceController( + trace_type=TraceType.LLM_CALL, + run_type=RunType.KICKOFF, + crew_type=CrewType.CREW, + run_id="test-run-id", + context={ + "messages": test_messages, + "temperature": 0.7, + "max_tokens": 100, + }, + ) + + # Set model and messages in the context + controller.context["model"] = "gpt-4" + controller.context["messages"] = test_messages + + controller.start_trace() + controller.end_trace(result=test_response) + + crew_trace = controller.to_crew_trace() + assert isinstance(crew_trace, CrewTrace) + assert isinstance(crew_trace.request, LLMRequest) + assert isinstance(crew_trace.response, LLMResponse) + assert crew_trace.request.model == "gpt-4" + assert crew_trace.request.messages == test_messages + assert crew_trace.response.content == test_response["content"] + assert crew_trace.response.finish_reason == test_response["finish_reason"] + + def test_to_crew_trace_flow_step(self): + """Test converting to CrewTrace for flow step""" + flow_step_data = { + "function_name": "test_function", + "inputs": {"param1": "value1"}, + "metadata": {"meta": "data"}, + } + + controller = UnifiedTraceController( + trace_type=TraceType.FLOW_STEP, + run_type=RunType.KICKOFF, + crew_type=CrewType.FLOW, + run_id="test-run-id", + flow_step=flow_step_data, + ) + + controller.start_trace() + controller.end_trace(result="test result") + + crew_trace = controller.to_crew_trace() + assert isinstance(crew_trace, CrewTrace) + assert isinstance(crew_trace.flow_step, FlowStepIO) + assert crew_trace.flow_step.function_name == "test_function" + assert crew_trace.flow_step.inputs == {"param1": "value1"} + assert crew_trace.flow_step.outputs == {"result": "test result"} + + def test_should_trace(self): + """Test should_trace function""" + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + assert should_trace() is True + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "false"}): + assert should_trace() is False + + with patch.dict(os.environ, clear=True): + assert should_trace() is False + + @pytest.mark.asyncio + async def test_trace_flow_step_decorator(self): + """Test trace_flow_step decorator""" + + class TestFlow: + flow_id = "test-flow-id" + + @trace_flow_step + async def test_method(self, method_name, method, *args, **kwargs): + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + flow = TestFlow() + result = await flow.test_method("test_method", lambda x: x, arg1="value1") + assert result == "test result" + + def test_trace_llm_call_decorator(self): + """Test trace_llm_call decorator""" + + class TestLLM: + model = "gpt-4" + temperature = 0.7 + max_tokens = 100 + stop = None + + def _get_execution_context(self): + return MagicMock(), MagicMock() + + def _get_new_messages(self, messages): + return messages + + def _get_new_tool_results(self, agent): + return [] + + @trace_llm_call + def test_method(self, params): + return { + "choices": [ + { + "message": {"content": "test response"}, + "finish_reason": "stop", + } + ], + "usage": { + "total_tokens": 50, + "prompt_tokens": 20, + "completion_tokens": 30, + }, + } + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + llm = TestLLM() + result = llm.test_method({"messages": []}) + assert result["choices"][0]["message"]["content"] == "test response" + + def test_init_crew_main_trace_kickoff(self): + """Test init_crew_main_trace in kickoff mode""" + trace_context = None + + class TestCrew: + id = "test-crew-id" + _test = False + _train = False + + @init_crew_main_trace + def test_method(self): + nonlocal trace_context + trace_context = TraceContext.get_current() + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + crew = TestCrew() + result = test_method(crew) + assert result == "test result" + assert trace_context is not None + assert trace_context.trace_type == TraceType.LLM_CALL + assert trace_context.run_type == RunType.KICKOFF + assert trace_context.crew_type == CrewType.CREW + assert trace_context.run_id == str(crew.id) + + def test_init_crew_main_trace_test_mode(self): + """Test init_crew_main_trace in test mode""" + trace_context = None + + class TestCrew: + id = "test-crew-id" + _test = True + _train = False + + @init_crew_main_trace + def test_method(self): + nonlocal trace_context + trace_context = TraceContext.get_current() + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + crew = TestCrew() + result = test_method(crew) + assert result == "test result" + assert trace_context is not None + assert trace_context.run_type == RunType.TEST + + def test_init_crew_main_trace_train_mode(self): + """Test init_crew_main_trace in train mode""" + trace_context = None + + class TestCrew: + id = "test-crew-id" + _test = False + _train = True + + @init_crew_main_trace + def test_method(self): + nonlocal trace_context + trace_context = TraceContext.get_current() + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + crew = TestCrew() + result = test_method(crew) + assert result == "test result" + assert trace_context is not None + assert trace_context.run_type == RunType.TRAIN + + @pytest.mark.asyncio + async def test_init_flow_main_trace(self): + """Test init_flow_main_trace decorator""" + trace_context = None + test_inputs = {"test": "input"} + + class TestFlow: + flow_id = "test-flow-id" + + @init_flow_main_trace + async def test_method(self, **kwargs): + nonlocal trace_context + trace_context = TraceContext.get_current() + # Verify the context is set during execution + assert trace_context.context["context"]["inputs"] == test_inputs + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + flow = TestFlow() + result = await flow.test_method(inputs=test_inputs) + assert result == "test result" + assert trace_context is not None + assert trace_context.trace_type == TraceType.FLOW_STEP + assert trace_context.crew_type == CrewType.FLOW + assert trace_context.run_type == RunType.KICKOFF + assert trace_context.run_id == str(flow.flow_id) + assert trace_context.context["context"]["inputs"] == test_inputs + + def test_trace_context_management(self): + """Test TraceContext management""" + trace1 = UnifiedTraceController( + trace_type=TraceType.LLM_CALL, + run_type=RunType.KICKOFF, + crew_type=CrewType.CREW, + run_id="test-run-1", + ) + + trace2 = UnifiedTraceController( + trace_type=TraceType.FLOW_STEP, + run_type=RunType.TEST, + crew_type=CrewType.FLOW, + run_id="test-run-2", + ) + + # Test that context is initially empty + assert TraceContext.get_current() is None + + # Test setting and getting context + with TraceContext.set_current(trace1): + assert TraceContext.get_current() == trace1 + + # Test nested context + with TraceContext.set_current(trace2): + assert TraceContext.get_current() == trace2 + + # Test context restoration after nested block + assert TraceContext.get_current() == trace1 + + # Test context cleanup after with block + assert TraceContext.get_current() is None + + def test_trace_context_error_handling(self): + """Test TraceContext error handling""" + trace = UnifiedTraceController( + trace_type=TraceType.LLM_CALL, + run_type=RunType.KICKOFF, + crew_type=CrewType.CREW, + run_id="test-run", + ) + + # Test that context is properly cleaned up even if an error occurs + try: + with TraceContext.set_current(trace): + raise ValueError("Test error") + except ValueError: + pass + + assert TraceContext.get_current() is None