mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-14 15:02:37 +00:00
Compare commits
8 Commits
bugfix/tes
...
devin/1735
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4aaa37d7af | ||
|
|
4604ac618d | ||
|
|
b0d1d86c26 | ||
|
|
454ab55a26 | ||
|
|
613dd175ee | ||
|
|
1dc8ce2674 | ||
|
|
e0590e516f | ||
|
|
8c6883e5ee |
141
src/crewai/flow/core_flow_utils.py
Normal file
141
src/crewai/flow/core_flow_utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Core utility functions for Flow class operations.
|
||||
|
||||
This module contains utility functions that are specifically designed to work
|
||||
with the Flow class and require direct access to Flow class internals. These
|
||||
utilities are separated from general-purpose utilities to maintain a clean
|
||||
dependency structure and avoid circular imports.
|
||||
|
||||
Functions in this module are core to Flow functionality and are not related
|
||||
to visualization or other optional features.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Callable[..., Any]) -> Optional[List[str]]:
|
||||
"""Extract possible string return values from a function by analyzing its source code.
|
||||
|
||||
Analyzes the function's source code using AST to identify string constants that
|
||||
could be returned, including strings stored in dictionaries and direct returns.
|
||||
|
||||
Args:
|
||||
function: The function to analyze for possible return values
|
||||
|
||||
Returns:
|
||||
list[str] | None: List of possible string return values, or None if:
|
||||
- Source code cannot be retrieved
|
||||
- Source code has syntax/indentation errors
|
||||
- No string return values are found
|
||||
|
||||
Raises:
|
||||
OSError: If source code cannot be retrieved
|
||||
IndentationError: If source code has invalid indentation
|
||||
SyntaxError: If source code has syntax errors
|
||||
|
||||
Example:
|
||||
>>> def get_status():
|
||||
... paths = {"success": "completed", "error": "failed"}
|
||||
... return paths["success"]
|
||||
>>> get_possible_return_constants(get_status)
|
||||
['completed', 'failed']
|
||||
"""
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
# Can't get source code
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error retrieving source code for function {function.__name__}: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Remove leading indentation
|
||||
source = textwrap.dedent(source)
|
||||
# Parse the source code into an AST
|
||||
code_ast = ast.parse(source)
|
||||
except IndentationError as e:
|
||||
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except SyntaxError as e:
|
||||
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
|
||||
return_values = set()
|
||||
dict_definitions = {}
|
||||
|
||||
class DictionaryAssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_Assign(self, node):
|
||||
# Check if this assignment is assigning a dictionary literal to a variable
|
||||
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
|
||||
target = node.targets[0]
|
||||
if isinstance(target, ast.Name):
|
||||
var_name = target.id
|
||||
dict_values = []
|
||||
# Extract string values from the dictionary
|
||||
for val in node.value.values:
|
||||
if isinstance(val, ast.Constant) and isinstance(val.value, str):
|
||||
dict_values.append(val.value)
|
||||
# If non-string, skip or just ignore
|
||||
if dict_values:
|
||||
dict_definitions[var_name] = dict_values
|
||||
self.generic_visit(node)
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
# Direct string return
|
||||
if isinstance(node.value, ast.Constant) and isinstance(
|
||||
node.value.value, str
|
||||
):
|
||||
return_values.add(node.value.value)
|
||||
# Dictionary-based return, like return paths[result]
|
||||
elif isinstance(node.value, ast.Subscript):
|
||||
# Check if we're subscripting a known dictionary variable
|
||||
if isinstance(node.value.value, ast.Name):
|
||||
var_name = node.value.value.id
|
||||
if var_name in dict_definitions:
|
||||
# Add all possible dictionary values
|
||||
for v in dict_definitions[var_name]:
|
||||
return_values.add(v)
|
||||
self.generic_visit(node)
|
||||
|
||||
# First pass: identify dictionary assignments
|
||||
DictionaryAssignmentVisitor().visit(code_ast)
|
||||
# Second pass: identify returns
|
||||
ReturnVisitor().visit(code_ast)
|
||||
|
||||
return list(return_values) if return_values else None
|
||||
|
||||
|
||||
def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool:
|
||||
"""Check if one node is an ancestor of another in the flow graph.
|
||||
|
||||
Args:
|
||||
node: Target node to check ancestors for
|
||||
ancestor_candidate: Node to check if it's an ancestor
|
||||
ancestors: Dictionary mapping nodes to their ancestor sets
|
||||
|
||||
Returns:
|
||||
bool: True if ancestor_candidate is an ancestor of node
|
||||
|
||||
Raises:
|
||||
TypeError: If any argument has an invalid type
|
||||
"""
|
||||
if not isinstance(node, str):
|
||||
raise TypeError("Argument 'node' must be a string")
|
||||
if not isinstance(ancestor_candidate, str):
|
||||
raise TypeError("Argument 'ancestor_candidate' must be a string")
|
||||
if not isinstance(ancestors, dict):
|
||||
raise TypeError("Argument 'ancestors' must be a dictionary")
|
||||
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
@@ -17,59 +17,42 @@ from typing import (
|
||||
from blinker import Signal
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.flow.core_flow_utils import get_possible_return_constants
|
||||
from crewai.flow.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
|
||||
|
||||
|
||||
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
"""
|
||||
Marks a method as a flow's starting point.
|
||||
|
||||
This decorator designates a method as an entry point for the flow execution.
|
||||
It can optionally specify conditions that trigger the start based on other
|
||||
method executions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
condition : Optional[Union[str, dict, Callable]], optional
|
||||
Defines when the start method should execute. Can be:
|
||||
- str: Name of a method that triggers this start
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this start
|
||||
Default is None, meaning unconditional start.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable
|
||||
A decorator function that marks the method as a flow start point.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the condition format is invalid.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> @start() # Unconditional start
|
||||
>>> def begin_flow(self):
|
||||
... pass
|
||||
|
||||
>>> @start("method_name") # Start after specific method
|
||||
>>> def conditional_start(self):
|
||||
... pass
|
||||
|
||||
>>> @start(and_("method1", "method2")) # Start after multiple methods
|
||||
>>> def complex_start(self):
|
||||
... pass
|
||||
"""Marks a method as a flow starting point, optionally triggered by other methods.
|
||||
|
||||
Args:
|
||||
condition: The condition that triggers this method. Can be:
|
||||
- str: Name of the triggering method
|
||||
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
|
||||
- Callable: A function reference
|
||||
- None: No trigger condition (default)
|
||||
|
||||
Returns:
|
||||
Callable: The decorated function that will serve as a flow starting point.
|
||||
|
||||
Raises:
|
||||
ValueError: If the condition format is invalid.
|
||||
|
||||
Example:
|
||||
>>> @start() # No condition
|
||||
>>> def begin_flow():
|
||||
>>> pass
|
||||
>>>
|
||||
>>> @start("method_name") # Triggered by specific method
|
||||
>>> def conditional_start():
|
||||
>>> pass
|
||||
"""
|
||||
def decorator(func):
|
||||
func.__is_start_method__ = True
|
||||
@@ -95,41 +78,30 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""
|
||||
Creates a listener that executes when specified conditions are met.
|
||||
|
||||
This decorator sets up a method to execute in response to other method
|
||||
executions in the flow. It supports both simple and complex triggering
|
||||
conditions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the listener should execute. Can be:
|
||||
- str: Name of a method that triggers this listener
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this listener
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable
|
||||
A decorator function that sets up the method as a listener.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the condition format is invalid.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> @listen("process_data") # Listen to single method
|
||||
>>> def handle_processed_data(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(or_("success", "failure")) # Listen to multiple methods
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
"""Marks a method to execute when specified conditions/methods complete.
|
||||
|
||||
Args:
|
||||
condition: The condition that triggers this method. Can be:
|
||||
- str: Name of the triggering method
|
||||
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
|
||||
- Callable: A function reference
|
||||
|
||||
Returns:
|
||||
Callable: The decorated function that will execute when conditions are met.
|
||||
|
||||
Raises:
|
||||
ValueError: If the condition format is invalid.
|
||||
|
||||
Example:
|
||||
>>> @listen("start_method") # Listen to single method
|
||||
>>> def on_start():
|
||||
>>> pass
|
||||
>>>
|
||||
>>> @listen(and_("method1", "method2")) # Listen with AND condition
|
||||
>>> def on_both_complete():
|
||||
>>> pass
|
||||
"""
|
||||
def decorator(func):
|
||||
if isinstance(condition, str):
|
||||
@@ -155,45 +127,29 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
|
||||
|
||||
def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""
|
||||
Creates a routing method that directs flow execution based on conditions.
|
||||
|
||||
This decorator marks a method as a router, which can dynamically determine
|
||||
the next steps in the flow based on its return value. Routers are triggered
|
||||
by specified conditions and can return constants that determine which path
|
||||
the flow should take.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the router should execute. Can be:
|
||||
- str: Name of a method that triggers this router
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this router
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable
|
||||
A decorator function that sets up the method as a router.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the condition format is invalid.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> @router("check_status")
|
||||
>>> def route_based_on_status(self):
|
||||
... if self.state.status == "success":
|
||||
... return SUCCESS
|
||||
... return FAILURE
|
||||
|
||||
>>> @router(and_("validate", "process"))
|
||||
>>> def complex_routing(self):
|
||||
... if all([self.state.valid, self.state.processed]):
|
||||
... return CONTINUE
|
||||
... return STOP
|
||||
"""Marks a method as a router to direct flow based on its return value.
|
||||
|
||||
A router method can return different string values that trigger different
|
||||
subsequent methods, allowing for dynamic flow control.
|
||||
|
||||
Args:
|
||||
condition: The condition that triggers this router. Can be:
|
||||
- str: Name of the triggering method
|
||||
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
|
||||
- Callable: A function reference
|
||||
|
||||
Returns:
|
||||
Callable: The decorated function that will serve as a router.
|
||||
|
||||
Raises:
|
||||
ValueError: If the condition format is invalid.
|
||||
|
||||
Example:
|
||||
>>> @router("process_data")
|
||||
>>> def route_result(result):
|
||||
>>> if result.success:
|
||||
>>> return "handle_success"
|
||||
>>> return "handle_error"
|
||||
"""
|
||||
def decorator(func):
|
||||
func.__is_router__ = True
|
||||
@@ -218,38 +174,27 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with OR logic for flow control.
|
||||
|
||||
Creates a condition that is satisfied when any of the specified conditions
|
||||
are met. This is used with @start, @listen, or @router decorators to create
|
||||
complex triggering conditions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "OR", "methods": list_of_method_names}
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If any condition is invalid.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> @listen(or_("success", "timeout"))
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
"""Combines multiple conditions with OR logic for flow control.
|
||||
|
||||
Args:
|
||||
*conditions: Variable number of conditions. Each can be:
|
||||
- str: Name of a method
|
||||
- dict: Dictionary with 'type' and 'methods' keys
|
||||
- Callable: A function reference
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with 'type': 'OR' and 'methods' list.
|
||||
|
||||
Raises:
|
||||
ValueError: If any condition is invalid.
|
||||
|
||||
Example:
|
||||
>>> @listen(or_("method1", "method2"))
|
||||
>>> def on_either():
|
||||
>>> # Executes when either method1 OR method2 completes
|
||||
>>> pass
|
||||
"""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
@@ -265,37 +210,25 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
|
||||
|
||||
def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with AND logic for flow control.
|
||||
|
||||
Creates a condition that is satisfied only when all specified conditions
|
||||
are met. This is used with @start, @listen, or @router decorators to create
|
||||
complex triggering conditions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "AND", "methods": list_of_method_names}
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If any condition is invalid.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> @listen(and_("validated", "processed"))
|
||||
>>> def handle_complete_data(self):
|
||||
... pass
|
||||
"""Combines multiple conditions with AND logic for flow control.
|
||||
|
||||
Args:
|
||||
*conditions: Variable number of conditions. Each can be:
|
||||
- str: Name of a method
|
||||
- dict: Dictionary with 'type' and 'methods' keys
|
||||
- Callable: A function reference
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with 'type': 'AND' and 'methods' list.
|
||||
|
||||
Raises:
|
||||
ValueError: If any condition is invalid.
|
||||
|
||||
Example:
|
||||
>>> @listen(and_("method1", "method2"))
|
||||
>>> def on_both():
|
||||
>>> # Executes when BOTH method1 AND method2 complete
|
||||
>>> pass
|
||||
"""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
@@ -355,6 +288,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
event_emitter = Signal("event_emitter")
|
||||
|
||||
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
|
||||
"""Create a generic version of Flow with specified state type.
|
||||
|
||||
Args:
|
||||
cls: The Flow class
|
||||
item: The type parameter for the flow's state
|
||||
|
||||
Returns:
|
||||
Type["Flow"]: A new Flow class with the specified state type
|
||||
|
||||
Example:
|
||||
>>> class MyState(BaseModel):
|
||||
>>> value: int
|
||||
>>>
|
||||
>>> class MyFlow(Flow[MyState]):
|
||||
>>> pass
|
||||
"""
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
_initial_state_T = item # type: ignore
|
||||
|
||||
@@ -362,11 +311,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return _FlowGeneric
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize a new Flow instance.
|
||||
|
||||
Sets up internal state tracking, method registration, and telemetry.
|
||||
The flow's methods are automatically discovered and registered during initialization.
|
||||
|
||||
Attributes initialized:
|
||||
_methods: Dictionary mapping method names to their callable objects
|
||||
_state: The flow's state object of type T
|
||||
_method_execution_counts: Tracks how many times each method has executed
|
||||
_pending_and_listeners: Tracks methods waiting for AND conditions
|
||||
_method_outputs: List of all outputs from executed methods
|
||||
"""
|
||||
self._methods: Dict[str, Callable] = {}
|
||||
self._state: T = self._create_initial_state()
|
||||
self._method_execution_counts: Dict[str, int] = {}
|
||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._method_outputs: List[Any] = []
|
||||
|
||||
self._telemetry.flow_creation_span(self.__class__.__name__)
|
||||
|
||||
@@ -377,6 +338,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._methods[method_name] = getattr(self, method_name)
|
||||
|
||||
def _create_initial_state(self) -> T:
|
||||
"""Create the initial state for the flow.
|
||||
|
||||
The state is created based on the following priority:
|
||||
1. If initial_state is None and _initial_state_T exists (generic type), use that
|
||||
2. If initial_state is None, return empty dict
|
||||
3. If initial_state is a type, instantiate it
|
||||
4. Otherwise, use initial_state as-is
|
||||
|
||||
Returns:
|
||||
T: The initial state object of type T
|
||||
|
||||
Note:
|
||||
The type T can be either a Pydantic BaseModel or a dictionary.
|
||||
"""
|
||||
if self.initial_state is None and hasattr(self, "_initial_state_T"):
|
||||
return self._initial_state_T() # type: ignore
|
||||
if self.initial_state is None:
|
||||
@@ -388,11 +363,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
@property
|
||||
def state(self) -> T:
|
||||
"""Get the current state of the flow.
|
||||
|
||||
Returns:
|
||||
T: The current state object, either a Pydantic model or dictionary
|
||||
"""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def method_outputs(self) -> List[Any]:
|
||||
"""Returns the list of all outputs from executed methods."""
|
||||
"""Get the list of all outputs from executed methods.
|
||||
|
||||
Returns:
|
||||
List[Any]: A list containing the output values from all executed flow methods,
|
||||
in order of execution.
|
||||
"""
|
||||
return self._method_outputs
|
||||
|
||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||
@@ -462,23 +447,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return final_output
|
||||
|
||||
async def _execute_start_method(self, start_method_name: str) -> None:
|
||||
"""
|
||||
Executes a flow's start method and its triggered listeners.
|
||||
|
||||
This internal method handles the execution of methods marked with @start
|
||||
decorator and manages the subsequent chain of listener executions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_method_name : str
|
||||
The name of the start method to execute.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Executes the start method and captures its result
|
||||
- Triggers execution of any listeners waiting on this start method
|
||||
- Part of the flow's initialization sequence
|
||||
"""
|
||||
result = await self._execute_method(
|
||||
start_method_name, self._methods[start_method_name]
|
||||
)
|
||||
@@ -499,27 +467,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return result
|
||||
|
||||
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
||||
"""
|
||||
Executes all listeners and routers triggered by a method completion.
|
||||
|
||||
This internal method manages the execution flow by:
|
||||
1. First executing all triggered routers sequentially
|
||||
2. Then executing all triggered listeners in parallel
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trigger_method : str
|
||||
The name of the method that triggered these listeners.
|
||||
result : Any
|
||||
The result from the triggering method, passed to listeners
|
||||
that accept parameters.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Routers are executed sequentially to maintain flow control
|
||||
- Each router's result becomes the new trigger_method
|
||||
- Normal listeners are executed in parallel for efficiency
|
||||
- Listeners can receive the trigger method's result as a parameter
|
||||
"""Execute all listener methods triggered by a completed method.
|
||||
|
||||
This method handles both router and non-router listeners in a specific order:
|
||||
1. First executes all triggered router methods sequentially until no more routers
|
||||
are triggered
|
||||
2. Then executes all regular listeners in parallel
|
||||
|
||||
Args:
|
||||
trigger_method: The name of the method that completed execution
|
||||
result: The result value from the triggering method
|
||||
|
||||
Note:
|
||||
Router methods are executed sequentially to ensure proper flow control,
|
||||
while regular listeners are executed concurrently for better performance.
|
||||
This provides fine-grained control over the execution flow while
|
||||
maintaining efficiency.
|
||||
"""
|
||||
# First, handle routers repeatedly until no router triggers anymore
|
||||
while True:
|
||||
@@ -550,32 +513,26 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> List[str]:
|
||||
"""
|
||||
Finds all methods that should be triggered based on conditions.
|
||||
|
||||
This internal method evaluates both OR and AND conditions to determine
|
||||
which methods should be executed next in the flow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trigger_method : str
|
||||
The name of the method that just completed execution.
|
||||
router_only : bool
|
||||
If True, only consider router methods.
|
||||
If False, only consider non-router methods.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
Names of methods that should be triggered.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Handles both OR and AND conditions:
|
||||
* OR: Triggers if any condition is met
|
||||
* AND: Triggers only when all conditions are met
|
||||
- Maintains state for AND conditions using _pending_and_listeners
|
||||
- Separates router and normal listener evaluation
|
||||
"""Find all methods that should be triggered based on completed method and type.
|
||||
|
||||
Provides precise control over method triggering by handling both OR and AND
|
||||
conditions separately for router and non-router methods.
|
||||
|
||||
Args:
|
||||
trigger_method: The name of the method that completed execution
|
||||
router_only: If True, only find router methods; if False, only regular
|
||||
listeners
|
||||
|
||||
Returns:
|
||||
List[str]: Names of methods that should be executed next
|
||||
|
||||
Note:
|
||||
This method implements sophisticated flow control by:
|
||||
1. Filtering methods based on their router/non-router status
|
||||
2. Handling OR conditions for immediate triggering
|
||||
3. Managing AND conditions with state tracking for complex dependencies
|
||||
|
||||
This ensures predictable and consistent execution order in complex flows.
|
||||
"""
|
||||
triggered = []
|
||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||
@@ -605,32 +562,26 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return triggered
|
||||
|
||||
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
||||
"""
|
||||
Executes a single listener method with proper event handling.
|
||||
|
||||
This internal method manages the execution of an individual listener,
|
||||
including parameter inspection, event emission, and error handling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
listener_name : str
|
||||
The name of the listener method to execute.
|
||||
result : Any
|
||||
The result from the triggering method, which may be passed
|
||||
to the listener if it accepts parameters.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Inspects method signature to determine if it accepts the trigger result
|
||||
- Emits events for method execution start and finish
|
||||
- Handles errors gracefully with detailed logging
|
||||
- Recursively triggers listeners of this listener
|
||||
- Supports both parameterized and parameter-less listeners
|
||||
|
||||
Error Handling
|
||||
-------------
|
||||
Catches and logs any exceptions during execution, preventing
|
||||
individual listener failures from breaking the entire flow.
|
||||
"""Execute a single listener method with precise parameter handling and error tracking.
|
||||
|
||||
Provides fine-grained control over method execution through:
|
||||
1. Automatic parameter inspection to determine if the method accepts results
|
||||
2. Event emission for execution tracking
|
||||
3. Comprehensive error handling
|
||||
4. Recursive listener execution
|
||||
|
||||
Args:
|
||||
listener_name: The name of the listener method to execute
|
||||
result: The result from the triggering method, passed to the listener
|
||||
if its signature accepts parameters
|
||||
|
||||
Note:
|
||||
This method ensures precise execution control by:
|
||||
- Inspecting method signatures to handle parameters correctly
|
||||
- Emitting events for execution tracking
|
||||
- Providing comprehensive error handling
|
||||
- Supporting both parameterized and parameter-less methods
|
||||
- Maintaining execution chain through recursive listener calls
|
||||
"""
|
||||
try:
|
||||
method = self._methods[listener_name]
|
||||
@@ -675,8 +626,32 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def plot(self, filename: str = "crewai_flow") -> None:
|
||||
def plot(self, *args, **kwargs):
|
||||
"""Generate an interactive visualization of the flow's execution graph.
|
||||
|
||||
Creates a detailed HTML visualization showing the relationships between
|
||||
methods, including start points, listeners, routers, and their
|
||||
connections. Includes telemetry tracking for flow analysis.
|
||||
|
||||
Args:
|
||||
*args: Variable length argument list passed to plot_flow
|
||||
**kwargs: Arbitrary keyword arguments passed to plot_flow
|
||||
|
||||
Note:
|
||||
The visualization provides:
|
||||
- Clear representation of method relationships
|
||||
- Visual distinction between different method types
|
||||
- Interactive exploration capabilities
|
||||
- Execution path tracing
|
||||
- Telemetry tracking for flow analysis
|
||||
|
||||
Example:
|
||||
>>> flow = MyFlow()
|
||||
>>> flow.plot("my_workflow") # Creates my_workflow.html
|
||||
"""
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
|
||||
self._telemetry.flow_plotting_span(
|
||||
self.__class__.__name__, list(self._methods.keys())
|
||||
)
|
||||
plot_flow(self, filename)
|
||||
return plot_flow(self, *args, **kwargs)
|
||||
|
||||
240
src/crewai/flow/flow_visual_utils.py
Normal file
240
src/crewai/flow/flow_visual_utils.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Utility functions for Flow visualization.
|
||||
|
||||
This module contains utility functions specifically designed for visualizing
|
||||
Flow graphs and calculating layout information. These utilities are separated
|
||||
from general-purpose utilities to maintain a clean dependency structure.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Set, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
|
||||
def calculate_node_levels(flow: Flow[Any]) -> Dict[str, int]:
|
||||
"""Calculate the hierarchical level of each node in the flow graph.
|
||||
|
||||
Uses breadth-first traversal to assign levels to nodes, starting with
|
||||
start methods at level 0. Handles both OR and AND conditions for listeners,
|
||||
and considers router paths when calculating levels.
|
||||
|
||||
Args:
|
||||
flow: Flow instance containing methods, listeners, and router configurations
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Dictionary mapping method names to their hierarchical levels,
|
||||
where level 0 contains start methods and each subsequent level contains
|
||||
methods triggered by the previous level
|
||||
|
||||
Example:
|
||||
>>> flow = Flow()
|
||||
>>> @flow.start
|
||||
... def start(): pass
|
||||
>>> @flow.on("start")
|
||||
... def second(): pass
|
||||
>>> calculate_node_levels(flow)
|
||||
{'start': 0, 'second': 1}
|
||||
"""
|
||||
levels: Dict[str, int] = {}
|
||||
queue: List[str] = []
|
||||
visited: Set[str] = set()
|
||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
if hasattr(method, "__is_start_method__"):
|
||||
levels[method_name] = 0
|
||||
queue.append(method_name)
|
||||
|
||||
# Breadth-first traversal to assign levels
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
current_level = levels[current]
|
||||
visited.add(current)
|
||||
|
||||
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
|
||||
if condition_type == "OR":
|
||||
if current in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
if listener_name not in pending_and_listeners:
|
||||
pending_and_listeners[listener_name] = set()
|
||||
if current in trigger_methods:
|
||||
pending_and_listeners[listener_name].add(current)
|
||||
if set(trigger_methods) == pending_and_listeners[listener_name]:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
|
||||
# Handle router connections
|
||||
if current in flow._routers:
|
||||
router_method_name = current
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow: Flow[Any]) -> Dict[str, int]:
|
||||
"""Count the number of outgoing edges for each node in the flow graph.
|
||||
|
||||
An outgoing edge represents a connection from a method to a listener
|
||||
that it triggers. This is useful for visualization and analysis of
|
||||
flow structure.
|
||||
|
||||
Args:
|
||||
flow: Flow instance containing methods and their connections
|
||||
|
||||
Returns:
|
||||
dict[str, int]: Dictionary mapping method names to their number
|
||||
of outgoing connections
|
||||
"""
|
||||
counts: Dict[str, int] = {}
|
||||
for method_name in flow._methods:
|
||||
counts[method_name] = 0
|
||||
for method_name in flow._listeners:
|
||||
_, trigger_methods = flow._listeners[method_name]
|
||||
for trigger in trigger_methods:
|
||||
if trigger in flow._methods:
|
||||
counts[trigger] += 1
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow: Flow[Any]) -> Dict[str, Set[str]]:
|
||||
"""Build a dictionary mapping each node to its set of ancestor nodes.
|
||||
|
||||
Uses depth-first search to identify all ancestors (direct and indirect
|
||||
trigger methods) for each node in the flow graph. Handles both regular
|
||||
listeners and router paths.
|
||||
|
||||
Args:
|
||||
flow: Flow instance containing methods and their relationships
|
||||
|
||||
Returns:
|
||||
dict[str, set[str]]: Dictionary mapping each method name to a set
|
||||
of its ancestor method names
|
||||
"""
|
||||
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
|
||||
visited: Set[str] = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
return ancestors
|
||||
|
||||
|
||||
|
||||
|
||||
def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
|
||||
visited: Set[str], flow: Flow[Any]) -> None:
|
||||
"""Perform depth-first search to populate the ancestors dictionary.
|
||||
|
||||
Helper function for build_ancestor_dict that recursively traverses
|
||||
the flow graph to identify ancestors of each node.
|
||||
|
||||
Args:
|
||||
node: Current node being processed
|
||||
ancestors: Dictionary mapping nodes to their ancestor sets
|
||||
visited: Set of already visited nodes
|
||||
flow: Flow instance containing the graph structure
|
||||
"""
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
|
||||
# Handle regular listeners
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if node in trigger_methods:
|
||||
ancestors[listener_name].add(node)
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
# Handle router methods separately
|
||||
if node in flow._routers:
|
||||
router_method_name = node
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
# Only propagate the ancestors of the router method, not the router method itself
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
|
||||
def build_parent_children_dict(flow: Flow[Any]) -> Dict[str, List[str]]:
|
||||
"""Build a dictionary mapping each node to its list of child nodes.
|
||||
|
||||
Maps both regular trigger methods to their listeners and router
|
||||
methods to their path listeners. Useful for visualization and
|
||||
traversal of the flow graph structure.
|
||||
|
||||
Args:
|
||||
flow: Flow instance containing methods and their relationships
|
||||
|
||||
Returns:
|
||||
dict[str, list[str]]: Dictionary mapping each method name to a
|
||||
sorted list of its child method names
|
||||
"""
|
||||
parent_children: Dict[str, List[str]] = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
for trigger in trigger_methods:
|
||||
if trigger not in parent_children:
|
||||
parent_children[trigger] = []
|
||||
if listener_name not in parent_children[trigger]:
|
||||
parent_children[trigger].append(listener_name)
|
||||
|
||||
# Map router methods to their paths and to listeners
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
for path in paths:
|
||||
# Map router method to listeners of each path
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if router_method_name not in parent_children:
|
||||
parent_children[router_method_name] = []
|
||||
if listener_name not in parent_children[router_method_name]:
|
||||
parent_children[router_method_name].append(listener_name)
|
||||
|
||||
return parent_children
|
||||
|
||||
|
||||
def get_child_index(parent: str, child: str,
|
||||
parent_children: Dict[str, List[str]]) -> int:
|
||||
"""Get the index of a child node in its parent's sorted children list.
|
||||
|
||||
Args:
|
||||
parent: Parent node name
|
||||
child: Child node name to find index for
|
||||
parent_children: Dictionary mapping parents to their children lists
|
||||
|
||||
Returns:
|
||||
int: Zero-based index of the child in parent's sorted children list
|
||||
|
||||
Raises:
|
||||
ValueError: If child is not found in parent's children list
|
||||
"""
|
||||
children = parent_children.get(parent, [])
|
||||
children.sort()
|
||||
return children.index(child)
|
||||
@@ -6,10 +6,10 @@ from pathlib import Path
|
||||
from pyvis.network import Network
|
||||
|
||||
from crewai.flow.config import COLORS, NODE_STYLES
|
||||
from crewai.flow.flow_visual_utils import calculate_node_levels
|
||||
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
||||
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.utils import calculate_node_levels
|
||||
from crewai.flow.path_utils import safe_path_join, validate_file_path
|
||||
from crewai.flow.visualization_utils import (
|
||||
add_edges,
|
||||
add_nodes_to_network,
|
||||
@@ -21,206 +21,102 @@ class FlowPlot:
|
||||
"""Handles the creation and rendering of flow visualization diagrams."""
|
||||
|
||||
def __init__(self, flow):
|
||||
"""
|
||||
Initialize FlowPlot with a flow object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flow : Flow
|
||||
A Flow instance to visualize.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If flow object is invalid or missing required attributes.
|
||||
"""Initialize flow plot with flow instance and styling configuration.
|
||||
|
||||
Args:
|
||||
flow: A Flow instance with required attributes for visualization
|
||||
|
||||
Raises:
|
||||
ValueError: If flow object is invalid or missing required attributes
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
||||
raise ValueError("Invalid flow object: Missing '_methods' attribute")
|
||||
if not hasattr(flow, '_start_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
|
||||
raise ValueError("Invalid flow object: Missing '_start_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
raise ValueError("Invalid flow object: Missing '_listeners' attribute")
|
||||
|
||||
self.flow = flow
|
||||
self.colors = COLORS
|
||||
self.node_styles = NODE_STYLES
|
||||
|
||||
def plot(self, filename):
|
||||
"""
|
||||
Generate and save an HTML visualization of the flow.
|
||||
"""Generate and save interactive flow visualization to HTML file."""
|
||||
net = Network(
|
||||
directed=True,
|
||||
height="750px",
|
||||
width="100%",
|
||||
bgcolor=self.colors["bg"],
|
||||
layout=None,
|
||||
)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
Name of the output file (without extension).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If filename is invalid or network generation fails.
|
||||
IOError
|
||||
If file operations fail or visualization cannot be generated.
|
||||
RuntimeError
|
||||
If network visualization generation fails.
|
||||
"""
|
||||
if not filename or not isinstance(filename, str):
|
||||
raise ValueError("Filename must be a non-empty string")
|
||||
|
||||
try:
|
||||
# Initialize network
|
||||
net = Network(
|
||||
directed=True,
|
||||
height="750px",
|
||||
width="100%",
|
||||
bgcolor=self.colors["bg"],
|
||||
layout=None,
|
||||
)
|
||||
|
||||
# Set options to disable physics
|
||||
net.set_options(
|
||||
"""
|
||||
var options = {
|
||||
"nodes": {
|
||||
"font": {
|
||||
"multi": "html"
|
||||
}
|
||||
},
|
||||
"physics": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
# Set options to disable physics
|
||||
net.set_options(
|
||||
"""
|
||||
)
|
||||
var options = {
|
||||
"nodes": {
|
||||
"font": {
|
||||
"multi": "html"
|
||||
}
|
||||
},
|
||||
"physics": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
# Calculate levels for nodes
|
||||
try:
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to calculate node levels: {str(e)}")
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
|
||||
# Compute positions
|
||||
try:
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to compute node positions: {str(e)}")
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
|
||||
# Add nodes to the network
|
||||
try:
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
||||
try:
|
||||
# Ensure the output path is safe
|
||||
output_dir = os.getcwd()
|
||||
output_path = safe_path_join(output_dir, f"{filename}.html")
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {output_path}")
|
||||
except (IOError, ValueError) as e:
|
||||
raise IOError(f"Failed to save flow visualization: {str(e)}")
|
||||
|
||||
# Add edges to the network
|
||||
try:
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
|
||||
|
||||
# Generate HTML
|
||||
try:
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
|
||||
|
||||
# Save the final HTML content to the file
|
||||
try:
|
||||
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {filename}.html")
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
|
||||
|
||||
except (ValueError, RuntimeError, IOError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
|
||||
finally:
|
||||
self._cleanup_pyvis_lib()
|
||||
self._cleanup_pyvis_lib()
|
||||
|
||||
def _generate_final_html(self, network_html):
|
||||
"""
|
||||
Generate the final HTML content with network visualization and legend.
|
||||
"""Generate final HTML content with network visualization and legend."""
|
||||
current_dir = os.path.dirname(__file__)
|
||||
template_path = os.path.join(
|
||||
current_dir, "assets", "crewai_flow_visual_template.html"
|
||||
)
|
||||
logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg")
|
||||
|
||||
Parameters
|
||||
----------
|
||||
network_html : str
|
||||
HTML content generated by pyvis Network.
|
||||
html_handler = HTMLTemplateHandler(template_path, logo_path)
|
||||
network_body = html_handler.extract_body_content(network_html)
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Complete HTML content with styling and legend.
|
||||
|
||||
Raises
|
||||
------
|
||||
IOError
|
||||
If template or logo files cannot be accessed.
|
||||
ValueError
|
||||
If network_html is invalid.
|
||||
"""
|
||||
if not network_html:
|
||||
raise ValueError("Invalid network HTML content")
|
||||
|
||||
try:
|
||||
# Extract just the body content from the generated HTML
|
||||
current_dir = os.path.dirname(__file__)
|
||||
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
|
||||
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
|
||||
|
||||
if not os.path.exists(template_path):
|
||||
raise IOError(f"Template file not found: {template_path}")
|
||||
if not os.path.exists(logo_path):
|
||||
raise IOError(f"Logo file not found: {logo_path}")
|
||||
|
||||
html_handler = HTMLTemplateHandler(template_path, logo_path)
|
||||
network_body = html_handler.extract_body_content(network_html)
|
||||
|
||||
# Generate the legend items HTML
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
network_body, legend_items_html
|
||||
)
|
||||
return final_html_content
|
||||
except Exception as e:
|
||||
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
network_body, legend_items_html
|
||||
)
|
||||
return final_html_content
|
||||
|
||||
def _cleanup_pyvis_lib(self):
|
||||
"""
|
||||
Clean up the generated lib folder from pyvis.
|
||||
|
||||
This method safely removes the temporary lib directory created by pyvis
|
||||
during network visualization generation.
|
||||
"""
|
||||
"""Clean up temporary files generated by pyvis library."""
|
||||
lib_folder = os.path.join(os.getcwd(), "lib")
|
||||
try:
|
||||
lib_folder = safe_path_join("lib", root=os.getcwd())
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
shutil.rmtree(lib_folder)
|
||||
except ValueError as e:
|
||||
print(f"Error validating lib folder path: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up lib folder: {e}")
|
||||
print(f"Error cleaning up {lib_folder}: {e}")
|
||||
|
||||
|
||||
def plot_flow(flow, filename="flow_plot"):
|
||||
"""
|
||||
Convenience function to create and save a flow visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flow : Flow
|
||||
Flow instance to visualize.
|
||||
filename : str, optional
|
||||
Output filename without extension, by default "flow_plot".
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If flow object or filename is invalid.
|
||||
IOError
|
||||
If file operations fail.
|
||||
"""
|
||||
"""Create and save a visualization of the given flow."""
|
||||
visualizer = FlowPlot(flow)
|
||||
visualizer.plot(filename)
|
||||
|
||||
@@ -1,55 +1,107 @@
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import safe_path_join, validate_file_path
|
||||
|
||||
|
||||
class HTMLTemplateHandler:
|
||||
"""Handles HTML template processing and generation for flow visualization diagrams."""
|
||||
|
||||
def __init__(self, template_path, logo_path):
|
||||
"""
|
||||
Initialize HTMLTemplateHandler with validated template and logo paths.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
template_path : str
|
||||
Path to the HTML template file.
|
||||
logo_path : str
|
||||
Path to the logo image file.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If template or logo paths are invalid or files don't exist.
|
||||
"""Initialize template handler with template and logo file paths.
|
||||
|
||||
Args:
|
||||
template_path: Path to the HTML template file
|
||||
logo_path: Path to the logo SVG file
|
||||
|
||||
Raises:
|
||||
ValueError: If template_path or logo_path is invalid or files don't exist
|
||||
"""
|
||||
try:
|
||||
self.template_path = validate_path_exists(template_path, "file")
|
||||
self.logo_path = validate_path_exists(logo_path, "file")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid template or logo path: {e}")
|
||||
self.template_path = validate_file_path(template_path)
|
||||
self.logo_path = validate_file_path(logo_path)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"Invalid file path: {str(e)}")
|
||||
|
||||
def read_template(self):
|
||||
"""Read and return the HTML template file contents."""
|
||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
"""Read and return the HTML template file contents.
|
||||
|
||||
Returns:
|
||||
str: The contents of the template file
|
||||
|
||||
Raises:
|
||||
IOError: If template file cannot be read
|
||||
"""
|
||||
try:
|
||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to read template file {self.template_path}: {str(e)}")
|
||||
|
||||
def encode_logo(self):
|
||||
"""Convert the logo SVG file to base64 encoded string."""
|
||||
with open(self.logo_path, "rb") as logo_file:
|
||||
logo_svg_data = logo_file.read()
|
||||
return base64.b64encode(logo_svg_data).decode("utf-8")
|
||||
"""Convert the logo SVG file to base64 encoded string.
|
||||
|
||||
Returns:
|
||||
str: Base64 encoded logo data
|
||||
|
||||
Raises:
|
||||
IOError: If logo file cannot be read
|
||||
ValueError: If logo data cannot be encoded
|
||||
"""
|
||||
try:
|
||||
with open(self.logo_path, "rb") as logo_file:
|
||||
logo_svg_data = logo_file.read()
|
||||
try:
|
||||
return base64.b64encode(logo_svg_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to encode logo data: {str(e)}")
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to read logo file {self.logo_path}: {str(e)}")
|
||||
|
||||
def extract_body_content(self, html):
|
||||
"""Extract and return content between body tags from HTML string."""
|
||||
"""Extract and return content between body tags from HTML string.
|
||||
|
||||
Args:
|
||||
html: HTML string to extract body content from
|
||||
|
||||
Returns:
|
||||
str: Content between body tags, or empty string if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If input HTML is invalid
|
||||
"""
|
||||
if not html or not isinstance(html, str):
|
||||
raise ValueError("Input HTML must be a non-empty string")
|
||||
|
||||
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
def generate_legend_items_html(self, legend_items):
|
||||
"""Generate HTML markup for the legend items."""
|
||||
"""Generate HTML markup for the legend items.
|
||||
|
||||
Args:
|
||||
legend_items: List of dictionaries containing legend item properties
|
||||
|
||||
Returns:
|
||||
str: Generated HTML markup for legend items
|
||||
|
||||
Raises:
|
||||
ValueError: If legend_items is invalid or missing required properties
|
||||
"""
|
||||
if not isinstance(legend_items, list):
|
||||
raise ValueError("legend_items must be a list")
|
||||
|
||||
legend_items_html = ""
|
||||
for item in legend_items:
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError("Each legend item must be a dictionary")
|
||||
if "color" not in item:
|
||||
raise ValueError("Each legend item must have a 'color' property")
|
||||
if "label" not in item:
|
||||
raise ValueError("Each legend item must have a 'label' property")
|
||||
|
||||
if "border" in item:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
@@ -75,19 +127,42 @@ class HTMLTemplateHandler:
|
||||
return legend_items_html
|
||||
|
||||
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
|
||||
"""Combine all components into final HTML document with network visualization."""
|
||||
html_template = self.read_template()
|
||||
logo_svg_base64 = self.encode_logo()
|
||||
"""Combine all components into final HTML document with network visualization.
|
||||
|
||||
Args:
|
||||
network_body: HTML string containing network visualization
|
||||
legend_items_html: HTML string containing legend items markup
|
||||
title: Title for the visualization page (default: "Flow Plot")
|
||||
|
||||
Returns:
|
||||
str: Complete HTML document with all components integrated
|
||||
|
||||
Raises:
|
||||
ValueError: If any input parameters are invalid
|
||||
IOError: If template or logo files cannot be read
|
||||
"""
|
||||
if not isinstance(network_body, str):
|
||||
raise ValueError("network_body must be a string")
|
||||
if not isinstance(legend_items_html, str):
|
||||
raise ValueError("legend_items_html must be a string")
|
||||
if not isinstance(title, str):
|
||||
raise ValueError("title must be a string")
|
||||
|
||||
try:
|
||||
html_template = self.read_template()
|
||||
logo_svg_base64 = self.encode_logo()
|
||||
|
||||
final_html_content = html_template.replace("{{ title }}", title)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ network_content }}", network_body
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
)
|
||||
final_html_content = html_template.replace("{{ title }}", title)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ network_content }}", network_body
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
)
|
||||
|
||||
return final_html_content
|
||||
return final_html_content
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate final HTML: {str(e)}")
|
||||
|
||||
@@ -1,135 +1,123 @@
|
||||
"""
|
||||
Path utilities for secure file operations in CrewAI flow module.
|
||||
"""Utilities for safe path handling in flow visualization.
|
||||
|
||||
This module provides utilities for secure path handling to prevent directory
|
||||
traversal attacks and ensure paths remain within allowed boundaries.
|
||||
This module provides a comprehensive set of utilities for secure path handling,
|
||||
including path joining, validation, and normalization. It helps prevent common
|
||||
security issues like directory traversal attacks while providing a consistent
|
||||
interface for path operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
def safe_path_join(base_dir: Union[str, Path], filename: str) -> str:
|
||||
"""Safely join base directory with filename, preventing directory traversal.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory path
|
||||
filename: Filename or path to join with base_dir
|
||||
|
||||
Returns:
|
||||
str: Safely joined absolute path
|
||||
|
||||
Raises:
|
||||
ValueError: If resulting path would escape base_dir or contains dangerous patterns
|
||||
TypeError: If inputs are not strings or Path objects
|
||||
OSError: If path resolution fails
|
||||
"""
|
||||
Safely join path components and ensure the result is within allowed boundaries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*parts : str
|
||||
Variable number of path components to join.
|
||||
root : Union[str, Path, None], optional
|
||||
Root directory to use as base. If None, uses current working directory.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the resolved path.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the resulting path would be outside the root directory
|
||||
or if any path component is invalid.
|
||||
"""
|
||||
if not parts:
|
||||
raise ValueError("No path components provided")
|
||||
|
||||
if not isinstance(base_dir, (str, Path)):
|
||||
raise TypeError("base_dir must be a string or Path object")
|
||||
if not isinstance(filename, str):
|
||||
raise TypeError("filename must be a string")
|
||||
|
||||
# Check for dangerous patterns
|
||||
dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`']
|
||||
if any(pattern in filename for pattern in dangerous_patterns):
|
||||
raise ValueError(f"Invalid filename: Contains dangerous pattern")
|
||||
|
||||
try:
|
||||
# Convert all parts to strings and clean them
|
||||
clean_parts = [str(part).strip() for part in parts if part]
|
||||
if not clean_parts:
|
||||
raise ValueError("No valid path components provided")
|
||||
|
||||
# Establish root directory
|
||||
root_path = Path(root).resolve() if root else Path.cwd()
|
||||
base_path = Path(base_dir).resolve(strict=True)
|
||||
full_path = Path(base_path, filename).resolve(strict=True)
|
||||
|
||||
# Join and resolve the full path
|
||||
full_path = Path(root_path, *clean_parts).resolve()
|
||||
|
||||
# Check if the resolved path is within root
|
||||
if not str(full_path).startswith(str(root_path)):
|
||||
if not str(full_path).startswith(str(base_path)):
|
||||
raise ValueError(
|
||||
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
||||
f"Invalid path: {filename} would escape base directory {base_dir}"
|
||||
)
|
||||
|
||||
return str(full_path)
|
||||
|
||||
except OSError as e:
|
||||
raise OSError(f"Failed to resolve path: {str(e)}")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path components: {str(e)}")
|
||||
raise ValueError(f"Failed to process paths: {str(e)}")
|
||||
|
||||
|
||||
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
|
||||
"""
|
||||
Validate that a path exists and is of the expected type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Union[str, Path]
|
||||
Path to validate.
|
||||
file_type : str, optional
|
||||
Expected type ('file' or 'directory'), by default 'file'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Validated path as string.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If path doesn't exist or is not of expected type.
|
||||
def normalize_path(path: Union[str, Path]) -> str:
|
||||
"""Normalize a path by resolving symlinks and removing redundant separators.
|
||||
|
||||
Args:
|
||||
path: Path to normalize
|
||||
|
||||
Returns:
|
||||
str: Normalized absolute path
|
||||
|
||||
Raises:
|
||||
TypeError: If path is not a string or Path object
|
||||
OSError: If path resolution fails
|
||||
"""
|
||||
if not isinstance(path, (str, Path)):
|
||||
raise TypeError("path must be a string or Path object")
|
||||
|
||||
try:
|
||||
path_obj = Path(path).resolve()
|
||||
return str(Path(path).resolve(strict=True))
|
||||
except OSError as e:
|
||||
raise OSError(f"Failed to normalize path: {str(e)}")
|
||||
|
||||
|
||||
def validate_path_components(components: List[str]) -> None:
|
||||
"""Validate path components for potentially dangerous patterns.
|
||||
|
||||
Args:
|
||||
components: List of path components to validate
|
||||
|
||||
if not path_obj.exists():
|
||||
raise ValueError(f"Path does not exist: {path}")
|
||||
|
||||
if file_type == "file" and not path_obj.is_file():
|
||||
raise ValueError(f"Path is not a file: {path}")
|
||||
elif file_type == "directory" and not path_obj.is_dir():
|
||||
raise ValueError(f"Path is not a directory: {path}")
|
||||
|
||||
return str(path_obj)
|
||||
Raises:
|
||||
TypeError: If components is not a list or contains non-string items
|
||||
ValueError: If any component contains dangerous patterns
|
||||
"""
|
||||
if not isinstance(components, list):
|
||||
raise TypeError("components must be a list")
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path: {str(e)}")
|
||||
dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`']
|
||||
for component in components:
|
||||
if not isinstance(component, str):
|
||||
raise TypeError(f"Path component '{component}' must be a string")
|
||||
if any(pattern in component for pattern in dangerous_patterns):
|
||||
raise ValueError(f"Invalid path component '{component}': Contains dangerous pattern")
|
||||
|
||||
|
||||
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
"""
|
||||
Safely list files in a directory matching a pattern.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
directory : Union[str, Path]
|
||||
Directory to search in.
|
||||
pattern : str, optional
|
||||
Glob pattern to match files against, by default "*".
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
List of matching file paths.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If directory is invalid or inaccessible.
|
||||
def validate_file_path(path: Union[str, Path], must_exist: bool = True) -> str:
|
||||
"""Validate a file path for security and existence.
|
||||
|
||||
Args:
|
||||
path: File path to validate
|
||||
must_exist: Whether the file must exist (default: True)
|
||||
|
||||
Returns:
|
||||
str: Validated absolute path
|
||||
|
||||
Raises:
|
||||
ValueError: If path is invalid or file doesn't exist when required
|
||||
TypeError: If path is not a string or Path object
|
||||
"""
|
||||
if not isinstance(path, (str, Path)):
|
||||
raise TypeError("path must be a string or Path object")
|
||||
|
||||
try:
|
||||
dir_path = Path(directory).resolve()
|
||||
if not dir_path.is_dir():
|
||||
raise ValueError(f"Not a directory: {directory}")
|
||||
|
||||
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
|
||||
resolved_path = Path(path).resolve()
|
||||
|
||||
if must_exist and not resolved_path.is_file():
|
||||
raise ValueError(f"File not found: {path}")
|
||||
|
||||
return str(resolved_path)
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error listing files: {str(e)}")
|
||||
raise ValueError(f"Invalid file path {path}: {str(e)}")
|
||||
|
||||
@@ -1,362 +1,35 @@
|
||||
"""
|
||||
Utility functions for flow visualization and dependency analysis.
|
||||
"""General utility functions for flow execution.
|
||||
|
||||
This module provides core functionality for analyzing and manipulating flow structures,
|
||||
including node level calculation, ancestor tracking, and return value analysis.
|
||||
Functions in this module are primarily used by the visualization system to create
|
||||
accurate and informative flow diagrams.
|
||||
This module has been deprecated. All functionality has been moved to:
|
||||
- core_flow_utils.py: Core flow execution utilities
|
||||
- flow_visual_utils.py: Visualization-related utilities
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> flow = Flow()
|
||||
>>> node_levels = calculate_node_levels(flow)
|
||||
>>> ancestors = build_ancestor_dict(flow)
|
||||
This module is kept as a temporary redirect to maintain backwards compatibility.
|
||||
New code should import from the appropriate new modules directly.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from .core_flow_utils import get_possible_return_constants, is_ancestor
|
||||
from .flow_visual_utils import (
|
||||
build_ancestor_dict,
|
||||
build_parent_children_dict,
|
||||
calculate_node_levels,
|
||||
count_outgoing_edges,
|
||||
dfs_ancestors,
|
||||
get_child_index,
|
||||
)
|
||||
|
||||
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
# Can't get source code
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error retrieving source code for function {function.__name__}: {e}")
|
||||
return None
|
||||
# Re-export all functions for backwards compatibility
|
||||
__all__ = [
|
||||
'get_possible_return_constants',
|
||||
'calculate_node_levels',
|
||||
'count_outgoing_edges',
|
||||
'build_ancestor_dict',
|
||||
'dfs_ancestors',
|
||||
'is_ancestor',
|
||||
'build_parent_children_dict',
|
||||
'get_child_index',
|
||||
]
|
||||
|
||||
try:
|
||||
# Remove leading indentation
|
||||
source = textwrap.dedent(source)
|
||||
# Parse the source code into an AST
|
||||
code_ast = ast.parse(source)
|
||||
except IndentationError as e:
|
||||
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except SyntaxError as e:
|
||||
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
return None
|
||||
|
||||
return_values = set()
|
||||
dict_definitions = {}
|
||||
|
||||
class DictionaryAssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_Assign(self, node):
|
||||
# Check if this assignment is assigning a dictionary literal to a variable
|
||||
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
|
||||
target = node.targets[0]
|
||||
if isinstance(target, ast.Name):
|
||||
var_name = target.id
|
||||
dict_values = []
|
||||
# Extract string values from the dictionary
|
||||
for val in node.value.values:
|
||||
if isinstance(val, ast.Constant) and isinstance(val.value, str):
|
||||
dict_values.append(val.value)
|
||||
# If non-string, skip or just ignore
|
||||
if dict_values:
|
||||
dict_definitions[var_name] = dict_values
|
||||
self.generic_visit(node)
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
# Direct string return
|
||||
if isinstance(node.value, ast.Constant) and isinstance(
|
||||
node.value.value, str
|
||||
):
|
||||
return_values.add(node.value.value)
|
||||
# Dictionary-based return, like return paths[result]
|
||||
elif isinstance(node.value, ast.Subscript):
|
||||
# Check if we're subscripting a known dictionary variable
|
||||
if isinstance(node.value.value, ast.Name):
|
||||
var_name = node.value.value.id
|
||||
if var_name in dict_definitions:
|
||||
# Add all possible dictionary values
|
||||
for v in dict_definitions[var_name]:
|
||||
return_values.add(v)
|
||||
self.generic_visit(node)
|
||||
|
||||
# First pass: identify dictionary assignments
|
||||
DictionaryAssignmentVisitor().visit(code_ast)
|
||||
# Second pass: identify returns
|
||||
ReturnVisitor().visit(code_ast)
|
||||
|
||||
return list(return_values) if return_values else None
|
||||
|
||||
|
||||
def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
"""
|
||||
Calculate the hierarchical level of each node in the flow.
|
||||
|
||||
Performs a breadth-first traversal of the flow graph to assign levels
|
||||
to nodes, starting with start methods at level 0.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flow : Any
|
||||
The flow instance containing methods, listeners, and router configurations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, int]
|
||||
Dictionary mapping method names to their hierarchical levels.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Start methods are assigned level 0
|
||||
- Each subsequent connected node is assigned level = parent_level + 1
|
||||
- Handles both OR and AND conditions for listeners
|
||||
- Processes router paths separately
|
||||
"""
|
||||
levels: Dict[str, int] = {}
|
||||
queue: List[str] = []
|
||||
visited: Set[str] = set()
|
||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
if hasattr(method, "__is_start_method__"):
|
||||
levels[method_name] = 0
|
||||
queue.append(method_name)
|
||||
|
||||
# Breadth-first traversal to assign levels
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
current_level = levels[current]
|
||||
visited.add(current)
|
||||
|
||||
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
|
||||
if condition_type == "OR":
|
||||
if current in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
if listener_name not in pending_and_listeners:
|
||||
pending_and_listeners[listener_name] = set()
|
||||
if current in trigger_methods:
|
||||
pending_and_listeners[listener_name].add(current)
|
||||
if set(trigger_methods) == pending_and_listeners[listener_name]:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
|
||||
# Handle router connections
|
||||
if current in flow._routers:
|
||||
router_method_name = current
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
if listener_name not in visited:
|
||||
queue.append(listener_name)
|
||||
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
"""
|
||||
Count the number of outgoing edges for each method in the flow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flow : Any
|
||||
The flow instance to analyze.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, int]
|
||||
Dictionary mapping method names to their outgoing edge count.
|
||||
"""
|
||||
counts = {}
|
||||
for method_name in flow._methods:
|
||||
counts[method_name] = 0
|
||||
for method_name in flow._listeners:
|
||||
_, trigger_methods = flow._listeners[method_name]
|
||||
for trigger in trigger_methods:
|
||||
if trigger in flow._methods:
|
||||
counts[trigger] += 1
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Build a dictionary mapping each node to its ancestor nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flow : Any
|
||||
The flow instance to analyze.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Set[str]]
|
||||
Dictionary mapping each node to a set of its ancestor nodes.
|
||||
"""
|
||||
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
|
||||
visited: Set[str] = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
return ancestors
|
||||
|
||||
|
||||
def dfs_ancestors(
|
||||
node: str,
|
||||
ancestors: Dict[str, Set[str]],
|
||||
visited: Set[str],
|
||||
flow: Any
|
||||
) -> None:
|
||||
"""
|
||||
Perform depth-first search to build ancestor relationships.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
node : str
|
||||
Current node being processed.
|
||||
ancestors : Dict[str, Set[str]]
|
||||
Dictionary tracking ancestor relationships.
|
||||
visited : Set[str]
|
||||
Set of already visited nodes.
|
||||
flow : Any
|
||||
The flow instance being analyzed.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function modifies the ancestors dictionary in-place to build
|
||||
the complete ancestor graph.
|
||||
"""
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
|
||||
# Handle regular listeners
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if node in trigger_methods:
|
||||
ancestors[listener_name].add(node)
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
# Handle router methods separately
|
||||
if node in flow._routers:
|
||||
router_method_name = node
|
||||
paths = flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
# Only propagate the ancestors of the router method, not the router method itself
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
|
||||
def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool:
|
||||
"""
|
||||
Check if one node is an ancestor of another.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
node : str
|
||||
The node to check ancestors for.
|
||||
ancestor_candidate : str
|
||||
The potential ancestor node.
|
||||
ancestors : Dict[str, Set[str]]
|
||||
Dictionary containing ancestor relationships.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if ancestor_candidate is an ancestor of node, False otherwise.
|
||||
"""
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
|
||||
|
||||
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Build a dictionary mapping parent nodes to their children.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flow : Any
|
||||
The flow instance to analyze.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, List[str]]
|
||||
Dictionary mapping parent method names to lists of their child method names.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Maps listeners to their trigger methods
|
||||
- Maps router methods to their paths and listeners
|
||||
- Children lists are sorted for consistent ordering
|
||||
"""
|
||||
parent_children: Dict[str, List[str]] = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
for trigger in trigger_methods:
|
||||
if trigger not in parent_children:
|
||||
parent_children[trigger] = []
|
||||
if listener_name not in parent_children[trigger]:
|
||||
parent_children[trigger].append(listener_name)
|
||||
|
||||
# Map router methods to their paths and to listeners
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
for path in paths:
|
||||
# Map router method to listeners of each path
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if router_method_name not in parent_children:
|
||||
parent_children[router_method_name] = []
|
||||
if listener_name not in parent_children[router_method_name]:
|
||||
parent_children[router_method_name].append(listener_name)
|
||||
|
||||
return parent_children
|
||||
|
||||
|
||||
def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str]]) -> int:
|
||||
"""
|
||||
Get the index of a child node in its parent's sorted children list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent : str
|
||||
The parent node name.
|
||||
child : str
|
||||
The child node name to find the index for.
|
||||
parent_children : Dict[str, List[str]]
|
||||
Dictionary mapping parents to their children lists.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Zero-based index of the child in its parent's sorted children list.
|
||||
"""
|
||||
children = parent_children.get(parent, [])
|
||||
children.sort()
|
||||
return children.index(child)
|
||||
# Function implementations have been moved to core_flow_utils.py and flow_visual_utils.py
|
||||
|
||||
@@ -1,61 +1,63 @@
|
||||
"""
|
||||
Utilities for creating visual representations of flow structures.
|
||||
|
||||
This module provides functions for generating network visualizations of flows,
|
||||
including node placement, edge creation, and visual styling. It handles the
|
||||
conversion of flow structures into visual network graphs with appropriate
|
||||
styling and layout.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> flow = Flow()
|
||||
>>> net = Network(directed=True)
|
||||
>>> node_positions = compute_positions(flow, node_levels)
|
||||
>>> add_nodes_to_network(net, flow, node_positions, node_styles)
|
||||
>>> add_edges(net, flow, node_positions, colors)
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
from .utils import (
|
||||
from crewai.flow.flow import Flow
|
||||
from pyvis.network import Network
|
||||
|
||||
from .core_flow_utils import is_ancestor
|
||||
from .flow_visual_utils import (
|
||||
build_ancestor_dict,
|
||||
build_parent_children_dict,
|
||||
get_child_index,
|
||||
is_ancestor,
|
||||
)
|
||||
from .path_utils import safe_path_join, validate_file_path
|
||||
|
||||
|
||||
def method_calls_crew(method: Any) -> bool:
|
||||
"""
|
||||
Check if the method contains a call to `.crew()`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method : Any
|
||||
The method to analyze for crew() calls.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the method calls .crew(), False otherwise.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Uses AST analysis to detect method calls, specifically looking for
|
||||
attribute access of 'crew'.
|
||||
def method_calls_crew(method: Optional[Callable[..., Any]]) -> bool:
|
||||
"""Check if the method contains a .crew() call in its implementation.
|
||||
|
||||
Analyzes the method's source code using AST to detect if it makes any
|
||||
calls to the .crew() method, which indicates crew involvement in the
|
||||
flow execution.
|
||||
|
||||
Args:
|
||||
method: The method to analyze for crew calls, can be None
|
||||
|
||||
Returns:
|
||||
bool: True if the method contains a .crew() call, False otherwise
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not None and not a callable method
|
||||
ValueError: If method source code cannot be parsed
|
||||
RuntimeError: If unexpected error occurs during parsing
|
||||
"""
|
||||
if method is None:
|
||||
return False
|
||||
if not callable(method):
|
||||
raise TypeError("Input must be a callable method")
|
||||
|
||||
try:
|
||||
source = inspect.getsource(method)
|
||||
source = inspect.cleandoc(source)
|
||||
tree = ast.parse(source)
|
||||
except (TypeError, ValueError, OSError) as e:
|
||||
raise ValueError(f"Could not parse method {getattr(method, '__name__', str(method))}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Could not parse method {method.__name__}: {e}")
|
||||
return False
|
||||
raise RuntimeError(f"Unexpected error parsing method: {e}")
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to detect .crew() method calls."""
|
||||
"""AST visitor to detect .crew() method calls in source code.
|
||||
|
||||
A specialized AST visitor that analyzes Python source code to precisely
|
||||
identify calls to the .crew() method, enabling accurate detection of
|
||||
crew involvement in flow methods.
|
||||
|
||||
Attributes:
|
||||
found (bool): Indicates whether a .crew() call was found
|
||||
"""
|
||||
def __init__(self):
|
||||
self.found = False
|
||||
|
||||
@@ -70,35 +72,64 @@ def method_calls_crew(method: Any) -> bool:
|
||||
return visitor.found
|
||||
|
||||
|
||||
def add_nodes_to_network(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
node_styles: Dict[str, Dict[str, Any]]
|
||||
) -> None:
|
||||
def add_nodes_to_network(net: Network, flow: Flow[Any],
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
node_styles: Dict[str, dict],
|
||||
output_dir: Optional[str] = None) -> None:
|
||||
"""Add nodes to the network visualization with precise styling and positioning.
|
||||
|
||||
Creates and styles nodes in the visualization network based on their type
|
||||
(start, router, crew, or regular method) with fine-grained control over
|
||||
appearance and positioning.
|
||||
|
||||
Args:
|
||||
net: The network visualization object to add nodes to
|
||||
flow: Flow object containing method definitions and relationships
|
||||
node_positions: Dictionary mapping method names to (x,y) coordinates
|
||||
node_styles: Dictionary mapping node types to their visual styles
|
||||
output_dir: Optional directory path for saving visualization assets
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ValueError: If flow object is invalid or required styles are missing
|
||||
TypeError: If input arguments have incorrect types
|
||||
OSError: If output directory operations fail
|
||||
|
||||
Note:
|
||||
Node styles are applied with precise control over shape, font, color,
|
||||
and positioning to ensure accurate visual representation of the flow.
|
||||
If output_dir is provided, it will be validated and created if needed.
|
||||
"""
|
||||
Add nodes to the network visualization with appropriate styling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
net : Any
|
||||
The pyvis Network instance to add nodes to.
|
||||
flow : Any
|
||||
The flow instance containing method information.
|
||||
node_positions : Dict[str, Tuple[float, float]]
|
||||
Dictionary mapping node names to their (x, y) positions.
|
||||
node_styles : Dict[str, Dict[str, Any]]
|
||||
Dictionary containing style configurations for different node types.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Node types include:
|
||||
- Start methods
|
||||
- Router methods
|
||||
- Crew methods
|
||||
- Regular methods
|
||||
"""
|
||||
def human_friendly_label(method_name):
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not isinstance(node_positions, dict):
|
||||
raise TypeError("node_positions must be a dictionary")
|
||||
if not isinstance(node_styles, dict):
|
||||
raise TypeError("node_styles must be a dictionary")
|
||||
|
||||
required_styles = {'start', 'router', 'crew', 'method'}
|
||||
missing_styles = required_styles - set(node_styles.keys())
|
||||
if missing_styles:
|
||||
raise ValueError(f"Missing required node styles: {missing_styles}")
|
||||
|
||||
# Validate and create output directory if specified
|
||||
if output_dir:
|
||||
try:
|
||||
output_dir = validate_file_path(output_dir, must_exist=False)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except (ValueError, OSError) as e:
|
||||
raise OSError(f"Failed to create or validate output directory: {e}")
|
||||
def human_friendly_label(method_name: str) -> str:
|
||||
"""Convert method name to human-readable format.
|
||||
|
||||
Args:
|
||||
method_name: Original method name with underscores
|
||||
|
||||
Returns:
|
||||
str: Formatted method name with spaces and title case
|
||||
"""
|
||||
return method_name.replace("_", " ").title()
|
||||
|
||||
for method_name, (x, y) in node_positions.items():
|
||||
@@ -115,6 +146,15 @@ def add_nodes_to_network(
|
||||
node_style = node_style.copy()
|
||||
label = human_friendly_label(method_name)
|
||||
|
||||
# Handle file-based assets if output directory is provided
|
||||
if output_dir and node_style.get("image"):
|
||||
try:
|
||||
image_path = node_style["image"]
|
||||
safe_image_path = safe_path_join(output_dir, Path(image_path).name)
|
||||
node_style["image"] = str(safe_image_path)
|
||||
except (ValueError, OSError) as e:
|
||||
raise OSError(f"Failed to process node image path: {e}")
|
||||
|
||||
node_style.update(
|
||||
{
|
||||
"label": label,
|
||||
@@ -136,31 +176,39 @@ def add_nodes_to_network(
|
||||
)
|
||||
|
||||
|
||||
def compute_positions(
|
||||
flow: Any,
|
||||
node_levels: Dict[str, int],
|
||||
y_spacing: float = 150,
|
||||
x_spacing: float = 150
|
||||
) -> Dict[str, Tuple[float, float]]:
|
||||
"""
|
||||
Compute the (x, y) positions for each node in the flow graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flow : Any
|
||||
The flow instance to compute positions for.
|
||||
node_levels : Dict[str, int]
|
||||
Dictionary mapping node names to their hierarchical levels.
|
||||
y_spacing : float, optional
|
||||
Vertical spacing between levels, by default 150.
|
||||
x_spacing : float, optional
|
||||
Horizontal spacing between nodes, by default 150.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Tuple[float, float]]
|
||||
Dictionary mapping node names to their (x, y) coordinates.
|
||||
def compute_positions(flow: Flow[Any], node_levels: Dict[str, int],
|
||||
y_spacing: float = 150, x_spacing: float = 150) -> Dict[str, Tuple[float, float]]:
|
||||
"""Calculate precise x,y coordinates for each node in the flow diagram.
|
||||
|
||||
Computes optimal node positions with fine-grained control over spacing
|
||||
and alignment, ensuring clear visualization of flow hierarchy and
|
||||
relationships.
|
||||
|
||||
Args:
|
||||
flow: Flow object containing method definitions
|
||||
node_levels: Dictionary mapping method names to their hierarchy levels
|
||||
y_spacing: Vertical spacing between hierarchy levels (default: 150)
|
||||
x_spacing: Horizontal spacing between nodes at same level (default: 150)
|
||||
|
||||
Returns:
|
||||
dict[str, tuple[float, float]]: Dictionary mapping method names to
|
||||
their calculated (x,y) coordinates in the visualization
|
||||
|
||||
Note:
|
||||
Positions are calculated to maintain clear hierarchical structure while
|
||||
ensuring optimal spacing and readability of the flow diagram.
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not isinstance(node_levels, dict):
|
||||
raise TypeError("node_levels must be a dictionary")
|
||||
if not isinstance(y_spacing, (int, float)) or y_spacing <= 0:
|
||||
raise ValueError("y_spacing must be a positive number")
|
||||
if not isinstance(x_spacing, (int, float)) or x_spacing <= 0:
|
||||
raise ValueError("x_spacing must be a positive number")
|
||||
|
||||
if not node_levels:
|
||||
raise ValueError("node_levels dictionary cannot be empty")
|
||||
level_nodes: Dict[int, List[str]] = {}
|
||||
node_positions: Dict[str, Tuple[float, float]] = {}
|
||||
|
||||
@@ -177,33 +225,34 @@ def compute_positions(
|
||||
return node_positions
|
||||
|
||||
|
||||
def add_edges(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
colors: Dict[str, str]
|
||||
) -> None:
|
||||
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
|
||||
"""
|
||||
Add edges to the network visualization with appropriate styling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
net : Any
|
||||
The pyvis Network instance to add edges to.
|
||||
flow : Any
|
||||
The flow instance containing edge information.
|
||||
node_positions : Dict[str, Tuple[float, float]]
|
||||
Dictionary mapping node names to their positions.
|
||||
colors : Dict[str, str]
|
||||
Dictionary mapping edge types to their colors.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Handles both normal listener edges and router edges
|
||||
- Applies appropriate styling (color, dashes) based on edge type
|
||||
- Adds curvature to edges when needed (cycles or multiple children)
|
||||
"""
|
||||
def add_edges(net: Network, flow: Flow[Any],
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
colors: Dict[str, str],
|
||||
asset_dir: Optional[str] = None) -> None:
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
||||
if not hasattr(flow, '_router_paths'):
|
||||
raise ValueError("Invalid flow object: missing '_router_paths' attribute")
|
||||
|
||||
if not isinstance(node_positions, dict):
|
||||
raise TypeError("node_positions must be a dictionary")
|
||||
if not isinstance(colors, dict):
|
||||
raise TypeError("colors must be a dictionary")
|
||||
|
||||
required_colors = {'edge', 'router_edge'}
|
||||
missing_colors = required_colors - set(colors.keys())
|
||||
if missing_colors:
|
||||
raise ValueError(f"Missing required edge colors: {missing_colors}")
|
||||
|
||||
# Validate asset directory if provided
|
||||
if asset_dir:
|
||||
try:
|
||||
asset_dir = validate_file_path(asset_dir, must_exist=False)
|
||||
os.makedirs(asset_dir, exist_ok=True)
|
||||
except (ValueError, OSError) as e:
|
||||
raise OSError(f"Failed to create or validate asset directory: {e}")
|
||||
ancestors = build_ancestor_dict(flow)
|
||||
parent_children = build_parent_children_dict(flow)
|
||||
|
||||
@@ -232,24 +281,24 @@ def add_edges(
|
||||
dx = target_pos[0] - source_pos[0]
|
||||
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
||||
index = get_child_index(trigger, method_name, parent_children)
|
||||
edge_smooth = {
|
||||
edge_config = {
|
||||
"type": smooth_type,
|
||||
"roundness": 0.2 + (0.1 * index),
|
||||
}
|
||||
else:
|
||||
edge_smooth = {"type": "cubicBezier"}
|
||||
edge_config = {"type": "cubicBezier"}
|
||||
else:
|
||||
edge_smooth.update({"type": "continuous"})
|
||||
edge_config = {"type": "straight"}
|
||||
|
||||
edge_style = {
|
||||
edge_props: Dict[str, Any] = {
|
||||
"color": edge_color,
|
||||
"width": 2,
|
||||
"arrows": "to",
|
||||
"dashes": True if is_router_edge or is_and_condition else False,
|
||||
"smooth": edge_smooth,
|
||||
"smooth": edge_config,
|
||||
}
|
||||
|
||||
net.add_edge(trigger, method_name, **edge_style)
|
||||
net.add_edge(trigger, method_name, **edge_props)
|
||||
else:
|
||||
# Nodes not found in node_positions. Check if it's a known router outcome and a known method.
|
||||
is_router_edge = any(
|
||||
@@ -295,23 +344,23 @@ def add_edges(
|
||||
index = get_child_index(
|
||||
router_method_name, listener_name, parent_children
|
||||
)
|
||||
edge_smooth = {
|
||||
edge_config = {
|
||||
"type": smooth_type,
|
||||
"roundness": 0.2 + (0.1 * index),
|
||||
}
|
||||
else:
|
||||
edge_smooth = {"type": "cubicBezier"}
|
||||
edge_config = {"type": "cubicBezier"}
|
||||
else:
|
||||
edge_smooth.update({"type": "continuous"})
|
||||
edge_config = {"type": "straight"}
|
||||
|
||||
edge_style = {
|
||||
router_edge_props: Dict[str, Any] = {
|
||||
"color": colors["router_edge"],
|
||||
"width": 2,
|
||||
"arrows": "to",
|
||||
"dashes": True,
|
||||
"smooth": edge_smooth,
|
||||
"smooth": edge_config,
|
||||
}
|
||||
net.add_edge(router_method_name, listener_name, **edge_style)
|
||||
net.add_edge(router_method_name, listener_name, **router_edge_props)
|
||||
else:
|
||||
# Same check here: known router edge and known method?
|
||||
method_known = listener_name in flow._methods
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -7,8 +5,6 @@ from pydantic import BaseModel, Field
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlanPerTask(BaseModel):
|
||||
task: str = Field(..., description="The task for which the plan is created")
|
||||
@@ -72,39 +68,19 @@ class CrewPlanner:
|
||||
output_pydantic=PlannerTaskPydanticOutput,
|
||||
)
|
||||
|
||||
def _get_agent_knowledge(self, task: Task) -> List[str]:
|
||||
"""
|
||||
Safely retrieve knowledge source content from the task's agent.
|
||||
|
||||
Args:
|
||||
task: The task containing an agent with potential knowledge sources
|
||||
|
||||
Returns:
|
||||
List[str]: A list of knowledge source strings
|
||||
"""
|
||||
try:
|
||||
if task.agent and task.agent.knowledge_sources:
|
||||
return [source.content for source in task.agent.knowledge_sources]
|
||||
except AttributeError:
|
||||
logger.warning("Error accessing agent knowledge sources")
|
||||
return []
|
||||
|
||||
def _create_tasks_summary(self) -> str:
|
||||
"""Creates a summary of all tasks."""
|
||||
tasks_summary = []
|
||||
for idx, task in enumerate(self.tasks):
|
||||
knowledge_list = self._get_agent_knowledge(task)
|
||||
task_summary = f"""
|
||||
tasks_summary.append(
|
||||
f"""
|
||||
Task Number {idx + 1} - {task.description}
|
||||
"task_description": {task.description}
|
||||
"task_expected_output": {task.expected_output}
|
||||
"agent": {task.agent.role if task.agent else "None"}
|
||||
"agent_goal": {task.agent.goal if task.agent else "None"}
|
||||
"task_tools": {task.tools}
|
||||
"agent_tools": %s%s""" % (
|
||||
f"[{', '.join(str(tool) for tool in task.agent.tools)}]" if task.agent and task.agent.tools else '"agent has no tools"',
|
||||
f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' if knowledge_list and str(knowledge_list) != "None" else ""
|
||||
)
|
||||
|
||||
tasks_summary.append(task_summary)
|
||||
"agent_tools": {task.agent.tools if task.agent else "None"}
|
||||
"""
|
||||
)
|
||||
return " ".join(tasks_summary)
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
"""
|
||||
Tests for verifying the integration of knowledge sources in the planning process.
|
||||
This module ensures that agent knowledge is properly included during task planning.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.planning_handler import CrewPlanner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_knowledge_source():
|
||||
"""
|
||||
Create a mock knowledge source with test content.
|
||||
Returns:
|
||||
StringKnowledgeSource:
|
||||
A knowledge source containing AI-related test content
|
||||
"""
|
||||
content = """
|
||||
Important context about AI:
|
||||
1. AI systems use machine learning algorithms
|
||||
2. Neural networks are a key component
|
||||
3. Training data is essential for good performance
|
||||
"""
|
||||
return StringKnowledgeSource(content=content)
|
||||
|
||||
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
|
||||
def test_knowledge_included_in_planning(mock_chroma):
|
||||
"""Test that verifies knowledge sources are properly included in planning."""
|
||||
# Mock ChromaDB collection
|
||||
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
|
||||
mock_collection.add.return_value = None
|
||||
|
||||
# Create an agent with knowledge
|
||||
agent = Agent(
|
||||
role="AI Researcher",
|
||||
goal="Research and explain AI concepts",
|
||||
backstory="Expert in artificial intelligence",
|
||||
knowledge_sources=[
|
||||
StringKnowledgeSource(
|
||||
content="AI systems require careful training and validation."
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Create a task for the agent
|
||||
task = Task(
|
||||
description="Explain the basics of AI systems",
|
||||
expected_output="A clear explanation of AI fundamentals",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
# Create a crew planner
|
||||
planner = CrewPlanner([task], None)
|
||||
|
||||
# Get the task summary
|
||||
task_summary = planner._create_tasks_summary()
|
||||
|
||||
# Verify that knowledge is included in planning when present
|
||||
assert "AI systems require careful training" in task_summary, \
|
||||
"Knowledge content should be present in task summary when knowledge exists"
|
||||
assert '"agent_knowledge"' in task_summary, \
|
||||
"agent_knowledge field should be present in task summary when knowledge exists"
|
||||
|
||||
# Verify that knowledge is properly formatted
|
||||
assert isinstance(task.agent.knowledge_sources, list), \
|
||||
"Knowledge sources should be stored in a list"
|
||||
assert len(task.agent.knowledge_sources) > 0, \
|
||||
"At least one knowledge source should be present"
|
||||
assert task.agent.knowledge_sources[0].content in task_summary, \
|
||||
"Knowledge source content should be included in task summary"
|
||||
|
||||
# Verify that other expected components are still present
|
||||
assert task.description in task_summary, \
|
||||
"Task description should be present in task summary"
|
||||
assert task.expected_output in task_summary, \
|
||||
"Expected output should be present in task summary"
|
||||
assert agent.role in task_summary, \
|
||||
"Agent role should be present in task summary"
|
||||
@@ -1,14 +1,10 @@
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.planning_handler import (
|
||||
CrewPlanner,
|
||||
PlannerTaskPydanticOutput,
|
||||
@@ -96,72 +92,7 @@ class TestCrewPlanner:
|
||||
tasks_summary = crew_planner._create_tasks_summary()
|
||||
assert isinstance(tasks_summary, str)
|
||||
assert tasks_summary.startswith("\n Task Number 1 - Task 1")
|
||||
assert '"agent_tools": "agent has no tools"' in tasks_summary
|
||||
# Knowledge field should not be present when empty
|
||||
assert '"agent_knowledge"' not in tasks_summary
|
||||
|
||||
@patch('crewai.knowledge.storage.knowledge_storage.chromadb')
|
||||
def test_create_tasks_summary_with_knowledge_and_tools(self, mock_chroma):
|
||||
"""Test task summary generation with both knowledge and tools present."""
|
||||
# Mock ChromaDB collection
|
||||
mock_collection = mock_chroma.return_value.get_or_create_collection.return_value
|
||||
mock_collection.add.return_value = None
|
||||
|
||||
# Create mock tools with proper string descriptions and structured tool support
|
||||
class MockTool(BaseTool):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
def __init__(self, name: str, description: str):
|
||||
tool_data = {"name": name, "description": description}
|
||||
super().__init__(**tool_data)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
def to_structured_tool(self):
|
||||
return self
|
||||
|
||||
def _run(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _generate_description(self) -> str:
|
||||
"""Override _generate_description to avoid args_schema handling."""
|
||||
return self.description
|
||||
|
||||
tool1 = MockTool("tool1", "Tool 1 description")
|
||||
tool2 = MockTool("tool2", "Tool 2 description")
|
||||
|
||||
# Create a task with knowledge and tools
|
||||
task = Task(
|
||||
description="Task with knowledge and tools",
|
||||
expected_output="Expected output",
|
||||
agent=Agent(
|
||||
role="Test Agent",
|
||||
goal="Test Goal",
|
||||
backstory="Test Backstory",
|
||||
tools=[tool1, tool2],
|
||||
knowledge_sources=[
|
||||
StringKnowledgeSource(content="Test knowledge content")
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Create planner with the new task
|
||||
planner = CrewPlanner([task], None)
|
||||
tasks_summary = planner._create_tasks_summary()
|
||||
|
||||
# Verify task summary content
|
||||
assert isinstance(tasks_summary, str)
|
||||
assert task.description in tasks_summary
|
||||
assert task.expected_output in tasks_summary
|
||||
assert '"agent_tools": [tool1, tool2]' in tasks_summary
|
||||
assert '"agent_knowledge": "[\\"Test knowledge content\\"]"' in tasks_summary
|
||||
assert task.agent.role in tasks_summary
|
||||
assert task.agent.goal in tasks_summary
|
||||
assert tasks_summary.endswith('"agent_tools": []\n ')
|
||||
|
||||
def test_handle_crew_planning_different_llm(self, crew_planner_different_llm):
|
||||
with patch.object(Task, "execute_sync") as execute:
|
||||
|
||||
68
uv.lock
generated
68
uv.lock
generated
@@ -1,10 +1,18 @@
|
||||
version = 1
|
||||
requires-python = ">=3.10, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version < '3.11'",
|
||||
"python_full_version == '3.11.*'",
|
||||
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
|
||||
"python_full_version >= '3.12.4'",
|
||||
"python_full_version < '3.11' and sys_platform == 'darwin'",
|
||||
"python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
|
||||
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
|
||||
"python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
"python_full_version >= '3.12.4' and sys_platform == 'darwin'",
|
||||
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
|
||||
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -300,7 +308,7 @@ name = "build"
|
||||
version = "1.2.2.post1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "os_name == 'nt'" },
|
||||
{ name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "importlib-metadata", marker = "python_full_version < '3.10.2'" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pyproject-hooks" },
|
||||
@@ -535,7 +543,7 @@ name = "click"
|
||||
version = "8.1.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
|
||||
wheels = [
|
||||
@@ -642,7 +650,6 @@ tools = [
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "cairosvg" },
|
||||
{ name = "crewai-tools" },
|
||||
{ name = "mkdocs" },
|
||||
{ name = "mkdocs-material" },
|
||||
{ name = "mkdocs-material-extensions" },
|
||||
@@ -696,7 +703,6 @@ requires-dist = [
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "cairosvg", specifier = ">=2.7.1" },
|
||||
{ name = "crewai-tools", specifier = ">=0.17.0" },
|
||||
{ name = "mkdocs", specifier = ">=1.4.3" },
|
||||
{ name = "mkdocs-material", specifier = ">=9.5.7" },
|
||||
{ name = "mkdocs-material-extensions", specifier = ">=1.3.1" },
|
||||
@@ -2462,7 +2468,7 @@ version = "1.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "ghp-import" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "markdown" },
|
||||
@@ -2643,7 +2649,7 @@ version = "2.10.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pygments" },
|
||||
{ name = "pywin32", marker = "platform_system == 'Windows'" },
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3a/93/80ac75c20ce54c785648b4ed363c88f148bf22637e10c9863db4fbe73e74/mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97", size = 271270 }
|
||||
@@ -2890,7 +2896,7 @@ name = "nvidia-cudnn-cu12"
|
||||
version = "9.1.0.70"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
|
||||
@@ -2917,9 +2923,9 @@ name = "nvidia-cusolver-cu12"
|
||||
version = "11.4.5.107"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
|
||||
@@ -2930,7 +2936,7 @@ name = "nvidia-cusparse-cu12"
|
||||
version = "12.1.0.106"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
|
||||
@@ -3480,7 +3486,7 @@ name = "portalocker"
|
||||
version = "2.10.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "platform_system == 'Windows'" },
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
|
||||
wheels = [
|
||||
@@ -5022,19 +5028,19 @@ dependencies = [
|
||||
{ name = "fsspec" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "networkx" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "sympy" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
wheels = [
|
||||
@@ -5081,7 +5087,7 @@ name = "tqdm"
|
||||
version = "4.66.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "platform_system == 'Windows'" },
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 }
|
||||
wheels = [
|
||||
@@ -5124,7 +5130,7 @@ version = "0.27.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "attrs" },
|
||||
{ name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" },
|
||||
{ name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "idna" },
|
||||
{ name = "outcome" },
|
||||
@@ -5155,7 +5161,7 @@ name = "triton"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
|
||||
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },
|
||||
|
||||
Reference in New Issue
Block a user