From d40846e6afc99fa8cc0fc79cb637eeeb8fbe9eea Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Wed, 15 Oct 2025 13:37:53 -0700 Subject: [PATCH] fix test --- tests/test_crew_thread_safety.py | 61 +++++++++++++++++--------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/tests/test_crew_thread_safety.py b/tests/test_crew_thread_safety.py index 145a0405c..ac458e8ca 100644 --- a/tests/test_crew_thread_safety.py +++ b/tests/test_crew_thread_safety.py @@ -1,7 +1,8 @@ import asyncio import threading +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Any, Callable +from typing import Any from unittest.mock import patch import pytest @@ -24,9 +25,12 @@ def simple_agent_factory(): @pytest.fixture def simple_task_factory(): - def create_task(name: str, callback: Callable = None) -> Task: + def create_task(name: str, agent: Agent, callback: Callable | None = None) -> Task: return Task( - description=f"Task for {name}", expected_output="Done", callback=callback + description=f"Task for {name}", + expected_output="Done", + agent=agent, + callback=callback, ) return create_task @@ -34,10 +38,9 @@ def simple_task_factory(): @pytest.fixture def crew_factory(simple_agent_factory, simple_task_factory): - def create_crew(name: str, task_callback: Callable = None) -> Crew: + def create_crew(name: str, task_callback: Callable | None = None) -> Crew: agent = simple_agent_factory(name) - task = simple_task_factory(name, callback=task_callback) - task.agent = agent + task = simple_task_factory(name, agent=agent, callback=task_callback) return Crew(agents=[agent], tasks=[task], verbose=False) @@ -50,7 +53,7 @@ class TestCrewThreadSafety: mock_execute_task.return_value = "Task completed" num_crews = 5 - def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]: + def run_crew_with_context_check(crew_id: str) -> dict[str, Any]: results = {"crew_id": crew_id, "contexts": []} def check_context_task(output): @@ -105,28 +108,28 @@ class TestCrewThreadSafety: before_ctx = next( ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff" ) - assert ( - before_ctx["crew_id"] is None - ), f"Context should be None before kickoff for {result['crew_id']}" + assert before_ctx["crew_id"] is None, ( + f"Context should be None before kickoff for {result['crew_id']}" + ) task_ctx = next( ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback" ) - assert ( - task_ctx["crew_id"] == crew_uuid - ), f"Context mismatch during task for {result['crew_id']}" + assert task_ctx["crew_id"] == crew_uuid, ( + f"Context mismatch during task for {result['crew_id']}" + ) after_ctx = next( ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff" ) - assert ( - after_ctx["crew_id"] is None - ), f"Context should be None after kickoff for {result['crew_id']}" + assert after_ctx["crew_id"] is None, ( + f"Context should be None after kickoff for {result['crew_id']}" + ) thread_name = before_ctx["thread"] - assert ( - "ThreadPoolExecutor" in thread_name - ), f"Should run in thread pool for {result['crew_id']}" + assert "ThreadPoolExecutor" in thread_name, ( + f"Should run in thread pool for {result['crew_id']}" + ) @pytest.mark.asyncio @patch("crewai.Agent.execute_task") @@ -134,7 +137,7 @@ class TestCrewThreadSafety: mock_execute_task.return_value = "Task completed" num_crews = 5 - async def run_crew_async(crew_id: str) -> Dict[str, Any]: + async def run_crew_async(crew_id: str) -> dict[str, Any]: task_context = {"crew_id": crew_id, "context": None} def capture_context(output): @@ -162,12 +165,12 @@ class TestCrewThreadSafety: crew_uuid = result["crew_uuid"] task_ctx = result["task_context"]["context"] - assert ( - task_ctx is not None - ), f"Context should exist during task for {result['crew_id']}" - assert ( - task_ctx["crew_id"] == crew_uuid - ), f"Context mismatch for {result['crew_id']}" + assert task_ctx is not None, ( + f"Context should exist during task for {result['crew_id']}" + ) + assert task_ctx["crew_id"] == crew_uuid, ( + f"Context mismatch for {result['crew_id']}" + ) @patch("crewai.Agent.execute_task") def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory): @@ -193,9 +196,9 @@ class TestCrewThreadSafety: assert len(contexts_captured) == len(inputs) context_ids = [ctx["context_id"] for ctx in contexts_captured] - assert len(set(context_ids)) == len( - inputs - ), "Each execution should have unique context" + assert len(set(context_ids)) == len(inputs), ( + "Each execution should have unique context" + ) @patch("crewai.Agent.execute_task") def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):