mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
fix test
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, Callable
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -24,9 +25,12 @@ def simple_agent_factory():
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task_factory():
|
||||
def create_task(name: str, callback: Callable = None) -> Task:
|
||||
def create_task(name: str, agent: Agent, callback: Callable | None = None) -> Task:
|
||||
return Task(
|
||||
description=f"Task for {name}", expected_output="Done", callback=callback
|
||||
description=f"Task for {name}",
|
||||
expected_output="Done",
|
||||
agent=agent,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
return create_task
|
||||
@@ -34,10 +38,9 @@ def simple_task_factory():
|
||||
|
||||
@pytest.fixture
|
||||
def crew_factory(simple_agent_factory, simple_task_factory):
|
||||
def create_crew(name: str, task_callback: Callable = None) -> Crew:
|
||||
def create_crew(name: str, task_callback: Callable | None = None) -> Crew:
|
||||
agent = simple_agent_factory(name)
|
||||
task = simple_task_factory(name, callback=task_callback)
|
||||
task.agent = agent
|
||||
task = simple_task_factory(name, agent=agent, callback=task_callback)
|
||||
|
||||
return Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
@@ -50,7 +53,7 @@ class TestCrewThreadSafety:
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
|
||||
def run_crew_with_context_check(crew_id: str) -> dict[str, Any]:
|
||||
results = {"crew_id": crew_id, "contexts": []}
|
||||
|
||||
def check_context_task(output):
|
||||
@@ -105,28 +108,28 @@ class TestCrewThreadSafety:
|
||||
before_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
|
||||
)
|
||||
assert (
|
||||
before_ctx["crew_id"] is None
|
||||
), f"Context should be None before kickoff for {result['crew_id']}"
|
||||
assert before_ctx["crew_id"] is None, (
|
||||
f"Context should be None before kickoff for {result['crew_id']}"
|
||||
)
|
||||
|
||||
task_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
|
||||
)
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch during task for {result['crew_id']}"
|
||||
assert task_ctx["crew_id"] == crew_uuid, (
|
||||
f"Context mismatch during task for {result['crew_id']}"
|
||||
)
|
||||
|
||||
after_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
|
||||
)
|
||||
assert (
|
||||
after_ctx["crew_id"] is None
|
||||
), f"Context should be None after kickoff for {result['crew_id']}"
|
||||
assert after_ctx["crew_id"] is None, (
|
||||
f"Context should be None after kickoff for {result['crew_id']}"
|
||||
)
|
||||
|
||||
thread_name = before_ctx["thread"]
|
||||
assert (
|
||||
"ThreadPoolExecutor" in thread_name
|
||||
), f"Should run in thread pool for {result['crew_id']}"
|
||||
assert "ThreadPoolExecutor" in thread_name, (
|
||||
f"Should run in thread pool for {result['crew_id']}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.execute_task")
|
||||
@@ -134,7 +137,7 @@ class TestCrewThreadSafety:
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
|
||||
async def run_crew_async(crew_id: str) -> dict[str, Any]:
|
||||
task_context = {"crew_id": crew_id, "context": None}
|
||||
|
||||
def capture_context(output):
|
||||
@@ -162,12 +165,12 @@ class TestCrewThreadSafety:
|
||||
crew_uuid = result["crew_uuid"]
|
||||
task_ctx = result["task_context"]["context"]
|
||||
|
||||
assert (
|
||||
task_ctx is not None
|
||||
), f"Context should exist during task for {result['crew_id']}"
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch for {result['crew_id']}"
|
||||
assert task_ctx is not None, (
|
||||
f"Context should exist during task for {result['crew_id']}"
|
||||
)
|
||||
assert task_ctx["crew_id"] == crew_uuid, (
|
||||
f"Context mismatch for {result['crew_id']}"
|
||||
)
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
|
||||
@@ -193,9 +196,9 @@ class TestCrewThreadSafety:
|
||||
assert len(contexts_captured) == len(inputs)
|
||||
|
||||
context_ids = [ctx["context_id"] for ctx in contexts_captured]
|
||||
assert len(set(context_ids)) == len(
|
||||
inputs
|
||||
), "Each execution should have unique context"
|
||||
assert len(set(context_ids)) == len(inputs), (
|
||||
"Each execution should have unique context"
|
||||
)
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):
|
||||
|
||||
Reference in New Issue
Block a user