mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Update validators and tests
This commit is contained in:
@@ -276,15 +276,39 @@ class Crew(BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_async_task_cannot_include_async_tasks_in_context(self):
|
def validate_async_task_cannot_include_sequential_async_tasks_in_context(self):
|
||||||
"""Validates that if a task is set to be executed asynchronously, it cannot include other asynchronous tasks in its context."""
|
"""
|
||||||
for task in self.tasks:
|
Validates that if a task is set to be executed asynchronously,
|
||||||
|
it cannot include other asynchronous tasks in its context unless
|
||||||
|
separated by a synchronous task.
|
||||||
|
"""
|
||||||
|
for i, task in enumerate(self.tasks):
|
||||||
if task.async_execution and task.context:
|
if task.async_execution and task.context:
|
||||||
async_tasks_in_context = [t for t in task.context if t.async_execution]
|
for context_task in task.context:
|
||||||
if async_tasks_in_context:
|
if context_task.async_execution:
|
||||||
raise ValueError(
|
for j in range(i - 1, -1, -1):
|
||||||
f"Task '{task.description}' is asynchronous and cannot include other asynchronous tasks in its context."
|
if self.tasks[j] == context_task:
|
||||||
)
|
raise ValueError(
|
||||||
|
f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context."
|
||||||
|
)
|
||||||
|
if not self.tasks[j].async_execution:
|
||||||
|
break
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_context_no_future_tasks(self):
|
||||||
|
"""Validates that a task's context does not include future tasks."""
|
||||||
|
task_indices = {id(task): i for i, task in enumerate(self.tasks)}
|
||||||
|
|
||||||
|
for task in self.tasks:
|
||||||
|
if task.context:
|
||||||
|
for context_task in task.context:
|
||||||
|
if id(context_task) not in task_indices:
|
||||||
|
continue # Skip context tasks not in the main tasks list
|
||||||
|
if task_indices[id(context_task)] > task_indices[id(task)]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed."
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _setup_from_config(self):
|
def _setup_from_config(self):
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pydantic_core
|
import pydantic_core
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.agents.cache import CacheHandler
|
from crewai.agents.cache import CacheHandler
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
@@ -86,7 +87,7 @@ def test_crew_config_conditional_requirement():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_async_task_context_validation():
|
def test_async_task_cannot_include_sequential_async_tasks_in_context():
|
||||||
task1 = Task(
|
task1 = Task(
|
||||||
description="Task 1",
|
description="Task 1",
|
||||||
async_execution=True,
|
async_execution=True,
|
||||||
@@ -102,15 +103,69 @@ def test_async_task_context_validation():
|
|||||||
)
|
)
|
||||||
task3 = Task(
|
task3 = Task(
|
||||||
description="Task 3",
|
description="Task 3",
|
||||||
|
async_execution=True,
|
||||||
|
expected_output="output",
|
||||||
|
agent=researcher,
|
||||||
|
context=[task2],
|
||||||
|
)
|
||||||
|
task4 = Task(
|
||||||
|
description="Task 4",
|
||||||
|
expected_output="output",
|
||||||
|
agent=writer,
|
||||||
|
)
|
||||||
|
task5 = Task(
|
||||||
|
description="Task 5",
|
||||||
|
async_execution=True,
|
||||||
|
expected_output="output",
|
||||||
|
agent=researcher,
|
||||||
|
context=[task4],
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should raise an error because task2 is async and has task1 in its context without a sync task in between
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Task 'Task 2' is asynchronous and cannot include other sequential asynchronous tasks in its context.",
|
||||||
|
):
|
||||||
|
Crew(tasks=[task1, task2, task3, task4, task5], agents=[researcher, writer])
|
||||||
|
|
||||||
|
# This should not raise an error because task5 has a sync task (task4) in its context
|
||||||
|
try:
|
||||||
|
Crew(tasks=[task1, task4, task5], agents=[researcher, writer])
|
||||||
|
except ValueError:
|
||||||
|
pytest.fail("Unexpected ValidationError raised")
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_no_future_tasks():
|
||||||
|
|
||||||
|
task2 = Task(
|
||||||
|
description="Task 2",
|
||||||
expected_output="output",
|
expected_output="output",
|
||||||
agent=researcher,
|
agent=researcher,
|
||||||
)
|
)
|
||||||
|
task3 = Task(
|
||||||
|
description="Task 3",
|
||||||
|
expected_output="output",
|
||||||
|
agent=researcher,
|
||||||
|
context=[task2],
|
||||||
|
)
|
||||||
|
task4 = Task(
|
||||||
|
description="Task 4",
|
||||||
|
expected_output="output",
|
||||||
|
agent=researcher,
|
||||||
|
)
|
||||||
|
task1 = Task(
|
||||||
|
description="Task 1",
|
||||||
|
expected_output="output",
|
||||||
|
agent=researcher,
|
||||||
|
context=[task4],
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should raise an error because task1 has a context dependency on a future task (task4)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="Task 'Task 2' is asynchronous and cannot include other asynchronous tasks in its context.",
|
match="Task 'Task 1' has a context dependency on a future task 'Task 4', which is not allowed.",
|
||||||
):
|
):
|
||||||
Crew(tasks=[task1, task2, task3], agents=[researcher, writer])
|
Crew(tasks=[task1, task2, task3, task4], agents=[researcher, writer])
|
||||||
|
|
||||||
|
|
||||||
def test_crew_config_with_wrong_keys():
|
def test_crew_config_with_wrong_keys():
|
||||||
|
|||||||
Reference in New Issue
Block a user