mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Add kickoff_for_each_parallel method using ThreadPoolExecutor to fix issue #2406
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
@@ -707,6 +708,62 @@ class Crew(BaseModel):
|
||||
self.usage_metrics = total_usage_metrics
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
def kickoff_for_each_parallel(self, inputs: List[Dict[str, Any]], max_workers: Optional[int] = None) -> List[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input in the list in parallel using ThreadPoolExecutor.
|
||||
|
||||
Args:
|
||||
inputs: List of input dictionaries to be passed to each crew execution.
|
||||
max_workers: Maximum number of worker threads to use. If None, uses the default
|
||||
ThreadPoolExecutor behavior (typically min(32, os.cpu_count() + 4)).
|
||||
|
||||
Returns:
|
||||
List of CrewOutput objects, one for each input.
|
||||
"""
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
if not isinstance(inputs, list):
|
||||
raise TypeError("Inputs must be a list of dictionaries.")
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
results: List[CrewOutput] = []
|
||||
|
||||
# Initialize the parent crew's usage metrics
|
||||
total_usage_metrics = UsageMetrics()
|
||||
|
||||
# Create a copy of the crew for each input to avoid state conflicts
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
# Execute each crew in parallel
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks to the executor
|
||||
future_to_crew = {
|
||||
executor.submit(crew_copies[i].kickoff, inputs[i]): i
|
||||
for i in range(len(inputs))
|
||||
}
|
||||
|
||||
# Process results as they complete
|
||||
for future in concurrent.futures.as_completed(future_to_crew):
|
||||
crew_index = future_to_crew[future]
|
||||
try:
|
||||
output = future.result()
|
||||
results.append(output)
|
||||
|
||||
# Aggregate usage metrics
|
||||
if crew_copies[crew_index].usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copies[crew_index].usage_metrics)
|
||||
except Exception as exc:
|
||||
# Re-raise the exception to maintain consistent behavior with kickoff_for_each
|
||||
raise exc
|
||||
|
||||
# Set the aggregated metrics on the parent crew
|
||||
self.usage_metrics = total_usage_metrics
|
||||
self._task_output_handler.reset()
|
||||
|
||||
return results
|
||||
|
||||
def _handle_crew_planning(self):
|
||||
"""Handles the Crew planning."""
|
||||
|
||||
208
tests/test_kickoff_for_each_parallel.py
Normal file
208
tests/test_kickoff_for_each_parallel.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Test for the kickoff_for_each_parallel method in Crew class."""
|
||||
|
||||
import concurrent.futures
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
def test_kickoff_for_each_parallel_single_input():
|
||||
"""Tests if kickoff_for_each_parallel works with a single input."""
|
||||
|
||||
inputs = [{"topic": "dog"}]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
# Mock the kickoff method to avoid API calls
|
||||
expected_output = CrewOutput(raw="Dogs are loyal companions.")
|
||||
with patch.object(Crew, "kickoff", return_value=expected_output):
|
||||
results = crew.kickoff_for_each_parallel(inputs=inputs)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].raw == "Dogs are loyal companions."
|
||||
|
||||
|
||||
def test_kickoff_for_each_parallel_multiple_inputs():
|
||||
"""Tests if kickoff_for_each_parallel works with multiple inputs."""
|
||||
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
{"topic": "cat"},
|
||||
{"topic": "apple"},
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
# Mock the kickoff method to avoid API calls
|
||||
expected_outputs = [
|
||||
CrewOutput(raw="Dogs are loyal companions."),
|
||||
CrewOutput(raw="Cats are independent pets."),
|
||||
CrewOutput(raw="Apples are nutritious fruits."),
|
||||
]
|
||||
|
||||
with patch.object(Crew, "copy") as mock_copy:
|
||||
# Setup mock crew copies
|
||||
crew_copies = []
|
||||
for i in range(len(inputs)):
|
||||
crew_copy = MagicMock()
|
||||
crew_copy.kickoff.return_value = expected_outputs[i]
|
||||
crew_copies.append(crew_copy)
|
||||
mock_copy.side_effect = crew_copies
|
||||
|
||||
results = crew.kickoff_for_each_parallel(inputs=inputs)
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
# Since ThreadPoolExecutor returns results in completion order, not input order,
|
||||
# we just check that all expected outputs are in the results
|
||||
result_texts = [result.raw for result in results]
|
||||
expected_texts = [output.raw for output in expected_outputs]
|
||||
for expected_text in expected_texts:
|
||||
assert expected_text in result_texts
|
||||
|
||||
|
||||
def test_kickoff_for_each_parallel_empty_input():
|
||||
"""Tests if kickoff_for_each_parallel handles an empty input list."""
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
results = crew.kickoff_for_each_parallel(inputs=[])
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_kickoff_for_each_parallel_invalid_input():
|
||||
"""Tests if kickoff_for_each_parallel raises TypeError for invalid input types."""
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
# No need to mock here since we're testing input validation which happens before any API calls
|
||||
with pytest.raises(TypeError):
|
||||
# Pass a string instead of a list
|
||||
crew.kickoff_for_each_parallel("invalid input")
|
||||
|
||||
|
||||
def test_kickoff_for_each_parallel_error_handling():
|
||||
"""Tests error handling in kickoff_for_each_parallel when kickoff raises an error."""
|
||||
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
{"topic": "cat"},
|
||||
{"topic": "apple"},
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
with patch.object(Crew, "copy") as mock_copy:
|
||||
# Setup mock crew copies
|
||||
crew_copies = []
|
||||
for i in range(len(inputs)):
|
||||
crew_copy = MagicMock()
|
||||
# Make the third crew copy raise an exception
|
||||
if i == 2:
|
||||
crew_copy.kickoff.side_effect = Exception("Simulated kickoff error")
|
||||
else:
|
||||
crew_copy.kickoff.return_value = f"Output for {inputs[i]['topic']}"
|
||||
crew_copies.append(crew_copy)
|
||||
mock_copy.side_effect = crew_copies
|
||||
|
||||
with pytest.raises(Exception, match="Simulated kickoff error"):
|
||||
crew.kickoff_for_each_parallel(inputs=inputs)
|
||||
|
||||
|
||||
def test_kickoff_for_each_parallel_max_workers():
|
||||
"""Tests if kickoff_for_each_parallel respects the max_workers parameter."""
|
||||
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
{"topic": "cat"},
|
||||
{"topic": "apple"},
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
# Mock both ThreadPoolExecutor and crew.copy to avoid API calls
|
||||
with patch.object(concurrent.futures, "ThreadPoolExecutor", wraps=concurrent.futures.ThreadPoolExecutor) as mock_executor:
|
||||
with patch.object(Crew, "copy") as mock_copy:
|
||||
# Setup mock crew copies
|
||||
crew_copies = []
|
||||
for _ in range(len(inputs)):
|
||||
crew_copy = MagicMock()
|
||||
crew_copy.kickoff.return_value = CrewOutput(raw="Test output")
|
||||
crew_copies.append(crew_copy)
|
||||
mock_copy.side_effect = crew_copies
|
||||
|
||||
crew.kickoff_for_each_parallel(inputs=inputs, max_workers=2)
|
||||
mock_executor.assert_called_once_with(max_workers=2)
|
||||
Reference in New Issue
Block a user