mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
fix test
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
|
from collections.abc import Callable
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Dict, Any, Callable
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -24,9 +25,12 @@ def simple_agent_factory():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def simple_task_factory():
|
def simple_task_factory():
|
||||||
def create_task(name: str, callback: Callable = None) -> Task:
|
def create_task(name: str, agent: Agent, callback: Callable | None = None) -> Task:
|
||||||
return Task(
|
return Task(
|
||||||
description=f"Task for {name}", expected_output="Done", callback=callback
|
description=f"Task for {name}",
|
||||||
|
expected_output="Done",
|
||||||
|
agent=agent,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
return create_task
|
return create_task
|
||||||
@@ -34,10 +38,9 @@ def simple_task_factory():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def crew_factory(simple_agent_factory, simple_task_factory):
|
def crew_factory(simple_agent_factory, simple_task_factory):
|
||||||
def create_crew(name: str, task_callback: Callable = None) -> Crew:
|
def create_crew(name: str, task_callback: Callable | None = None) -> Crew:
|
||||||
agent = simple_agent_factory(name)
|
agent = simple_agent_factory(name)
|
||||||
task = simple_task_factory(name, callback=task_callback)
|
task = simple_task_factory(name, agent=agent, callback=task_callback)
|
||||||
task.agent = agent
|
|
||||||
|
|
||||||
return Crew(agents=[agent], tasks=[task], verbose=False)
|
return Crew(agents=[agent], tasks=[task], verbose=False)
|
||||||
|
|
||||||
@@ -50,7 +53,7 @@ class TestCrewThreadSafety:
|
|||||||
mock_execute_task.return_value = "Task completed"
|
mock_execute_task.return_value = "Task completed"
|
||||||
num_crews = 5
|
num_crews = 5
|
||||||
|
|
||||||
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
|
def run_crew_with_context_check(crew_id: str) -> dict[str, Any]:
|
||||||
results = {"crew_id": crew_id, "contexts": []}
|
results = {"crew_id": crew_id, "contexts": []}
|
||||||
|
|
||||||
def check_context_task(output):
|
def check_context_task(output):
|
||||||
@@ -105,28 +108,28 @@ class TestCrewThreadSafety:
|
|||||||
before_ctx = next(
|
before_ctx = next(
|
||||||
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
|
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
|
||||||
)
|
)
|
||||||
assert (
|
assert before_ctx["crew_id"] is None, (
|
||||||
before_ctx["crew_id"] is None
|
f"Context should be None before kickoff for {result['crew_id']}"
|
||||||
), f"Context should be None before kickoff for {result['crew_id']}"
|
)
|
||||||
|
|
||||||
task_ctx = next(
|
task_ctx = next(
|
||||||
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
|
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
|
||||||
)
|
)
|
||||||
assert (
|
assert task_ctx["crew_id"] == crew_uuid, (
|
||||||
task_ctx["crew_id"] == crew_uuid
|
f"Context mismatch during task for {result['crew_id']}"
|
||||||
), f"Context mismatch during task for {result['crew_id']}"
|
)
|
||||||
|
|
||||||
after_ctx = next(
|
after_ctx = next(
|
||||||
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
|
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
|
||||||
)
|
)
|
||||||
assert (
|
assert after_ctx["crew_id"] is None, (
|
||||||
after_ctx["crew_id"] is None
|
f"Context should be None after kickoff for {result['crew_id']}"
|
||||||
), f"Context should be None after kickoff for {result['crew_id']}"
|
)
|
||||||
|
|
||||||
thread_name = before_ctx["thread"]
|
thread_name = before_ctx["thread"]
|
||||||
assert (
|
assert "ThreadPoolExecutor" in thread_name, (
|
||||||
"ThreadPoolExecutor" in thread_name
|
f"Should run in thread pool for {result['crew_id']}"
|
||||||
), f"Should run in thread pool for {result['crew_id']}"
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("crewai.Agent.execute_task")
|
@patch("crewai.Agent.execute_task")
|
||||||
@@ -134,7 +137,7 @@ class TestCrewThreadSafety:
|
|||||||
mock_execute_task.return_value = "Task completed"
|
mock_execute_task.return_value = "Task completed"
|
||||||
num_crews = 5
|
num_crews = 5
|
||||||
|
|
||||||
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
|
async def run_crew_async(crew_id: str) -> dict[str, Any]:
|
||||||
task_context = {"crew_id": crew_id, "context": None}
|
task_context = {"crew_id": crew_id, "context": None}
|
||||||
|
|
||||||
def capture_context(output):
|
def capture_context(output):
|
||||||
@@ -162,12 +165,12 @@ class TestCrewThreadSafety:
|
|||||||
crew_uuid = result["crew_uuid"]
|
crew_uuid = result["crew_uuid"]
|
||||||
task_ctx = result["task_context"]["context"]
|
task_ctx = result["task_context"]["context"]
|
||||||
|
|
||||||
assert (
|
assert task_ctx is not None, (
|
||||||
task_ctx is not None
|
f"Context should exist during task for {result['crew_id']}"
|
||||||
), f"Context should exist during task for {result['crew_id']}"
|
)
|
||||||
assert (
|
assert task_ctx["crew_id"] == crew_uuid, (
|
||||||
task_ctx["crew_id"] == crew_uuid
|
f"Context mismatch for {result['crew_id']}"
|
||||||
), f"Context mismatch for {result['crew_id']}"
|
)
|
||||||
|
|
||||||
@patch("crewai.Agent.execute_task")
|
@patch("crewai.Agent.execute_task")
|
||||||
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
|
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
|
||||||
@@ -193,9 +196,9 @@ class TestCrewThreadSafety:
|
|||||||
assert len(contexts_captured) == len(inputs)
|
assert len(contexts_captured) == len(inputs)
|
||||||
|
|
||||||
context_ids = [ctx["context_id"] for ctx in contexts_captured]
|
context_ids = [ctx["context_id"] for ctx in contexts_captured]
|
||||||
assert len(set(context_ids)) == len(
|
assert len(set(context_ids)) == len(inputs), (
|
||||||
inputs
|
"Each execution should have unique context"
|
||||||
), "Each execution should have unique context"
|
)
|
||||||
|
|
||||||
@patch("crewai.Agent.execute_task")
|
@patch("crewai.Agent.execute_task")
|
||||||
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):
|
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):
|
||||||
|
|||||||
Reference in New Issue
Block a user