diff --git a/tests/crew_test.py b/tests/crew_test.py index 141ecfb7c..f3fb27872 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -18,6 +18,7 @@ from crewai.task import Task from crewai.tasks.conditional_task import ConditionalTask from crewai.tasks.output_format import OutputFormat from crewai.tasks.task_output import TaskOutput +from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import Logger, RPMController from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler @@ -597,14 +598,10 @@ def test_crew_kickoff_usage_metrics(): assert len(results) == len(inputs) for result in results: # Assert that all required keys are in usage_metrics and their values are not None - for key in [ - "total_tokens", - "prompt_tokens", - "completion_tokens", - "successful_requests", - ]: - assert key in result.token_usage - assert result.token_usage[key] > 0 + assert result.token_usage.total_tokens > 0 + assert result.token_usage.prompt_tokens > 0 + assert result.token_usage.completion_tokens > 0 + assert result.token_usage.successful_requests > 0 def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set(): @@ -1318,12 +1315,12 @@ def test_agent_usage_metrics_are_captured_for_hierarchical_process(): print(crew.usage_metrics) - assert crew.usage_metrics == { - "total_tokens": 219, - "prompt_tokens": 201, - "completion_tokens": 18, - "successful_requests": 1, - } + assert crew.usage_metrics == UsageMetrics( + total_tokens=219, + prompt_tokens=201, + completion_tokens=18, + successful_requests=1, + ) @pytest.mark.vcr(filter_headers=["authorization"]) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index e7b7a745c..7f3cb4bf9 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -10,13 +10,12 @@ from crewai.pipeline.pipeline_run_result import PipelineRunResult from crewai.process import Process from crewai.task import Task from crewai.tasks.task_output import TaskOutput +from crewai.types.usage_metrics import UsageMetrics from pydantic import BaseModel, ValidationError -DEFAULT_TOKEN_USAGE = { - "total_tokens": 100, - "prompt_tokens": 50, - "completion_tokens": 50, -} +DEFAULT_TOKEN_USAGE = UsageMetrics( + total_tokens=100, prompt_tokens=50, completion_tokens=50, successful_requests=3 +) @pytest.fixture @@ -443,6 +442,7 @@ Options: - Should the final output include the accumulation of previous stages' outputs? """ + @pytest.mark.asyncio async def test_pipeline_data_accumulation(mock_crew_factory): crew1 = mock_crew_factory(name="Crew 1", output_json_dict={"key1": "value1"})