From 04de7730fa202e4094f69b2bef1884023a33e89c Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Mon, 29 Jul 2024 12:44:10 -0400 Subject: [PATCH] Add doc strings --- src/crewai/pipeline/pipeline.py | 119 +++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/src/crewai/pipeline/pipeline.py b/src/crewai/pipeline/pipeline.py index a83e438de..a0efdbb62 100644 --- a/src/crewai/pipeline/pipeline.py +++ b/src/crewai/pipeline/pipeline.py @@ -83,6 +83,15 @@ class Pipeline(BaseModel): @model_validator(mode="before") @classmethod def validate_stages(cls, values): + """ + Validates the stages to ensure correct nesting and types. + + Args: + values (dict): Dictionary containing the pipeline stages. + + Returns: + dict: Validated stages. + """ stages = values.get("stages", []) def check_nesting_and_type(item, depth=0): @@ -104,7 +113,13 @@ class Pipeline(BaseModel): self, run_inputs: List[Dict[str, Any]] ) -> List[PipelineRunResult]: """ - Process multiple runs in parallel, with each run going through all stages. + Processes multiple runs in parallel, each going through all pipeline stages. + + Args: + run_inputs (List[Dict[str, Any]]): List of inputs for each run. + + Returns: + List[PipelineRunResult]: List of results from each run. """ pipeline_results = [] @@ -123,6 +138,15 @@ class Pipeline(BaseModel): async def process_single_run( self, run_input: Dict[str, Any] ) -> List[PipelineRunResult]: + """ + Processes a single run through all pipeline stages. + + Args: + run_input (Dict[str, Any]): The input for the run. + + Returns: + List[PipelineRunResult]: The results of processing the run. + """ initial_input = copy.deepcopy(run_input) current_input = copy.deepcopy(run_input) usage_metrics = {} @@ -146,6 +170,16 @@ class Pipeline(BaseModel): async def _process_stage( self, stage: Union[Crew, List[Crew]], current_input: Dict[str, Any] ) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: + """ + Processes a single stage of the pipeline, which can be either sequential or parallel. + + Args: + stage (Union[Crew, List[Crew]]): The stage to process. + current_input (Dict[str, Any]): The input for the stage. + + Returns: + Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: The outputs and trace of the stage. + """ if isinstance(stage, Crew): return await self._process_single_crew(stage, current_input) else: @@ -154,12 +188,32 @@ class Pipeline(BaseModel): async def _process_single_crew( self, crew: Crew, current_input: Dict[str, Any] ) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: + """ + Processes a single crew. + + Args: + crew (Crew): The crew to process. + current_input (Dict[str, Any]): The input for the crew. + + Returns: + Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: The output and trace of the crew. + """ output = await crew.kickoff_async(inputs=current_input) return [output], [crew.name or str(crew.id)] async def _process_parallel_crews( self, crews: List[Crew], current_input: Dict[str, Any] ) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: + """ + Processes multiple crews in parallel. + + Args: + crews (List[Crew]): The list of crews to process in parallel. + current_input (Dict[str, Any]): The input for the crews. + + Returns: + Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: The outputs and traces of the crews. + """ parallel_outputs = await asyncio.gather( *[crew.kickoff_async(inputs=current_input) for crew in crews] ) @@ -172,6 +226,15 @@ class Pipeline(BaseModel): stage: Union[Crew, List[Crew]], outputs: List[CrewOutput], ) -> None: + """ + Updates metrics and current input with the outputs of a stage. + + Args: + usage_metrics (Dict[str, Any]): The usage metrics to update. + current_input (Dict[str, Any]): The current input to update. + stage (Union[Crew, List[Crew]]): The stage that was processed. + outputs (List[CrewOutput]): The outputs of the stage. + """ for crew, output in zip([stage] if isinstance(stage, Crew) else stage, outputs): usage_metrics[crew.name or str(crew.id)] = output.token_usage current_input.update(output.to_dict()) @@ -182,6 +245,17 @@ class Pipeline(BaseModel): traces: List[List[Union[str, Dict[str, Any]]]], token_usage: Dict[str, Any], ) -> List[PipelineRunResult]: + """ + Builds the results of a pipeline run. + + Args: + all_stage_outputs (List[List[CrewOutput]]): All stage outputs. + traces (List[List[Union[str, Dict[str, Any]]]]): All traces. + token_usage (Dict[str, Any]): Token usage metrics. + + Returns: + List[PipelineRunResult]: The results of the pipeline run. + """ formatted_traces = self._format_traces(traces) formatted_crew_outputs = self._format_crew_outputs(all_stage_outputs) @@ -202,12 +276,30 @@ class Pipeline(BaseModel): def _format_traces( self, traces: List[List[Union[str, Dict[str, Any]]]] ) -> List[List[Trace]]: + """ + Formats the traces of a pipeline run. + + Args: + traces (List[List[Union[str, Dict[str, Any]]]]): The traces to format. + + Returns: + List[List[Trace]]: The formatted traces. + """ formatted_traces: List[Trace] = self._format_single_trace(traces[:-1]) return self._format_multiple_traces(formatted_traces, traces[-1]) def _format_single_trace( self, traces: List[List[Union[str, Dict[str, Any]]]] ) -> List[Trace]: + """ + Formats single traces. + + Args: + traces (List[List[Union[str, Dict[str, Any]]]]): The traces to format. + + Returns: + List[Trace]: The formatted single traces. + """ formatted_traces: List[Trace] = [] for trace in traces: formatted_traces.append(trace[0] if len(trace) == 1 else trace) @@ -218,6 +310,16 @@ class Pipeline(BaseModel): formatted_traces: List[Trace], final_trace: List[Union[str, Dict[str, Any]]], ) -> List[List[Trace]]: + """ + Formats multiple traces. + + Args: + formatted_traces (List[Trace]): The formatted single traces. + final_trace (List[Union[str, Dict[str, Any]]]): The final trace to format. + + Returns: + List[List[Trace]]: The formatted multiple traces. + """ traces_to_return: List[List[Trace]] = [] if len(final_trace) == 1: formatted_traces.append(final_trace[0]) @@ -232,6 +334,15 @@ class Pipeline(BaseModel): def _format_crew_outputs( self, all_stage_outputs: List[List[CrewOutput]] ) -> List[List[CrewOutput]]: + """ + Formats the outputs of all stages into a list of crew outputs. + + Args: + all_stage_outputs (List[List[CrewOutput]]): All stage outputs. + + Returns: + List[List[CrewOutput]]: Formatted crew outputs. + """ crew_outputs: List[CrewOutput] = [ output for stage_outputs in all_stage_outputs[:-1] @@ -242,6 +353,12 @@ class Pipeline(BaseModel): def __rshift__(self, other: Any) -> "Pipeline": """ Implements the >> operator to add another Stage (Crew or List[Crew]) to an existing Pipeline. + + Args: + other (Any): The stage to add. + + Returns: + Pipeline: A new pipeline with the added stage. """ if isinstance(other, Crew): return type(self)(stages=self.stages + [other])