Improve kickoff_for_each_parallel based on PR feedback

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-03-19 17:12:09 +00:00
parent fd18bdfabb
commit 15d59cfc34

View File

@@ -7,7 +7,7 @@ import warnings
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy as shallow_copy from copy import copy as shallow_copy
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from pydantic import ( from pydantic import (
UUID4, UUID4,
@@ -709,22 +709,21 @@ class Crew(BaseModel):
self._task_output_handler.reset() self._task_output_handler.reset()
return results return results
def kickoff_for_each_parallel(self, inputs: List[Dict[str, Any]], max_workers: Optional[int] = None) -> List[CrewOutput]: def kickoff_for_each_parallel(self, inputs: Sequence[Dict[str, Any]], max_workers: Optional[int] = None) -> List[CrewOutput]:
"""Executes the Crew's workflow for each input in the list in parallel using ThreadPoolExecutor. """Executes the Crew's workflow for each input in the list in parallel using ThreadPoolExecutor.
Args: Args:
inputs: List of input dictionaries to be passed to each crew execution. inputs: Sequence of input dictionaries to be passed to each crew execution.
max_workers: Maximum number of worker threads to use. If None, uses the default max_workers: Maximum number of worker threads to use. If None, uses the default
ThreadPoolExecutor behavior (typically min(32, os.cpu_count() + 4)). ThreadPoolExecutor behavior (typically min(32, os.cpu_count() + 4)).
Returns: Returns:
List of CrewOutput objects, one for each input. List of CrewOutput objects, one for each input.
""" """
import concurrent.futures from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
if not isinstance(inputs, list): if not isinstance(inputs, (list, tuple)):
raise TypeError("Inputs must be a list of dictionaries.") raise TypeError(f"Inputs must be a list of dictionaries. Received {type(inputs).__name__} instead.")
if not inputs: if not inputs:
return [] return []
@@ -738,6 +737,7 @@ class Crew(BaseModel):
crew_copies = [self.copy() for _ in inputs] crew_copies = [self.copy() for _ in inputs]
# Execute each crew in parallel # Execute each crew in parallel
try:
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks to the executor # Submit all tasks to the executor
future_to_crew = { future_to_crew = {
@@ -746,7 +746,7 @@ class Crew(BaseModel):
} }
# Process results as they complete # Process results as they complete
for future in concurrent.futures.as_completed(future_to_crew): for future in as_completed(future_to_crew):
crew_index = future_to_crew[future] crew_index = future_to_crew[future]
try: try:
output = future.result() output = future.result()
@@ -758,6 +758,9 @@ class Crew(BaseModel):
except Exception as exc: except Exception as exc:
# Re-raise the exception to maintain consistent behavior with kickoff_for_each # Re-raise the exception to maintain consistent behavior with kickoff_for_each
raise exc raise exc
finally:
# Clean up to assist garbage collection
crew_copies.clear()
# Set the aggregated metrics on the parent crew # Set the aggregated metrics on the parent crew
self.usage_metrics = total_usage_metrics self.usage_metrics = total_usage_metrics