mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Brandon/new release cleanup (#1918)
* WIP * fixes to match enterprise changes
This commit is contained in:
committed by
GitHub
parent
4a44245de9
commit
627bb3f5f6
@@ -2,26 +2,26 @@ from crewai.types.usage_metrics import UsageMetrics
|
|||||||
|
|
||||||
|
|
||||||
class TokenProcess:
|
class TokenProcess:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.total_tokens: int = 0
|
self.total_tokens: int = 0
|
||||||
self.prompt_tokens: int = 0
|
self.prompt_tokens: int = 0
|
||||||
self.cached_prompt_tokens: int = 0
|
self.cached_prompt_tokens: int = 0
|
||||||
self.completion_tokens: int = 0
|
self.completion_tokens: int = 0
|
||||||
self.successful_requests: int = 0
|
self.successful_requests: int = 0
|
||||||
|
|
||||||
def sum_prompt_tokens(self, tokens: int):
|
def sum_prompt_tokens(self, tokens: int) -> None:
|
||||||
self.prompt_tokens = self.prompt_tokens + tokens
|
self.prompt_tokens += tokens
|
||||||
self.total_tokens = self.total_tokens + tokens
|
self.total_tokens += tokens
|
||||||
|
|
||||||
def sum_completion_tokens(self, tokens: int):
|
def sum_completion_tokens(self, tokens: int) -> None:
|
||||||
self.completion_tokens = self.completion_tokens + tokens
|
self.completion_tokens += tokens
|
||||||
self.total_tokens = self.total_tokens + tokens
|
self.total_tokens += tokens
|
||||||
|
|
||||||
def sum_cached_prompt_tokens(self, tokens: int):
|
def sum_cached_prompt_tokens(self, tokens: int) -> None:
|
||||||
self.cached_prompt_tokens = self.cached_prompt_tokens + tokens
|
self.cached_prompt_tokens += tokens
|
||||||
|
|
||||||
def sum_successful_requests(self, requests: int):
|
def sum_successful_requests(self, requests: int) -> None:
|
||||||
self.successful_requests = self.successful_requests + requests
|
self.successful_requests += requests
|
||||||
|
|
||||||
def get_summary(self) -> UsageMetrics:
|
def get_summary(self) -> UsageMetrics:
|
||||||
return UsageMetrics(
|
return UsageMetrics(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import uuid
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@@ -13,7 +12,6 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
|
||||||
)
|
)
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@@ -27,7 +25,6 @@ from crewai.flow.flow_events import (
|
|||||||
MethodExecutionStartedEvent,
|
MethodExecutionStartedEvent,
|
||||||
)
|
)
|
||||||
from crewai.flow.flow_visualizer import plot_flow
|
from crewai.flow.flow_visualizer import plot_flow
|
||||||
from crewai.flow.persistence import FlowPersistence
|
|
||||||
from crewai.flow.persistence.base import FlowPersistence
|
from crewai.flow.persistence.base import FlowPersistence
|
||||||
from crewai.flow.utils import get_possible_return_constants
|
from crewai.flow.utils import get_possible_return_constants
|
||||||
from crewai.telemetry import Telemetry
|
from crewai.telemetry import Telemetry
|
||||||
@@ -35,22 +32,32 @@ from crewai.telemetry import Telemetry
|
|||||||
|
|
||||||
class FlowState(BaseModel):
|
class FlowState(BaseModel):
|
||||||
"""Base model for all flow states, ensuring each state has a unique ID."""
|
"""Base model for all flow states, ensuring each state has a unique ID."""
|
||||||
id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the flow state")
|
|
||||||
|
id: str = Field(
|
||||||
|
default_factory=lambda: str(uuid4()),
|
||||||
|
description="Unique identifier for the flow state",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Type variables with explicit bounds
|
# Type variables with explicit bounds
|
||||||
T = TypeVar("T", bound=Union[Dict[str, Any], BaseModel]) # Generic flow state type parameter
|
T = TypeVar(
|
||||||
StateT = TypeVar("StateT", bound=Union[Dict[str, Any], BaseModel]) # State validation type parameter
|
"T", bound=Union[Dict[str, Any], BaseModel]
|
||||||
|
) # Generic flow state type parameter
|
||||||
|
StateT = TypeVar(
|
||||||
|
"StateT", bound=Union[Dict[str, Any], BaseModel]
|
||||||
|
) # State validation type parameter
|
||||||
|
|
||||||
|
|
||||||
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||||
"""Ensure state matches expected type with proper validation.
|
"""Ensure state matches expected type with proper validation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: State instance to validate
|
state: State instance to validate
|
||||||
expected_type: Expected type for the state
|
expected_type: Expected type for the state
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Validated state instance
|
Validated state instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If state doesn't match expected type
|
TypeError: If state doesn't match expected type
|
||||||
ValueError: If state validation fails
|
ValueError: If state validation fails
|
||||||
@@ -68,13 +75,15 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
|||||||
TypeError: If state doesn't match expected type
|
TypeError: If state doesn't match expected type
|
||||||
ValueError: If state validation fails
|
ValueError: If state validation fails
|
||||||
"""
|
"""
|
||||||
if expected_type == dict:
|
if expected_type is dict:
|
||||||
if not isinstance(state, dict):
|
if not isinstance(state, dict):
|
||||||
raise TypeError(f"Expected dict, got {type(state).__name__}")
|
raise TypeError(f"Expected dict, got {type(state).__name__}")
|
||||||
return cast(StateT, state)
|
return cast(StateT, state)
|
||||||
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
|
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
|
||||||
if not isinstance(state, expected_type):
|
if not isinstance(state, expected_type):
|
||||||
raise TypeError(f"Expected {expected_type.__name__}, got {type(state).__name__}")
|
raise TypeError(
|
||||||
|
f"Expected {expected_type.__name__}, got {type(state).__name__}"
|
||||||
|
)
|
||||||
return cast(StateT, state)
|
return cast(StateT, state)
|
||||||
raise TypeError(f"Invalid expected_type: {expected_type}")
|
raise TypeError(f"Invalid expected_type: {expected_type}")
|
||||||
|
|
||||||
@@ -120,6 +129,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
|||||||
>>> def complex_start(self):
|
>>> def complex_start(self):
|
||||||
... pass
|
... pass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
func.__is_start_method__ = True
|
func.__is_start_method__ = True
|
||||||
if condition is not None:
|
if condition is not None:
|
||||||
@@ -144,6 +154,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def listen(condition: Union[str, dict, Callable]) -> Callable:
|
def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||||
"""
|
"""
|
||||||
Creates a listener that executes when specified conditions are met.
|
Creates a listener that executes when specified conditions are met.
|
||||||
@@ -180,6 +191,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
|||||||
>>> def handle_completion(self):
|
>>> def handle_completion(self):
|
||||||
... pass
|
... pass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
if isinstance(condition, str):
|
if isinstance(condition, str):
|
||||||
func.__trigger_methods__ = [condition]
|
func.__trigger_methods__ = [condition]
|
||||||
@@ -244,6 +256,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
|
|||||||
... return CONTINUE
|
... return CONTINUE
|
||||||
... return STOP
|
... return STOP
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
func.__is_router__ = True
|
func.__is_router__ = True
|
||||||
if isinstance(condition, str):
|
if isinstance(condition, str):
|
||||||
@@ -267,6 +280,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||||
"""
|
"""
|
||||||
Combines multiple conditions with OR logic for flow control.
|
Combines multiple conditions with OR logic for flow control.
|
||||||
@@ -370,22 +384,27 @@ class FlowMeta(type):
|
|||||||
|
|
||||||
for attr_name, attr_value in dct.items():
|
for attr_name, attr_value in dct.items():
|
||||||
# Check for any flow-related attributes
|
# Check for any flow-related attributes
|
||||||
if (hasattr(attr_value, "__is_flow_method__") or
|
if (
|
||||||
hasattr(attr_value, "__is_start_method__") or
|
hasattr(attr_value, "__is_flow_method__")
|
||||||
hasattr(attr_value, "__trigger_methods__") or
|
or hasattr(attr_value, "__is_start_method__")
|
||||||
hasattr(attr_value, "__is_router__")):
|
or hasattr(attr_value, "__trigger_methods__")
|
||||||
|
or hasattr(attr_value, "__is_router__")
|
||||||
|
):
|
||||||
|
|
||||||
# Register start methods
|
# Register start methods
|
||||||
if hasattr(attr_value, "__is_start_method__"):
|
if hasattr(attr_value, "__is_start_method__"):
|
||||||
start_methods.append(attr_name)
|
start_methods.append(attr_name)
|
||||||
|
|
||||||
# Register listeners and routers
|
# Register listeners and routers
|
||||||
if hasattr(attr_value, "__trigger_methods__"):
|
if hasattr(attr_value, "__trigger_methods__"):
|
||||||
methods = attr_value.__trigger_methods__
|
methods = attr_value.__trigger_methods__
|
||||||
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
||||||
listeners[attr_name] = (condition_type, methods)
|
listeners[attr_name] = (condition_type, methods)
|
||||||
|
|
||||||
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
|
if (
|
||||||
|
hasattr(attr_value, "__is_router__")
|
||||||
|
and attr_value.__is_router__
|
||||||
|
):
|
||||||
routers.add(attr_name)
|
routers.add(attr_name)
|
||||||
possible_returns = get_possible_return_constants(attr_value)
|
possible_returns = get_possible_return_constants(attr_value)
|
||||||
if possible_returns:
|
if possible_returns:
|
||||||
@@ -401,8 +420,9 @@ class FlowMeta(type):
|
|||||||
|
|
||||||
class Flow(Generic[T], metaclass=FlowMeta):
|
class Flow(Generic[T], metaclass=FlowMeta):
|
||||||
"""Base class for all flows.
|
"""Base class for all flows.
|
||||||
|
|
||||||
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
|
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
|
||||||
|
|
||||||
_telemetry = Telemetry()
|
_telemetry = Telemetry()
|
||||||
|
|
||||||
_start_methods: List[str] = []
|
_start_methods: List[str] = []
|
||||||
@@ -426,7 +446,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a new Flow instance.
|
"""Initialize a new Flow instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
persistence: Optional persistence backend for storing flow states
|
persistence: Optional persistence backend for storing flow states
|
||||||
restore_uuid: Optional UUID to restore state from persistence
|
restore_uuid: Optional UUID to restore state from persistence
|
||||||
@@ -438,29 +458,38 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||||
self._persistence: Optional[FlowPersistence] = persistence
|
self._persistence: Optional[FlowPersistence] = persistence
|
||||||
|
|
||||||
# Validate state model before initialization
|
# Validate state model before initialization
|
||||||
if isinstance(self.initial_state, type):
|
if isinstance(self.initial_state, type):
|
||||||
if issubclass(self.initial_state, BaseModel) and not issubclass(self.initial_state, FlowState):
|
if issubclass(self.initial_state, BaseModel) and not issubclass(
|
||||||
|
self.initial_state, FlowState
|
||||||
|
):
|
||||||
# Check if model has id field
|
# Check if model has id field
|
||||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||||
if not model_fields or "id" not in model_fields:
|
if not model_fields or "id" not in model_fields:
|
||||||
raise ValueError("Flow state model must have an 'id' field")
|
raise ValueError("Flow state model must have an 'id' field")
|
||||||
|
|
||||||
# Handle persistence and potential ID conflicts
|
# Handle persistence and potential ID conflicts
|
||||||
stored_state = None
|
stored_state = None
|
||||||
if self._persistence is not None:
|
if self._persistence is not None:
|
||||||
if restore_uuid and kwargs and "id" in kwargs and restore_uuid != kwargs["id"]:
|
if (
|
||||||
|
restore_uuid
|
||||||
|
and kwargs
|
||||||
|
and "id" in kwargs
|
||||||
|
and restore_uuid != kwargs["id"]
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Conflicting IDs provided: restore_uuid='{restore_uuid}' "
|
f"Conflicting IDs provided: restore_uuid='{restore_uuid}' "
|
||||||
f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration."
|
f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attempt to load state, prioritizing restore_uuid
|
# Attempt to load state, prioritizing restore_uuid
|
||||||
if restore_uuid:
|
if restore_uuid:
|
||||||
stored_state = self._persistence.load_state(restore_uuid)
|
stored_state = self._persistence.load_state(restore_uuid)
|
||||||
if not stored_state:
|
if not stored_state:
|
||||||
raise ValueError(f"No state found for restore_uuid='{restore_uuid}'")
|
raise ValueError(
|
||||||
|
f"No state found for restore_uuid='{restore_uuid}'"
|
||||||
|
)
|
||||||
elif kwargs and "id" in kwargs:
|
elif kwargs and "id" in kwargs:
|
||||||
stored_state = self._persistence.load_state(kwargs["id"])
|
stored_state = self._persistence.load_state(kwargs["id"])
|
||||||
if not stored_state:
|
if not stored_state:
|
||||||
@@ -469,7 +498,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if kwargs:
|
if kwargs:
|
||||||
self._initialize_state(kwargs)
|
self._initialize_state(kwargs)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Initialize state based on persistence and kwargs
|
# Initialize state based on persistence and kwargs
|
||||||
if stored_state:
|
if stored_state:
|
||||||
# Create initial state and restore from persistence
|
# Create initial state and restore from persistence
|
||||||
@@ -494,23 +523,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if not method_name.startswith("_"):
|
if not method_name.startswith("_"):
|
||||||
method = getattr(self, method_name)
|
method = getattr(self, method_name)
|
||||||
# Check for any flow-related attributes
|
# Check for any flow-related attributes
|
||||||
if (hasattr(method, "__is_flow_method__") or
|
if (
|
||||||
hasattr(method, "__is_start_method__") or
|
hasattr(method, "__is_flow_method__")
|
||||||
hasattr(method, "__trigger_methods__") or
|
or hasattr(method, "__is_start_method__")
|
||||||
hasattr(method, "__is_router__")):
|
or hasattr(method, "__trigger_methods__")
|
||||||
|
or hasattr(method, "__is_router__")
|
||||||
|
):
|
||||||
# Ensure method is bound to this instance
|
# Ensure method is bound to this instance
|
||||||
if not hasattr(method, "__self__"):
|
if not hasattr(method, "__self__"):
|
||||||
method = method.__get__(self, self.__class__)
|
method = method.__get__(self, self.__class__)
|
||||||
self._methods[method_name] = method
|
self._methods[method_name] = method
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _create_initial_state(self) -> T:
|
def _create_initial_state(self) -> T:
|
||||||
"""Create and initialize flow state with UUID and default values.
|
"""Create and initialize flow state with UUID and default values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
New state instance with UUID and default values initialized
|
New state instance with UUID and default values initialized
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If structured state model lacks 'id' field
|
ValueError: If structured state model lacks 'id' field
|
||||||
TypeError: If state is neither BaseModel nor dictionary
|
TypeError: If state is neither BaseModel nor dictionary
|
||||||
@@ -522,24 +551,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if issubclass(state_type, FlowState):
|
if issubclass(state_type, FlowState):
|
||||||
# Create instance without id, then set it
|
# Create instance without id, then set it
|
||||||
instance = state_type()
|
instance = state_type()
|
||||||
if not hasattr(instance, 'id'):
|
if not hasattr(instance, "id"):
|
||||||
setattr(instance, 'id', str(uuid4()))
|
setattr(instance, "id", str(uuid4()))
|
||||||
return cast(T, instance)
|
return cast(T, instance)
|
||||||
elif issubclass(state_type, BaseModel):
|
elif issubclass(state_type, BaseModel):
|
||||||
# Create a new type that includes the ID field
|
# Create a new type that includes the ID field
|
||||||
class StateWithId(state_type, FlowState): # type: ignore
|
class StateWithId(state_type, FlowState): # type: ignore
|
||||||
pass
|
pass
|
||||||
|
|
||||||
instance = StateWithId()
|
instance = StateWithId()
|
||||||
if not hasattr(instance, 'id'):
|
if not hasattr(instance, "id"):
|
||||||
setattr(instance, 'id', str(uuid4()))
|
setattr(instance, "id", str(uuid4()))
|
||||||
return cast(T, instance)
|
return cast(T, instance)
|
||||||
elif state_type == dict:
|
elif state_type is dict:
|
||||||
return cast(T, {"id": str(uuid4())}) # Minimal dict state
|
return cast(T, {"id": str(uuid4())})
|
||||||
|
|
||||||
# Handle case where no initial state is provided
|
# Handle case where no initial state is provided
|
||||||
if self.initial_state is None:
|
if self.initial_state is None:
|
||||||
return cast(T, {"id": str(uuid4())})
|
return cast(T, {"id": str(uuid4())})
|
||||||
|
|
||||||
# Handle case where initial_state is a type (class)
|
# Handle case where initial_state is a type (class)
|
||||||
if isinstance(self.initial_state, type):
|
if isinstance(self.initial_state, type):
|
||||||
if issubclass(self.initial_state, FlowState):
|
if issubclass(self.initial_state, FlowState):
|
||||||
@@ -550,22 +580,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if not model_fields or "id" not in model_fields:
|
if not model_fields or "id" not in model_fields:
|
||||||
raise ValueError("Flow state model must have an 'id' field")
|
raise ValueError("Flow state model must have an 'id' field")
|
||||||
return cast(T, self.initial_state()) # Uses model defaults
|
return cast(T, self.initial_state()) # Uses model defaults
|
||||||
elif self.initial_state == dict:
|
elif self.initial_state is dict:
|
||||||
return cast(T, {"id": str(uuid4())})
|
return cast(T, {"id": str(uuid4())})
|
||||||
|
|
||||||
# Handle dictionary instance case
|
# Handle dictionary instance case
|
||||||
if isinstance(self.initial_state, dict):
|
if isinstance(self.initial_state, dict):
|
||||||
new_state = dict(self.initial_state) # Copy to avoid mutations
|
new_state = dict(self.initial_state) # Copy to avoid mutations
|
||||||
if "id" not in new_state:
|
if "id" not in new_state:
|
||||||
new_state["id"] = str(uuid4())
|
new_state["id"] = str(uuid4())
|
||||||
return cast(T, new_state)
|
return cast(T, new_state)
|
||||||
|
|
||||||
# Handle BaseModel instance case
|
# Handle BaseModel instance case
|
||||||
if isinstance(self.initial_state, BaseModel):
|
if isinstance(self.initial_state, BaseModel):
|
||||||
model = cast(BaseModel, self.initial_state)
|
model = cast(BaseModel, self.initial_state)
|
||||||
if not hasattr(model, "id"):
|
if not hasattr(model, "id"):
|
||||||
raise ValueError("Flow state model must have an 'id' field")
|
raise ValueError("Flow state model must have an 'id' field")
|
||||||
|
|
||||||
# Create new instance with same values to avoid mutations
|
# Create new instance with same values to avoid mutations
|
||||||
if hasattr(model, "model_dump"):
|
if hasattr(model, "model_dump"):
|
||||||
# Pydantic v2
|
# Pydantic v2
|
||||||
@@ -576,60 +606,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
else:
|
else:
|
||||||
# Fallback for other BaseModel implementations
|
# Fallback for other BaseModel implementations
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k: v for k, v in model.__dict__.items()
|
k: v for k, v in model.__dict__.items() if not k.startswith("_")
|
||||||
if not k.startswith("_")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create new instance of the same class
|
# Create new instance of the same class
|
||||||
model_class = type(model)
|
model_class = type(model)
|
||||||
return cast(T, model_class(**state_dict))
|
return cast(T, model_class(**state_dict))
|
||||||
|
|
||||||
raise TypeError(
|
|
||||||
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
|
|
||||||
)
|
|
||||||
# Handle case where initial_state is None but we have a type parameter
|
|
||||||
if self.initial_state is None and hasattr(self, "_initial_state_T"):
|
|
||||||
state_type = getattr(self, "_initial_state_T")
|
|
||||||
if isinstance(state_type, type):
|
|
||||||
if issubclass(state_type, FlowState):
|
|
||||||
return cast(T, state_type())
|
|
||||||
elif issubclass(state_type, BaseModel):
|
|
||||||
# Create a new type that includes the ID field
|
|
||||||
class StateWithId(state_type, FlowState): # type: ignore
|
|
||||||
pass
|
|
||||||
return cast(T, StateWithId())
|
|
||||||
elif state_type == dict:
|
|
||||||
return cast(T, {"id": str(uuid4())})
|
|
||||||
|
|
||||||
# Handle case where no initial state is provided
|
|
||||||
if self.initial_state is None:
|
|
||||||
return cast(T, {"id": str(uuid4())})
|
|
||||||
|
|
||||||
# Handle case where initial_state is a type (class)
|
|
||||||
if isinstance(self.initial_state, type):
|
|
||||||
if issubclass(self.initial_state, FlowState):
|
|
||||||
return cast(T, self.initial_state())
|
|
||||||
elif issubclass(self.initial_state, BaseModel):
|
|
||||||
# Validate that the model has an id field
|
|
||||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
|
||||||
if not model_fields or "id" not in model_fields:
|
|
||||||
raise ValueError("Flow state model must have an 'id' field")
|
|
||||||
return cast(T, self.initial_state())
|
|
||||||
elif self.initial_state == dict:
|
|
||||||
return cast(T, {"id": str(uuid4())})
|
|
||||||
|
|
||||||
# Handle dictionary instance case
|
|
||||||
if isinstance(self.initial_state, dict):
|
|
||||||
if "id" not in self.initial_state:
|
|
||||||
self.initial_state["id"] = str(uuid4())
|
|
||||||
return cast(T, dict(self.initial_state)) # Create new dict to avoid mutations
|
|
||||||
|
|
||||||
# Handle BaseModel instance case
|
|
||||||
if isinstance(self.initial_state, BaseModel):
|
|
||||||
if not hasattr(self.initial_state, "id"):
|
|
||||||
raise ValueError("Flow state model must have an 'id' field")
|
|
||||||
return cast(T, self.initial_state)
|
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
|
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
|
||||||
)
|
)
|
||||||
@@ -645,10 +628,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||||
"""Initialize or update flow state with new inputs.
|
"""Initialize or update flow state with new inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: Dictionary of state values to set/update
|
inputs: Dictionary of state values to set/update
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If validation fails for structured state
|
ValueError: If validation fails for structured state
|
||||||
TypeError: If state is neither BaseModel nor dictionary
|
TypeError: If state is neither BaseModel nor dictionary
|
||||||
@@ -675,13 +658,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
current_state = model.dict()
|
current_state = model.dict()
|
||||||
else:
|
else:
|
||||||
current_state = {
|
current_state = {
|
||||||
k: v for k, v in model.__dict__.items()
|
k: v for k, v in model.__dict__.items() if not k.startswith("_")
|
||||||
if not k.startswith("_")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create new state with preserved fields and updates
|
# Create new state with preserved fields and updates
|
||||||
new_state = {**current_state, **inputs}
|
new_state = {**current_state, **inputs}
|
||||||
|
|
||||||
# Create new instance with merged state
|
# Create new instance with merged state
|
||||||
model_class = type(model)
|
model_class = type(model)
|
||||||
if hasattr(model_class, "model_validate"):
|
if hasattr(model_class, "model_validate"):
|
||||||
@@ -697,13 +679,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
||||||
else:
|
else:
|
||||||
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
||||||
|
|
||||||
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
|
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
|
||||||
"""Restore flow state from persistence.
|
"""Restore flow state from persistence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stored_state: Previously stored state to restore
|
stored_state: Previously stored state to restore
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If validation fails for structured state
|
ValueError: If validation fails for structured state
|
||||||
TypeError: If state is neither BaseModel nor dictionary
|
TypeError: If state is neither BaseModel nor dictionary
|
||||||
@@ -712,7 +694,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
stored_id = stored_state.get("id")
|
stored_id = stored_state.get("id")
|
||||||
if not stored_id:
|
if not stored_id:
|
||||||
raise ValueError("Stored state must have an 'id' field")
|
raise ValueError("Stored state must have an 'id' field")
|
||||||
|
|
||||||
if isinstance(self._state, dict):
|
if isinstance(self._state, dict):
|
||||||
# For dict states, update all fields from stored state
|
# For dict states, update all fields from stored state
|
||||||
self._state.clear()
|
self._state.clear()
|
||||||
@@ -730,9 +712,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
# Fallback for other BaseModel implementations
|
# Fallback for other BaseModel implementations
|
||||||
self._state = cast(T, type(model)(**stored_state))
|
self._state = cast(T, type(model)(**stored_state))
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
|
||||||
f"State must be dict or BaseModel, got {type(self._state)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||||
self.event_emitter.send(
|
self.event_emitter.send(
|
||||||
|
|||||||
@@ -24,10 +24,12 @@ def create_llm(
|
|||||||
|
|
||||||
# 1) If llm_value is already an LLM object, return it directly
|
# 1) If llm_value is already an LLM object, return it directly
|
||||||
if isinstance(llm_value, LLM):
|
if isinstance(llm_value, LLM):
|
||||||
|
print("LLM value is already an LLM object")
|
||||||
return llm_value
|
return llm_value
|
||||||
|
|
||||||
# 2) If llm_value is a string (model name)
|
# 2) If llm_value is a string (model name)
|
||||||
if isinstance(llm_value, str):
|
if isinstance(llm_value, str):
|
||||||
|
print("LLM value is a string")
|
||||||
try:
|
try:
|
||||||
created_llm = LLM(model=llm_value)
|
created_llm = LLM(model=llm_value)
|
||||||
return created_llm
|
return created_llm
|
||||||
@@ -37,10 +39,12 @@ def create_llm(
|
|||||||
|
|
||||||
# 3) If llm_value is None, parse environment variables or use default
|
# 3) If llm_value is None, parse environment variables or use default
|
||||||
if llm_value is None:
|
if llm_value is None:
|
||||||
|
print("LLM value is None")
|
||||||
return _llm_via_environment_or_fallback()
|
return _llm_via_environment_or_fallback()
|
||||||
|
|
||||||
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||||
try:
|
try:
|
||||||
|
print("LLM value is an unknown object")
|
||||||
# Extract attributes with explicit types
|
# Extract attributes with explicit types
|
||||||
model = (
|
model = (
|
||||||
getattr(llm_value, "model_name", None)
|
getattr(llm_value, "model_name", None)
|
||||||
|
|||||||
@@ -114,35 +114,6 @@ def test_custom_llm_temperature_preservation():
|
|||||||
assert agent.llm.temperature == 0.7
|
assert agent.llm.temperature == 0.7
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
||||||
def test_agent_execute_task():
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from crewai import Task
|
|
||||||
|
|
||||||
agent = Agent(
|
|
||||||
role="Math Tutor",
|
|
||||||
goal="Solve math problems accurately",
|
|
||||||
backstory="You are an experienced math tutor with a knack for explaining complex concepts simply.",
|
|
||||||
llm=ChatOpenAI(temperature=0.7, model="gpt-4o-mini"),
|
|
||||||
)
|
|
||||||
|
|
||||||
task = Task(
|
|
||||||
description="Calculate the area of a circle with radius 5 cm.",
|
|
||||||
expected_output="The calculated area of the circle in square centimeters.",
|
|
||||||
agent=agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = agent.execute_task(task)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert (
|
|
||||||
result
|
|
||||||
== "The calculated area of the circle is approximately 78.5 square centimeters."
|
|
||||||
)
|
|
||||||
assert "square centimeters" in result.lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_agent_execution():
|
def test_agent_execution():
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
|
|||||||
@@ -1,121 +0,0 @@
|
|||||||
interactions:
|
|
||||||
- request:
|
|
||||||
body: '{"messages": [{"role": "system", "content": "You are Math Tutor. You are
|
|
||||||
an experienced math tutor with a knack for explaining complex concepts simply.\nYour
|
|
||||||
personal goal is: Solve math problems accurately\nTo give my best complete final
|
|
||||||
answer to the task use the exact following format:\n\nThought: I now can give
|
|
||||||
a great answer\nFinal Answer: Your final answer must be the great and the most
|
|
||||||
complete as possible, it must be outcome described.\n\nI MUST use these formats,
|
|
||||||
my job depends on it!"}, {"role": "user", "content": "\nCurrent Task: Calculate
|
|
||||||
the area of a circle with radius 5 cm.\n\nThis is the expect criteria for your
|
|
||||||
final answer: The calculated area of the circle in square centimeters.\nyou
|
|
||||||
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
|
|
||||||
This is VERY important to you, use the tools available and give your best Final
|
|
||||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "temperature":
|
|
||||||
0.7}'
|
|
||||||
headers:
|
|
||||||
accept:
|
|
||||||
- application/json
|
|
||||||
accept-encoding:
|
|
||||||
- gzip, deflate
|
|
||||||
connection:
|
|
||||||
- keep-alive
|
|
||||||
content-length:
|
|
||||||
- '969'
|
|
||||||
content-type:
|
|
||||||
- application/json
|
|
||||||
host:
|
|
||||||
- api.openai.com
|
|
||||||
user-agent:
|
|
||||||
- OpenAI/Python 1.47.0
|
|
||||||
x-stainless-arch:
|
|
||||||
- arm64
|
|
||||||
x-stainless-async:
|
|
||||||
- 'false'
|
|
||||||
x-stainless-lang:
|
|
||||||
- python
|
|
||||||
x-stainless-os:
|
|
||||||
- MacOS
|
|
||||||
x-stainless-package-version:
|
|
||||||
- 1.47.0
|
|
||||||
x-stainless-raw-response:
|
|
||||||
- 'true'
|
|
||||||
x-stainless-runtime:
|
|
||||||
- CPython
|
|
||||||
x-stainless-runtime-version:
|
|
||||||
- 3.11.7
|
|
||||||
method: POST
|
|
||||||
uri: https://api.openai.com/v1/chat/completions
|
|
||||||
response:
|
|
||||||
content: "{\n \"id\": \"chatcmpl-AB7LEfa5gX4cncpI4avsK0CJG8pCb\",\n \"object\":
|
|
||||||
\"chat.completion\",\n \"created\": 1727213192,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
|
||||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
|
||||||
\"assistant\",\n \"content\": \"I now can give a great answer\\n\\nTo
|
|
||||||
calculate the area of a circle, we use the formula:\\n\\n\\\\[ A = \\\\pi r^2
|
|
||||||
\\\\]\\n\\nwhere \\\\( A \\\\) is the area, \\\\( \\\\pi \\\\) (approximately
|
|
||||||
3.14), and \\\\( r \\\\) is the radius of the circle.\\n\\nGiven that the radius
|
|
||||||
\\\\( r \\\\) is 5 cm, we can substitute this value into the formula:\\n\\n\\\\[
|
|
||||||
A = \\\\pi (5 \\\\, \\\\text{cm})^2 \\\\]\\n\\nCalculating this step-by-step:\\n\\n1.
|
|
||||||
First, square the radius:\\n \\\\[ (5 \\\\, \\\\text{cm})^2 = 25 \\\\, \\\\text{cm}^2
|
|
||||||
\\\\]\\n\\n2. Then, multiply by \\\\( \\\\pi \\\\):\\n \\\\[ A = \\\\pi \\\\times
|
|
||||||
25 \\\\, \\\\text{cm}^2 \\\\]\\n\\nUsing the approximate value of \\\\( \\\\pi
|
|
||||||
\\\\):\\n \\\\[ A \\\\approx 3.14 \\\\times 25 \\\\, \\\\text{cm}^2 \\\\]\\n
|
|
||||||
\ \\\\[ A \\\\approx 78.5 \\\\, \\\\text{cm}^2 \\\\]\\n\\nThus, the area of
|
|
||||||
the circle is approximately 78.5 square centimeters.\\n\\nFinal Answer: The
|
|
||||||
calculated area of the circle is approximately 78.5 square centimeters.\",\n
|
|
||||||
\ \"refusal\": null\n },\n \"logprobs\": null,\n \"finish_reason\":
|
|
||||||
\"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 182,\n \"completion_tokens\":
|
|
||||||
270,\n \"total_tokens\": 452,\n \"completion_tokens_details\": {\n \"reasoning_tokens\":
|
|
||||||
0\n }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n"
|
|
||||||
headers:
|
|
||||||
CF-Cache-Status:
|
|
||||||
- DYNAMIC
|
|
||||||
CF-RAY:
|
|
||||||
- 8c85da71fcac1cf3-GRU
|
|
||||||
Connection:
|
|
||||||
- keep-alive
|
|
||||||
Content-Encoding:
|
|
||||||
- gzip
|
|
||||||
Content-Type:
|
|
||||||
- application/json
|
|
||||||
Date:
|
|
||||||
- Tue, 24 Sep 2024 21:26:34 GMT
|
|
||||||
Server:
|
|
||||||
- cloudflare
|
|
||||||
Set-Cookie:
|
|
||||||
- __cf_bm=rb61BZH2ejzD5YPmLaEJqI7km71QqyNJGTVdNxBq6qk-1727213194-1.0.1.1-pJ49onmgX9IugEMuYQMralzD7oj_6W.CHbSu4Su1z3NyjTGYg.rhgJZWng8feFYah._oSnoYlkTjpK1Wd2C9FA;
|
|
||||||
path=/; expires=Tue, 24-Sep-24 21:56:34 GMT; domain=.api.openai.com; HttpOnly;
|
|
||||||
Secure; SameSite=None
|
|
||||||
- _cfuvid=lbRdAddVWV6W3f5Dm9SaOPWDUOxqtZBSPr_fTW26nEA-1727213194587-0.0.1.1-604800000;
|
|
||||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
|
||||||
Transfer-Encoding:
|
|
||||||
- chunked
|
|
||||||
X-Content-Type-Options:
|
|
||||||
- nosniff
|
|
||||||
access-control-expose-headers:
|
|
||||||
- X-Request-ID
|
|
||||||
openai-organization:
|
|
||||||
- crewai-iuxna1
|
|
||||||
openai-processing-ms:
|
|
||||||
- '2244'
|
|
||||||
openai-version:
|
|
||||||
- '2020-10-01'
|
|
||||||
strict-transport-security:
|
|
||||||
- max-age=31536000; includeSubDomains; preload
|
|
||||||
x-ratelimit-limit-requests:
|
|
||||||
- '30000'
|
|
||||||
x-ratelimit-limit-tokens:
|
|
||||||
- '150000000'
|
|
||||||
x-ratelimit-remaining-requests:
|
|
||||||
- '29999'
|
|
||||||
x-ratelimit-remaining-tokens:
|
|
||||||
- '149999774'
|
|
||||||
x-ratelimit-reset-requests:
|
|
||||||
- 2ms
|
|
||||||
x-ratelimit-reset-tokens:
|
|
||||||
- 0s
|
|
||||||
x-request-id:
|
|
||||||
- req_2e565b5f24c38968e4e923a47ecc6233
|
|
||||||
http_version: HTTP/1.1
|
|
||||||
status_code: 200
|
|
||||||
version: 1
|
|
||||||
@@ -3480,10 +3480,12 @@ def test_crew_guardrail_feedback_in_context():
|
|||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_before_kickoff_callback():
|
def test_before_kickoff_callback():
|
||||||
from crewai.project import CrewBase, agent, before_kickoff, crew, task
|
from crewai.project import CrewBase, agent, before_kickoff, task
|
||||||
|
|
||||||
@CrewBase
|
@CrewBase
|
||||||
class TestCrewClass:
|
class TestCrewClass:
|
||||||
|
from crewai.project import crew
|
||||||
|
|
||||||
agents_config = None
|
agents_config = None
|
||||||
tasks_config = None
|
tasks_config = None
|
||||||
|
|
||||||
@@ -3510,7 +3512,7 @@ def test_before_kickoff_callback():
|
|||||||
task = Task(
|
task = Task(
|
||||||
description="Test task description",
|
description="Test task description",
|
||||||
expected_output="Test expected output",
|
expected_output="Test expected output",
|
||||||
agent=self.my_agent(), # Use the agent instance
|
agent=self.my_agent(),
|
||||||
)
|
)
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -3520,28 +3522,30 @@ def test_before_kickoff_callback():
|
|||||||
|
|
||||||
test_crew_instance = TestCrewClass()
|
test_crew_instance = TestCrewClass()
|
||||||
|
|
||||||
crew = test_crew_instance.crew()
|
test_crew = test_crew_instance.crew()
|
||||||
|
|
||||||
# Verify that the before_kickoff_callbacks are set
|
# Verify that the before_kickoff_callbacks are set
|
||||||
assert len(crew.before_kickoff_callbacks) == 1
|
assert len(test_crew.before_kickoff_callbacks) == 1
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
inputs = {"initial": True}
|
inputs = {"initial": True}
|
||||||
|
|
||||||
# Call kickoff
|
# Call kickoff
|
||||||
crew.kickoff(inputs=inputs)
|
test_crew.kickoff(inputs=inputs)
|
||||||
|
|
||||||
# Check that the before_kickoff function was called and modified inputs
|
# Check that the before_kickoff function was called and modified inputs
|
||||||
assert test_crew_instance.inputs_modified
|
assert test_crew_instance.inputs_modified
|
||||||
assert inputs.get("modified") == True
|
assert inputs.get("modified")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_before_kickoff_without_inputs():
|
def test_before_kickoff_without_inputs():
|
||||||
from crewai.project import CrewBase, agent, before_kickoff, crew, task
|
from crewai.project import CrewBase, agent, before_kickoff, task
|
||||||
|
|
||||||
@CrewBase
|
@CrewBase
|
||||||
class TestCrewClass:
|
class TestCrewClass:
|
||||||
|
from crewai.project import crew
|
||||||
|
|
||||||
agents_config = None
|
agents_config = None
|
||||||
tasks_config = None
|
tasks_config = None
|
||||||
|
|
||||||
@@ -3579,12 +3583,12 @@ def test_before_kickoff_without_inputs():
|
|||||||
# Instantiate the class
|
# Instantiate the class
|
||||||
test_crew_instance = TestCrewClass()
|
test_crew_instance = TestCrewClass()
|
||||||
# Build the crew
|
# Build the crew
|
||||||
crew = test_crew_instance.crew()
|
test_crew = test_crew_instance.crew()
|
||||||
# Verify that the before_kickoff_callback is registered
|
# Verify that the before_kickoff_callback is registered
|
||||||
assert len(crew.before_kickoff_callbacks) == 1
|
assert len(test_crew.before_kickoff_callbacks) == 1
|
||||||
|
|
||||||
# Call kickoff without passing inputs
|
# Call kickoff without passing inputs
|
||||||
output = crew.kickoff()
|
test_crew.kickoff()
|
||||||
|
|
||||||
# Check that the before_kickoff function was called
|
# Check that the before_kickoff function was called
|
||||||
assert test_crew_instance.inputs_modified
|
assert test_crew_instance.inputs_modified
|
||||||
|
|||||||
Reference in New Issue
Block a user