From 39d6a9a643c399fe76e495b4f498ab0469c8e5d3 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 10 Jul 2024 13:51:54 -0400 Subject: [PATCH] Update validators and tests --- src/crewai/crew.py | 40 ++++++++++++++++++++++++------ tests/crew_test.py | 61 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 90 insertions(+), 11 deletions(-) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 8256c53dc..954a8f583 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -276,15 +276,39 @@ class Crew(BaseModel): return self @model_validator(mode="after") - def validate_async_task_cannot_include_async_tasks_in_context(self): - """Validates that if a task is set to be executed asynchronously, it cannot include other asynchronous tasks in its context.""" - for task in self.tasks: + def validate_async_task_cannot_include_sequential_async_tasks_in_context(self): + """ + Validates that if a task is set to be executed asynchronously, + it cannot include other asynchronous tasks in its context unless + separated by a synchronous task. + """ + for i, task in enumerate(self.tasks): if task.async_execution and task.context: - async_tasks_in_context = [t for t in task.context if t.async_execution] - if async_tasks_in_context: - raise ValueError( - f"Task '{task.description}' is asynchronous and cannot include other asynchronous tasks in its context." - ) + for context_task in task.context: + if context_task.async_execution: + for j in range(i - 1, -1, -1): + if self.tasks[j] == context_task: + raise ValueError( + f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context." + ) + if not self.tasks[j].async_execution: + break + return self + + @model_validator(mode="after") + def validate_context_no_future_tasks(self): + """Validates that a task's context does not include future tasks.""" + task_indices = {id(task): i for i, task in enumerate(self.tasks)} + + for task in self.tasks: + if task.context: + for context_task in task.context: + if id(context_task) not in task_indices: + continue # Skip context tasks not in the main tasks list + if task_indices[id(context_task)] > task_indices[id(task)]: + raise ValueError( + f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed." + ) return self def _setup_from_config(self): diff --git a/tests/crew_test.py b/tests/crew_test.py index 1d28bf0ce..d7d15e117 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -7,6 +7,7 @@ from unittest.mock import patch import pydantic_core import pytest + from crewai.agent import Agent from crewai.agents.cache import CacheHandler from crewai.crew import Crew @@ -86,7 +87,7 @@ def test_crew_config_conditional_requirement(): ] -def test_async_task_context_validation(): +def test_async_task_cannot_include_sequential_async_tasks_in_context(): task1 = Task( description="Task 1", async_execution=True, @@ -102,15 +103,69 @@ def test_async_task_context_validation(): ) task3 = Task( description="Task 3", + async_execution=True, + expected_output="output", + agent=researcher, + context=[task2], + ) + task4 = Task( + description="Task 4", + expected_output="output", + agent=writer, + ) + task5 = Task( + description="Task 5", + async_execution=True, + expected_output="output", + agent=researcher, + context=[task4], + ) + + # This should raise an error because task2 is async and has task1 in its context without a sync task in between + with pytest.raises( + ValueError, + match="Task 'Task 2' is asynchronous and cannot include other sequential asynchronous tasks in its context.", + ): + Crew(tasks=[task1, task2, task3, task4, task5], agents=[researcher, writer]) + + # This should not raise an error because task5 has a sync task (task4) in its context + try: + Crew(tasks=[task1, task4, task5], agents=[researcher, writer]) + except ValueError: + pytest.fail("Unexpected ValidationError raised") + + +def test_context_no_future_tasks(): + + task2 = Task( + description="Task 2", expected_output="output", agent=researcher, ) + task3 = Task( + description="Task 3", + expected_output="output", + agent=researcher, + context=[task2], + ) + task4 = Task( + description="Task 4", + expected_output="output", + agent=researcher, + ) + task1 = Task( + description="Task 1", + expected_output="output", + agent=researcher, + context=[task4], + ) + # This should raise an error because task1 has a context dependency on a future task (task4) with pytest.raises( ValueError, - match="Task 'Task 2' is asynchronous and cannot include other asynchronous tasks in its context.", + match="Task 'Task 1' has a context dependency on a future task 'Task 4', which is not allowed.", ): - Crew(tasks=[task1, task2, task3], agents=[researcher, writer]) + Crew(tasks=[task1, task2, task3, task4], agents=[researcher, writer]) def test_crew_config_with_wrong_keys():