mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-17 04:48:30 +00:00
Compare commits
20 Commits
bugfix/flo
...
feat/add-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec43a6fae9 | ||
|
|
6025301205 | ||
|
|
2a80a2a611 | ||
|
|
e41e2c1210 | ||
|
|
1e140fc6d8 | ||
|
|
679bfce647 | ||
|
|
ba197ec8db | ||
|
|
8f3bf31339 | ||
|
|
e3026ebd56 | ||
|
|
2f846fc945 | ||
|
|
bd6e45b905 | ||
|
|
c20020e3fb | ||
|
|
16722925eb | ||
|
|
fa01dcb5dc | ||
|
|
e67f772a64 | ||
|
|
6a8ca951a7 | ||
|
|
c06e6e0021 | ||
|
|
d57b017e7b | ||
|
|
4957a9c20c | ||
|
|
d263540325 |
@@ -38,6 +38,7 @@ from crewai.tasks.task_output import TaskOutput
|
|||||||
from crewai.telemetry import Telemetry
|
from crewai.telemetry import Telemetry
|
||||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||||
from crewai.tools.base_tool import Tool
|
from crewai.tools.base_tool import Tool
|
||||||
|
from crewai.traces.unified_trace_controller import init_crew_main_trace
|
||||||
from crewai.types.usage_metrics import UsageMetrics
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
@@ -545,6 +546,7 @@ class Crew(BaseModel):
|
|||||||
CrewTrainingHandler(filename).clear()
|
CrewTrainingHandler(filename).clear()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@init_crew_main_trace
|
||||||
def kickoff(
|
def kickoff(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[Dict[str, Any]] = None,
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
|
|||||||
@@ -30,6 +30,10 @@ from crewai.flow.flow_visualizer import plot_flow
|
|||||||
from crewai.flow.persistence.base import FlowPersistence
|
from crewai.flow.persistence.base import FlowPersistence
|
||||||
from crewai.flow.utils import get_possible_return_constants
|
from crewai.flow.utils import get_possible_return_constants
|
||||||
from crewai.telemetry import Telemetry
|
from crewai.telemetry import Telemetry
|
||||||
|
from crewai.traces.unified_trace_controller import (
|
||||||
|
init_flow_main_trace,
|
||||||
|
trace_flow_step,
|
||||||
|
)
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -753,8 +757,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if inputs is not None and "id" not in inputs:
|
if inputs is not None and "id" not in inputs:
|
||||||
self._initialize_state(inputs)
|
self._initialize_state(inputs)
|
||||||
|
|
||||||
return asyncio.run(self.kickoff_async())
|
async def run_flow():
|
||||||
|
return await self.kickoff_async()
|
||||||
|
|
||||||
|
return asyncio.run(run_flow())
|
||||||
|
|
||||||
|
@init_flow_main_trace
|
||||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||||
if not self._start_methods:
|
if not self._start_methods:
|
||||||
raise ValueError("No start method defined")
|
raise ValueError("No start method defined")
|
||||||
@@ -804,6 +812,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
)
|
)
|
||||||
await self._execute_listeners(start_method_name, result)
|
await self._execute_listeners(start_method_name, result)
|
||||||
|
|
||||||
|
@trace_flow_step
|
||||||
async def _execute_method(
|
async def _execute_method(
|
||||||
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
|
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -5,7 +6,17 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -18,9 +29,11 @@ with warnings.catch_warnings():
|
|||||||
from litellm.utils import supports_response_schema
|
from litellm.utils import supports_response_schema
|
||||||
|
|
||||||
|
|
||||||
|
from crewai.traces.unified_trace_controller import trace_llm_call
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
)
|
)
|
||||||
|
from crewai.utilities.protocols import AgentExecutorProtocol
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
@@ -164,6 +177,7 @@ class LLM:
|
|||||||
self.context_window_size = 0
|
self.context_window_size = 0
|
||||||
self.reasoning_effort = reasoning_effort
|
self.reasoning_effort = reasoning_effort
|
||||||
self.additional_params = kwargs
|
self.additional_params = kwargs
|
||||||
|
self._message_history: List[Dict[str, str]] = []
|
||||||
self.is_anthropic = self._is_anthropic_model(model)
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
@@ -179,6 +193,12 @@ class LLM:
|
|||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
self.set_env_callbacks()
|
self.set_env_callbacks()
|
||||||
|
|
||||||
|
@trace_llm_call
|
||||||
|
def _call_llm(self, params: Dict[str, Any]) -> Any:
|
||||||
|
with suppress_warnings():
|
||||||
|
response = litellm.completion(**params)
|
||||||
|
return response
|
||||||
|
|
||||||
def _is_anthropic_model(self, model: str) -> bool:
|
def _is_anthropic_model(self, model: str) -> bool:
|
||||||
"""Determine if the model is from Anthropic provider.
|
"""Determine if the model is from Anthropic provider.
|
||||||
|
|
||||||
@@ -188,7 +208,7 @@ class LLM:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if the model is from Anthropic, False otherwise.
|
bool: True if the model is from Anthropic, False otherwise.
|
||||||
"""
|
"""
|
||||||
ANTHROPIC_PREFIXES = ('anthropic/', 'claude-', 'claude/')
|
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
@@ -288,7 +308,7 @@ class LLM:
|
|||||||
params = {k: v for k, v in params.items() if v is not None}
|
params = {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
# --- 2) Make the completion call
|
# --- 2) Make the completion call
|
||||||
response = litellm.completion(**params)
|
response = self._call_llm(params)
|
||||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||||
0
|
0
|
||||||
].message
|
].message
|
||||||
@@ -348,7 +368,9 @@ class LLM:
|
|||||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _format_messages_for_provider(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
def _format_messages_for_provider(
|
||||||
|
self, messages: List[Dict[str, str]]
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
"""Format messages according to provider requirements.
|
"""Format messages according to provider requirements.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -368,7 +390,9 @@ class LLM:
|
|||||||
# Validate message format first
|
# Validate message format first
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||||
raise TypeError("Invalid message format. Each message must be a dict with 'role' and 'content' keys")
|
raise TypeError(
|
||||||
|
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
|
||||||
|
)
|
||||||
|
|
||||||
if not self.is_anthropic:
|
if not self.is_anthropic:
|
||||||
return messages
|
return messages
|
||||||
@@ -495,3 +519,95 @@ class LLM:
|
|||||||
|
|
||||||
litellm.success_callback = success_callbacks
|
litellm.success_callback = success_callbacks
|
||||||
litellm.failure_callback = failure_callbacks
|
litellm.failure_callback = failure_callbacks
|
||||||
|
|
||||||
|
def _get_execution_context(self) -> Tuple[Optional[Any], Optional[Any]]:
|
||||||
|
"""Get the agent and task from the execution context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (agent, task) from any AgentExecutor context, or (None, None) if not found
|
||||||
|
"""
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
caller_frame = frame.f_back if frame else None
|
||||||
|
agent = None
|
||||||
|
task = None
|
||||||
|
|
||||||
|
# Add a maximum depth to prevent infinite loops
|
||||||
|
max_depth = 100 # Reasonable limit for call stack depth
|
||||||
|
current_depth = 0
|
||||||
|
|
||||||
|
while caller_frame and current_depth < max_depth:
|
||||||
|
if "self" in caller_frame.f_locals:
|
||||||
|
caller_self = caller_frame.f_locals["self"]
|
||||||
|
if isinstance(caller_self, AgentExecutorProtocol):
|
||||||
|
agent = caller_self.agent
|
||||||
|
task = caller_self.task
|
||||||
|
break
|
||||||
|
caller_frame = caller_frame.f_back
|
||||||
|
current_depth += 1
|
||||||
|
|
||||||
|
return agent, task
|
||||||
|
|
||||||
|
def _get_new_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||||
|
"""Get only the new messages that haven't been processed before."""
|
||||||
|
if not hasattr(self, "_message_history"):
|
||||||
|
self._message_history = []
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
message_key = (message["role"], message["content"])
|
||||||
|
if message_key not in [
|
||||||
|
(m["role"], m["content"]) for m in self._message_history
|
||||||
|
]:
|
||||||
|
new_messages.append(message)
|
||||||
|
self._message_history.append(message)
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
def _get_new_tool_results(self, agent) -> List[Dict]:
|
||||||
|
"""Get only the new tool results that haven't been processed before."""
|
||||||
|
if not agent or not agent.tools_results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not hasattr(self, "_tool_results_history"):
|
||||||
|
self._tool_results_history: List[Dict] = []
|
||||||
|
|
||||||
|
new_tool_results = []
|
||||||
|
|
||||||
|
for result in agent.tools_results:
|
||||||
|
# Process tool arguments to extract actual values
|
||||||
|
processed_args = {}
|
||||||
|
if isinstance(result["tool_args"], dict):
|
||||||
|
for key, value in result["tool_args"].items():
|
||||||
|
if isinstance(value, dict) and "type" in value:
|
||||||
|
# Skip metadata and just store the actual value
|
||||||
|
continue
|
||||||
|
processed_args[key] = value
|
||||||
|
|
||||||
|
# Create a clean result with processed arguments
|
||||||
|
clean_result = {
|
||||||
|
"tool_name": result["tool_name"],
|
||||||
|
"tool_args": processed_args,
|
||||||
|
"result": result["result"],
|
||||||
|
"content": result.get("content", ""),
|
||||||
|
"start_time": result.get("start_time", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if this exact tool execution exists in history
|
||||||
|
is_duplicate = False
|
||||||
|
for history_result in self._tool_results_history:
|
||||||
|
if (
|
||||||
|
clean_result["tool_name"] == history_result["tool_name"]
|
||||||
|
and str(clean_result["tool_args"])
|
||||||
|
== str(history_result["tool_args"])
|
||||||
|
and str(clean_result["result"]) == str(history_result["result"])
|
||||||
|
and clean_result["content"] == history_result.get("content", "")
|
||||||
|
and clean_result["start_time"]
|
||||||
|
== history_result.get("start_time", "")
|
||||||
|
):
|
||||||
|
is_duplicate = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not is_duplicate:
|
||||||
|
new_tool_results.append(clean_result)
|
||||||
|
self._tool_results_history.append(clean_result)
|
||||||
|
|
||||||
|
return new_tool_results
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import ast
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from datetime import UTC
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
@@ -116,7 +117,10 @@ class ToolUsage:
|
|||||||
self._printer.print(content=f"\n\n{error}\n", color="red")
|
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||||
return error
|
return error
|
||||||
|
|
||||||
if isinstance(tool, CrewStructuredTool) and tool.name == self._i18n.tools("add_image")["name"]: # type: ignore
|
if (
|
||||||
|
isinstance(tool, CrewStructuredTool)
|
||||||
|
and tool.name == self._i18n.tools("add_image")["name"] # type: ignore
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
result = self._use(tool_string=tool_string, tool=tool, calling=calling)
|
result = self._use(tool_string=tool_string, tool=tool, calling=calling)
|
||||||
return result
|
return result
|
||||||
@@ -154,6 +158,7 @@ class ToolUsage:
|
|||||||
self.task.increment_tools_errors()
|
self.task.increment_tools_errors()
|
||||||
|
|
||||||
started_at = time.time()
|
started_at = time.time()
|
||||||
|
started_at_trace = datetime.datetime.now(UTC)
|
||||||
from_cache = False
|
from_cache = False
|
||||||
|
|
||||||
result = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
result = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||||
@@ -181,7 +186,9 @@ class ToolUsage:
|
|||||||
|
|
||||||
if calling.arguments:
|
if calling.arguments:
|
||||||
try:
|
try:
|
||||||
acceptable_args = tool.args_schema.model_json_schema()["properties"].keys() # type: ignore
|
acceptable_args = tool.args_schema.model_json_schema()[
|
||||||
|
"properties"
|
||||||
|
].keys() # type: ignore
|
||||||
arguments = {
|
arguments = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in calling.arguments.items()
|
for k, v in calling.arguments.items()
|
||||||
@@ -202,7 +209,7 @@ class ToolUsage:
|
|||||||
error=e, tool=tool.name, tool_inputs=tool.description
|
error=e, tool=tool.name, tool_inputs=tool.description
|
||||||
)
|
)
|
||||||
error = ToolUsageErrorException(
|
error = ToolUsageErrorException(
|
||||||
f'\n{error_message}.\nMoving on then. {self._i18n.slice("format").format(tool_names=self.tools_names)}'
|
f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||||
).message
|
).message
|
||||||
self.task.increment_tools_errors()
|
self.task.increment_tools_errors()
|
||||||
if self.agent.verbose:
|
if self.agent.verbose:
|
||||||
@@ -244,6 +251,7 @@ class ToolUsage:
|
|||||||
"result": result,
|
"result": result,
|
||||||
"tool_name": tool.name,
|
"tool_name": tool.name,
|
||||||
"tool_args": calling.arguments,
|
"tool_args": calling.arguments,
|
||||||
|
"start_time": started_at_trace,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.on_tool_use_finished(
|
self.on_tool_use_finished(
|
||||||
@@ -368,7 +376,7 @@ class ToolUsage:
|
|||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
return ToolUsageErrorException(
|
return ToolUsageErrorException(
|
||||||
f'{self._i18n.errors("tool_arguments_error")}'
|
f"{self._i18n.errors('tool_arguments_error')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(arguments, dict):
|
if not isinstance(arguments, dict):
|
||||||
@@ -376,7 +384,7 @@ class ToolUsage:
|
|||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
return ToolUsageErrorException(
|
return ToolUsageErrorException(
|
||||||
f'{self._i18n.errors("tool_arguments_error")}'
|
f"{self._i18n.errors('tool_arguments_error')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ToolCalling(
|
return ToolCalling(
|
||||||
@@ -404,7 +412,7 @@ class ToolUsage:
|
|||||||
if self.agent.verbose:
|
if self.agent.verbose:
|
||||||
self._printer.print(content=f"\n\n{e}\n", color="red")
|
self._printer.print(content=f"\n\n{e}\n", color="red")
|
||||||
return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling")
|
return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling")
|
||||||
f'{self._i18n.errors("tool_usage_error").format(error=e)}\nMoving on then. {self._i18n.slice("format").format(tool_names=self.tools_names)}'
|
f"{self._i18n.errors('tool_usage_error').format(error=e)}\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||||
)
|
)
|
||||||
return self._tool_calling(tool_string)
|
return self._tool_calling(tool_string)
|
||||||
|
|
||||||
|
|||||||
0
src/crewai/traces/__init__.py
Normal file
0
src/crewai/traces/__init__.py
Normal file
39
src/crewai/traces/context.py
Normal file
39
src/crewai/traces/context.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
|
||||||
|
class TraceContext:
|
||||||
|
"""Maintains the current trace context throughout the execution stack.
|
||||||
|
|
||||||
|
This class provides a context manager for tracking trace execution across
|
||||||
|
async and sync code paths using ContextVars.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_context: ContextVar = ContextVar("trace_context", default=None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_current(cls):
|
||||||
|
"""Get the current trace context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[UnifiedTraceController]: The current trace controller or None if not set.
|
||||||
|
"""
|
||||||
|
return cls._context.get()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def set_current(cls, trace):
|
||||||
|
"""Set the current trace context within a context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace: The trace controller to set as current.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
UnifiedTraceController: The current trace controller.
|
||||||
|
"""
|
||||||
|
token = cls._context.set(trace)
|
||||||
|
try:
|
||||||
|
yield trace
|
||||||
|
finally:
|
||||||
|
cls._context.reset(token)
|
||||||
19
src/crewai/traces/enums.py
Normal file
19
src/crewai/traces/enums.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class TraceType(Enum):
|
||||||
|
LLM_CALL = "llm_call"
|
||||||
|
TOOL_CALL = "tool_call"
|
||||||
|
FLOW_STEP = "flow_step"
|
||||||
|
START_CALL = "start_call"
|
||||||
|
|
||||||
|
|
||||||
|
class RunType(Enum):
|
||||||
|
KICKOFF = "kickoff"
|
||||||
|
TRAIN = "train"
|
||||||
|
TEST = "test"
|
||||||
|
|
||||||
|
|
||||||
|
class CrewType(Enum):
|
||||||
|
CREW = "crew"
|
||||||
|
FLOW = "flow"
|
||||||
89
src/crewai/traces/models.py
Normal file
89
src/crewai/traces/models.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
"""Model representing a tool call during execution"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
arguments: Dict[str, Any]
|
||||||
|
output: str
|
||||||
|
start_time: datetime
|
||||||
|
end_time: Optional[datetime] = None
|
||||||
|
latency_ms: Optional[int] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRequest(BaseModel):
|
||||||
|
"""Model representing the LLM request details"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
messages: List[Dict[str, str]]
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
additional_params: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMResponse(BaseModel):
|
||||||
|
"""Model representing the LLM response details"""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FlowStepIO(BaseModel):
|
||||||
|
"""Model representing flow step input/output details"""
|
||||||
|
|
||||||
|
function_name: str
|
||||||
|
inputs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
outputs: Any
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CrewTrace(BaseModel):
|
||||||
|
"""Model for tracking detailed information about LLM interactions and Flow steps"""
|
||||||
|
|
||||||
|
deployment_instance_id: Optional[str] = Field(
|
||||||
|
description="ID of the deployment instance"
|
||||||
|
)
|
||||||
|
trace_id: str = Field(description="Unique identifier for this trace")
|
||||||
|
run_id: str = Field(description="Identifier for the execution run")
|
||||||
|
agent_role: Optional[str] = Field(description="Role of the agent")
|
||||||
|
task_id: Optional[str] = Field(description="ID of the current task being executed")
|
||||||
|
task_name: Optional[str] = Field(description="Name of the current task")
|
||||||
|
task_description: Optional[str] = Field(
|
||||||
|
description="Description of the current task"
|
||||||
|
)
|
||||||
|
trace_type: str = Field(description="Type of the trace")
|
||||||
|
crew_type: str = Field(description="Type of the crew")
|
||||||
|
run_type: str = Field(description="Type of the run")
|
||||||
|
|
||||||
|
# Timing information
|
||||||
|
start_time: Optional[datetime] = None
|
||||||
|
end_time: Optional[datetime] = None
|
||||||
|
latency_ms: Optional[int] = None
|
||||||
|
|
||||||
|
# Request/Response for LLM calls
|
||||||
|
request: Optional[LLMRequest] = None
|
||||||
|
response: Optional[LLMResponse] = None
|
||||||
|
|
||||||
|
# Input/Output for Flow steps
|
||||||
|
flow_step: Optional[FlowStepIO] = None
|
||||||
|
|
||||||
|
# Tool usage
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
tokens_used: Optional[int] = None
|
||||||
|
prompt_tokens: Optional[int] = None
|
||||||
|
completion_tokens: Optional[int] = None
|
||||||
|
cost: Optional[float] = None
|
||||||
|
|
||||||
|
# Additional metadata
|
||||||
|
status: str = "running" # running, completed, error
|
||||||
|
error: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
tags: List[str] = Field(default_factory=list)
|
||||||
543
src/crewai/traces/unified_trace_controller.py
Normal file
543
src/crewai/traces/unified_trace_controller.py
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from crewai.traces.context import TraceContext
|
||||||
|
from crewai.traces.enums import CrewType, RunType, TraceType
|
||||||
|
from crewai.traces.models import (
|
||||||
|
CrewTrace,
|
||||||
|
FlowStepIO,
|
||||||
|
LLMRequest,
|
||||||
|
LLMResponse,
|
||||||
|
ToolCall,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedTraceController:
|
||||||
|
"""Controls and manages trace execution and recording.
|
||||||
|
|
||||||
|
This class handles the lifecycle of traces including creation, execution tracking,
|
||||||
|
and recording of results for various types of operations (LLM calls, tool calls, flow steps).
|
||||||
|
"""
|
||||||
|
|
||||||
|
_task_traces: Dict[str, List["UnifiedTraceController"]] = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
trace_type: TraceType,
|
||||||
|
run_type: RunType,
|
||||||
|
crew_type: CrewType,
|
||||||
|
run_id: str,
|
||||||
|
deployment_instance_id: str = os.environ.get(
|
||||||
|
"CREWAI_DEPLOYMENT_INSTANCE_ID", ""
|
||||||
|
),
|
||||||
|
parent_trace_id: Optional[str] = None,
|
||||||
|
agent_role: Optional[str] = "unknown",
|
||||||
|
task_name: Optional[str] = None,
|
||||||
|
task_description: Optional[str] = None,
|
||||||
|
task_id: Optional[str] = None,
|
||||||
|
flow_step: Dict[str, Any] = {},
|
||||||
|
tool_calls: List[ToolCall] = [],
|
||||||
|
**context: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a new trace controller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_type: Type of trace being recorded.
|
||||||
|
run_type: Type of run being executed.
|
||||||
|
crew_type: Type of crew executing the trace.
|
||||||
|
run_id: Unique identifier for the run.
|
||||||
|
deployment_instance_id: Optional deployment instance identifier.
|
||||||
|
parent_trace_id: Optional parent trace identifier for nested traces.
|
||||||
|
agent_role: Role of the agent executing the trace.
|
||||||
|
task_name: Optional name of the task being executed.
|
||||||
|
task_description: Optional description of the task.
|
||||||
|
task_id: Optional unique identifier for the task.
|
||||||
|
flow_step: Optional flow step information.
|
||||||
|
tool_calls: Optional list of tool calls made during execution.
|
||||||
|
**context: Additional context parameters.
|
||||||
|
"""
|
||||||
|
self.trace_id = str(uuid4())
|
||||||
|
self.run_id = run_id
|
||||||
|
self.parent_trace_id = parent_trace_id
|
||||||
|
self.trace_type = trace_type
|
||||||
|
self.run_type = run_type
|
||||||
|
self.crew_type = crew_type
|
||||||
|
self.context = context
|
||||||
|
self.agent_role = agent_role
|
||||||
|
self.task_name = task_name
|
||||||
|
self.task_description = task_description
|
||||||
|
self.task_id = task_id
|
||||||
|
self.deployment_instance_id = deployment_instance_id
|
||||||
|
self.children: List[Dict[str, Any]] = []
|
||||||
|
self.start_time: Optional[datetime] = None
|
||||||
|
self.end_time: Optional[datetime] = None
|
||||||
|
self.error: Optional[str] = None
|
||||||
|
self.tool_calls = tool_calls
|
||||||
|
self.flow_step = flow_step
|
||||||
|
self.status: str = "running"
|
||||||
|
|
||||||
|
# Add trace to task's trace collection if task_id is present
|
||||||
|
if task_id:
|
||||||
|
self._add_to_task_traces()
|
||||||
|
|
||||||
|
def _add_to_task_traces(self) -> None:
|
||||||
|
"""Add this trace to the task's trace collection."""
|
||||||
|
if not hasattr(UnifiedTraceController, "_task_traces"):
|
||||||
|
UnifiedTraceController._task_traces = {}
|
||||||
|
|
||||||
|
if self.task_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.task_id not in UnifiedTraceController._task_traces:
|
||||||
|
UnifiedTraceController._task_traces[self.task_id] = []
|
||||||
|
|
||||||
|
UnifiedTraceController._task_traces[self.task_id].append(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_task_traces(cls, task_id: str) -> List["UnifiedTraceController"]:
|
||||||
|
"""Get all traces for a specific task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task to get traces for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of traces associated with the task
|
||||||
|
"""
|
||||||
|
return cls._task_traces.get(task_id, [])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_task_traces(cls, task_id: str) -> None:
|
||||||
|
"""Clear traces for a specific task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task to clear traces for
|
||||||
|
"""
|
||||||
|
if hasattr(cls, "_task_traces") and task_id in cls._task_traces:
|
||||||
|
del cls._task_traces[task_id]
|
||||||
|
|
||||||
|
def _get_current_trace(self) -> "UnifiedTraceController":
|
||||||
|
return TraceContext.get_current()
|
||||||
|
|
||||||
|
def start_trace(self) -> "UnifiedTraceController":
|
||||||
|
"""Start the trace execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UnifiedTraceController: Self for method chaining.
|
||||||
|
"""
|
||||||
|
self.start_time = datetime.now(UTC)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def end_trace(self, result: Any = None, error: Optional[str] = None) -> None:
|
||||||
|
"""End the trace execution and record results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Optional result from the trace execution.
|
||||||
|
error: Optional error message if the trace failed.
|
||||||
|
"""
|
||||||
|
self.end_time = datetime.now(UTC)
|
||||||
|
self.status = "error" if error else "completed"
|
||||||
|
self.error = error
|
||||||
|
self._record_trace(result)
|
||||||
|
|
||||||
|
def add_child_trace(self, child_trace: Dict[str, Any]) -> None:
|
||||||
|
"""Add a child trace to this trace's execution history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
child_trace: The child trace information to add.
|
||||||
|
"""
|
||||||
|
self.children.append(child_trace)
|
||||||
|
|
||||||
|
def to_crew_trace(self) -> CrewTrace:
|
||||||
|
"""Convert to CrewTrace format for storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CrewTrace: The trace data in CrewTrace format.
|
||||||
|
"""
|
||||||
|
latency_ms = None
|
||||||
|
|
||||||
|
if self.tool_calls and hasattr(self.tool_calls[0], "start_time"):
|
||||||
|
self.start_time = self.tool_calls[0].start_time
|
||||||
|
|
||||||
|
if self.start_time and self.end_time:
|
||||||
|
latency_ms = int((self.end_time - self.start_time).total_seconds() * 1000)
|
||||||
|
|
||||||
|
request = None
|
||||||
|
response = None
|
||||||
|
flow_step_obj = None
|
||||||
|
|
||||||
|
if self.trace_type in [TraceType.LLM_CALL, TraceType.TOOL_CALL]:
|
||||||
|
request = LLMRequest(
|
||||||
|
model=self.context.get("model", "unknown"),
|
||||||
|
messages=self.context.get("messages", []),
|
||||||
|
temperature=self.context.get("temperature"),
|
||||||
|
max_tokens=self.context.get("max_tokens"),
|
||||||
|
stop_sequences=self.context.get("stop_sequences"),
|
||||||
|
)
|
||||||
|
if "response" in self.context:
|
||||||
|
response = LLMResponse(
|
||||||
|
content=self.context["response"].get("content", ""),
|
||||||
|
finish_reason=self.context["response"].get("finish_reason"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.trace_type == TraceType.FLOW_STEP:
|
||||||
|
flow_step_obj = FlowStepIO(
|
||||||
|
function_name=self.flow_step.get("function_name", "unknown"),
|
||||||
|
inputs=self.flow_step.get("inputs", {}),
|
||||||
|
outputs={"result": self.context.get("response")},
|
||||||
|
metadata=self.flow_step.get("metadata", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
return CrewTrace(
|
||||||
|
deployment_instance_id=self.deployment_instance_id,
|
||||||
|
trace_id=self.trace_id,
|
||||||
|
task_id=self.task_id,
|
||||||
|
run_id=self.run_id,
|
||||||
|
agent_role=self.agent_role,
|
||||||
|
task_name=self.task_name,
|
||||||
|
task_description=self.task_description,
|
||||||
|
trace_type=self.trace_type.value,
|
||||||
|
crew_type=self.crew_type.value,
|
||||||
|
run_type=self.run_type.value,
|
||||||
|
start_time=self.start_time,
|
||||||
|
end_time=self.end_time,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
request=request,
|
||||||
|
response=response,
|
||||||
|
flow_step=flow_step_obj,
|
||||||
|
tool_calls=self.tool_calls,
|
||||||
|
tokens_used=self.context.get("tokens_used"),
|
||||||
|
prompt_tokens=self.context.get("prompt_tokens"),
|
||||||
|
completion_tokens=self.context.get("completion_tokens"),
|
||||||
|
status=self.status,
|
||||||
|
error=self.error,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _record_trace(self, result: Any = None) -> None:
|
||||||
|
"""Record the trace.
|
||||||
|
|
||||||
|
This method is called when a trace is completed. It ensures the trace
|
||||||
|
is properly recorded and associated with its task if applicable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Optional result to include in the trace
|
||||||
|
"""
|
||||||
|
if result:
|
||||||
|
self.context["response"] = result
|
||||||
|
|
||||||
|
# Add to task traces if this trace belongs to a task
|
||||||
|
if self.task_id:
|
||||||
|
self._add_to_task_traces()
|
||||||
|
|
||||||
|
|
||||||
|
def should_trace() -> bool:
|
||||||
|
"""Check if tracing is enabled via environment variable."""
|
||||||
|
return os.getenv("CREWAI_ENABLE_TRACING", "false").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
|
# Crew main trace
|
||||||
|
def init_crew_main_trace(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
"""Decorator to initialize and track the main crew execution trace.
|
||||||
|
|
||||||
|
This decorator sets up the trace context for the main crew execution,
|
||||||
|
handling both synchronous and asynchronous crew operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The crew function to be traced.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped function that creates and manages the main crew trace context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
if not should_trace():
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
trace = build_crew_main_trace(self)
|
||||||
|
with TraceContext.set_current(trace):
|
||||||
|
try:
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
trace.end_trace(error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def build_crew_main_trace(self: Any) -> "UnifiedTraceController":
|
||||||
|
"""Build the main trace controller for a crew execution.
|
||||||
|
|
||||||
|
This function creates a trace controller configured for the main crew execution,
|
||||||
|
handling different run types (kickoff, test, train) and maintaining context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The crew instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UnifiedTraceController: The configured trace controller for the crew.
|
||||||
|
"""
|
||||||
|
run_type = RunType.KICKOFF
|
||||||
|
if hasattr(self, "_test") and self._test:
|
||||||
|
run_type = RunType.TEST
|
||||||
|
elif hasattr(self, "_train") and self._train:
|
||||||
|
run_type = RunType.TRAIN
|
||||||
|
|
||||||
|
current_trace = TraceContext.get_current()
|
||||||
|
|
||||||
|
trace = UnifiedTraceController(
|
||||||
|
trace_type=TraceType.LLM_CALL,
|
||||||
|
run_type=run_type,
|
||||||
|
crew_type=current_trace.crew_type if current_trace else CrewType.CREW,
|
||||||
|
run_id=current_trace.run_id if current_trace else str(self.id),
|
||||||
|
parent_trace_id=current_trace.trace_id if current_trace else None,
|
||||||
|
)
|
||||||
|
return trace
|
||||||
|
|
||||||
|
|
||||||
|
# Flow main trace
|
||||||
|
def init_flow_main_trace(
|
||||||
|
func: Callable[..., Awaitable[Any]],
|
||||||
|
) -> Callable[..., Awaitable[Any]]:
|
||||||
|
"""Decorator to initialize and track the main flow execution trace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The async flow function to be traced.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped async function that creates and manages the main flow trace context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
if not should_trace():
|
||||||
|
return await func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
trace = build_flow_main_trace(self, *args, **kwargs)
|
||||||
|
with TraceContext.set_current(trace):
|
||||||
|
try:
|
||||||
|
return await func(self, *args, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def build_flow_main_trace(
|
||||||
|
self: Any, *args: Any, **kwargs: Any
|
||||||
|
) -> "UnifiedTraceController":
|
||||||
|
"""Build the main trace controller for a flow execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The flow instance.
|
||||||
|
*args: Variable positional arguments.
|
||||||
|
**kwargs: Variable keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UnifiedTraceController: The configured trace controller for the flow.
|
||||||
|
"""
|
||||||
|
current_trace = TraceContext.get_current()
|
||||||
|
trace = UnifiedTraceController(
|
||||||
|
trace_type=TraceType.FLOW_STEP,
|
||||||
|
run_id=current_trace.run_id if current_trace else str(self.flow_id),
|
||||||
|
parent_trace_id=current_trace.trace_id if current_trace else None,
|
||||||
|
crew_type=CrewType.FLOW,
|
||||||
|
run_type=RunType.KICKOFF,
|
||||||
|
context={
|
||||||
|
"crew_name": self.__class__.__name__,
|
||||||
|
"inputs": kwargs.get("inputs", {}),
|
||||||
|
"agents": [],
|
||||||
|
"tasks": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return trace
|
||||||
|
|
||||||
|
|
||||||
|
# Flow step trace
|
||||||
|
def trace_flow_step(
|
||||||
|
func: Callable[..., Awaitable[Any]],
|
||||||
|
) -> Callable[..., Awaitable[Any]]:
|
||||||
|
"""Decorator to trace individual flow step executions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The async flow step function to be traced.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped async function that creates and manages the flow step trace context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(
|
||||||
|
self: Any,
|
||||||
|
method_name: str,
|
||||||
|
method: Callable[..., Any],
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if not should_trace():
|
||||||
|
return await func(self, method_name, method, *args, **kwargs)
|
||||||
|
|
||||||
|
trace = build_flow_step_trace(self, method_name, method, *args, **kwargs)
|
||||||
|
with TraceContext.set_current(trace):
|
||||||
|
trace.start_trace()
|
||||||
|
try:
|
||||||
|
result = await func(self, method_name, method, *args, **kwargs)
|
||||||
|
trace.end_trace(result=result)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
trace.end_trace(error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def build_flow_step_trace(
|
||||||
|
self: Any, method_name: str, method: Callable[..., Any], *args: Any, **kwargs: Any
|
||||||
|
) -> "UnifiedTraceController":
|
||||||
|
"""Build a trace controller for an individual flow step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The flow instance.
|
||||||
|
method_name: Name of the method being executed.
|
||||||
|
method: The actual method being executed.
|
||||||
|
*args: Variable positional arguments.
|
||||||
|
**kwargs: Variable keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UnifiedTraceController: The configured trace controller for the flow step.
|
||||||
|
"""
|
||||||
|
current_trace = TraceContext.get_current()
|
||||||
|
|
||||||
|
# Get method signature
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
params = list(sig.parameters.values())
|
||||||
|
|
||||||
|
# Create inputs dictionary mapping parameter names to values
|
||||||
|
method_params = [p for p in params if p.name != "self"]
|
||||||
|
inputs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Map positional args to their parameter names
|
||||||
|
for i, param in enumerate(method_params):
|
||||||
|
if i < len(args):
|
||||||
|
inputs[param.name] = args[i]
|
||||||
|
|
||||||
|
# Add keyword arguments
|
||||||
|
inputs.update(kwargs)
|
||||||
|
|
||||||
|
trace = UnifiedTraceController(
|
||||||
|
trace_type=TraceType.FLOW_STEP,
|
||||||
|
run_type=current_trace.run_type if current_trace else RunType.KICKOFF,
|
||||||
|
crew_type=current_trace.crew_type if current_trace else CrewType.FLOW,
|
||||||
|
run_id=current_trace.run_id if current_trace else str(self.flow_id),
|
||||||
|
parent_trace_id=current_trace.trace_id if current_trace else None,
|
||||||
|
flow_step={
|
||||||
|
"function_name": method_name,
|
||||||
|
"inputs": inputs,
|
||||||
|
"metadata": {
|
||||||
|
"crew_name": self.__class__.__name__,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return trace
|
||||||
|
|
||||||
|
|
||||||
|
# LLM trace
|
||||||
|
def trace_llm_call(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
"""Decorator to trace LLM calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to trace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped function that creates and manages the LLM call trace context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
if not should_trace():
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
trace = build_llm_trace(self, *args, **kwargs)
|
||||||
|
with TraceContext.set_current(trace):
|
||||||
|
trace.start_trace()
|
||||||
|
try:
|
||||||
|
response = func(self, *args, **kwargs)
|
||||||
|
# Extract relevant data from response
|
||||||
|
trace_response = {
|
||||||
|
"content": response["choices"][0]["message"]["content"],
|
||||||
|
"finish_reason": response["choices"][0].get("finish_reason"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add usage metrics to context
|
||||||
|
if "usage" in response:
|
||||||
|
trace.context["tokens_used"] = response["usage"].get(
|
||||||
|
"total_tokens", 0
|
||||||
|
)
|
||||||
|
trace.context["prompt_tokens"] = response["usage"].get(
|
||||||
|
"prompt_tokens", 0
|
||||||
|
)
|
||||||
|
trace.context["completion_tokens"] = response["usage"].get(
|
||||||
|
"completion_tokens", 0
|
||||||
|
)
|
||||||
|
|
||||||
|
trace.end_trace(trace_response)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
trace.end_trace(error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def build_llm_trace(
|
||||||
|
self: Any, params: Dict[str, Any], *args: Any, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Build a trace controller for an LLM call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: The LLM instance.
|
||||||
|
params: The parameters for the LLM call.
|
||||||
|
*args: Variable positional arguments.
|
||||||
|
**kwargs: Variable keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UnifiedTraceController: The configured trace controller for the LLM call.
|
||||||
|
"""
|
||||||
|
current_trace = TraceContext.get_current()
|
||||||
|
agent, task = self._get_execution_context()
|
||||||
|
|
||||||
|
# Get new messages and tool results
|
||||||
|
new_messages = self._get_new_messages(params.get("messages", []))
|
||||||
|
new_tool_results = self._get_new_tool_results(agent)
|
||||||
|
|
||||||
|
# Create trace context
|
||||||
|
trace = UnifiedTraceController(
|
||||||
|
trace_type=TraceType.TOOL_CALL if new_tool_results else TraceType.LLM_CALL,
|
||||||
|
crew_type=current_trace.crew_type if current_trace else CrewType.CREW,
|
||||||
|
run_type=current_trace.run_type if current_trace else RunType.KICKOFF,
|
||||||
|
run_id=current_trace.run_id if current_trace else str(uuid4()),
|
||||||
|
parent_trace_id=current_trace.trace_id if current_trace else None,
|
||||||
|
agent_role=agent.role if agent else "unknown",
|
||||||
|
task_id=str(task.id) if task else None,
|
||||||
|
task_name=task.name if task else None,
|
||||||
|
task_description=task.description if task else None,
|
||||||
|
model=self.model,
|
||||||
|
messages=new_messages,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
stop_sequences=self.stop,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
name=result["tool_name"],
|
||||||
|
arguments=result["tool_args"],
|
||||||
|
output=str(result["result"]),
|
||||||
|
start_time=result.get("start_time", ""),
|
||||||
|
end_time=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
for result in new_tool_results
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return trace
|
||||||
12
src/crewai/utilities/protocols.py
Normal file
12
src/crewai/utilities/protocols.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class AgentExecutorProtocol(Protocol):
|
||||||
|
"""Protocol defining the expected interface for an agent executor."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent(self) -> Any: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task(self) -> Any: ...
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Test Agent creation and execution basic functionality."""
|
"""Test Agent creation and execution basic functionality."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from datetime import UTC, datetime, timezone
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -908,6 +909,8 @@ def test_tool_result_as_answer_is_the_final_answer_for_the_agent():
|
|||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_tool_usage_information_is_appended_to_agent():
|
def test_tool_usage_information_is_appended_to_agent():
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
class MyCustomTool(BaseTool):
|
class MyCustomTool(BaseTool):
|
||||||
@@ -917,30 +920,36 @@ def test_tool_usage_information_is_appended_to_agent():
|
|||||||
def _run(self) -> str:
|
def _run(self) -> str:
|
||||||
return "Howdy!"
|
return "Howdy!"
|
||||||
|
|
||||||
agent1 = Agent(
|
fixed_datetime = datetime(2025, 2, 10, 12, 0, 0, tzinfo=UTC)
|
||||||
role="Friendly Neighbor",
|
with patch("datetime.datetime") as mock_datetime:
|
||||||
goal="Make everyone feel welcome",
|
mock_datetime.now.return_value = fixed_datetime
|
||||||
backstory="You are the friendly neighbor",
|
mock_datetime.side_effect = lambda *args, **kw: datetime(*args, **kw)
|
||||||
tools=[MyCustomTool(result_as_answer=True)],
|
|
||||||
)
|
|
||||||
|
|
||||||
greeting = Task(
|
agent1 = Agent(
|
||||||
description="Say an appropriate greeting.",
|
role="Friendly Neighbor",
|
||||||
expected_output="The greeting.",
|
goal="Make everyone feel welcome",
|
||||||
agent=agent1,
|
backstory="You are the friendly neighbor",
|
||||||
)
|
tools=[MyCustomTool(result_as_answer=True)],
|
||||||
tasks = [greeting]
|
)
|
||||||
crew = Crew(agents=[agent1], tasks=tasks)
|
|
||||||
|
|
||||||
crew.kickoff()
|
greeting = Task(
|
||||||
assert agent1.tools_results == [
|
description="Say an appropriate greeting.",
|
||||||
{
|
expected_output="The greeting.",
|
||||||
"result": "Howdy!",
|
agent=agent1,
|
||||||
"tool_name": "Decide Greetings",
|
)
|
||||||
"tool_args": {},
|
tasks = [greeting]
|
||||||
"result_as_answer": True,
|
crew = Crew(agents=[agent1], tasks=tasks)
|
||||||
}
|
|
||||||
]
|
crew.kickoff()
|
||||||
|
assert agent1.tools_results == [
|
||||||
|
{
|
||||||
|
"result": "Howdy!",
|
||||||
|
"tool_name": "Decide Greetings",
|
||||||
|
"tool_args": {},
|
||||||
|
"result_as_answer": True,
|
||||||
|
"start_time": fixed_datetime,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_agent_definition_based_on_dict():
|
def test_agent_definition_based_on_dict():
|
||||||
|
|||||||
360
tests/traces/test_unified_trace_controller.py
Normal file
360
tests/traces/test_unified_trace_controller.py
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
import os
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.traces.context import TraceContext
|
||||||
|
from crewai.traces.enums import CrewType, RunType, TraceType
|
||||||
|
from crewai.traces.models import (
|
||||||
|
CrewTrace,
|
||||||
|
FlowStepIO,
|
||||||
|
LLMRequest,
|
||||||
|
LLMResponse,
|
||||||
|
)
|
||||||
|
from crewai.traces.unified_trace_controller import (
|
||||||
|
UnifiedTraceController,
|
||||||
|
init_crew_main_trace,
|
||||||
|
init_flow_main_trace,
|
||||||
|
should_trace,
|
||||||
|
trace_flow_step,
|
||||||
|
trace_llm_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnifiedTraceController:
|
||||||
|
@pytest.fixture
|
||||||
|
def basic_trace_controller(self):
|
||||||
|
return UnifiedTraceController(
|
||||||
|
trace_type=TraceType.LLM_CALL,
|
||||||
|
run_type=RunType.KICKOFF,
|
||||||
|
crew_type=CrewType.CREW,
|
||||||
|
run_id="test-run-id",
|
||||||
|
agent_role="test-agent",
|
||||||
|
task_name="test-task",
|
||||||
|
task_description="test description",
|
||||||
|
task_id="test-task-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialization(self, basic_trace_controller):
|
||||||
|
"""Test basic initialization of UnifiedTraceController"""
|
||||||
|
assert basic_trace_controller.trace_type == TraceType.LLM_CALL
|
||||||
|
assert basic_trace_controller.run_type == RunType.KICKOFF
|
||||||
|
assert basic_trace_controller.crew_type == CrewType.CREW
|
||||||
|
assert basic_trace_controller.run_id == "test-run-id"
|
||||||
|
assert basic_trace_controller.agent_role == "test-agent"
|
||||||
|
assert basic_trace_controller.task_name == "test-task"
|
||||||
|
assert basic_trace_controller.task_description == "test description"
|
||||||
|
assert basic_trace_controller.task_id == "test-task-id"
|
||||||
|
assert basic_trace_controller.status == "running"
|
||||||
|
assert isinstance(UUID(basic_trace_controller.trace_id), UUID)
|
||||||
|
|
||||||
|
def test_start_trace(self, basic_trace_controller):
|
||||||
|
"""Test starting a trace"""
|
||||||
|
result = basic_trace_controller.start_trace()
|
||||||
|
assert result == basic_trace_controller
|
||||||
|
assert basic_trace_controller.start_time is not None
|
||||||
|
assert isinstance(basic_trace_controller.start_time, datetime)
|
||||||
|
|
||||||
|
def test_end_trace_success(self, basic_trace_controller):
|
||||||
|
"""Test ending a trace successfully"""
|
||||||
|
basic_trace_controller.start_trace()
|
||||||
|
basic_trace_controller.end_trace(result={"test": "result"})
|
||||||
|
|
||||||
|
assert basic_trace_controller.end_time is not None
|
||||||
|
assert basic_trace_controller.status == "completed"
|
||||||
|
assert basic_trace_controller.error is None
|
||||||
|
assert basic_trace_controller.context.get("response") == {"test": "result"}
|
||||||
|
|
||||||
|
def test_end_trace_with_error(self, basic_trace_controller):
|
||||||
|
"""Test ending a trace with an error"""
|
||||||
|
basic_trace_controller.start_trace()
|
||||||
|
basic_trace_controller.end_trace(error="Test error occurred")
|
||||||
|
|
||||||
|
assert basic_trace_controller.end_time is not None
|
||||||
|
assert basic_trace_controller.status == "error"
|
||||||
|
assert basic_trace_controller.error == "Test error occurred"
|
||||||
|
|
||||||
|
def test_add_child_trace(self, basic_trace_controller):
|
||||||
|
"""Test adding a child trace"""
|
||||||
|
child_trace = {"id": "child-1", "type": "test"}
|
||||||
|
basic_trace_controller.add_child_trace(child_trace)
|
||||||
|
assert len(basic_trace_controller.children) == 1
|
||||||
|
assert basic_trace_controller.children[0] == child_trace
|
||||||
|
|
||||||
|
def test_to_crew_trace_llm_call(self):
|
||||||
|
"""Test converting to CrewTrace for LLM call"""
|
||||||
|
test_messages = [{"role": "user", "content": "test"}]
|
||||||
|
test_response = {
|
||||||
|
"content": "test response",
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
|
||||||
|
controller = UnifiedTraceController(
|
||||||
|
trace_type=TraceType.LLM_CALL,
|
||||||
|
run_type=RunType.KICKOFF,
|
||||||
|
crew_type=CrewType.CREW,
|
||||||
|
run_id="test-run-id",
|
||||||
|
context={
|
||||||
|
"messages": test_messages,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 100,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set model and messages in the context
|
||||||
|
controller.context["model"] = "gpt-4"
|
||||||
|
controller.context["messages"] = test_messages
|
||||||
|
|
||||||
|
controller.start_trace()
|
||||||
|
controller.end_trace(result=test_response)
|
||||||
|
|
||||||
|
crew_trace = controller.to_crew_trace()
|
||||||
|
assert isinstance(crew_trace, CrewTrace)
|
||||||
|
assert isinstance(crew_trace.request, LLMRequest)
|
||||||
|
assert isinstance(crew_trace.response, LLMResponse)
|
||||||
|
assert crew_trace.request.model == "gpt-4"
|
||||||
|
assert crew_trace.request.messages == test_messages
|
||||||
|
assert crew_trace.response.content == test_response["content"]
|
||||||
|
assert crew_trace.response.finish_reason == test_response["finish_reason"]
|
||||||
|
|
||||||
|
def test_to_crew_trace_flow_step(self):
|
||||||
|
"""Test converting to CrewTrace for flow step"""
|
||||||
|
flow_step_data = {
|
||||||
|
"function_name": "test_function",
|
||||||
|
"inputs": {"param1": "value1"},
|
||||||
|
"metadata": {"meta": "data"},
|
||||||
|
}
|
||||||
|
|
||||||
|
controller = UnifiedTraceController(
|
||||||
|
trace_type=TraceType.FLOW_STEP,
|
||||||
|
run_type=RunType.KICKOFF,
|
||||||
|
crew_type=CrewType.FLOW,
|
||||||
|
run_id="test-run-id",
|
||||||
|
flow_step=flow_step_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
controller.start_trace()
|
||||||
|
controller.end_trace(result="test result")
|
||||||
|
|
||||||
|
crew_trace = controller.to_crew_trace()
|
||||||
|
assert isinstance(crew_trace, CrewTrace)
|
||||||
|
assert isinstance(crew_trace.flow_step, FlowStepIO)
|
||||||
|
assert crew_trace.flow_step.function_name == "test_function"
|
||||||
|
assert crew_trace.flow_step.inputs == {"param1": "value1"}
|
||||||
|
assert crew_trace.flow_step.outputs == {"result": "test result"}
|
||||||
|
|
||||||
|
def test_should_trace(self):
|
||||||
|
"""Test should_trace function"""
|
||||||
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
||||||
|
assert should_trace() is True
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "false"}):
|
||||||
|
assert should_trace() is False
|
||||||
|
|
||||||
|
with patch.dict(os.environ, clear=True):
|
||||||
|
assert should_trace() is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trace_flow_step_decorator(self):
|
||||||
|
"""Test trace_flow_step decorator"""
|
||||||
|
|
||||||
|
class TestFlow:
|
||||||
|
flow_id = "test-flow-id"
|
||||||
|
|
||||||
|
@trace_flow_step
|
||||||
|
async def test_method(self, method_name, method, *args, **kwargs):
|
||||||
|
return "test result"
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
||||||
|
flow = TestFlow()
|
||||||
|
result = await flow.test_method("test_method", lambda x: x, arg1="value1")
|
||||||
|
assert result == "test result"
|
||||||
|
|
||||||
|
def test_trace_llm_call_decorator(self):
|
||||||
|
"""Test trace_llm_call decorator"""
|
||||||
|
|
||||||
|
class TestLLM:
|
||||||
|
model = "gpt-4"
|
||||||
|
temperature = 0.7
|
||||||
|
max_tokens = 100
|
||||||
|
stop = None
|
||||||
|
|
||||||
|
def _get_execution_context(self):
|
||||||
|
return MagicMock(), MagicMock()
|
||||||
|
|
||||||
|
def _get_new_messages(self, messages):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _get_new_tool_results(self, agent):
|
||||||
|
return []
|
||||||
|
|
||||||
|
@trace_llm_call
|
||||||
|
def test_method(self, params):
|
||||||
|
return {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {"content": "test response"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"total_tokens": 50,
|
||||||
|
"prompt_tokens": 20,
|
||||||
|
"completion_tokens": 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
||||||
|
llm = TestLLM()
|
||||||
|
result = llm.test_method({"messages": []})
|
||||||
|
assert result["choices"][0]["message"]["content"] == "test response"
|
||||||
|
|
||||||
|
def test_init_crew_main_trace_kickoff(self):
|
||||||
|
"""Test init_crew_main_trace in kickoff mode"""
|
||||||
|
trace_context = None
|
||||||
|
|
||||||
|
class TestCrew:
|
||||||
|
id = "test-crew-id"
|
||||||
|
_test = False
|
||||||
|
_train = False
|
||||||
|
|
||||||
|
@init_crew_main_trace
|
||||||
|
def test_method(self):
|
||||||
|
nonlocal trace_context
|
||||||
|
trace_context = TraceContext.get_current()
|
||||||
|
return "test result"
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
||||||
|
crew = TestCrew()
|
||||||
|
result = test_method(crew)
|
||||||
|
assert result == "test result"
|
||||||
|
assert trace_context is not None
|
||||||
|
assert trace_context.trace_type == TraceType.LLM_CALL
|
||||||
|
assert trace_context.run_type == RunType.KICKOFF
|
||||||
|
assert trace_context.crew_type == CrewType.CREW
|
||||||
|
assert trace_context.run_id == str(crew.id)
|
||||||
|
|
||||||
|
def test_init_crew_main_trace_test_mode(self):
|
||||||
|
"""Test init_crew_main_trace in test mode"""
|
||||||
|
trace_context = None
|
||||||
|
|
||||||
|
class TestCrew:
|
||||||
|
id = "test-crew-id"
|
||||||
|
_test = True
|
||||||
|
_train = False
|
||||||
|
|
||||||
|
@init_crew_main_trace
|
||||||
|
def test_method(self):
|
||||||
|
nonlocal trace_context
|
||||||
|
trace_context = TraceContext.get_current()
|
||||||
|
return "test result"
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
||||||
|
crew = TestCrew()
|
||||||
|
result = test_method(crew)
|
||||||
|
assert result == "test result"
|
||||||
|
assert trace_context is not None
|
||||||
|
assert trace_context.run_type == RunType.TEST
|
||||||
|
|
||||||
|
def test_init_crew_main_trace_train_mode(self):
|
||||||
|
"""Test init_crew_main_trace in train mode"""
|
||||||
|
trace_context = None
|
||||||
|
|
||||||
|
class TestCrew:
|
||||||
|
id = "test-crew-id"
|
||||||
|
_test = False
|
||||||
|
_train = True
|
||||||
|
|
||||||
|
@init_crew_main_trace
|
||||||
|
def test_method(self):
|
||||||
|
nonlocal trace_context
|
||||||
|
trace_context = TraceContext.get_current()
|
||||||
|
return "test result"
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
||||||
|
crew = TestCrew()
|
||||||
|
result = test_method(crew)
|
||||||
|
assert result == "test result"
|
||||||
|
assert trace_context is not None
|
||||||
|
assert trace_context.run_type == RunType.TRAIN
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_flow_main_trace(self):
|
||||||
|
"""Test init_flow_main_trace decorator"""
|
||||||
|
trace_context = None
|
||||||
|
test_inputs = {"test": "input"}
|
||||||
|
|
||||||
|
class TestFlow:
|
||||||
|
flow_id = "test-flow-id"
|
||||||
|
|
||||||
|
@init_flow_main_trace
|
||||||
|
async def test_method(self, **kwargs):
|
||||||
|
nonlocal trace_context
|
||||||
|
trace_context = TraceContext.get_current()
|
||||||
|
# Verify the context is set during execution
|
||||||
|
assert trace_context.context["context"]["inputs"] == test_inputs
|
||||||
|
return "test result"
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
||||||
|
flow = TestFlow()
|
||||||
|
result = await flow.test_method(inputs=test_inputs)
|
||||||
|
assert result == "test result"
|
||||||
|
assert trace_context is not None
|
||||||
|
assert trace_context.trace_type == TraceType.FLOW_STEP
|
||||||
|
assert trace_context.crew_type == CrewType.FLOW
|
||||||
|
assert trace_context.run_type == RunType.KICKOFF
|
||||||
|
assert trace_context.run_id == str(flow.flow_id)
|
||||||
|
assert trace_context.context["context"]["inputs"] == test_inputs
|
||||||
|
|
||||||
|
def test_trace_context_management(self):
|
||||||
|
"""Test TraceContext management"""
|
||||||
|
trace1 = UnifiedTraceController(
|
||||||
|
trace_type=TraceType.LLM_CALL,
|
||||||
|
run_type=RunType.KICKOFF,
|
||||||
|
crew_type=CrewType.CREW,
|
||||||
|
run_id="test-run-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
trace2 = UnifiedTraceController(
|
||||||
|
trace_type=TraceType.FLOW_STEP,
|
||||||
|
run_type=RunType.TEST,
|
||||||
|
crew_type=CrewType.FLOW,
|
||||||
|
run_id="test-run-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that context is initially empty
|
||||||
|
assert TraceContext.get_current() is None
|
||||||
|
|
||||||
|
# Test setting and getting context
|
||||||
|
with TraceContext.set_current(trace1):
|
||||||
|
assert TraceContext.get_current() == trace1
|
||||||
|
|
||||||
|
# Test nested context
|
||||||
|
with TraceContext.set_current(trace2):
|
||||||
|
assert TraceContext.get_current() == trace2
|
||||||
|
|
||||||
|
# Test context restoration after nested block
|
||||||
|
assert TraceContext.get_current() == trace1
|
||||||
|
|
||||||
|
# Test context cleanup after with block
|
||||||
|
assert TraceContext.get_current() is None
|
||||||
|
|
||||||
|
def test_trace_context_error_handling(self):
|
||||||
|
"""Test TraceContext error handling"""
|
||||||
|
trace = UnifiedTraceController(
|
||||||
|
trace_type=TraceType.LLM_CALL,
|
||||||
|
run_type=RunType.KICKOFF,
|
||||||
|
crew_type=CrewType.CREW,
|
||||||
|
run_id="test-run",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that context is properly cleaned up even if an error occurs
|
||||||
|
try:
|
||||||
|
with TraceContext.set_current(trace):
|
||||||
|
raise ValueError("Test error")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert TraceContext.get_current() is None
|
||||||
Reference in New Issue
Block a user