mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 05:38:12 +00:00
fix: preserve task outputs across async batch flush
This commit is contained in:
@@ -1283,8 +1283,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
pending_tasks.append((task, async_task, task_index))
|
||||
else:
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(
|
||||
pending_tasks, was_replayed
|
||||
task_outputs.extend(
|
||||
await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
)
|
||||
pending_tasks.clear()
|
||||
|
||||
@@ -1299,7 +1299,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
task_outputs.extend(
|
||||
await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
)
|
||||
|
||||
return self._create_crew_output(task_outputs)
|
||||
|
||||
@@ -1313,7 +1315,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
) -> TaskOutput | None:
|
||||
"""Handle conditional task evaluation using native async."""
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
task_outputs.extend(
|
||||
await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
)
|
||||
pending_tasks.clear()
|
||||
|
||||
return check_conditional_skip(
|
||||
@@ -1489,7 +1493,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
task_outputs.extend(
|
||||
self._process_async_tasks(futures, was_replayed)
|
||||
)
|
||||
futures.clear()
|
||||
|
||||
context = self._get_context(task, task_outputs)
|
||||
@@ -1503,7 +1509,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
task_outputs.extend(self._process_async_tasks(futures, was_replayed))
|
||||
|
||||
return self._create_crew_output(task_outputs)
|
||||
|
||||
@@ -1516,7 +1522,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
was_replayed: bool,
|
||||
) -> TaskOutput | None:
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
task_outputs.extend(self._process_async_tasks(futures, was_replayed))
|
||||
futures.clear()
|
||||
|
||||
return check_conditional_skip(
|
||||
|
||||
@@ -1254,6 +1254,119 @@ async def test_async_task_execution_call_count(researcher, writer):
|
||||
assert mock_execute_sync.call_count == 1
|
||||
|
||||
|
||||
def test_mixed_sync_async_task_outputs_not_dropped(researcher, writer):
|
||||
"""Sync outputs accumulated before a pending async batch must survive the flush."""
|
||||
sync1_output = TaskOutput(description="sync1", raw="s1", agent="researcher")
|
||||
async1_output = TaskOutput(description="async1", raw="a1", agent="researcher")
|
||||
sync2_output = TaskOutput(description="sync2", raw="s2", agent="writer")
|
||||
|
||||
sync1 = Task(description="sync1", expected_output="x", agent=researcher)
|
||||
async1 = Task(
|
||||
description="async1",
|
||||
expected_output="x",
|
||||
agent=researcher,
|
||||
async_execution=True,
|
||||
)
|
||||
sync2 = Task(description="sync2", expected_output="x", agent=writer)
|
||||
|
||||
sync1.output = sync1_output
|
||||
async1.output = async1_output
|
||||
sync2.output = sync2_output
|
||||
|
||||
crew = Crew(agents=[researcher, writer], tasks=[sync1, async1, sync2])
|
||||
|
||||
mock_future = MagicMock(spec=Future)
|
||||
mock_future.result.return_value = async1_output
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
Task, "execute_sync", side_effect=[sync1_output, sync2_output]
|
||||
),
|
||||
patch.object(Task, "execute_async", return_value=mock_future),
|
||||
):
|
||||
result = crew.kickoff()
|
||||
|
||||
assert [o.raw for o in result.tasks_output] == ["s1", "a1", "s2"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_sync_async_task_outputs_not_dropped_native_async(
|
||||
researcher, writer
|
||||
):
|
||||
"""Same regression as the sync path, exercised via akickoff (native async)."""
|
||||
sync1_output = TaskOutput(description="sync1", raw="s1", agent="researcher")
|
||||
async1_output = TaskOutput(description="async1", raw="a1", agent="researcher")
|
||||
sync2_output = TaskOutput(description="sync2", raw="s2", agent="writer")
|
||||
|
||||
sync1 = Task(description="sync1", expected_output="x", agent=researcher)
|
||||
async1 = Task(
|
||||
description="async1",
|
||||
expected_output="x",
|
||||
agent=researcher,
|
||||
async_execution=True,
|
||||
)
|
||||
sync2 = Task(description="sync2", expected_output="x", agent=writer)
|
||||
|
||||
sync1.output = sync1_output
|
||||
async1.output = async1_output
|
||||
sync2.output = sync2_output
|
||||
|
||||
crew = Crew(agents=[researcher, writer], tasks=[sync1, async1, sync2])
|
||||
|
||||
aexecute_outputs = iter([sync1_output, async1_output, sync2_output])
|
||||
|
||||
async def fake_aexecute_sync(*_args: Any, **_kwargs: Any) -> TaskOutput:
|
||||
return next(aexecute_outputs)
|
||||
|
||||
with patch.object(Task, "aexecute_sync", side_effect=fake_aexecute_sync):
|
||||
result = await crew.akickoff()
|
||||
|
||||
assert [o.raw for o in result.tasks_output] == ["s1", "a1", "s2"]
|
||||
|
||||
|
||||
def test_pending_async_outputs_preserved_through_conditional_task(researcher, writer):
|
||||
"""A conditional task encountered after a pending async batch must not silently drop the async output."""
|
||||
sync1_output = TaskOutput(description="sync1", raw="s1", agent="researcher")
|
||||
async1_output = TaskOutput(description="async1", raw="a1", agent="researcher")
|
||||
|
||||
def always_skip(_: TaskOutput) -> bool:
|
||||
return False
|
||||
|
||||
sync1 = Task(description="sync1", expected_output="x", agent=researcher)
|
||||
async1 = Task(
|
||||
description="async1",
|
||||
expected_output="x",
|
||||
agent=researcher,
|
||||
async_execution=True,
|
||||
)
|
||||
conditional = ConditionalTask(
|
||||
description="conditional",
|
||||
expected_output="x",
|
||||
agent=writer,
|
||||
condition=always_skip,
|
||||
)
|
||||
|
||||
sync1.output = sync1_output
|
||||
async1.output = async1_output
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer], tasks=[sync1, async1, conditional]
|
||||
)
|
||||
|
||||
mock_future = MagicMock(spec=Future)
|
||||
mock_future.result.return_value = async1_output
|
||||
|
||||
with (
|
||||
patch.object(Task, "execute_sync", return_value=sync1_output),
|
||||
patch.object(Task, "execute_async", return_value=mock_future),
|
||||
):
|
||||
result = crew.kickoff()
|
||||
|
||||
raws = [o.raw for o in result.tasks_output]
|
||||
assert raws[:2] == ["s1", "a1"]
|
||||
assert len(result.tasks_output) == 3
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_kickoff_for_each_single_input():
|
||||
"""Tests if kickoff_for_each works with a single input."""
|
||||
|
||||
Reference in New Issue
Block a user