This commit is contained in:
lorenzejay
2025-10-15 13:37:53 -07:00
parent 884236f41c
commit d40846e6af

View File

@@ -1,7 +1,8 @@
import asyncio
import threading
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, Callable
from typing import Any
from unittest.mock import patch
import pytest
@@ -24,9 +25,12 @@ def simple_agent_factory():
@pytest.fixture
def simple_task_factory():
def create_task(name: str, callback: Callable = None) -> Task:
def create_task(name: str, agent: Agent, callback: Callable | None = None) -> Task:
return Task(
description=f"Task for {name}", expected_output="Done", callback=callback
description=f"Task for {name}",
expected_output="Done",
agent=agent,
callback=callback,
)
return create_task
@@ -34,10 +38,9 @@ def simple_task_factory():
@pytest.fixture
def crew_factory(simple_agent_factory, simple_task_factory):
def create_crew(name: str, task_callback: Callable = None) -> Crew:
def create_crew(name: str, task_callback: Callable | None = None) -> Crew:
agent = simple_agent_factory(name)
task = simple_task_factory(name, callback=task_callback)
task.agent = agent
task = simple_task_factory(name, agent=agent, callback=task_callback)
return Crew(agents=[agent], tasks=[task], verbose=False)
@@ -50,7 +53,7 @@ class TestCrewThreadSafety:
mock_execute_task.return_value = "Task completed"
num_crews = 5
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
def run_crew_with_context_check(crew_id: str) -> dict[str, Any]:
results = {"crew_id": crew_id, "contexts": []}
def check_context_task(output):
@@ -105,28 +108,28 @@ class TestCrewThreadSafety:
before_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
)
assert (
before_ctx["crew_id"] is None
), f"Context should be None before kickoff for {result['crew_id']}"
assert before_ctx["crew_id"] is None, (
f"Context should be None before kickoff for {result['crew_id']}"
)
task_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
)
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch during task for {result['crew_id']}"
assert task_ctx["crew_id"] == crew_uuid, (
f"Context mismatch during task for {result['crew_id']}"
)
after_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
)
assert (
after_ctx["crew_id"] is None
), f"Context should be None after kickoff for {result['crew_id']}"
assert after_ctx["crew_id"] is None, (
f"Context should be None after kickoff for {result['crew_id']}"
)
thread_name = before_ctx["thread"]
assert (
"ThreadPoolExecutor" in thread_name
), f"Should run in thread pool for {result['crew_id']}"
assert "ThreadPoolExecutor" in thread_name, (
f"Should run in thread pool for {result['crew_id']}"
)
@pytest.mark.asyncio
@patch("crewai.Agent.execute_task")
@@ -134,7 +137,7 @@ class TestCrewThreadSafety:
mock_execute_task.return_value = "Task completed"
num_crews = 5
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
async def run_crew_async(crew_id: str) -> dict[str, Any]:
task_context = {"crew_id": crew_id, "context": None}
def capture_context(output):
@@ -162,12 +165,12 @@ class TestCrewThreadSafety:
crew_uuid = result["crew_uuid"]
task_ctx = result["task_context"]["context"]
assert (
task_ctx is not None
), f"Context should exist during task for {result['crew_id']}"
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch for {result['crew_id']}"
assert task_ctx is not None, (
f"Context should exist during task for {result['crew_id']}"
)
assert task_ctx["crew_id"] == crew_uuid, (
f"Context mismatch for {result['crew_id']}"
)
@patch("crewai.Agent.execute_task")
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
@@ -193,9 +196,9 @@ class TestCrewThreadSafety:
assert len(contexts_captured) == len(inputs)
context_ids = [ctx["context_id"] for ctx in contexts_captured]
assert len(set(context_ids)) == len(
inputs
), "Each execution should have unique context"
assert len(set(context_ids)) == len(inputs), (
"Each execution should have unique context"
)
@patch("crewai.Agent.execute_task")
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):