diff --git a/src/crewai/traces/unified_trace_controller.py b/src/crewai/traces/unified_trace_controller.py index 20ed590d6..cbfdbea9e 100644 --- a/src/crewai/traces/unified_trace_controller.py +++ b/src/crewai/traces/unified_trace_controller.py @@ -79,6 +79,9 @@ class UnifiedTraceController: self.flow_step = flow_step self.status: str = "running" + def _get_current_trace(self) -> "UnifiedTraceController": + return TraceContext.get_current() + def start_trace(self) -> "UnifiedTraceController": """Start the trace execution. @@ -190,89 +193,6 @@ def should_trace() -> bool: return os.getenv("CREWAI_ENABLE_TRACING", "false").lower() == "true" -def trace_llm_call(func: Callable[..., Any]) -> Callable[..., Any]: - """Decorator to trace LLM calls. - - Args: - func: The function to trace. - - Returns: - Wrapped function that creates and manages the LLM call trace context. - """ - - @wraps(func) - def wrapper(self: Any, params: Dict[str, Any], *args: Any, **kwargs: Any) -> Any: - if not should_trace(): - return func(self, params, *args, **kwargs) - - current_trace = TraceContext.get_current() - agent, task = self._get_execution_context() - - # Get new messages and tool results - new_messages = self._get_new_messages(params.get("messages", [])) - new_tool_results = self._get_new_tool_results(agent) - - # Create trace context - trace = UnifiedTraceController( - trace_type=TraceType.TOOL_CALL if new_tool_results else TraceType.LLM_CALL, - crew_type=current_trace.crew_type if current_trace else CrewType.CREW, - run_type=current_trace.run_type if current_trace else RunType.KICKOFF, - run_id=current_trace.run_id if current_trace else str(uuid4()), - parent_trace_id=current_trace.trace_id if current_trace else None, - agent_role=agent.role if agent else "unknown", - task_id=str(task.id) if task else None, - task_name=task.name if task else None, - task_description=task.description if task else None, - model=self.model, - messages=new_messages, - temperature=self.temperature, - max_tokens=self.max_tokens, - stop_sequences=self.stop, - tool_calls=[ - ToolCall( - name=result["tool_name"], - arguments=result["tool_args"], - output=str(result["result"]), - start_time=result.get("start_time", ""), - end_time=datetime.now(UTC), - ) - for result in new_tool_results - ], - ) - - with TraceContext.set_current(trace): - trace.start_trace() - try: - response = func(self, params, *args, **kwargs) - - # Extract relevant data from response - trace_response = { - "content": response["choices"][0]["message"]["content"], - "finish_reason": response["choices"][0].get("finish_reason"), - } - - # Add usage metrics to context - if "usage" in response: - trace.context["tokens_used"] = response["usage"].get( - "total_tokens", 0 - ) - trace.context["prompt_tokens"] = response["usage"].get( - "prompt_tokens", 0 - ) - trace.context["completion_tokens"] = response["usage"].get( - "completion_tokens", 0 - ) - - trace.end_trace(trace_response) - - return response - except Exception as e: - trace.end_trace(error=str(e)) - raise - - return wrapper - - # Crew main trace def init_crew_main_trace(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator to initialize and track the main crew execution trace. @@ -478,3 +398,102 @@ def build_flow_step_trace( }, ) return trace + + +# LLM trace +def trace_llm_call(func: Callable[..., Any]) -> Callable[..., Any]: + """Decorator to trace LLM calls. + + Args: + func: The function to trace. + + Returns: + Wrapped function that creates and manages the LLM call trace context. + """ + + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if not should_trace(): + return func(self, *args, **kwargs) + + trace = build_llm_trace(self, *args, **kwargs) + with TraceContext.set_current(trace): + trace.start_trace() + try: + response = func(self, *args, **kwargs) + # Extract relevant data from response + trace_response = { + "content": response["choices"][0]["message"]["content"], + "finish_reason": response["choices"][0].get("finish_reason"), + } + + # Add usage metrics to context + if "usage" in response: + trace.context["tokens_used"] = response["usage"].get( + "total_tokens", 0 + ) + trace.context["prompt_tokens"] = response["usage"].get( + "prompt_tokens", 0 + ) + trace.context["completion_tokens"] = response["usage"].get( + "completion_tokens", 0 + ) + + trace.end_trace(trace_response) + return response + except Exception as e: + trace.end_trace(error=str(e)) + raise + + return wrapper + + +def build_llm_trace( + self: Any, params: Dict[str, Any], *args: Any, **kwargs: Any +) -> Any: + """Build a trace controller for an LLM call. + + Args: + self: The LLM instance. + params: The parameters for the LLM call. + *args: Variable positional arguments. + **kwargs: Variable keyword arguments. + + Returns: + UnifiedTraceController: The configured trace controller for the LLM call. + """ + current_trace = TraceContext.get_current() + agent, task = self._get_execution_context() + + # Get new messages and tool results + new_messages = self._get_new_messages(params.get("messages", [])) + new_tool_results = self._get_new_tool_results(agent) + + # Create trace context + trace = UnifiedTraceController( + trace_type=TraceType.TOOL_CALL if new_tool_results else TraceType.LLM_CALL, + crew_type=current_trace.crew_type if current_trace else CrewType.CREW, + run_type=current_trace.run_type if current_trace else RunType.KICKOFF, + run_id=current_trace.run_id if current_trace else str(uuid4()), + parent_trace_id=current_trace.trace_id if current_trace else None, + agent_role=agent.role if agent else "unknown", + task_id=str(task.id) if task else None, + task_name=task.name if task else None, + task_description=task.description if task else None, + model=self.model, + messages=new_messages, + temperature=self.temperature, + max_tokens=self.max_tokens, + stop_sequences=self.stop, + tool_calls=[ + ToolCall( + name=result["tool_name"], + arguments=result["tool_args"], + output=str(result["result"]), + start_time=result.get("start_time", ""), + end_time=datetime.now(UTC), + ) + for result in new_tool_results + ], + ) + return trace