mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
361 lines
13 KiB
Python
361 lines
13 KiB
Python
import os
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import UUID
|
|
|
|
import pytest
|
|
|
|
from crewai.traces.context import TraceContext
|
|
from crewai.traces.enums import CrewType, RunType, TraceType
|
|
from crewai.traces.models import (
|
|
CrewTrace,
|
|
FlowStepIO,
|
|
LLMRequest,
|
|
LLMResponse,
|
|
)
|
|
from crewai.traces.unified_trace_controller import (
|
|
UnifiedTraceController,
|
|
init_crew_main_trace,
|
|
init_flow_main_trace,
|
|
should_trace,
|
|
trace_flow_step,
|
|
trace_llm_call,
|
|
)
|
|
|
|
|
|
class TestUnifiedTraceController:
|
|
@pytest.fixture
|
|
def basic_trace_controller(self):
|
|
return UnifiedTraceController(
|
|
trace_type=TraceType.LLM_CALL,
|
|
run_type=RunType.KICKOFF,
|
|
crew_type=CrewType.CREW,
|
|
run_id="test-run-id",
|
|
agent_role="test-agent",
|
|
task_name="test-task",
|
|
task_description="test description",
|
|
task_id="test-task-id",
|
|
)
|
|
|
|
def test_initialization(self, basic_trace_controller):
|
|
"""Test basic initialization of UnifiedTraceController"""
|
|
assert basic_trace_controller.trace_type == TraceType.LLM_CALL
|
|
assert basic_trace_controller.run_type == RunType.KICKOFF
|
|
assert basic_trace_controller.crew_type == CrewType.CREW
|
|
assert basic_trace_controller.run_id == "test-run-id"
|
|
assert basic_trace_controller.agent_role == "test-agent"
|
|
assert basic_trace_controller.task_name == "test-task"
|
|
assert basic_trace_controller.task_description == "test description"
|
|
assert basic_trace_controller.task_id == "test-task-id"
|
|
assert basic_trace_controller.status == "running"
|
|
assert isinstance(UUID(basic_trace_controller.trace_id), UUID)
|
|
|
|
def test_start_trace(self, basic_trace_controller):
|
|
"""Test starting a trace"""
|
|
result = basic_trace_controller.start_trace()
|
|
assert result == basic_trace_controller
|
|
assert basic_trace_controller.start_time is not None
|
|
assert isinstance(basic_trace_controller.start_time, datetime)
|
|
|
|
def test_end_trace_success(self, basic_trace_controller):
|
|
"""Test ending a trace successfully"""
|
|
basic_trace_controller.start_trace()
|
|
basic_trace_controller.end_trace(result={"test": "result"})
|
|
|
|
assert basic_trace_controller.end_time is not None
|
|
assert basic_trace_controller.status == "completed"
|
|
assert basic_trace_controller.error is None
|
|
assert basic_trace_controller.context.get("response") == {"test": "result"}
|
|
|
|
def test_end_trace_with_error(self, basic_trace_controller):
|
|
"""Test ending a trace with an error"""
|
|
basic_trace_controller.start_trace()
|
|
basic_trace_controller.end_trace(error="Test error occurred")
|
|
|
|
assert basic_trace_controller.end_time is not None
|
|
assert basic_trace_controller.status == "error"
|
|
assert basic_trace_controller.error == "Test error occurred"
|
|
|
|
def test_add_child_trace(self, basic_trace_controller):
|
|
"""Test adding a child trace"""
|
|
child_trace = {"id": "child-1", "type": "test"}
|
|
basic_trace_controller.add_child_trace(child_trace)
|
|
assert len(basic_trace_controller.children) == 1
|
|
assert basic_trace_controller.children[0] == child_trace
|
|
|
|
def test_to_crew_trace_llm_call(self):
|
|
"""Test converting to CrewTrace for LLM call"""
|
|
test_messages = [{"role": "user", "content": "test"}]
|
|
test_response = {
|
|
"content": "test response",
|
|
"finish_reason": "stop",
|
|
}
|
|
|
|
controller = UnifiedTraceController(
|
|
trace_type=TraceType.LLM_CALL,
|
|
run_type=RunType.KICKOFF,
|
|
crew_type=CrewType.CREW,
|
|
run_id="test-run-id",
|
|
context={
|
|
"messages": test_messages,
|
|
"temperature": 0.7,
|
|
"max_tokens": 100,
|
|
},
|
|
)
|
|
|
|
# Set model and messages in the context
|
|
controller.context["model"] = "gpt-4"
|
|
controller.context["messages"] = test_messages
|
|
|
|
controller.start_trace()
|
|
controller.end_trace(result=test_response)
|
|
|
|
crew_trace = controller.to_crew_trace()
|
|
assert isinstance(crew_trace, CrewTrace)
|
|
assert isinstance(crew_trace.request, LLMRequest)
|
|
assert isinstance(crew_trace.response, LLMResponse)
|
|
assert crew_trace.request.model == "gpt-4"
|
|
assert crew_trace.request.messages == test_messages
|
|
assert crew_trace.response.content == test_response["content"]
|
|
assert crew_trace.response.finish_reason == test_response["finish_reason"]
|
|
|
|
def test_to_crew_trace_flow_step(self):
|
|
"""Test converting to CrewTrace for flow step"""
|
|
flow_step_data = {
|
|
"function_name": "test_function",
|
|
"inputs": {"param1": "value1"},
|
|
"metadata": {"meta": "data"},
|
|
}
|
|
|
|
controller = UnifiedTraceController(
|
|
trace_type=TraceType.FLOW_STEP,
|
|
run_type=RunType.KICKOFF,
|
|
crew_type=CrewType.FLOW,
|
|
run_id="test-run-id",
|
|
flow_step=flow_step_data,
|
|
)
|
|
|
|
controller.start_trace()
|
|
controller.end_trace(result="test result")
|
|
|
|
crew_trace = controller.to_crew_trace()
|
|
assert isinstance(crew_trace, CrewTrace)
|
|
assert isinstance(crew_trace.flow_step, FlowStepIO)
|
|
assert crew_trace.flow_step.function_name == "test_function"
|
|
assert crew_trace.flow_step.inputs == {"param1": "value1"}
|
|
assert crew_trace.flow_step.outputs == {"result": "test result"}
|
|
|
|
def test_should_trace(self):
|
|
"""Test should_trace function"""
|
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
|
assert should_trace() is True
|
|
|
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "false"}):
|
|
assert should_trace() is False
|
|
|
|
with patch.dict(os.environ, clear=True):
|
|
assert should_trace() is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_trace_flow_step_decorator(self):
|
|
"""Test trace_flow_step decorator"""
|
|
|
|
class TestFlow:
|
|
flow_id = "test-flow-id"
|
|
|
|
@trace_flow_step
|
|
async def test_method(self, method_name, method, *args, **kwargs):
|
|
return "test result"
|
|
|
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
|
flow = TestFlow()
|
|
result = await flow.test_method("test_method", lambda x: x, arg1="value1")
|
|
assert result == "test result"
|
|
|
|
def test_trace_llm_call_decorator(self):
|
|
"""Test trace_llm_call decorator"""
|
|
|
|
class TestLLM:
|
|
model = "gpt-4"
|
|
temperature = 0.7
|
|
max_tokens = 100
|
|
stop = None
|
|
|
|
def _get_execution_context(self):
|
|
return MagicMock(), MagicMock()
|
|
|
|
def _get_new_messages(self, messages):
|
|
return messages
|
|
|
|
def _get_new_tool_results(self, agent):
|
|
return []
|
|
|
|
@trace_llm_call
|
|
def test_method(self, params):
|
|
return {
|
|
"choices": [
|
|
{
|
|
"message": {"content": "test response"},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"total_tokens": 50,
|
|
"prompt_tokens": 20,
|
|
"completion_tokens": 30,
|
|
},
|
|
}
|
|
|
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
|
llm = TestLLM()
|
|
result = llm.test_method({"messages": []})
|
|
assert result["choices"][0]["message"]["content"] == "test response"
|
|
|
|
def test_init_crew_main_trace_kickoff(self):
|
|
"""Test init_crew_main_trace in kickoff mode"""
|
|
trace_context = None
|
|
|
|
class TestCrew:
|
|
id = "test-crew-id"
|
|
_test = False
|
|
_train = False
|
|
|
|
@init_crew_main_trace
|
|
def test_method(self):
|
|
nonlocal trace_context
|
|
trace_context = TraceContext.get_current()
|
|
return "test result"
|
|
|
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
|
crew = TestCrew()
|
|
result = test_method(crew)
|
|
assert result == "test result"
|
|
assert trace_context is not None
|
|
assert trace_context.trace_type == TraceType.LLM_CALL
|
|
assert trace_context.run_type == RunType.KICKOFF
|
|
assert trace_context.crew_type == CrewType.CREW
|
|
assert trace_context.run_id == str(crew.id)
|
|
|
|
def test_init_crew_main_trace_test_mode(self):
|
|
"""Test init_crew_main_trace in test mode"""
|
|
trace_context = None
|
|
|
|
class TestCrew:
|
|
id = "test-crew-id"
|
|
_test = True
|
|
_train = False
|
|
|
|
@init_crew_main_trace
|
|
def test_method(self):
|
|
nonlocal trace_context
|
|
trace_context = TraceContext.get_current()
|
|
return "test result"
|
|
|
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
|
crew = TestCrew()
|
|
result = test_method(crew)
|
|
assert result == "test result"
|
|
assert trace_context is not None
|
|
assert trace_context.run_type == RunType.TEST
|
|
|
|
def test_init_crew_main_trace_train_mode(self):
|
|
"""Test init_crew_main_trace in train mode"""
|
|
trace_context = None
|
|
|
|
class TestCrew:
|
|
id = "test-crew-id"
|
|
_test = False
|
|
_train = True
|
|
|
|
@init_crew_main_trace
|
|
def test_method(self):
|
|
nonlocal trace_context
|
|
trace_context = TraceContext.get_current()
|
|
return "test result"
|
|
|
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
|
crew = TestCrew()
|
|
result = test_method(crew)
|
|
assert result == "test result"
|
|
assert trace_context is not None
|
|
assert trace_context.run_type == RunType.TRAIN
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_flow_main_trace(self):
|
|
"""Test init_flow_main_trace decorator"""
|
|
trace_context = None
|
|
test_inputs = {"test": "input"}
|
|
|
|
class TestFlow:
|
|
flow_id = "test-flow-id"
|
|
|
|
@init_flow_main_trace
|
|
async def test_method(self, **kwargs):
|
|
nonlocal trace_context
|
|
trace_context = TraceContext.get_current()
|
|
# Verify the context is set during execution
|
|
assert trace_context.context["context"]["inputs"] == test_inputs
|
|
return "test result"
|
|
|
|
with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}):
|
|
flow = TestFlow()
|
|
result = await flow.test_method(inputs=test_inputs)
|
|
assert result == "test result"
|
|
assert trace_context is not None
|
|
assert trace_context.trace_type == TraceType.FLOW_STEP
|
|
assert trace_context.crew_type == CrewType.FLOW
|
|
assert trace_context.run_type == RunType.KICKOFF
|
|
assert trace_context.run_id == str(flow.flow_id)
|
|
assert trace_context.context["context"]["inputs"] == test_inputs
|
|
|
|
def test_trace_context_management(self):
|
|
"""Test TraceContext management"""
|
|
trace1 = UnifiedTraceController(
|
|
trace_type=TraceType.LLM_CALL,
|
|
run_type=RunType.KICKOFF,
|
|
crew_type=CrewType.CREW,
|
|
run_id="test-run-1",
|
|
)
|
|
|
|
trace2 = UnifiedTraceController(
|
|
trace_type=TraceType.FLOW_STEP,
|
|
run_type=RunType.TEST,
|
|
crew_type=CrewType.FLOW,
|
|
run_id="test-run-2",
|
|
)
|
|
|
|
# Test that context is initially empty
|
|
assert TraceContext.get_current() is None
|
|
|
|
# Test setting and getting context
|
|
with TraceContext.set_current(trace1):
|
|
assert TraceContext.get_current() == trace1
|
|
|
|
# Test nested context
|
|
with TraceContext.set_current(trace2):
|
|
assert TraceContext.get_current() == trace2
|
|
|
|
# Test context restoration after nested block
|
|
assert TraceContext.get_current() == trace1
|
|
|
|
# Test context cleanup after with block
|
|
assert TraceContext.get_current() is None
|
|
|
|
def test_trace_context_error_handling(self):
|
|
"""Test TraceContext error handling"""
|
|
trace = UnifiedTraceController(
|
|
trace_type=TraceType.LLM_CALL,
|
|
run_type=RunType.KICKOFF,
|
|
crew_type=CrewType.CREW,
|
|
run_id="test-run",
|
|
)
|
|
|
|
# Test that context is properly cleaned up even if an error occurs
|
|
try:
|
|
with TraceContext.set_current(trace):
|
|
raise ValueError("Test error")
|
|
except ValueError:
|
|
pass
|
|
|
|
assert TraceContext.get_current() is None
|