This commit is contained in:
lorenzejay
2025-10-15 13:37:53 -07:00
parent 884236f41c
commit d40846e6af

View File

@@ -1,7 +1,8 @@
import asyncio import asyncio
import threading import threading
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, Callable from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -24,9 +25,12 @@ def simple_agent_factory():
@pytest.fixture @pytest.fixture
def simple_task_factory(): def simple_task_factory():
def create_task(name: str, callback: Callable = None) -> Task: def create_task(name: str, agent: Agent, callback: Callable | None = None) -> Task:
return Task( return Task(
description=f"Task for {name}", expected_output="Done", callback=callback description=f"Task for {name}",
expected_output="Done",
agent=agent,
callback=callback,
) )
return create_task return create_task
@@ -34,10 +38,9 @@ def simple_task_factory():
@pytest.fixture @pytest.fixture
def crew_factory(simple_agent_factory, simple_task_factory): def crew_factory(simple_agent_factory, simple_task_factory):
def create_crew(name: str, task_callback: Callable = None) -> Crew: def create_crew(name: str, task_callback: Callable | None = None) -> Crew:
agent = simple_agent_factory(name) agent = simple_agent_factory(name)
task = simple_task_factory(name, callback=task_callback) task = simple_task_factory(name, agent=agent, callback=task_callback)
task.agent = agent
return Crew(agents=[agent], tasks=[task], verbose=False) return Crew(agents=[agent], tasks=[task], verbose=False)
@@ -50,7 +53,7 @@ class TestCrewThreadSafety:
mock_execute_task.return_value = "Task completed" mock_execute_task.return_value = "Task completed"
num_crews = 5 num_crews = 5
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]: def run_crew_with_context_check(crew_id: str) -> dict[str, Any]:
results = {"crew_id": crew_id, "contexts": []} results = {"crew_id": crew_id, "contexts": []}
def check_context_task(output): def check_context_task(output):
@@ -105,28 +108,28 @@ class TestCrewThreadSafety:
before_ctx = next( before_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff" ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
) )
assert ( assert before_ctx["crew_id"] is None, (
before_ctx["crew_id"] is None f"Context should be None before kickoff for {result['crew_id']}"
), f"Context should be None before kickoff for {result['crew_id']}" )
task_ctx = next( task_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback" ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
) )
assert ( assert task_ctx["crew_id"] == crew_uuid, (
task_ctx["crew_id"] == crew_uuid f"Context mismatch during task for {result['crew_id']}"
), f"Context mismatch during task for {result['crew_id']}" )
after_ctx = next( after_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff" ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
) )
assert ( assert after_ctx["crew_id"] is None, (
after_ctx["crew_id"] is None f"Context should be None after kickoff for {result['crew_id']}"
), f"Context should be None after kickoff for {result['crew_id']}" )
thread_name = before_ctx["thread"] thread_name = before_ctx["thread"]
assert ( assert "ThreadPoolExecutor" in thread_name, (
"ThreadPoolExecutor" in thread_name f"Should run in thread pool for {result['crew_id']}"
), f"Should run in thread pool for {result['crew_id']}" )
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("crewai.Agent.execute_task") @patch("crewai.Agent.execute_task")
@@ -134,7 +137,7 @@ class TestCrewThreadSafety:
mock_execute_task.return_value = "Task completed" mock_execute_task.return_value = "Task completed"
num_crews = 5 num_crews = 5
async def run_crew_async(crew_id: str) -> Dict[str, Any]: async def run_crew_async(crew_id: str) -> dict[str, Any]:
task_context = {"crew_id": crew_id, "context": None} task_context = {"crew_id": crew_id, "context": None}
def capture_context(output): def capture_context(output):
@@ -162,12 +165,12 @@ class TestCrewThreadSafety:
crew_uuid = result["crew_uuid"] crew_uuid = result["crew_uuid"]
task_ctx = result["task_context"]["context"] task_ctx = result["task_context"]["context"]
assert ( assert task_ctx is not None, (
task_ctx is not None f"Context should exist during task for {result['crew_id']}"
), f"Context should exist during task for {result['crew_id']}" )
assert ( assert task_ctx["crew_id"] == crew_uuid, (
task_ctx["crew_id"] == crew_uuid f"Context mismatch for {result['crew_id']}"
), f"Context mismatch for {result['crew_id']}" )
@patch("crewai.Agent.execute_task") @patch("crewai.Agent.execute_task")
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory): def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
@@ -193,9 +196,9 @@ class TestCrewThreadSafety:
assert len(contexts_captured) == len(inputs) assert len(contexts_captured) == len(inputs)
context_ids = [ctx["context_id"] for ctx in contexts_captured] context_ids = [ctx["context_id"] for ctx in contexts_captured]
assert len(set(context_ids)) == len( assert len(set(context_ids)) == len(inputs), (
inputs "Each execution should have unique context"
), "Each execution should have unique context" )
@patch("crewai.Agent.execute_task") @patch("crewai.Agent.execute_task")
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory): def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):