Compare commits

..

8 Commits

Author SHA1 Message Date
João Moura
4aaa37d7af Merge branch 'main' into devin/1735591359-circular-import-fix 2024-12-30 22:15:32 -03:00
Devin AI
4604ac618d fix: rename edge_props to router_edge_props to fix type checker error
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 21:18:29 +00:00
Devin AI
b0d1d86c26 fix: type checker errors and docstrings in flow modules
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 21:15:51 +00:00
Devin AI
454ab55a26 fix: type annotations and import sorting in flow modules
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 20:59:06 +00:00
Devin AI
613dd175ee fix: type checker errors in flow modules
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 20:53:59 +00:00
Devin AI
1dc8ce2674 style: fix import sorting in flow modules
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 20:47:54 +00:00
Devin AI
e0590e516f fix: break circular import by refactoring flow/visualizer/utils import structure
- Split utils.py into specialized modules (core_flow_utils.py, flow_visual_utils.py)
- Add path_utils.py for secure file path handling
- Update imports to prevent circular dependencies
- Use TYPE_CHECKING for type hints
- Fix import sorting issues

Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 20:42:39 +00:00
Marco Vinciguerra
8c6883e5ee feat: add docstring 2024-12-30 16:44:11 +01:00
12 changed files with 1162 additions and 1296 deletions

View File

@@ -0,0 +1,141 @@
"""Core utility functions for Flow class operations.
This module contains utility functions that are specifically designed to work
with the Flow class and require direct access to Flow class internals. These
utilities are separated from general-purpose utilities to maintain a clean
dependency structure and avoid circular imports.
Functions in this module are core to Flow functionality and are not related
to visualization or other optional features.
"""
import ast
import inspect
import textwrap
from typing import Any, Callable, Dict, List, Optional, Set, Union
from pydantic import BaseModel
def get_possible_return_constants(function: Callable[..., Any]) -> Optional[List[str]]:
"""Extract possible string return values from a function by analyzing its source code.
Analyzes the function's source code using AST to identify string constants that
could be returned, including strings stored in dictionaries and direct returns.
Args:
function: The function to analyze for possible return values
Returns:
list[str] | None: List of possible string return values, or None if:
- Source code cannot be retrieved
- Source code has syntax/indentation errors
- No string return values are found
Raises:
OSError: If source code cannot be retrieved
IndentationError: If source code has invalid indentation
SyntaxError: If source code has syntax errors
Example:
>>> def get_status():
... paths = {"success": "completed", "error": "failed"}
... return paths["success"]
>>> get_possible_return_constants(get_status)
['completed', 'failed']
"""
try:
source = inspect.getsource(function)
except OSError:
# Can't get source code
return None
except Exception as e:
print(f"Error retrieving source code for function {function.__name__}: {e}")
return None
try:
# Remove leading indentation
source = textwrap.dedent(source)
# Parse the source code into an AST
code_ast = ast.parse(source)
except IndentationError as e:
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except SyntaxError as e:
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except Exception as e:
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
return_values = set()
dict_definitions = {}
class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node):
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
dict_values = []
# Extract string values from the dictionary
for val in node.value.values:
if isinstance(val, ast.Constant) and isinstance(val.value, str):
dict_values.append(val.value)
# If non-string, skip or just ignore
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
# Direct string return
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
elif isinstance(node.value, ast.Subscript):
# Check if we're subscripting a known dictionary variable
if isinstance(node.value.value, ast.Name):
var_name = node.value.value.id
if var_name in dict_definitions:
# Add all possible dictionary values
for v in dict_definitions[var_name]:
return_values.add(v)
self.generic_visit(node)
# First pass: identify dictionary assignments
DictionaryAssignmentVisitor().visit(code_ast)
# Second pass: identify returns
ReturnVisitor().visit(code_ast)
return list(return_values) if return_values else None
def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool:
"""Check if one node is an ancestor of another in the flow graph.
Args:
node: Target node to check ancestors for
ancestor_candidate: Node to check if it's an ancestor
ancestors: Dictionary mapping nodes to their ancestor sets
Returns:
bool: True if ancestor_candidate is an ancestor of node
Raises:
TypeError: If any argument has an invalid type
"""
if not isinstance(node, str):
raise TypeError("Argument 'node' must be a string")
if not isinstance(ancestor_candidate, str):
raise TypeError("Argument 'ancestor_candidate' must be a string")
if not isinstance(ancestors, dict):
raise TypeError("Argument 'ancestors' must be a dictionary")
return ancestor_candidate in ancestors.get(node, set())

View File

@@ -17,59 +17,42 @@ from typing import (
from blinker import Signal
from pydantic import BaseModel, ValidationError
from crewai.flow.core_flow_utils import get_possible_return_constants
from crewai.flow.flow_events import (
FlowFinishedEvent,
FlowStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.utils import get_possible_return_constants
from crewai.telemetry import Telemetry
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
"""
Marks a method as a flow's starting point.
This decorator designates a method as an entry point for the flow execution.
It can optionally specify conditions that trigger the start based on other
method executions.
Parameters
----------
condition : Optional[Union[str, dict, Callable]], optional
Defines when the start method should execute. Can be:
- str: Name of a method that triggers this start
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- Callable: A method reference that triggers this start
Default is None, meaning unconditional start.
Returns
-------
Callable
A decorator function that marks the method as a flow start point.
Raises
------
ValueError
If the condition format is invalid.
Examples
--------
>>> @start() # Unconditional start
>>> def begin_flow(self):
... pass
>>> @start("method_name") # Start after specific method
>>> def conditional_start(self):
... pass
>>> @start(and_("method1", "method2")) # Start after multiple methods
>>> def complex_start(self):
... pass
"""Marks a method as a flow starting point, optionally triggered by other methods.
Args:
condition: The condition that triggers this method. Can be:
- str: Name of the triggering method
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
- Callable: A function reference
- None: No trigger condition (default)
Returns:
Callable: The decorated function that will serve as a flow starting point.
Raises:
ValueError: If the condition format is invalid.
Example:
>>> @start() # No condition
>>> def begin_flow():
>>> pass
>>>
>>> @start("method_name") # Triggered by specific method
>>> def conditional_start():
>>> pass
"""
def decorator(func):
func.__is_start_method__ = True
@@ -95,41 +78,30 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
return decorator
def listen(condition: Union[str, dict, Callable]) -> Callable:
"""
Creates a listener that executes when specified conditions are met.
This decorator sets up a method to execute in response to other method
executions in the flow. It supports both simple and complex triggering
conditions.
Parameters
----------
condition : Union[str, dict, Callable]
Specifies when the listener should execute. Can be:
- str: Name of a method that triggers this listener
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- Callable: A method reference that triggers this listener
Returns
-------
Callable
A decorator function that sets up the method as a listener.
Raises
------
ValueError
If the condition format is invalid.
Examples
--------
>>> @listen("process_data") # Listen to single method
>>> def handle_processed_data(self):
... pass
>>> @listen(or_("success", "failure")) # Listen to multiple methods
>>> def handle_completion(self):
... pass
"""Marks a method to execute when specified conditions/methods complete.
Args:
condition: The condition that triggers this method. Can be:
- str: Name of the triggering method
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
- Callable: A function reference
Returns:
Callable: The decorated function that will execute when conditions are met.
Raises:
ValueError: If the condition format is invalid.
Example:
>>> @listen("start_method") # Listen to single method
>>> def on_start():
>>> pass
>>>
>>> @listen(and_("method1", "method2")) # Listen with AND condition
>>> def on_both_complete():
>>> pass
"""
def decorator(func):
if isinstance(condition, str):
@@ -155,45 +127,29 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
def router(condition: Union[str, dict, Callable]) -> Callable:
"""
Creates a routing method that directs flow execution based on conditions.
This decorator marks a method as a router, which can dynamically determine
the next steps in the flow based on its return value. Routers are triggered
by specified conditions and can return constants that determine which path
the flow should take.
Parameters
----------
condition : Union[str, dict, Callable]
Specifies when the router should execute. Can be:
- str: Name of a method that triggers this router
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- Callable: A method reference that triggers this router
Returns
-------
Callable
A decorator function that sets up the method as a router.
Raises
------
ValueError
If the condition format is invalid.
Examples
--------
>>> @router("check_status")
>>> def route_based_on_status(self):
... if self.state.status == "success":
... return SUCCESS
... return FAILURE
>>> @router(and_("validate", "process"))
>>> def complex_routing(self):
... if all([self.state.valid, self.state.processed]):
... return CONTINUE
... return STOP
"""Marks a method as a router to direct flow based on its return value.
A router method can return different string values that trigger different
subsequent methods, allowing for dynamic flow control.
Args:
condition: The condition that triggers this router. Can be:
- str: Name of the triggering method
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
- Callable: A function reference
Returns:
Callable: The decorated function that will serve as a router.
Raises:
ValueError: If the condition format is invalid.
Example:
>>> @router("process_data")
>>> def route_result(result):
>>> if result.success:
>>> return "handle_success"
>>> return "handle_error"
"""
def decorator(func):
func.__is_router__ = True
@@ -218,38 +174,27 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
return decorator
def or_(*conditions: Union[str, dict, Callable]) -> dict:
"""
Combines multiple conditions with OR logic for flow control.
Creates a condition that is satisfied when any of the specified conditions
are met. This is used with @start, @listen, or @router decorators to create
complex triggering conditions.
Parameters
----------
*conditions : Union[str, dict, Callable]
Variable number of conditions that can be:
- str: Method names
- dict: Existing condition dictionaries
- Callable: Method references
Returns
-------
dict
A condition dictionary with format:
{"type": "OR", "methods": list_of_method_names}
Raises
------
ValueError
If any condition is invalid.
Examples
--------
>>> @listen(or_("success", "timeout"))
>>> def handle_completion(self):
... pass
"""Combines multiple conditions with OR logic for flow control.
Args:
*conditions: Variable number of conditions. Each can be:
- str: Name of a method
- dict: Dictionary with 'type' and 'methods' keys
- Callable: A function reference
Returns:
dict: A dictionary with 'type': 'OR' and 'methods' list.
Raises:
ValueError: If any condition is invalid.
Example:
>>> @listen(or_("method1", "method2"))
>>> def on_either():
>>> # Executes when either method1 OR method2 completes
>>> pass
"""
methods = []
for condition in conditions:
@@ -265,37 +210,25 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
def and_(*conditions: Union[str, dict, Callable]) -> dict:
"""
Combines multiple conditions with AND logic for flow control.
Creates a condition that is satisfied only when all specified conditions
are met. This is used with @start, @listen, or @router decorators to create
complex triggering conditions.
Parameters
----------
*conditions : Union[str, dict, Callable]
Variable number of conditions that can be:
- str: Method names
- dict: Existing condition dictionaries
- Callable: Method references
Returns
-------
dict
A condition dictionary with format:
{"type": "AND", "methods": list_of_method_names}
Raises
------
ValueError
If any condition is invalid.
Examples
--------
>>> @listen(and_("validated", "processed"))
>>> def handle_complete_data(self):
... pass
"""Combines multiple conditions with AND logic for flow control.
Args:
*conditions: Variable number of conditions. Each can be:
- str: Name of a method
- dict: Dictionary with 'type' and 'methods' keys
- Callable: A function reference
Returns:
dict: A dictionary with 'type': 'AND' and 'methods' list.
Raises:
ValueError: If any condition is invalid.
Example:
>>> @listen(and_("method1", "method2"))
>>> def on_both():
>>> # Executes when BOTH method1 AND method2 complete
>>> pass
"""
methods = []
for condition in conditions:
@@ -355,6 +288,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
event_emitter = Signal("event_emitter")
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
"""Create a generic version of Flow with specified state type.
Args:
cls: The Flow class
item: The type parameter for the flow's state
Returns:
Type["Flow"]: A new Flow class with the specified state type
Example:
>>> class MyState(BaseModel):
>>> value: int
>>>
>>> class MyFlow(Flow[MyState]):
>>> pass
"""
class _FlowGeneric(cls): # type: ignore
_initial_state_T = item # type: ignore
@@ -362,11 +311,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
return _FlowGeneric
def __init__(self) -> None:
"""Initialize a new Flow instance.
Sets up internal state tracking, method registration, and telemetry.
The flow's methods are automatically discovered and registered during initialization.
Attributes initialized:
_methods: Dictionary mapping method names to their callable objects
_state: The flow's state object of type T
_method_execution_counts: Tracks how many times each method has executed
_pending_and_listeners: Tracks methods waiting for AND conditions
_method_outputs: List of all outputs from executed methods
"""
self._methods: Dict[str, Callable] = {}
self._state: T = self._create_initial_state()
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
self._method_outputs: List[Any] = []
self._telemetry.flow_creation_span(self.__class__.__name__)
@@ -377,6 +338,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._methods[method_name] = getattr(self, method_name)
def _create_initial_state(self) -> T:
"""Create the initial state for the flow.
The state is created based on the following priority:
1. If initial_state is None and _initial_state_T exists (generic type), use that
2. If initial_state is None, return empty dict
3. If initial_state is a type, instantiate it
4. Otherwise, use initial_state as-is
Returns:
T: The initial state object of type T
Note:
The type T can be either a Pydantic BaseModel or a dictionary.
"""
if self.initial_state is None and hasattr(self, "_initial_state_T"):
return self._initial_state_T() # type: ignore
if self.initial_state is None:
@@ -388,11 +363,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
@property
def state(self) -> T:
"""Get the current state of the flow.
Returns:
T: The current state object, either a Pydantic model or dictionary
"""
return self._state
@property
def method_outputs(self) -> List[Any]:
"""Returns the list of all outputs from executed methods."""
"""Get the list of all outputs from executed methods.
Returns:
List[Any]: A list containing the output values from all executed flow methods,
in order of execution.
"""
return self._method_outputs
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
@@ -462,23 +447,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
return final_output
async def _execute_start_method(self, start_method_name: str) -> None:
"""
Executes a flow's start method and its triggered listeners.
This internal method handles the execution of methods marked with @start
decorator and manages the subsequent chain of listener executions.
Parameters
----------
start_method_name : str
The name of the start method to execute.
Notes
-----
- Executes the start method and captures its result
- Triggers execution of any listeners waiting on this start method
- Part of the flow's initialization sequence
"""
result = await self._execute_method(
start_method_name, self._methods[start_method_name]
)
@@ -499,27 +467,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
return result
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
"""
Executes all listeners and routers triggered by a method completion.
This internal method manages the execution flow by:
1. First executing all triggered routers sequentially
2. Then executing all triggered listeners in parallel
Parameters
----------
trigger_method : str
The name of the method that triggered these listeners.
result : Any
The result from the triggering method, passed to listeners
that accept parameters.
Notes
-----
- Routers are executed sequentially to maintain flow control
- Each router's result becomes the new trigger_method
- Normal listeners are executed in parallel for efficiency
- Listeners can receive the trigger method's result as a parameter
"""Execute all listener methods triggered by a completed method.
This method handles both router and non-router listeners in a specific order:
1. First executes all triggered router methods sequentially until no more routers
are triggered
2. Then executes all regular listeners in parallel
Args:
trigger_method: The name of the method that completed execution
result: The result value from the triggering method
Note:
Router methods are executed sequentially to ensure proper flow control,
while regular listeners are executed concurrently for better performance.
This provides fine-grained control over the execution flow while
maintaining efficiency.
"""
# First, handle routers repeatedly until no router triggers anymore
while True:
@@ -550,32 +513,26 @@ class Flow(Generic[T], metaclass=FlowMeta):
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
"""
Finds all methods that should be triggered based on conditions.
This internal method evaluates both OR and AND conditions to determine
which methods should be executed next in the flow.
Parameters
----------
trigger_method : str
The name of the method that just completed execution.
router_only : bool
If True, only consider router methods.
If False, only consider non-router methods.
Returns
-------
List[str]
Names of methods that should be triggered.
Notes
-----
- Handles both OR and AND conditions:
* OR: Triggers if any condition is met
* AND: Triggers only when all conditions are met
- Maintains state for AND conditions using _pending_and_listeners
- Separates router and normal listener evaluation
"""Find all methods that should be triggered based on completed method and type.
Provides precise control over method triggering by handling both OR and AND
conditions separately for router and non-router methods.
Args:
trigger_method: The name of the method that completed execution
router_only: If True, only find router methods; if False, only regular
listeners
Returns:
List[str]: Names of methods that should be executed next
Note:
This method implements sophisticated flow control by:
1. Filtering methods based on their router/non-router status
2. Handling OR conditions for immediate triggering
3. Managing AND conditions with state tracking for complex dependencies
This ensures predictable and consistent execution order in complex flows.
"""
triggered = []
for listener_name, (condition_type, methods) in self._listeners.items():
@@ -605,32 +562,26 @@ class Flow(Generic[T], metaclass=FlowMeta):
return triggered
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
"""
Executes a single listener method with proper event handling.
This internal method manages the execution of an individual listener,
including parameter inspection, event emission, and error handling.
Parameters
----------
listener_name : str
The name of the listener method to execute.
result : Any
The result from the triggering method, which may be passed
to the listener if it accepts parameters.
Notes
-----
- Inspects method signature to determine if it accepts the trigger result
- Emits events for method execution start and finish
- Handles errors gracefully with detailed logging
- Recursively triggers listeners of this listener
- Supports both parameterized and parameter-less listeners
Error Handling
-------------
Catches and logs any exceptions during execution, preventing
individual listener failures from breaking the entire flow.
"""Execute a single listener method with precise parameter handling and error tracking.
Provides fine-grained control over method execution through:
1. Automatic parameter inspection to determine if the method accepts results
2. Event emission for execution tracking
3. Comprehensive error handling
4. Recursive listener execution
Args:
listener_name: The name of the listener method to execute
result: The result from the triggering method, passed to the listener
if its signature accepts parameters
Note:
This method ensures precise execution control by:
- Inspecting method signatures to handle parameters correctly
- Emitting events for execution tracking
- Providing comprehensive error handling
- Supporting both parameterized and parameter-less methods
- Maintaining execution chain through recursive listener calls
"""
try:
method = self._methods[listener_name]
@@ -675,8 +626,32 @@ class Flow(Generic[T], metaclass=FlowMeta):
traceback.print_exc()
def plot(self, filename: str = "crewai_flow") -> None:
def plot(self, *args, **kwargs):
"""Generate an interactive visualization of the flow's execution graph.
Creates a detailed HTML visualization showing the relationships between
methods, including start points, listeners, routers, and their
connections. Includes telemetry tracking for flow analysis.
Args:
*args: Variable length argument list passed to plot_flow
**kwargs: Arbitrary keyword arguments passed to plot_flow
Note:
The visualization provides:
- Clear representation of method relationships
- Visual distinction between different method types
- Interactive exploration capabilities
- Execution path tracing
- Telemetry tracking for flow analysis
Example:
>>> flow = MyFlow()
>>> flow.plot("my_workflow") # Creates my_workflow.html
"""
from crewai.flow.flow_visualizer import plot_flow
self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys())
)
plot_flow(self, filename)
return plot_flow(self, *args, **kwargs)

View File

@@ -0,0 +1,240 @@
"""Utility functions for Flow visualization.
This module contains utility functions specifically designed for visualizing
Flow graphs and calculating layout information. These utilities are separated
from general-purpose utilities to maintain a clean dependency structure.
"""
from typing import Any, Dict, List, Set, TYPE_CHECKING
if TYPE_CHECKING:
from crewai.flow.flow import Flow
def calculate_node_levels(flow: Flow[Any]) -> Dict[str, int]:
"""Calculate the hierarchical level of each node in the flow graph.
Uses breadth-first traversal to assign levels to nodes, starting with
start methods at level 0. Handles both OR and AND conditions for listeners,
and considers router paths when calculating levels.
Args:
flow: Flow instance containing methods, listeners, and router configurations
Returns:
dict[str, int]: Dictionary mapping method names to their hierarchical levels,
where level 0 contains start methods and each subsequent level contains
methods triggered by the previous level
Example:
>>> flow = Flow()
>>> @flow.start
... def start(): pass
>>> @flow.on("start")
... def second(): pass
>>> calculate_node_levels(flow)
{'start': 0, 'second': 1}
"""
levels: Dict[str, int] = {}
queue: List[str] = []
visited: Set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
if hasattr(method, "__is_start_method__"):
levels[method_name] = 0
queue.append(method_name)
# Breadth-first traversal to assign levels
while queue:
current = queue.pop(0)
current_level = levels[current]
visited.add(current)
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
if condition_type == "OR":
if current in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
elif condition_type == "AND":
if listener_name not in pending_and_listeners:
pending_and_listeners[listener_name] = set()
if current in trigger_methods:
pending_and_listeners[listener_name].add(current)
if set(trigger_methods) == pending_and_listeners[listener_name]:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
# Handle router connections
if current in flow._routers:
router_method_name = current
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels
def count_outgoing_edges(flow: Flow[Any]) -> Dict[str, int]:
"""Count the number of outgoing edges for each node in the flow graph.
An outgoing edge represents a connection from a method to a listener
that it triggers. This is useful for visualization and analysis of
flow structure.
Args:
flow: Flow instance containing methods and their connections
Returns:
dict[str, int]: Dictionary mapping method names to their number
of outgoing connections
"""
counts: Dict[str, int] = {}
for method_name in flow._methods:
counts[method_name] = 0
for method_name in flow._listeners:
_, trigger_methods = flow._listeners[method_name]
for trigger in trigger_methods:
if trigger in flow._methods:
counts[trigger] += 1
return counts
def build_ancestor_dict(flow: Flow[Any]) -> Dict[str, Set[str]]:
"""Build a dictionary mapping each node to its set of ancestor nodes.
Uses depth-first search to identify all ancestors (direct and indirect
trigger methods) for each node in the flow graph. Handles both regular
listeners and router paths.
Args:
flow: Flow instance containing methods and their relationships
Returns:
dict[str, set[str]]: Dictionary mapping each method name to a set
of its ancestor method names
"""
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
visited: Set[str] = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
return ancestors
def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
visited: Set[str], flow: Flow[Any]) -> None:
"""Perform depth-first search to populate the ancestors dictionary.
Helper function for build_ancestor_dict that recursively traverses
the flow graph to identify ancestors of each node.
Args:
node: Current node being processed
ancestors: Dictionary mapping nodes to their ancestor sets
visited: Set of already visited nodes
flow: Flow instance containing the graph structure
"""
if node in visited:
return
visited.add(node)
# Handle regular listeners
for listener_name, (_, trigger_methods) in flow._listeners.items():
if node in trigger_methods:
ancestors[listener_name].add(node)
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
# Handle router methods separately
if node in flow._routers:
router_method_name = node
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
# Only propagate the ancestors of the router method, not the router method itself
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
def build_parent_children_dict(flow: Flow[Any]) -> Dict[str, List[str]]:
"""Build a dictionary mapping each node to its list of child nodes.
Maps both regular trigger methods to their listeners and router
methods to their path listeners. Useful for visualization and
traversal of the flow graph structure.
Args:
flow: Flow instance containing methods and their relationships
Returns:
dict[str, list[str]]: Dictionary mapping each method name to a
sorted list of its child method names
"""
parent_children: Dict[str, List[str]] = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
for trigger in trigger_methods:
if trigger not in parent_children:
parent_children[trigger] = []
if listener_name not in parent_children[trigger]:
parent_children[trigger].append(listener_name)
# Map router methods to their paths and to listeners
for router_method_name, paths in flow._router_paths.items():
for path in paths:
# Map router method to listeners of each path
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
if router_method_name not in parent_children:
parent_children[router_method_name] = []
if listener_name not in parent_children[router_method_name]:
parent_children[router_method_name].append(listener_name)
return parent_children
def get_child_index(parent: str, child: str,
parent_children: Dict[str, List[str]]) -> int:
"""Get the index of a child node in its parent's sorted children list.
Args:
parent: Parent node name
child: Child node name to find index for
parent_children: Dictionary mapping parents to their children lists
Returns:
int: Zero-based index of the child in parent's sorted children list
Raises:
ValueError: If child is not found in parent's children list
"""
children = parent_children.get(parent, [])
children.sort()
return children.index(child)

View File

@@ -6,10 +6,10 @@ from pathlib import Path
from pyvis.network import Network
from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.flow_visual_utils import calculate_node_levels
from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
from crewai.flow.path_utils import safe_path_join, validate_path_exists
from crewai.flow.utils import calculate_node_levels
from crewai.flow.path_utils import safe_path_join, validate_file_path
from crewai.flow.visualization_utils import (
add_edges,
add_nodes_to_network,
@@ -21,206 +21,102 @@ class FlowPlot:
"""Handles the creation and rendering of flow visualization diagrams."""
def __init__(self, flow):
"""
Initialize FlowPlot with a flow object.
Parameters
----------
flow : Flow
A Flow instance to visualize.
Raises
------
ValueError
If flow object is invalid or missing required attributes.
"""Initialize flow plot with flow instance and styling configuration.
Args:
flow: A Flow instance with required attributes for visualization
Raises:
ValueError: If flow object is invalid or missing required attributes
"""
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not hasattr(flow, '_listeners'):
raise ValueError("Invalid flow object: missing '_listeners' attribute")
raise ValueError("Invalid flow object: Missing '_methods' attribute")
if not hasattr(flow, '_start_methods'):
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
raise ValueError("Invalid flow object: Missing '_start_methods' attribute")
if not hasattr(flow, '_listeners'):
raise ValueError("Invalid flow object: Missing '_listeners' attribute")
self.flow = flow
self.colors = COLORS
self.node_styles = NODE_STYLES
def plot(self, filename):
"""
Generate and save an HTML visualization of the flow.
"""Generate and save interactive flow visualization to HTML file."""
net = Network(
directed=True,
height="750px",
width="100%",
bgcolor=self.colors["bg"],
layout=None,
)
Parameters
----------
filename : str
Name of the output file (without extension).
Raises
------
ValueError
If filename is invalid or network generation fails.
IOError
If file operations fail or visualization cannot be generated.
RuntimeError
If network visualization generation fails.
"""
if not filename or not isinstance(filename, str):
raise ValueError("Filename must be a non-empty string")
try:
# Initialize network
net = Network(
directed=True,
height="750px",
width="100%",
bgcolor=self.colors["bg"],
layout=None,
)
# Set options to disable physics
net.set_options(
"""
var options = {
"nodes": {
"font": {
"multi": "html"
}
},
"physics": {
"enabled": false
}
}
# Set options to disable physics
net.set_options(
"""
)
var options = {
"nodes": {
"font": {
"multi": "html"
}
},
"physics": {
"enabled": false
}
}
"""
)
# Calculate levels for nodes
try:
node_levels = calculate_node_levels(self.flow)
except Exception as e:
raise ValueError(f"Failed to calculate node levels: {str(e)}")
node_levels = calculate_node_levels(self.flow)
node_positions = compute_positions(self.flow, node_levels)
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
add_edges(net, self.flow, node_positions, self.colors)
# Compute positions
try:
node_positions = compute_positions(self.flow, node_levels)
except Exception as e:
raise ValueError(f"Failed to compute node positions: {str(e)}")
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
# Add nodes to the network
try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
except Exception as e:
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
try:
# Ensure the output path is safe
output_dir = os.getcwd()
output_path = safe_path_join(output_dir, f"{filename}.html")
with open(output_path, "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Plot saved as {output_path}")
except (IOError, ValueError) as e:
raise IOError(f"Failed to save flow visualization: {str(e)}")
# Add edges to the network
try:
add_edges(net, self.flow, node_positions, self.colors)
except Exception as e:
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
# Generate HTML
try:
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
except Exception as e:
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
# Save the final HTML content to the file
try:
with open(f"{filename}.html", "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Plot saved as {filename}.html")
except IOError as e:
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
except (ValueError, RuntimeError, IOError) as e:
raise e
except Exception as e:
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
finally:
self._cleanup_pyvis_lib()
self._cleanup_pyvis_lib()
def _generate_final_html(self, network_html):
"""
Generate the final HTML content with network visualization and legend.
"""Generate final HTML content with network visualization and legend."""
current_dir = os.path.dirname(__file__)
template_path = os.path.join(
current_dir, "assets", "crewai_flow_visual_template.html"
)
logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg")
Parameters
----------
network_html : str
HTML content generated by pyvis Network.
html_handler = HTMLTemplateHandler(template_path, logo_path)
network_body = html_handler.extract_body_content(network_html)
Returns
-------
str
Complete HTML content with styling and legend.
Raises
------
IOError
If template or logo files cannot be accessed.
ValueError
If network_html is invalid.
"""
if not network_html:
raise ValueError("Invalid network HTML content")
try:
# Extract just the body content from the generated HTML
current_dir = os.path.dirname(__file__)
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
if not os.path.exists(template_path):
raise IOError(f"Template file not found: {template_path}")
if not os.path.exists(logo_path):
raise IOError(f"Logo file not found: {logo_path}")
html_handler = HTMLTemplateHandler(template_path, logo_path)
network_body = html_handler.extract_body_content(network_html)
# Generate the legend items HTML
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html(
network_body, legend_items_html
)
return final_html_content
except Exception as e:
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html(
network_body, legend_items_html
)
return final_html_content
def _cleanup_pyvis_lib(self):
"""
Clean up the generated lib folder from pyvis.
This method safely removes the temporary lib directory created by pyvis
during network visualization generation.
"""
"""Clean up temporary files generated by pyvis library."""
lib_folder = os.path.join(os.getcwd(), "lib")
try:
lib_folder = safe_path_join("lib", root=os.getcwd())
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
shutil.rmtree(lib_folder)
except ValueError as e:
print(f"Error validating lib folder path: {e}")
except Exception as e:
print(f"Error cleaning up lib folder: {e}")
print(f"Error cleaning up {lib_folder}: {e}")
def plot_flow(flow, filename="flow_plot"):
"""
Convenience function to create and save a flow visualization.
Parameters
----------
flow : Flow
Flow instance to visualize.
filename : str, optional
Output filename without extension, by default "flow_plot".
Raises
------
ValueError
If flow object or filename is invalid.
IOError
If file operations fail.
"""
"""Create and save a visualization of the given flow."""
visualizer = FlowPlot(flow)
visualizer.plot(filename)

View File

@@ -1,55 +1,107 @@
import base64
import os
import re
from pathlib import Path
from crewai.flow.path_utils import safe_path_join, validate_path_exists
from crewai.flow.path_utils import safe_path_join, validate_file_path
class HTMLTemplateHandler:
"""Handles HTML template processing and generation for flow visualization diagrams."""
def __init__(self, template_path, logo_path):
"""
Initialize HTMLTemplateHandler with validated template and logo paths.
Parameters
----------
template_path : str
Path to the HTML template file.
logo_path : str
Path to the logo image file.
Raises
------
ValueError
If template or logo paths are invalid or files don't exist.
"""Initialize template handler with template and logo file paths.
Args:
template_path: Path to the HTML template file
logo_path: Path to the logo SVG file
Raises:
ValueError: If template_path or logo_path is invalid or files don't exist
"""
try:
self.template_path = validate_path_exists(template_path, "file")
self.logo_path = validate_path_exists(logo_path, "file")
except ValueError as e:
raise ValueError(f"Invalid template or logo path: {e}")
self.template_path = validate_file_path(template_path)
self.logo_path = validate_file_path(logo_path)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid file path: {str(e)}")
def read_template(self):
"""Read and return the HTML template file contents."""
with open(self.template_path, "r", encoding="utf-8") as f:
return f.read()
"""Read and return the HTML template file contents.
Returns:
str: The contents of the template file
Raises:
IOError: If template file cannot be read
"""
try:
with open(self.template_path, "r", encoding="utf-8") as f:
return f.read()
except IOError as e:
raise IOError(f"Failed to read template file {self.template_path}: {str(e)}")
def encode_logo(self):
"""Convert the logo SVG file to base64 encoded string."""
with open(self.logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
return base64.b64encode(logo_svg_data).decode("utf-8")
"""Convert the logo SVG file to base64 encoded string.
Returns:
str: Base64 encoded logo data
Raises:
IOError: If logo file cannot be read
ValueError: If logo data cannot be encoded
"""
try:
with open(self.logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
try:
return base64.b64encode(logo_svg_data).decode("utf-8")
except Exception as e:
raise ValueError(f"Failed to encode logo data: {str(e)}")
except IOError as e:
raise IOError(f"Failed to read logo file {self.logo_path}: {str(e)}")
def extract_body_content(self, html):
"""Extract and return content between body tags from HTML string."""
"""Extract and return content between body tags from HTML string.
Args:
html: HTML string to extract body content from
Returns:
str: Content between body tags, or empty string if not found
Raises:
ValueError: If input HTML is invalid
"""
if not html or not isinstance(html, str):
raise ValueError("Input HTML must be a non-empty string")
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
return match.group(1) if match else ""
def generate_legend_items_html(self, legend_items):
"""Generate HTML markup for the legend items."""
"""Generate HTML markup for the legend items.
Args:
legend_items: List of dictionaries containing legend item properties
Returns:
str: Generated HTML markup for legend items
Raises:
ValueError: If legend_items is invalid or missing required properties
"""
if not isinstance(legend_items, list):
raise ValueError("legend_items must be a list")
legend_items_html = ""
for item in legend_items:
if not isinstance(item, dict):
raise ValueError("Each legend item must be a dictionary")
if "color" not in item:
raise ValueError("Each legend item must have a 'color' property")
if "label" not in item:
raise ValueError("Each legend item must have a 'label' property")
if "border" in item:
legend_items_html += f"""
<div class="legend-item">
@@ -75,19 +127,42 @@ class HTMLTemplateHandler:
return legend_items_html
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
"""Combine all components into final HTML document with network visualization."""
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()
"""Combine all components into final HTML document with network visualization.
Args:
network_body: HTML string containing network visualization
legend_items_html: HTML string containing legend items markup
title: Title for the visualization page (default: "Flow Plot")
Returns:
str: Complete HTML document with all components integrated
Raises:
ValueError: If any input parameters are invalid
IOError: If template or logo files cannot be read
"""
if not isinstance(network_body, str):
raise ValueError("network_body must be a string")
if not isinstance(legend_items_html, str):
raise ValueError("legend_items_html must be a string")
if not isinstance(title, str):
raise ValueError("title must be a string")
try:
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()
final_html_content = html_template.replace("{{ title }}", title)
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
final_html_content = html_template.replace("{{ title }}", title)
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
return final_html_content
return final_html_content
except Exception as e:
raise ValueError(f"Failed to generate final HTML: {str(e)}")

View File

@@ -1,135 +1,123 @@
"""
Path utilities for secure file operations in CrewAI flow module.
"""Utilities for safe path handling in flow visualization.
This module provides utilities for secure path handling to prevent directory
traversal attacks and ensure paths remain within allowed boundaries.
This module provides a comprehensive set of utilities for secure path handling,
including path joining, validation, and normalization. It helps prevent common
security issues like directory traversal attacks while providing a consistent
interface for path operations.
"""
import os
from pathlib import Path
from typing import List, Union
from typing import List, Optional, Union
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
def safe_path_join(base_dir: Union[str, Path], filename: str) -> str:
"""Safely join base directory with filename, preventing directory traversal.
Args:
base_dir: Base directory path
filename: Filename or path to join with base_dir
Returns:
str: Safely joined absolute path
Raises:
ValueError: If resulting path would escape base_dir or contains dangerous patterns
TypeError: If inputs are not strings or Path objects
OSError: If path resolution fails
"""
Safely join path components and ensure the result is within allowed boundaries.
Parameters
----------
*parts : str
Variable number of path components to join.
root : Union[str, Path, None], optional
Root directory to use as base. If None, uses current working directory.
Returns
-------
str
String representation of the resolved path.
Raises
------
ValueError
If the resulting path would be outside the root directory
or if any path component is invalid.
"""
if not parts:
raise ValueError("No path components provided")
if not isinstance(base_dir, (str, Path)):
raise TypeError("base_dir must be a string or Path object")
if not isinstance(filename, str):
raise TypeError("filename must be a string")
# Check for dangerous patterns
dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`']
if any(pattern in filename for pattern in dangerous_patterns):
raise ValueError(f"Invalid filename: Contains dangerous pattern")
try:
# Convert all parts to strings and clean them
clean_parts = [str(part).strip() for part in parts if part]
if not clean_parts:
raise ValueError("No valid path components provided")
# Establish root directory
root_path = Path(root).resolve() if root else Path.cwd()
base_path = Path(base_dir).resolve(strict=True)
full_path = Path(base_path, filename).resolve(strict=True)
# Join and resolve the full path
full_path = Path(root_path, *clean_parts).resolve()
# Check if the resolved path is within root
if not str(full_path).startswith(str(root_path)):
if not str(full_path).startswith(str(base_path)):
raise ValueError(
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
f"Invalid path: {filename} would escape base directory {base_dir}"
)
return str(full_path)
except OSError as e:
raise OSError(f"Failed to resolve path: {str(e)}")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid path components: {str(e)}")
raise ValueError(f"Failed to process paths: {str(e)}")
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
"""
Validate that a path exists and is of the expected type.
Parameters
----------
path : Union[str, Path]
Path to validate.
file_type : str, optional
Expected type ('file' or 'directory'), by default 'file'.
Returns
-------
str
Validated path as string.
Raises
------
ValueError
If path doesn't exist or is not of expected type.
def normalize_path(path: Union[str, Path]) -> str:
"""Normalize a path by resolving symlinks and removing redundant separators.
Args:
path: Path to normalize
Returns:
str: Normalized absolute path
Raises:
TypeError: If path is not a string or Path object
OSError: If path resolution fails
"""
if not isinstance(path, (str, Path)):
raise TypeError("path must be a string or Path object")
try:
path_obj = Path(path).resolve()
return str(Path(path).resolve(strict=True))
except OSError as e:
raise OSError(f"Failed to normalize path: {str(e)}")
def validate_path_components(components: List[str]) -> None:
"""Validate path components for potentially dangerous patterns.
Args:
components: List of path components to validate
if not path_obj.exists():
raise ValueError(f"Path does not exist: {path}")
if file_type == "file" and not path_obj.is_file():
raise ValueError(f"Path is not a file: {path}")
elif file_type == "directory" and not path_obj.is_dir():
raise ValueError(f"Path is not a directory: {path}")
return str(path_obj)
Raises:
TypeError: If components is not a list or contains non-string items
ValueError: If any component contains dangerous patterns
"""
if not isinstance(components, list):
raise TypeError("components must be a list")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid path: {str(e)}")
dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`']
for component in components:
if not isinstance(component, str):
raise TypeError(f"Path component '{component}' must be a string")
if any(pattern in component for pattern in dangerous_patterns):
raise ValueError(f"Invalid path component '{component}': Contains dangerous pattern")
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
"""
Safely list files in a directory matching a pattern.
Parameters
----------
directory : Union[str, Path]
Directory to search in.
pattern : str, optional
Glob pattern to match files against, by default "*".
Returns
-------
List[str]
List of matching file paths.
Raises
------
ValueError
If directory is invalid or inaccessible.
def validate_file_path(path: Union[str, Path], must_exist: bool = True) -> str:
"""Validate a file path for security and existence.
Args:
path: File path to validate
must_exist: Whether the file must exist (default: True)
Returns:
str: Validated absolute path
Raises:
ValueError: If path is invalid or file doesn't exist when required
TypeError: If path is not a string or Path object
"""
if not isinstance(path, (str, Path)):
raise TypeError("path must be a string or Path object")
try:
dir_path = Path(directory).resolve()
if not dir_path.is_dir():
raise ValueError(f"Not a directory: {directory}")
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
resolved_path = Path(path).resolve()
if must_exist and not resolved_path.is_file():
raise ValueError(f"File not found: {path}")
return str(resolved_path)
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Error listing files: {str(e)}")
raise ValueError(f"Invalid file path {path}: {str(e)}")

View File

@@ -1,362 +1,35 @@
"""
Utility functions for flow visualization and dependency analysis.
"""General utility functions for flow execution.
This module provides core functionality for analyzing and manipulating flow structures,
including node level calculation, ancestor tracking, and return value analysis.
Functions in this module are primarily used by the visualization system to create
accurate and informative flow diagrams.
This module has been deprecated. All functionality has been moved to:
- core_flow_utils.py: Core flow execution utilities
- flow_visual_utils.py: Visualization-related utilities
Example
-------
>>> flow = Flow()
>>> node_levels = calculate_node_levels(flow)
>>> ancestors = build_ancestor_dict(flow)
This module is kept as a temporary redirect to maintain backwards compatibility.
New code should import from the appropriate new modules directly.
"""
import ast
import inspect
import textwrap
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set
from .core_flow_utils import get_possible_return_constants, is_ancestor
from .flow_visual_utils import (
build_ancestor_dict,
build_parent_children_dict,
calculate_node_levels,
count_outgoing_edges,
dfs_ancestors,
get_child_index,
)
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
try:
source = inspect.getsource(function)
except OSError:
# Can't get source code
return None
except Exception as e:
print(f"Error retrieving source code for function {function.__name__}: {e}")
return None
# Re-export all functions for backwards compatibility
__all__ = [
'get_possible_return_constants',
'calculate_node_levels',
'count_outgoing_edges',
'build_ancestor_dict',
'dfs_ancestors',
'is_ancestor',
'build_parent_children_dict',
'get_child_index',
]
try:
# Remove leading indentation
source = textwrap.dedent(source)
# Parse the source code into an AST
code_ast = ast.parse(source)
except IndentationError as e:
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except SyntaxError as e:
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except Exception as e:
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
return_values = set()
dict_definitions = {}
class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node):
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
dict_values = []
# Extract string values from the dictionary
for val in node.value.values:
if isinstance(val, ast.Constant) and isinstance(val.value, str):
dict_values.append(val.value)
# If non-string, skip or just ignore
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
# Direct string return
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
elif isinstance(node.value, ast.Subscript):
# Check if we're subscripting a known dictionary variable
if isinstance(node.value.value, ast.Name):
var_name = node.value.value.id
if var_name in dict_definitions:
# Add all possible dictionary values
for v in dict_definitions[var_name]:
return_values.add(v)
self.generic_visit(node)
# First pass: identify dictionary assignments
DictionaryAssignmentVisitor().visit(code_ast)
# Second pass: identify returns
ReturnVisitor().visit(code_ast)
return list(return_values) if return_values else None
def calculate_node_levels(flow: Any) -> Dict[str, int]:
"""
Calculate the hierarchical level of each node in the flow.
Performs a breadth-first traversal of the flow graph to assign levels
to nodes, starting with start methods at level 0.
Parameters
----------
flow : Any
The flow instance containing methods, listeners, and router configurations.
Returns
-------
Dict[str, int]
Dictionary mapping method names to their hierarchical levels.
Notes
-----
- Start methods are assigned level 0
- Each subsequent connected node is assigned level = parent_level + 1
- Handles both OR and AND conditions for listeners
- Processes router paths separately
"""
levels: Dict[str, int] = {}
queue: List[str] = []
visited: Set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
if hasattr(method, "__is_start_method__"):
levels[method_name] = 0
queue.append(method_name)
# Breadth-first traversal to assign levels
while queue:
current = queue.pop(0)
current_level = levels[current]
visited.add(current)
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
if condition_type == "OR":
if current in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
elif condition_type == "AND":
if listener_name not in pending_and_listeners:
pending_and_listeners[listener_name] = set()
if current in trigger_methods:
pending_and_listeners[listener_name].add(current)
if set(trigger_methods) == pending_and_listeners[listener_name]:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
# Handle router connections
if current in flow._routers:
router_method_name = current
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
"""
Count the number of outgoing edges for each method in the flow.
Parameters
----------
flow : Any
The flow instance to analyze.
Returns
-------
Dict[str, int]
Dictionary mapping method names to their outgoing edge count.
"""
counts = {}
for method_name in flow._methods:
counts[method_name] = 0
for method_name in flow._listeners:
_, trigger_methods = flow._listeners[method_name]
for trigger in trigger_methods:
if trigger in flow._methods:
counts[trigger] += 1
return counts
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
"""
Build a dictionary mapping each node to its ancestor nodes.
Parameters
----------
flow : Any
The flow instance to analyze.
Returns
-------
Dict[str, Set[str]]
Dictionary mapping each node to a set of its ancestor nodes.
"""
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
visited: Set[str] = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
return ancestors
def dfs_ancestors(
node: str,
ancestors: Dict[str, Set[str]],
visited: Set[str],
flow: Any
) -> None:
"""
Perform depth-first search to build ancestor relationships.
Parameters
----------
node : str
Current node being processed.
ancestors : Dict[str, Set[str]]
Dictionary tracking ancestor relationships.
visited : Set[str]
Set of already visited nodes.
flow : Any
The flow instance being analyzed.
Notes
-----
This function modifies the ancestors dictionary in-place to build
the complete ancestor graph.
"""
if node in visited:
return
visited.add(node)
# Handle regular listeners
for listener_name, (_, trigger_methods) in flow._listeners.items():
if node in trigger_methods:
ancestors[listener_name].add(node)
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
# Handle router methods separately
if node in flow._routers:
router_method_name = node
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
# Only propagate the ancestors of the router method, not the router method itself
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool:
"""
Check if one node is an ancestor of another.
Parameters
----------
node : str
The node to check ancestors for.
ancestor_candidate : str
The potential ancestor node.
ancestors : Dict[str, Set[str]]
Dictionary containing ancestor relationships.
Returns
-------
bool
True if ancestor_candidate is an ancestor of node, False otherwise.
"""
return ancestor_candidate in ancestors.get(node, set())
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
"""
Build a dictionary mapping parent nodes to their children.
Parameters
----------
flow : Any
The flow instance to analyze.
Returns
-------
Dict[str, List[str]]
Dictionary mapping parent method names to lists of their child method names.
Notes
-----
- Maps listeners to their trigger methods
- Maps router methods to their paths and listeners
- Children lists are sorted for consistent ordering
"""
parent_children: Dict[str, List[str]] = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
for trigger in trigger_methods:
if trigger not in parent_children:
parent_children[trigger] = []
if listener_name not in parent_children[trigger]:
parent_children[trigger].append(listener_name)
# Map router methods to their paths and to listeners
for router_method_name, paths in flow._router_paths.items():
for path in paths:
# Map router method to listeners of each path
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
if router_method_name not in parent_children:
parent_children[router_method_name] = []
if listener_name not in parent_children[router_method_name]:
parent_children[router_method_name].append(listener_name)
return parent_children
def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str]]) -> int:
"""
Get the index of a child node in its parent's sorted children list.
Parameters
----------
parent : str
The parent node name.
child : str
The child node name to find the index for.
parent_children : Dict[str, List[str]]
Dictionary mapping parents to their children lists.
Returns
-------
int
Zero-based index of the child in its parent's sorted children list.
"""
children = parent_children.get(parent, [])
children.sort()
return children.index(child)
# Function implementations have been moved to core_flow_utils.py and flow_visual_utils.py

View File

@@ -1,61 +1,63 @@
"""
Utilities for creating visual representations of flow structures.
This module provides functions for generating network visualizations of flows,
including node placement, edge creation, and visual styling. It handles the
conversion of flow structures into visual network graphs with appropriate
styling and layout.
Example
-------
>>> flow = Flow()
>>> net = Network(directed=True)
>>> node_positions = compute_positions(flow, node_levels)
>>> add_nodes_to_network(net, flow, node_positions, node_styles)
>>> add_edges(net, flow, node_positions, colors)
"""
import ast
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
from .utils import (
from crewai.flow.flow import Flow
from pyvis.network import Network
from .core_flow_utils import is_ancestor
from .flow_visual_utils import (
build_ancestor_dict,
build_parent_children_dict,
get_child_index,
is_ancestor,
)
from .path_utils import safe_path_join, validate_file_path
def method_calls_crew(method: Any) -> bool:
"""
Check if the method contains a call to `.crew()`.
Parameters
----------
method : Any
The method to analyze for crew() calls.
Returns
-------
bool
True if the method calls .crew(), False otherwise.
Notes
-----
Uses AST analysis to detect method calls, specifically looking for
attribute access of 'crew'.
def method_calls_crew(method: Optional[Callable[..., Any]]) -> bool:
"""Check if the method contains a .crew() call in its implementation.
Analyzes the method's source code using AST to detect if it makes any
calls to the .crew() method, which indicates crew involvement in the
flow execution.
Args:
method: The method to analyze for crew calls, can be None
Returns:
bool: True if the method contains a .crew() call, False otherwise
Raises:
TypeError: If input is not None and not a callable method
ValueError: If method source code cannot be parsed
RuntimeError: If unexpected error occurs during parsing
"""
if method is None:
return False
if not callable(method):
raise TypeError("Input must be a callable method")
try:
source = inspect.getsource(method)
source = inspect.cleandoc(source)
tree = ast.parse(source)
except (TypeError, ValueError, OSError) as e:
raise ValueError(f"Could not parse method {getattr(method, '__name__', str(method))}: {e}")
except Exception as e:
print(f"Could not parse method {method.__name__}: {e}")
return False
raise RuntimeError(f"Unexpected error parsing method: {e}")
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls."""
"""AST visitor to detect .crew() method calls in source code.
A specialized AST visitor that analyzes Python source code to precisely
identify calls to the .crew() method, enabling accurate detection of
crew involvement in flow methods.
Attributes:
found (bool): Indicates whether a .crew() call was found
"""
def __init__(self):
self.found = False
@@ -70,35 +72,64 @@ def method_calls_crew(method: Any) -> bool:
return visitor.found
def add_nodes_to_network(
net: Any,
flow: Any,
node_positions: Dict[str, Tuple[float, float]],
node_styles: Dict[str, Dict[str, Any]]
) -> None:
def add_nodes_to_network(net: Network, flow: Flow[Any],
node_positions: Dict[str, Tuple[float, float]],
node_styles: Dict[str, dict],
output_dir: Optional[str] = None) -> None:
"""Add nodes to the network visualization with precise styling and positioning.
Creates and styles nodes in the visualization network based on their type
(start, router, crew, or regular method) with fine-grained control over
appearance and positioning.
Args:
net: The network visualization object to add nodes to
flow: Flow object containing method definitions and relationships
node_positions: Dictionary mapping method names to (x,y) coordinates
node_styles: Dictionary mapping node types to their visual styles
output_dir: Optional directory path for saving visualization assets
Returns:
None
Raises:
ValueError: If flow object is invalid or required styles are missing
TypeError: If input arguments have incorrect types
OSError: If output directory operations fail
Note:
Node styles are applied with precise control over shape, font, color,
and positioning to ensure accurate visual representation of the flow.
If output_dir is provided, it will be validated and created if needed.
"""
Add nodes to the network visualization with appropriate styling.
Parameters
----------
net : Any
The pyvis Network instance to add nodes to.
flow : Any
The flow instance containing method information.
node_positions : Dict[str, Tuple[float, float]]
Dictionary mapping node names to their (x, y) positions.
node_styles : Dict[str, Dict[str, Any]]
Dictionary containing style configurations for different node types.
Notes
-----
Node types include:
- Start methods
- Router methods
- Crew methods
- Regular methods
"""
def human_friendly_label(method_name):
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not isinstance(node_positions, dict):
raise TypeError("node_positions must be a dictionary")
if not isinstance(node_styles, dict):
raise TypeError("node_styles must be a dictionary")
required_styles = {'start', 'router', 'crew', 'method'}
missing_styles = required_styles - set(node_styles.keys())
if missing_styles:
raise ValueError(f"Missing required node styles: {missing_styles}")
# Validate and create output directory if specified
if output_dir:
try:
output_dir = validate_file_path(output_dir, must_exist=False)
os.makedirs(output_dir, exist_ok=True)
except (ValueError, OSError) as e:
raise OSError(f"Failed to create or validate output directory: {e}")
def human_friendly_label(method_name: str) -> str:
"""Convert method name to human-readable format.
Args:
method_name: Original method name with underscores
Returns:
str: Formatted method name with spaces and title case
"""
return method_name.replace("_", " ").title()
for method_name, (x, y) in node_positions.items():
@@ -115,6 +146,15 @@ def add_nodes_to_network(
node_style = node_style.copy()
label = human_friendly_label(method_name)
# Handle file-based assets if output directory is provided
if output_dir and node_style.get("image"):
try:
image_path = node_style["image"]
safe_image_path = safe_path_join(output_dir, Path(image_path).name)
node_style["image"] = str(safe_image_path)
except (ValueError, OSError) as e:
raise OSError(f"Failed to process node image path: {e}")
node_style.update(
{
"label": label,
@@ -136,31 +176,39 @@ def add_nodes_to_network(
)
def compute_positions(
flow: Any,
node_levels: Dict[str, int],
y_spacing: float = 150,
x_spacing: float = 150
) -> Dict[str, Tuple[float, float]]:
"""
Compute the (x, y) positions for each node in the flow graph.
Parameters
----------
flow : Any
The flow instance to compute positions for.
node_levels : Dict[str, int]
Dictionary mapping node names to their hierarchical levels.
y_spacing : float, optional
Vertical spacing between levels, by default 150.
x_spacing : float, optional
Horizontal spacing between nodes, by default 150.
Returns
-------
Dict[str, Tuple[float, float]]
Dictionary mapping node names to their (x, y) coordinates.
def compute_positions(flow: Flow[Any], node_levels: Dict[str, int],
y_spacing: float = 150, x_spacing: float = 150) -> Dict[str, Tuple[float, float]]:
"""Calculate precise x,y coordinates for each node in the flow diagram.
Computes optimal node positions with fine-grained control over spacing
and alignment, ensuring clear visualization of flow hierarchy and
relationships.
Args:
flow: Flow object containing method definitions
node_levels: Dictionary mapping method names to their hierarchy levels
y_spacing: Vertical spacing between hierarchy levels (default: 150)
x_spacing: Horizontal spacing between nodes at same level (default: 150)
Returns:
dict[str, tuple[float, float]]: Dictionary mapping method names to
their calculated (x,y) coordinates in the visualization
Note:
Positions are calculated to maintain clear hierarchical structure while
ensuring optimal spacing and readability of the flow diagram.
"""
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not isinstance(node_levels, dict):
raise TypeError("node_levels must be a dictionary")
if not isinstance(y_spacing, (int, float)) or y_spacing <= 0:
raise ValueError("y_spacing must be a positive number")
if not isinstance(x_spacing, (int, float)) or x_spacing <= 0:
raise ValueError("x_spacing must be a positive number")
if not node_levels:
raise ValueError("node_levels dictionary cannot be empty")
level_nodes: Dict[int, List[str]] = {}
node_positions: Dict[str, Tuple[float, float]] = {}
@@ -177,33 +225,34 @@ def compute_positions(
return node_positions
def add_edges(
net: Any,
flow: Any,
node_positions: Dict[str, Tuple[float, float]],
colors: Dict[str, str]
) -> None:
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
"""
Add edges to the network visualization with appropriate styling.
Parameters
----------
net : Any
The pyvis Network instance to add edges to.
flow : Any
The flow instance containing edge information.
node_positions : Dict[str, Tuple[float, float]]
Dictionary mapping node names to their positions.
colors : Dict[str, str]
Dictionary mapping edge types to their colors.
Notes
-----
- Handles both normal listener edges and router edges
- Applies appropriate styling (color, dashes) based on edge type
- Adds curvature to edges when needed (cycles or multiple children)
"""
def add_edges(net: Network, flow: Flow[Any],
node_positions: Dict[str, Tuple[float, float]],
colors: Dict[str, str],
asset_dir: Optional[str] = None) -> None:
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not hasattr(flow, '_listeners'):
raise ValueError("Invalid flow object: missing '_listeners' attribute")
if not hasattr(flow, '_router_paths'):
raise ValueError("Invalid flow object: missing '_router_paths' attribute")
if not isinstance(node_positions, dict):
raise TypeError("node_positions must be a dictionary")
if not isinstance(colors, dict):
raise TypeError("colors must be a dictionary")
required_colors = {'edge', 'router_edge'}
missing_colors = required_colors - set(colors.keys())
if missing_colors:
raise ValueError(f"Missing required edge colors: {missing_colors}")
# Validate asset directory if provided
if asset_dir:
try:
asset_dir = validate_file_path(asset_dir, must_exist=False)
os.makedirs(asset_dir, exist_ok=True)
except (ValueError, OSError) as e:
raise OSError(f"Failed to create or validate asset directory: {e}")
ancestors = build_ancestor_dict(flow)
parent_children = build_parent_children_dict(flow)
@@ -232,24 +281,24 @@ def add_edges(
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(trigger, method_name, parent_children)
edge_smooth = {
edge_config = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = {"type": "cubicBezier"}
edge_config = {"type": "cubicBezier"}
else:
edge_smooth.update({"type": "continuous"})
edge_config = {"type": "straight"}
edge_style = {
edge_props: Dict[str, Any] = {
"color": edge_color,
"width": 2,
"arrows": "to",
"dashes": True if is_router_edge or is_and_condition else False,
"smooth": edge_smooth,
"smooth": edge_config,
}
net.add_edge(trigger, method_name, **edge_style)
net.add_edge(trigger, method_name, **edge_props)
else:
# Nodes not found in node_positions. Check if it's a known router outcome and a known method.
is_router_edge = any(
@@ -295,23 +344,23 @@ def add_edges(
index = get_child_index(
router_method_name, listener_name, parent_children
)
edge_smooth = {
edge_config = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = {"type": "cubicBezier"}
edge_config = {"type": "cubicBezier"}
else:
edge_smooth.update({"type": "continuous"})
edge_config = {"type": "straight"}
edge_style = {
router_edge_props: Dict[str, Any] = {
"color": colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": edge_smooth,
"smooth": edge_config,
}
net.add_edge(router_method_name, listener_name, **edge_style)
net.add_edge(router_method_name, listener_name, **router_edge_props)
else:
# Same check here: known router edge and known method?
method_known = listener_name in flow._methods

View File

@@ -1,5 +1,3 @@
import json
import logging
from typing import Any, List, Optional
from pydantic import BaseModel, Field
@@ -7,8 +5,6 @@ from pydantic import BaseModel, Field
from crewai.agent import Agent
from crewai.task import Task
logger = logging.getLogger(__name__)
class PlanPerTask(BaseModel):
task: str = Field(..., description="The task for which the plan is created")
@@ -72,39 +68,19 @@ class CrewPlanner:
output_pydantic=PlannerTaskPydanticOutput,
)
def _get_agent_knowledge(self, task: Task) -> List[str]:
"""
Safely retrieve knowledge source content from the task's agent.
Args:
task: The task containing an agent with potential knowledge sources
Returns:
List[str]: A list of knowledge source strings
"""
try:
if task.agent and task.agent.knowledge_sources:
return [source.content for source in task.agent.knowledge_sources]
except AttributeError:
logger.warning("Error accessing agent knowledge sources")
return []
def _create_tasks_summary(self) -> str:
"""Creates a summary of all tasks."""
tasks_summary = []
for idx, task in enumerate(self.tasks):
knowledge_list = self._get_agent_knowledge(task)
task_summary = f"""
tasks_summary.append(
f"""
Task Number {idx + 1} - {task.description}
"task_description": {task.description}
"task_expected_output": {task.expected_output}
"agent": {task.agent.role if task.agent else "None"}
"agent_goal": {task.agent.goal if task.agent else "None"}
"task_tools": {task.tools}
"agent_tools": %s%s""" % (
f"[{', '.join(str(tool) for tool in task.agent.tools)}]" if task.agent and task.agent.tools else '"agent has no tools"',
f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' if knowledge_list and str(knowledge_list) != "None" else ""
)
tasks_summary.append(task_summary)
"agent_tools": {task.agent.tools if task.agent else "None"}
"""
)
return " ".join(tasks_summary)

View File

@@ -1,84 +0,0 @@
"""
Tests for verifying the integration of knowledge sources in the planning process.
This module ensures that agent knowledge is properly included during task planning.
"""
from unittest.mock import patch
import pytest
from crewai.agent import Agent
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.task import Task
from crewai.utilities.planning_handler import CrewPlanner
@pytest.fixture
def mock_knowledge_source():
"""
Create a mock knowledge source with test content.
Returns:
StringKnowledgeSource:
A knowledge source containing AI-related test content
"""
content = """
Important context about AI:
1. AI systems use machine learning algorithms
2. Neural networks are a key component
3. Training data is essential for good performance
"""
return StringKnowledgeSource(content=content)
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
def test_knowledge_included_in_planning(mock_chroma):
"""Test that verifies knowledge sources are properly included in planning."""
# Mock ChromaDB collection
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
mock_collection.add.return_value = None
# Create an agent with knowledge
agent = Agent(
role="AI Researcher",
goal="Research and explain AI concepts",
backstory="Expert in artificial intelligence",
knowledge_sources=[
StringKnowledgeSource(
content="AI systems require careful training and validation."
)
]
)
# Create a task for the agent
task = Task(
description="Explain the basics of AI systems",
expected_output="A clear explanation of AI fundamentals",
agent=agent
)
# Create a crew planner
planner = CrewPlanner([task], None)
# Get the task summary
task_summary = planner._create_tasks_summary()
# Verify that knowledge is included in planning when present
assert "AI systems require careful training" in task_summary, \
"Knowledge content should be present in task summary when knowledge exists"
assert '"agent_knowledge"' in task_summary, \
"agent_knowledge field should be present in task summary when knowledge exists"
# Verify that knowledge is properly formatted
assert isinstance(task.agent.knowledge_sources, list), \
"Knowledge sources should be stored in a list"
assert len(task.agent.knowledge_sources) > 0, \
"At least one knowledge source should be present"
assert task.agent.knowledge_sources[0].content in task_summary, \
"Knowledge source content should be included in task summary"
# Verify that other expected components are still present
assert task.description in task_summary, \
"Task description should be present in task summary"
assert task.expected_output in task_summary, \
"Expected output should be present in task summary"
assert agent.role in task_summary, \
"Agent role should be present in task summary"

View File

@@ -1,14 +1,10 @@
from typing import Optional
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from pydantic import BaseModel
from crewai.agent import Agent
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool
from crewai.utilities.planning_handler import (
CrewPlanner,
PlannerTaskPydanticOutput,
@@ -96,72 +92,7 @@ class TestCrewPlanner:
tasks_summary = crew_planner._create_tasks_summary()
assert isinstance(tasks_summary, str)
assert tasks_summary.startswith("\n Task Number 1 - Task 1")
assert '"agent_tools": "agent has no tools"' in tasks_summary
# Knowledge field should not be present when empty
assert '"agent_knowledge"' not in tasks_summary
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
def test_create_tasks_summary_with_knowledge_and_tools(self, mock_chroma):
"""Test task summary generation with both knowledge and tools present."""
# Mock ChromaDB collection
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
mock_collection.add.return_value = None
# Create mock tools with proper string descriptions and structured tool support
class MockTool(BaseTool):
name: str
description: str
def __init__(self, name: str, description: str):
tool_data = {"name": name, "description": description}
super().__init__(**tool_data)
def __str__(self):
return self.name
def __repr__(self):
return self.name
def to_structured_tool(self):
return self
def _run(self, *args, **kwargs):
pass
def _generate_description(self) -> str:
"""Override _generate_description to avoid args_schema handling."""
return self.description
tool1 = MockTool("tool1", "Tool 1 description")
tool2 = MockTool("tool2", "Tool 2 description")
# Create a task with knowledge and tools
task = Task(
description="Task with knowledge and tools",
expected_output="Expected output",
agent=Agent(
role="Test Agent",
goal="Test Goal",
backstory="Test Backstory",
tools=[tool1, tool2],
knowledge_sources=[
StringKnowledgeSource(content="Test knowledge content")
]
)
)
# Create planner with the new task
planner = CrewPlanner([task], None)
tasks_summary = planner._create_tasks_summary()
# Verify task summary content
assert isinstance(tasks_summary, str)
assert task.description in tasks_summary
assert task.expected_output in tasks_summary
assert '"agent_tools": [tool1, tool2]' in tasks_summary
assert '"agent_knowledge": "[\\"Test knowledge content\\"]"' in tasks_summary
assert task.agent.role in tasks_summary
assert task.agent.goal in tasks_summary
assert tasks_summary.endswith('"agent_tools": []\n ')
def test_handle_crew_planning_different_llm(self, crew_planner_different_llm):
with patch.object(Task, "execute_sync") as execute:

68
uv.lock generated
View File

@@ -1,10 +1,18 @@
version = 1
requires-python = ">=3.10, <3.13"
resolution-markers = [
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
"python_full_version >= '3.12.4'",
"python_full_version < '3.11' and sys_platform == 'darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
"python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version >= '3.12.4' and sys_platform == 'darwin'",
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
]
[[package]]
@@ -300,7 +308,7 @@ name = "build"
version = "1.2.2.post1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "os_name == 'nt'" },
{ name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "importlib-metadata", marker = "python_full_version < '3.10.2'" },
{ name = "packaging" },
{ name = "pyproject-hooks" },
@@ -535,7 +543,7 @@ name = "click"
version = "8.1.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
wheels = [
@@ -642,7 +650,6 @@ tools = [
[package.dev-dependencies]
dev = [
{ name = "cairosvg" },
{ name = "crewai-tools" },
{ name = "mkdocs" },
{ name = "mkdocs-material" },
{ name = "mkdocs-material-extensions" },
@@ -696,7 +703,6 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
{ name = "cairosvg", specifier = ">=2.7.1" },
{ name = "crewai-tools", specifier = ">=0.17.0" },
{ name = "mkdocs", specifier = ">=1.4.3" },
{ name = "mkdocs-material", specifier = ">=9.5.7" },
{ name = "mkdocs-material-extensions", specifier = ">=1.3.1" },
@@ -2462,7 +2468,7 @@ version = "1.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "ghp-import" },
{ name = "jinja2" },
{ name = "markdown" },
@@ -2643,7 +2649,7 @@ version = "2.10.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pygments" },
{ name = "pywin32", marker = "platform_system == 'Windows'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3a/93/80ac75c20ce54c785648b4ed363c88f148bf22637e10c9863db4fbe73e74/mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97", size = 271270 }
@@ -2890,7 +2896,7 @@ name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -2917,9 +2923,9 @@ name = "nvidia-cusolver-cu12"
version = "11.4.5.107"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
@@ -2930,7 +2936,7 @@ name = "nvidia-cusparse-cu12"
version = "12.1.0.106"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
@@ -3480,7 +3486,7 @@ name = "portalocker"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "platform_system == 'Windows'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
wheels = [
@@ -5022,19 +5028,19 @@ dependencies = [
{ name = "fsspec" },
{ name = "jinja2" },
{ name = "networkx" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "sympy" },
{ name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "typing-extensions" },
]
wheels = [
@@ -5081,7 +5087,7 @@ name = "tqdm"
version = "4.66.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 }
wheels = [
@@ -5124,7 +5130,7 @@ version = "0.27.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "attrs" },
{ name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" },
{ name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "idna" },
{ name = "outcome" },
@@ -5155,7 +5161,7 @@ name = "triton"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },