From d57b017e7b55f9010e8d383cfe84425530bd1b40 Mon Sep 17 00:00:00 2001 From: Eduardo Chiarotti Date: Fri, 7 Feb 2025 16:55:58 -0300 Subject: [PATCH] feat: add tests for traces --- tests/traces/test_unified_trace_controller.py | 360 ++++++++++++++++++ 1 file changed, 360 insertions(+) create mode 100644 tests/traces/test_unified_trace_controller.py diff --git a/tests/traces/test_unified_trace_controller.py b/tests/traces/test_unified_trace_controller.py new file mode 100644 index 000000000..80a35f11a --- /dev/null +++ b/tests/traces/test_unified_trace_controller.py @@ -0,0 +1,360 @@ +import os +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch +from uuid import UUID + +import pytest +from crewai.traces.context import TraceContext +from crewai.traces.enums import CrewType, RunType, TraceType +from crewai.traces.models import ( + CrewTrace, + FlowStepIO, + LLMRequest, + LLMResponse, + ToolCall, +) +from crewai.traces.unified_trace_controller import ( + UnifiedTraceController, + init_crew_main_trace, + init_flow_main_trace, + should_trace, + trace_flow_step, + trace_llm_call, +) + + +class TestUnifiedTraceController: + @pytest.fixture + def basic_trace_controller(self): + return UnifiedTraceController( + trace_type=TraceType.LLM_CALL, + run_type=RunType.KICKOFF, + crew_type=CrewType.CREW, + run_id="test-run-id", + agent_role="test-agent", + task_name="test-task", + task_description="test description", + task_id="test-task-id", + ) + + def test_initialization(self, basic_trace_controller): + """Test basic initialization of UnifiedTraceController""" + assert basic_trace_controller.trace_type == TraceType.LLM_CALL + assert basic_trace_controller.run_type == RunType.KICKOFF + assert basic_trace_controller.crew_type == CrewType.CREW + assert basic_trace_controller.run_id == "test-run-id" + assert basic_trace_controller.agent_role == "test-agent" + assert basic_trace_controller.task_name == "test-task" + assert basic_trace_controller.task_description == "test description" + assert basic_trace_controller.task_id == "test-task-id" + assert basic_trace_controller.status == "running" + assert isinstance(UUID(basic_trace_controller.trace_id), UUID) + + def test_start_trace(self, basic_trace_controller): + """Test starting a trace""" + result = basic_trace_controller.start_trace() + assert result == basic_trace_controller + assert basic_trace_controller.start_time is not None + assert isinstance(basic_trace_controller.start_time, datetime) + + def test_end_trace_success(self, basic_trace_controller): + """Test ending a trace successfully""" + basic_trace_controller.start_trace() + basic_trace_controller.end_trace(result={"test": "result"}) + + assert basic_trace_controller.end_time is not None + assert basic_trace_controller.status == "completed" + assert basic_trace_controller.error is None + assert basic_trace_controller.context.get("response") == {"test": "result"} + + def test_end_trace_with_error(self, basic_trace_controller): + """Test ending a trace with an error""" + basic_trace_controller.start_trace() + basic_trace_controller.end_trace(error="Test error occurred") + + assert basic_trace_controller.end_time is not None + assert basic_trace_controller.status == "error" + assert basic_trace_controller.error == "Test error occurred" + + def test_add_child_trace(self, basic_trace_controller): + """Test adding a child trace""" + child_trace = {"id": "child-1", "type": "test"} + basic_trace_controller.add_child_trace(child_trace) + assert len(basic_trace_controller.children) == 1 + assert basic_trace_controller.children[0] == child_trace + + def test_to_crew_trace_llm_call(self): + """Test converting to CrewTrace for LLM call""" + test_messages = [{"role": "user", "content": "test"}] + test_response = { + "content": "test response", + "finish_reason": "stop", + } + + controller = UnifiedTraceController( + trace_type=TraceType.LLM_CALL, + run_type=RunType.KICKOFF, + crew_type=CrewType.CREW, + run_id="test-run-id", + context={ + "messages": test_messages, + "temperature": 0.7, + "max_tokens": 100, + }, + ) + + # Set model and messages in the context + controller.context["model"] = "gpt-4" + controller.context["messages"] = test_messages + + controller.start_trace() + controller.end_trace(result=test_response) + + crew_trace = controller.to_crew_trace() + assert isinstance(crew_trace, CrewTrace) + assert isinstance(crew_trace.request, LLMRequest) + assert isinstance(crew_trace.response, LLMResponse) + assert crew_trace.request.model == "gpt-4" + assert crew_trace.request.messages == test_messages + assert crew_trace.response.content == test_response["content"] + assert crew_trace.response.finish_reason == test_response["finish_reason"] + + def test_to_crew_trace_flow_step(self): + """Test converting to CrewTrace for flow step""" + flow_step_data = { + "function_name": "test_function", + "inputs": {"param1": "value1"}, + "metadata": {"meta": "data"}, + } + + controller = UnifiedTraceController( + trace_type=TraceType.FLOW_STEP, + run_type=RunType.KICKOFF, + crew_type=CrewType.FLOW, + run_id="test-run-id", + flow_step=flow_step_data, + ) + + controller.start_trace() + controller.end_trace(result="test result") + + crew_trace = controller.to_crew_trace() + assert isinstance(crew_trace, CrewTrace) + assert isinstance(crew_trace.flow_step, FlowStepIO) + assert crew_trace.flow_step.function_name == "test_function" + assert crew_trace.flow_step.inputs == {"param1": "value1"} + assert crew_trace.flow_step.outputs == {"result": "test result"} + + def test_should_trace(self): + """Test should_trace function""" + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + assert should_trace() is True + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "false"}): + assert should_trace() is False + + with patch.dict(os.environ, clear=True): + assert should_trace() is False + + @pytest.mark.asyncio + async def test_trace_flow_step_decorator(self): + """Test trace_flow_step decorator""" + + class TestFlow: + flow_id = "test-flow-id" + + @trace_flow_step + async def test_method(self, method_name, method, *args, **kwargs): + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + flow = TestFlow() + result = await flow.test_method("test_method", lambda x: x, arg1="value1") + assert result == "test result" + + def test_trace_llm_call_decorator(self): + """Test trace_llm_call decorator""" + + class TestLLM: + model = "gpt-4" + temperature = 0.7 + max_tokens = 100 + stop = None + + def _get_execution_context(self): + return MagicMock(), MagicMock() + + def _get_new_messages(self, messages): + return messages + + def _get_new_tool_results(self, agent): + return [] + + @trace_llm_call + def test_method(self, params): + return { + "choices": [ + { + "message": {"content": "test response"}, + "finish_reason": "stop", + } + ], + "usage": { + "total_tokens": 50, + "prompt_tokens": 20, + "completion_tokens": 30, + }, + } + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + llm = TestLLM() + result = llm.test_method({"messages": []}) + assert result["choices"][0]["message"]["content"] == "test response" + + def test_init_crew_main_trace_kickoff(self): + """Test init_crew_main_trace in kickoff mode""" + trace_context = None + + class TestCrew: + id = "test-crew-id" + _test = False + _train = False + + @init_crew_main_trace + def test_method(self): + nonlocal trace_context + trace_context = TraceContext.get_current() + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + crew = TestCrew() + result = test_method(crew) + assert result == "test result" + assert trace_context is not None + assert trace_context.trace_type == TraceType.LLM_CALL + assert trace_context.run_type == RunType.KICKOFF + assert trace_context.crew_type == CrewType.CREW + assert trace_context.run_id == str(crew.id) + + def test_init_crew_main_trace_test_mode(self): + """Test init_crew_main_trace in test mode""" + trace_context = None + + class TestCrew: + id = "test-crew-id" + _test = True + _train = False + + @init_crew_main_trace + def test_method(self): + nonlocal trace_context + trace_context = TraceContext.get_current() + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + crew = TestCrew() + result = test_method(crew) + assert result == "test result" + assert trace_context is not None + assert trace_context.run_type == RunType.TEST + + def test_init_crew_main_trace_train_mode(self): + """Test init_crew_main_trace in train mode""" + trace_context = None + + class TestCrew: + id = "test-crew-id" + _test = False + _train = True + + @init_crew_main_trace + def test_method(self): + nonlocal trace_context + trace_context = TraceContext.get_current() + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + crew = TestCrew() + result = test_method(crew) + assert result == "test result" + assert trace_context is not None + assert trace_context.run_type == RunType.TRAIN + + @pytest.mark.asyncio + async def test_init_flow_main_trace(self): + """Test init_flow_main_trace decorator""" + trace_context = None + test_inputs = {"test": "input"} + + class TestFlow: + flow_id = "test-flow-id" + + @init_flow_main_trace + async def test_method(self, **kwargs): + nonlocal trace_context + trace_context = TraceContext.get_current() + # Verify the context is set during execution + assert trace_context.context["context"]["inputs"] == test_inputs + return "test result" + + with patch.dict(os.environ, {"CREWAI_ENABLE_TRACING": "true"}): + flow = TestFlow() + result = await flow.test_method(inputs=test_inputs) + assert result == "test result" + assert trace_context is not None + assert trace_context.trace_type == TraceType.FLOW_STEP + assert trace_context.crew_type == CrewType.FLOW + assert trace_context.run_type == RunType.KICKOFF + assert trace_context.run_id == str(flow.flow_id) + assert trace_context.context["context"]["inputs"] == test_inputs + + def test_trace_context_management(self): + """Test TraceContext management""" + trace1 = UnifiedTraceController( + trace_type=TraceType.LLM_CALL, + run_type=RunType.KICKOFF, + crew_type=CrewType.CREW, + run_id="test-run-1", + ) + + trace2 = UnifiedTraceController( + trace_type=TraceType.FLOW_STEP, + run_type=RunType.TEST, + crew_type=CrewType.FLOW, + run_id="test-run-2", + ) + + # Test that context is initially empty + assert TraceContext.get_current() is None + + # Test setting and getting context + with TraceContext.set_current(trace1): + assert TraceContext.get_current() == trace1 + + # Test nested context + with TraceContext.set_current(trace2): + assert TraceContext.get_current() == trace2 + + # Test context restoration after nested block + assert TraceContext.get_current() == trace1 + + # Test context cleanup after with block + assert TraceContext.get_current() is None + + def test_trace_context_error_handling(self): + """Test TraceContext error handling""" + trace = UnifiedTraceController( + trace_type=TraceType.LLM_CALL, + run_type=RunType.KICKOFF, + crew_type=CrewType.CREW, + run_id="test-run", + ) + + # Test that context is properly cleaned up even if an error occurs + try: + with TraceContext.set_current(trace): + raise ValueError("Test error") + except ValueError: + pass + + assert TraceContext.get_current() is None