mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 13:58:15 +00:00
Merge branch 'main' of github.com:crewAIInc/crewAI into lorenze/imp/native-tool-calling
This commit is contained in:
@@ -246,6 +246,7 @@ class Agent(BaseAgent):
|
||||
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
|
||||
""",
|
||||
)
|
||||
executor_class: type[CrewAgentExecutor] | type[AgentExecutor] = Field(
|
||||
executor_class: type[CrewAgentExecutor] | type[AgentExecutor] = Field(
|
||||
default=CrewAgentExecutor,
|
||||
description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use AgentExecutor.",
|
||||
@@ -1615,6 +1616,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return None
|
||||
|
||||
def _prepare_kickoff(
|
||||
def _prepare_kickoff(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
@@ -88,7 +86,7 @@ class AgentReActState(BaseModel):
|
||||
|
||||
|
||||
class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
"""Flow-based agent executor for both standalone and crew-bound execution.
|
||||
"""Agent Executor for both standalone agents and crew-bound agents.
|
||||
|
||||
Inherits from:
|
||||
- Flow[AgentReActState]: Provides flow orchestration capabilities
|
||||
@@ -199,11 +197,6 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
|
||||
# Native tool calling support
|
||||
self._openai_tools: list[dict[str, Any]] = []
|
||||
self._available_functions: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
self._state = AgentReActState()
|
||||
|
||||
def _ensure_flow_initialized(self) -> None:
|
||||
@@ -949,6 +942,91 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
finally:
|
||||
self._is_executing = False
|
||||
|
||||
async def invoke_async(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute agent asynchronously with given inputs.
|
||||
|
||||
This method is designed for use within async contexts, such as when
|
||||
the agent is called from within an async Flow method. It uses
|
||||
kickoff_async() directly instead of running in a separate thread.
|
||||
|
||||
Args:
|
||||
inputs: Input dictionary containing prompt variables.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent output, or a coroutine if inside an event loop.
|
||||
"""
|
||||
# Magic auto-async: if inside event loop, return coroutine for Flow to await
|
||||
if is_inside_event_loop():
|
||||
return self.invoke_async(inputs)
|
||||
|
||||
self._ensure_flow_initialized()
|
||||
|
||||
with self._execution_lock:
|
||||
if self._is_executing:
|
||||
raise RuntimeError(
|
||||
"Executor is already running. "
|
||||
"Cannot invoke the same executor instance concurrently."
|
||||
)
|
||||
self._is_executing = True
|
||||
self._has_been_invoked = True
|
||||
|
||||
try:
|
||||
# Reset state for fresh execution
|
||||
self.state.messages.clear()
|
||||
self.state.iterations = 0
|
||||
self.state.current_answer = None
|
||||
self.state.is_finished = False
|
||||
|
||||
if "system" in self.prompt:
|
||||
prompt = cast("SystemPromptResult", self.prompt)
|
||||
system_prompt = self._format_prompt(prompt["system"], inputs)
|
||||
user_prompt = self._format_prompt(prompt["user"], inputs)
|
||||
self.state.messages.append(
|
||||
format_message_for_llm(system_prompt, role="system")
|
||||
)
|
||||
self.state.messages.append(format_message_for_llm(user_prompt))
|
||||
else:
|
||||
user_prompt = self._format_prompt(self.prompt["prompt"], inputs)
|
||||
self.state.messages.append(format_message_for_llm(user_prompt))
|
||||
|
||||
self.state.ask_for_human_input = bool(
|
||||
inputs.get("ask_for_human_input", False)
|
||||
)
|
||||
|
||||
# Use async kickoff directly since we're already in an async context
|
||||
await self.kickoff_async()
|
||||
|
||||
formatted_answer = self.state.current_answer
|
||||
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer."
|
||||
)
|
||||
|
||||
if self.state.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
self._create_external_memory(formatted_answer)
|
||||
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
except AssertionError:
|
||||
fail_text = Text()
|
||||
fail_text.append("❌ ", style="red bold")
|
||||
fail_text.append(
|
||||
"Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||
style="red",
|
||||
)
|
||||
self._console.print(fail_text)
|
||||
raise
|
||||
except Exception as e:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
finally:
|
||||
self._is_executing = False
|
||||
|
||||
async def invoke_async(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute agent asynchronously with given inputs.
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from concurrent.futures import Future
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -64,6 +65,7 @@ from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey
|
||||
from crewai.flow.utils import (
|
||||
_extract_all_methods,
|
||||
_extract_all_methods_recursive,
|
||||
_normalize_condition,
|
||||
get_possible_return_constants,
|
||||
is_flow_condition_dict,
|
||||
@@ -397,6 +399,62 @@ def and_(*conditions: str | FlowCondition | Callable[..., Any]) -> FlowCondition
|
||||
return {"type": AND_CONDITION, "conditions": processed_conditions}
|
||||
|
||||
|
||||
class StateProxy(Generic[T]):
|
||||
"""Proxy that provides thread-safe access to flow state.
|
||||
|
||||
Wraps state objects (dict or BaseModel) and uses a lock for all write
|
||||
operations to prevent race conditions when parallel listeners modify state.
|
||||
"""
|
||||
|
||||
__slots__ = ("_proxy_lock", "_proxy_state")
|
||||
|
||||
def __init__(self, state: T, lock: threading.Lock) -> None:
|
||||
object.__setattr__(self, "_proxy_state", state)
|
||||
object.__setattr__(self, "_proxy_lock", lock)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(object.__getattribute__(self, "_proxy_state"), name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in ("_proxy_state", "_proxy_lock"):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
with object.__getattribute__(self, "_proxy_lock"):
|
||||
setattr(object.__getattribute__(self, "_proxy_state"), name, value)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return object.__getattribute__(self, "_proxy_state")[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
with object.__getattribute__(self, "_proxy_lock"):
|
||||
object.__getattribute__(self, "_proxy_state")[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
with object.__getattribute__(self, "_proxy_lock"):
|
||||
del object.__getattribute__(self, "_proxy_state")[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in object.__getattribute__(self, "_proxy_state")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(object.__getattribute__(self, "_proxy_state"))
|
||||
|
||||
def _unwrap(self) -> T:
|
||||
"""Return the underlying state object."""
|
||||
return cast(T, object.__getattribute__(self, "_proxy_state"))
|
||||
|
||||
def model_dump(self) -> dict[str, Any]:
|
||||
"""Return state as a dictionary.
|
||||
|
||||
Works for both dict and BaseModel underlying states.
|
||||
"""
|
||||
state = object.__getattribute__(self, "_proxy_state")
|
||||
if isinstance(state, dict):
|
||||
return state
|
||||
result: dict[str, Any] = state.model_dump()
|
||||
return result
|
||||
|
||||
|
||||
class FlowMeta(type):
|
||||
def __new__(
|
||||
mcs,
|
||||
@@ -524,6 +582,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
set()
|
||||
) # Track OR listeners that already fired
|
||||
self._method_outputs: list[Any] = [] # list to store all method outputs
|
||||
self._state_lock = threading.Lock()
|
||||
self._or_listeners_lock = threading.Lock()
|
||||
self._completed_methods: set[FlowMethodName] = (
|
||||
set()
|
||||
) # Track completed methods for reload
|
||||
@@ -568,6 +628,175 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
method = method.__get__(self, self.__class__)
|
||||
self._methods[method.__name__] = method
|
||||
|
||||
def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool:
|
||||
"""Mark an OR listener as fired atomically.
|
||||
|
||||
Args:
|
||||
listener_name: The name of the OR listener to mark.
|
||||
|
||||
Returns:
|
||||
True if this call was the first to fire the listener.
|
||||
False if the listener was already fired.
|
||||
"""
|
||||
with self._or_listeners_lock:
|
||||
if listener_name in self._fired_or_listeners:
|
||||
return False
|
||||
self._fired_or_listeners.add(listener_name)
|
||||
return True
|
||||
|
||||
def _clear_or_listeners(self) -> None:
|
||||
"""Clear fired OR listeners for cyclic flows."""
|
||||
with self._or_listeners_lock:
|
||||
self._fired_or_listeners.clear()
|
||||
|
||||
def _discard_or_listener(self, listener_name: FlowMethodName) -> None:
|
||||
"""Discard a single OR listener from the fired set."""
|
||||
with self._or_listeners_lock:
|
||||
self._fired_or_listeners.discard(listener_name)
|
||||
|
||||
def _build_racing_groups(self) -> dict[frozenset[FlowMethodName], FlowMethodName]:
|
||||
"""Identify groups of methods that race for the same OR listener.
|
||||
|
||||
Analyzes the flow graph to find listeners with OR conditions that have
|
||||
multiple trigger methods. These trigger methods form a "racing group"
|
||||
where only the first to complete should trigger the OR listener.
|
||||
|
||||
Only methods that are EXCLUSIVELY sources for the OR listener are included
|
||||
in the racing group. Methods that are also triggers for other listeners
|
||||
(e.g., AND conditions) are not cancelled when another racing source wins.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frozensets of racing method names to their
|
||||
shared OR listener name.
|
||||
|
||||
Example:
|
||||
If we have `@listen(or_(method_a, method_b))` on `handler`,
|
||||
and method_a/method_b aren't used elsewhere,
|
||||
this returns: {frozenset({'method_a', 'method_b'}): 'handler'}
|
||||
"""
|
||||
racing_groups: dict[frozenset[FlowMethodName], FlowMethodName] = {}
|
||||
|
||||
method_to_listeners: dict[FlowMethodName, set[FlowMethodName]] = {}
|
||||
for listener_name, condition_data in self._listeners.items():
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
for m in methods:
|
||||
method_to_listeners.setdefault(m, set()).add(listener_name)
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
all_methods = _extract_all_methods_recursive(condition_data)
|
||||
for m in all_methods:
|
||||
method_name = FlowMethodName(m) if isinstance(m, str) else m
|
||||
method_to_listeners.setdefault(method_name, set()).add(
|
||||
listener_name
|
||||
)
|
||||
|
||||
for listener_name, condition_data in self._listeners.items():
|
||||
if listener_name in self._routers:
|
||||
continue
|
||||
|
||||
trigger_methods: set[FlowMethodName] = set()
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
condition_type, methods = condition_data
|
||||
if condition_type == OR_CONDITION and len(methods) > 1:
|
||||
trigger_methods = set(methods)
|
||||
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
top_level_type = condition_data.get("type", OR_CONDITION)
|
||||
if top_level_type == OR_CONDITION:
|
||||
all_methods = _extract_all_methods_recursive(condition_data)
|
||||
if len(all_methods) > 1:
|
||||
trigger_methods = set(
|
||||
FlowMethodName(m) if isinstance(m, str) else m
|
||||
for m in all_methods
|
||||
)
|
||||
|
||||
if trigger_methods:
|
||||
exclusive_methods = {
|
||||
m
|
||||
for m in trigger_methods
|
||||
if method_to_listeners.get(m, set()) == {listener_name}
|
||||
}
|
||||
if len(exclusive_methods) > 1:
|
||||
racing_groups[frozenset(exclusive_methods)] = listener_name
|
||||
|
||||
return racing_groups
|
||||
|
||||
def _get_racing_group_for_listeners(
|
||||
self,
|
||||
listener_names: list[FlowMethodName],
|
||||
) -> tuple[frozenset[FlowMethodName], FlowMethodName] | None:
|
||||
"""Check if the given listeners form a racing group.
|
||||
|
||||
Args:
|
||||
listener_names: List of listener method names being executed.
|
||||
|
||||
Returns:
|
||||
Tuple of (racing_members, or_listener_name) if these listeners race,
|
||||
None otherwise.
|
||||
"""
|
||||
if not hasattr(self, "_racing_groups_cache"):
|
||||
self._racing_groups_cache = self._build_racing_groups()
|
||||
|
||||
listener_set = set(listener_names)
|
||||
|
||||
for racing_members, or_listener in self._racing_groups_cache.items():
|
||||
if racing_members & listener_set:
|
||||
racing_subset = racing_members & listener_set
|
||||
if len(racing_subset) > 1:
|
||||
return (frozenset(racing_subset), or_listener)
|
||||
|
||||
return None
|
||||
|
||||
async def _execute_racing_listeners(
|
||||
self,
|
||||
racing_listeners: frozenset[FlowMethodName],
|
||||
other_listeners: list[FlowMethodName],
|
||||
result: Any,
|
||||
) -> None:
|
||||
"""Execute racing listeners with first-wins semantics.
|
||||
|
||||
Racing listeners are executed in parallel, but once the first one
|
||||
completes, the others are cancelled. Non-racing listeners in the
|
||||
same batch are executed normally in parallel.
|
||||
|
||||
Args:
|
||||
racing_listeners: Set of listener names that race for an OR condition.
|
||||
other_listeners: Other listeners to execute in parallel (not racing).
|
||||
result: The result from the triggering method.
|
||||
"""
|
||||
racing_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_single_listener(name, result),
|
||||
name=str(name),
|
||||
)
|
||||
for name in racing_listeners
|
||||
]
|
||||
|
||||
other_tasks = [
|
||||
asyncio.create_task(
|
||||
self._execute_single_listener(name, result),
|
||||
name=str(name),
|
||||
)
|
||||
for name in other_listeners
|
||||
]
|
||||
|
||||
if racing_tasks:
|
||||
for coro in asyncio.as_completed(racing_tasks):
|
||||
try:
|
||||
await coro
|
||||
except Exception as e:
|
||||
logger.debug(f"Racing listener failed: {e}")
|
||||
continue
|
||||
break
|
||||
|
||||
for task in racing_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
if other_tasks:
|
||||
await asyncio.gather(*other_tasks, return_exceptions=True)
|
||||
|
||||
@classmethod
|
||||
def from_pending(
|
||||
cls,
|
||||
@@ -745,12 +974,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# No default and no feedback - use first outcome
|
||||
collapsed_outcome = emit[0]
|
||||
elif emit:
|
||||
# Collapse feedback to outcome using LLM
|
||||
collapsed_outcome = self._collapse_to_outcome(
|
||||
feedback=feedback,
|
||||
outcomes=emit,
|
||||
llm=llm,
|
||||
)
|
||||
if llm is not None:
|
||||
collapsed_outcome = self._collapse_to_outcome(
|
||||
feedback=feedback,
|
||||
outcomes=emit,
|
||||
llm=llm,
|
||||
)
|
||||
else:
|
||||
collapsed_outcome = emit[0]
|
||||
|
||||
# Create result
|
||||
result = HumanFeedbackResult(
|
||||
@@ -789,21 +1020,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# This allows methods to re-execute in loops (e.g., implement_changes → suggest_changes → implement_changes)
|
||||
self._is_execution_resuming = False
|
||||
|
||||
# Determine what to pass to listeners
|
||||
final_result: Any = result
|
||||
try:
|
||||
if emit and collapsed_outcome:
|
||||
# Router behavior - the outcome itself triggers listeners
|
||||
# First, add the outcome to method outputs as a router would
|
||||
self._method_outputs.append(collapsed_outcome)
|
||||
|
||||
# Then trigger listeners for the outcome (e.g., "approved" triggers @listen("approved"))
|
||||
final_result = await self._execute_listeners(
|
||||
FlowMethodName(collapsed_outcome), # Use outcome as trigger
|
||||
result, # Pass HumanFeedbackResult to listeners
|
||||
await self._execute_listeners(
|
||||
FlowMethodName(collapsed_outcome),
|
||||
result,
|
||||
)
|
||||
else:
|
||||
# Normal behavior - pass the HumanFeedbackResult
|
||||
final_result = await self._execute_listeners(
|
||||
await self._execute_listeners(
|
||||
FlowMethodName(context.method_name),
|
||||
result,
|
||||
)
|
||||
@@ -899,18 +1125,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
# Handle case where initial_state is a type (class)
|
||||
if isinstance(self.initial_state, type):
|
||||
if issubclass(self.initial_state, FlowState):
|
||||
return self.initial_state() # Uses model defaults
|
||||
if issubclass(self.initial_state, BaseModel):
|
||||
# Validate that the model has an id field
|
||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||
state_class: type[T] = self.initial_state
|
||||
if issubclass(state_class, FlowState):
|
||||
return state_class()
|
||||
if issubclass(state_class, BaseModel):
|
||||
model_fields = getattr(state_class, "model_fields", None)
|
||||
if not model_fields or "id" not in model_fields:
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
instance = self.initial_state()
|
||||
# Ensure id is set - generate UUID if empty
|
||||
if not getattr(instance, "id", None):
|
||||
object.__setattr__(instance, "id", str(uuid4()))
|
||||
return instance
|
||||
model_instance = state_class()
|
||||
if not getattr(model_instance, "id", None):
|
||||
object.__setattr__(model_instance, "id", str(uuid4()))
|
||||
return model_instance
|
||||
if self.initial_state is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
|
||||
@@ -975,7 +1200,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
@property
|
||||
def state(self) -> T:
|
||||
return self._state
|
||||
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
def method_outputs(self) -> list[Any]:
|
||||
@@ -1300,7 +1525,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
self._pending_and_listeners.clear()
|
||||
self._fired_or_listeners.clear()
|
||||
self._clear_or_listeners()
|
||||
else:
|
||||
# We're restoring from persistence, set the flag
|
||||
self._is_execution_resuming = True
|
||||
@@ -1506,7 +1731,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# For cyclic flows, clear from completed to allow re-execution
|
||||
self._completed_methods.discard(start_method_name)
|
||||
# Also clear fired OR listeners to allow them to fire again in new cycle
|
||||
self._fired_or_listeners.clear()
|
||||
self._clear_or_listeners()
|
||||
|
||||
method = self._methods[start_method_name]
|
||||
enhanced_method = self._inject_trigger_payload_for_start_method(method)
|
||||
@@ -1529,9 +1754,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if self.last_human_feedback is not None
|
||||
else result
|
||||
)
|
||||
# Execute listeners sequentially to prevent race conditions on shared state
|
||||
for listener_name in listeners_for_result:
|
||||
await self._execute_single_listener(listener_name, listener_result)
|
||||
racing_group = self._get_racing_group_for_listeners(
|
||||
listeners_for_result
|
||||
)
|
||||
if racing_group:
|
||||
racing_members, _ = racing_group
|
||||
other_listeners = [
|
||||
name
|
||||
for name in listeners_for_result
|
||||
if name not in racing_members
|
||||
]
|
||||
await self._execute_racing_listeners(
|
||||
racing_members, other_listeners, listener_result
|
||||
)
|
||||
else:
|
||||
tasks = [
|
||||
self._execute_single_listener(listener_name, listener_result)
|
||||
for listener_name in listeners_for_result
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
else:
|
||||
await self._execute_listeners(start_method_name, result)
|
||||
|
||||
@@ -1756,11 +1997,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
listener_result = router_result_to_feedback.get(
|
||||
str(current_trigger), result
|
||||
)
|
||||
# Execute listeners sequentially to prevent race conditions on shared state
|
||||
for listener_name in listeners_triggered:
|
||||
await self._execute_single_listener(
|
||||
listener_name, listener_result
|
||||
racing_group = self._get_racing_group_for_listeners(
|
||||
listeners_triggered
|
||||
)
|
||||
if racing_group:
|
||||
racing_members, _ = racing_group
|
||||
other_listeners = [
|
||||
name
|
||||
for name in listeners_triggered
|
||||
if name not in racing_members
|
||||
]
|
||||
await self._execute_racing_listeners(
|
||||
racing_members, other_listeners, listener_result
|
||||
)
|
||||
else:
|
||||
tasks = [
|
||||
self._execute_single_listener(
|
||||
listener_name, listener_result
|
||||
)
|
||||
for listener_name in listeners_triggered
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if current_trigger in router_results:
|
||||
# Find start methods triggered by this router result
|
||||
@@ -1974,7 +2231,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# For cyclic flows, clear from completed to allow re-execution
|
||||
self._completed_methods.discard(listener_name)
|
||||
# Also clear from fired OR listeners for cyclic flows
|
||||
self._fired_or_listeners.discard(listener_name)
|
||||
self._discard_or_listener(listener_name)
|
||||
|
||||
try:
|
||||
method = self._methods[listener_name]
|
||||
@@ -2007,9 +2264,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if self.last_human_feedback is not None
|
||||
else listener_result
|
||||
)
|
||||
# Execute listeners sequentially to prevent race conditions on shared state
|
||||
for name in listeners_for_result:
|
||||
await self._execute_single_listener(name, feedback_result)
|
||||
racing_group = self._get_racing_group_for_listeners(
|
||||
listeners_for_result
|
||||
)
|
||||
if racing_group:
|
||||
racing_members, _ = racing_group
|
||||
other_listeners = [
|
||||
name
|
||||
for name in listeners_for_result
|
||||
if name not in racing_members
|
||||
]
|
||||
await self._execute_racing_listeners(
|
||||
racing_members, other_listeners, feedback_result
|
||||
)
|
||||
else:
|
||||
tasks = [
|
||||
self._execute_single_listener(name, feedback_result)
|
||||
for name in listeners_for_result
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
except Exception as e:
|
||||
# Don't log HumanFeedbackPending as an error - it's expected control flow
|
||||
@@ -2123,7 +2396,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
from crewai.utilities.i18n import get_i18n
|
||||
|
||||
# Get or create LLM instance
|
||||
llm_instance: BaseLLMClass
|
||||
if isinstance(llm, str):
|
||||
llm_instance = LLM(model=llm)
|
||||
elif isinstance(llm, BaseLLMClass):
|
||||
@@ -2158,26 +2431,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
response_model=FeedbackOutcome,
|
||||
)
|
||||
|
||||
# Parse the response - LLM returns JSON string when using response_model
|
||||
if isinstance(response, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
parsed = json.loads(response)
|
||||
return parsed.get("outcome", outcomes[0])
|
||||
return str(parsed.get("outcome", outcomes[0]))
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON, might be raw outcome string
|
||||
response_clean = response.strip()
|
||||
for outcome in outcomes:
|
||||
if outcome.lower() == response_clean.lower():
|
||||
return outcome
|
||||
return outcomes[0]
|
||||
elif isinstance(response, FeedbackOutcome):
|
||||
return response.outcome
|
||||
return str(response.outcome)
|
||||
elif hasattr(response, "outcome"):
|
||||
return response.outcome
|
||||
return str(response.outcome)
|
||||
else:
|
||||
# Unexpected type, fall back to first outcome
|
||||
logger.warning(f"Unexpected response type: {type(response)}")
|
||||
return outcomes[0]
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class PersistenceDecorator:
|
||||
@classmethod
|
||||
def persist_state(
|
||||
cls,
|
||||
flow_instance: Flow,
|
||||
flow_instance: Flow[Any],
|
||||
method_name: str,
|
||||
persistence_instance: FlowPersistence,
|
||||
verbose: bool = False,
|
||||
@@ -90,7 +90,13 @@ class PersistenceDecorator:
|
||||
flow_uuid: str | None = None
|
||||
if isinstance(state, dict):
|
||||
flow_uuid = state.get("id")
|
||||
elif isinstance(state, BaseModel):
|
||||
elif hasattr(state, "_unwrap"):
|
||||
unwrapped = state._unwrap()
|
||||
if isinstance(unwrapped, dict):
|
||||
flow_uuid = unwrapped.get("id")
|
||||
else:
|
||||
flow_uuid = getattr(unwrapped, "id", None)
|
||||
elif isinstance(state, BaseModel) or hasattr(state, "id"):
|
||||
flow_uuid = getattr(state, "id", None)
|
||||
|
||||
if not flow_uuid:
|
||||
@@ -104,10 +110,11 @@ class PersistenceDecorator:
|
||||
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
|
||||
|
||||
try:
|
||||
state_data = state._unwrap() if hasattr(state, "_unwrap") else state
|
||||
persistence_instance.save_state(
|
||||
flow_uuid=flow_uuid,
|
||||
method_name=method_name,
|
||||
state_data=state,
|
||||
state_data=state_data,
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
||||
@@ -126,7 +133,9 @@ class PersistenceDecorator:
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||
def persist(
|
||||
persistence: FlowPersistence | None = None, verbose: bool = False
|
||||
) -> Callable[[type | Callable[..., T]], type | Callable[..., T]]:
|
||||
"""Decorator to persist flow state.
|
||||
|
||||
This decorator can be applied at either the class level or method level.
|
||||
@@ -189,8 +198,8 @@ def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Create a closure to capture the current name and method
|
||||
def create_async_wrapper(
|
||||
method_name: str, original_method: Callable
|
||||
):
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@functools.wraps(original_method)
|
||||
async def method_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
@@ -221,8 +230,8 @@ def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||
else:
|
||||
# Create a closure to capture the current name and method
|
||||
def create_sync_wrapper(
|
||||
method_name: str, original_method: Callable
|
||||
):
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@functools.wraps(original_method)
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
@@ -268,7 +277,7 @@ def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
return cast(T, result)
|
||||
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
|
||||
@@ -1202,9 +1202,9 @@ def test_complex_and_or_branching():
|
||||
)
|
||||
assert execution_order.index("branch_2b") > min_branch_1_index
|
||||
|
||||
|
||||
# Final should be after both 2a and 2b
|
||||
# Note: final may not be absolutely last due to independent branches (like branch_1c)
|
||||
# that don't contribute to the final result path with sequential listener execution
|
||||
assert execution_order[-1] == "final"
|
||||
assert execution_order.index("final") > execution_order.index("branch_2a")
|
||||
assert execution_order.index("final") > execution_order.index("branch_2b")
|
||||
|
||||
@@ -1256,10 +1256,11 @@ def test_conditional_router_paths_exclusivity():
|
||||
|
||||
|
||||
def test_state_consistency_across_parallel_branches():
|
||||
"""Test that state remains consistent when branches execute sequentially.
|
||||
"""Test that state remains consistent when branches execute in parallel.
|
||||
|
||||
Note: Branches triggered by the same parent execute sequentially, not in parallel.
|
||||
This ensures predictable state mutations and prevents race conditions.
|
||||
Note: Branches triggered by the same parent execute in parallel for efficiency.
|
||||
Thread-safe state access via StateProxy ensures no race conditions.
|
||||
We check the execution order to ensure the branches execute in parallel.
|
||||
"""
|
||||
execution_order = []
|
||||
|
||||
@@ -1296,12 +1297,14 @@ def test_state_consistency_across_parallel_branches():
|
||||
flow = StateConsistencyFlow()
|
||||
flow.kickoff()
|
||||
|
||||
# Branches execute sequentially, so branch_a runs first, then branch_b
|
||||
assert flow.state["branch_a_value"] == 10 # Sees initial value
|
||||
assert flow.state["branch_b_value"] == 11 # Sees value after branch_a increment
|
||||
assert "branch_a" in execution_order
|
||||
assert "branch_b" in execution_order
|
||||
assert "verify_state" in execution_order
|
||||
|
||||
# Final counter should reflect both increments sequentially
|
||||
assert flow.state["counter"] == 16 # 10 + 1 + 5
|
||||
assert flow.state["branch_a_value"] is not None
|
||||
assert flow.state["branch_b_value"] is not None
|
||||
|
||||
assert flow.state["counter"] == 16
|
||||
|
||||
|
||||
def test_deeply_nested_conditions():
|
||||
|
||||
@@ -247,4 +247,4 @@ def test_persistence_with_base_model(tmp_path):
|
||||
assert message.role == "user"
|
||||
assert message.type == "text"
|
||||
assert message.content == "Hello, World!"
|
||||
assert isinstance(flow.state, State)
|
||||
assert isinstance(flow.state._unwrap(), State)
|
||||
|
||||
Reference in New Issue
Block a user