fix: preserve task outputs across async batch flush

This commit is contained in:
Greyson LaLonde
2026-05-04 20:24:24 +08:00
committed by GitHub
parent a23e118b11
commit f579aa53ae
2 changed files with 126 additions and 7 deletions

View File

@@ -1283,8 +1283,8 @@ class Crew(FlowTrackable, BaseModel):
pending_tasks.append((task, async_task, task_index))
else:
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(
pending_tasks, was_replayed
task_outputs.extend(
await self._aprocess_async_tasks(pending_tasks, was_replayed)
)
pending_tasks.clear()
@@ -1299,7 +1299,9 @@ class Crew(FlowTrackable, BaseModel):
self._store_execution_log(task, task_output, task_index, was_replayed)
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
task_outputs.extend(
await self._aprocess_async_tasks(pending_tasks, was_replayed)
)
return self._create_crew_output(task_outputs)
@@ -1313,7 +1315,9 @@ class Crew(FlowTrackable, BaseModel):
) -> TaskOutput | None:
"""Handle conditional task evaluation using native async."""
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
task_outputs.extend(
await self._aprocess_async_tasks(pending_tasks, was_replayed)
)
pending_tasks.clear()
return check_conditional_skip(
@@ -1489,7 +1493,9 @@ class Crew(FlowTrackable, BaseModel):
futures.append((task, future, task_index))
else:
if futures:
task_outputs = self._process_async_tasks(futures, was_replayed)
task_outputs.extend(
self._process_async_tasks(futures, was_replayed)
)
futures.clear()
context = self._get_context(task, task_outputs)
@@ -1503,7 +1509,7 @@ class Crew(FlowTrackable, BaseModel):
self._store_execution_log(task, task_output, task_index, was_replayed)
if futures:
task_outputs = self._process_async_tasks(futures, was_replayed)
task_outputs.extend(self._process_async_tasks(futures, was_replayed))
return self._create_crew_output(task_outputs)
@@ -1516,7 +1522,7 @@ class Crew(FlowTrackable, BaseModel):
was_replayed: bool,
) -> TaskOutput | None:
if futures:
task_outputs = self._process_async_tasks(futures, was_replayed)
task_outputs.extend(self._process_async_tasks(futures, was_replayed))
futures.clear()
return check_conditional_skip(

View File

@@ -1254,6 +1254,119 @@ async def test_async_task_execution_call_count(researcher, writer):
assert mock_execute_sync.call_count == 1
def test_mixed_sync_async_task_outputs_not_dropped(researcher, writer):
"""Sync outputs accumulated before a pending async batch must survive the flush."""
sync1_output = TaskOutput(description="sync1", raw="s1", agent="researcher")
async1_output = TaskOutput(description="async1", raw="a1", agent="researcher")
sync2_output = TaskOutput(description="sync2", raw="s2", agent="writer")
sync1 = Task(description="sync1", expected_output="x", agent=researcher)
async1 = Task(
description="async1",
expected_output="x",
agent=researcher,
async_execution=True,
)
sync2 = Task(description="sync2", expected_output="x", agent=writer)
sync1.output = sync1_output
async1.output = async1_output
sync2.output = sync2_output
crew = Crew(agents=[researcher, writer], tasks=[sync1, async1, sync2])
mock_future = MagicMock(spec=Future)
mock_future.result.return_value = async1_output
with (
patch.object(
Task, "execute_sync", side_effect=[sync1_output, sync2_output]
),
patch.object(Task, "execute_async", return_value=mock_future),
):
result = crew.kickoff()
assert [o.raw for o in result.tasks_output] == ["s1", "a1", "s2"]
@pytest.mark.asyncio
async def test_mixed_sync_async_task_outputs_not_dropped_native_async(
researcher, writer
):
"""Same regression as the sync path, exercised via akickoff (native async)."""
sync1_output = TaskOutput(description="sync1", raw="s1", agent="researcher")
async1_output = TaskOutput(description="async1", raw="a1", agent="researcher")
sync2_output = TaskOutput(description="sync2", raw="s2", agent="writer")
sync1 = Task(description="sync1", expected_output="x", agent=researcher)
async1 = Task(
description="async1",
expected_output="x",
agent=researcher,
async_execution=True,
)
sync2 = Task(description="sync2", expected_output="x", agent=writer)
sync1.output = sync1_output
async1.output = async1_output
sync2.output = sync2_output
crew = Crew(agents=[researcher, writer], tasks=[sync1, async1, sync2])
aexecute_outputs = iter([sync1_output, async1_output, sync2_output])
async def fake_aexecute_sync(*_args: Any, **_kwargs: Any) -> TaskOutput:
return next(aexecute_outputs)
with patch.object(Task, "aexecute_sync", side_effect=fake_aexecute_sync):
result = await crew.akickoff()
assert [o.raw for o in result.tasks_output] == ["s1", "a1", "s2"]
def test_pending_async_outputs_preserved_through_conditional_task(researcher, writer):
"""A conditional task encountered after a pending async batch must not silently drop the async output."""
sync1_output = TaskOutput(description="sync1", raw="s1", agent="researcher")
async1_output = TaskOutput(description="async1", raw="a1", agent="researcher")
def always_skip(_: TaskOutput) -> bool:
return False
sync1 = Task(description="sync1", expected_output="x", agent=researcher)
async1 = Task(
description="async1",
expected_output="x",
agent=researcher,
async_execution=True,
)
conditional = ConditionalTask(
description="conditional",
expected_output="x",
agent=writer,
condition=always_skip,
)
sync1.output = sync1_output
async1.output = async1_output
crew = Crew(
agents=[researcher, writer], tasks=[sync1, async1, conditional]
)
mock_future = MagicMock(spec=Future)
mock_future.result.return_value = async1_output
with (
patch.object(Task, "execute_sync", return_value=sync1_output),
patch.object(Task, "execute_async", return_value=mock_future),
):
result = crew.kickoff()
raws = [o.raw for o in result.tasks_output]
assert raws[:2] == ["s1", "a1"]
assert len(result.tasks_output) == 3
@pytest.mark.vcr()
def test_kickoff_for_each_single_input():
"""Tests if kickoff_for_each works with a single input."""