diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index e84b02018..3336313e0 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -246,6 +246,7 @@ class Agent(BaseAgent): Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig. """, ) + executor_class: type[CrewAgentExecutor] | type[AgentExecutor] = Field( executor_class: type[CrewAgentExecutor] | type[AgentExecutor] = Field( default=CrewAgentExecutor, description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use AgentExecutor.", @@ -1615,6 +1616,7 @@ class Agent(BaseAgent): ) return None + def _prepare_kickoff( def _prepare_kickoff( self, messages: str | list[LLMMessage], diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index f6372518c..ea6f41822 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -1,8 +1,6 @@ from __future__ import annotations from collections.abc import Callable, Coroutine -from datetime import datetime -import json import threading import time from typing import TYPE_CHECKING, Any, Literal, cast @@ -88,7 +86,7 @@ class AgentReActState(BaseModel): class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): - """Flow-based agent executor for both standalone and crew-bound execution. + """Agent Executor for both standalone agents and crew-bound agents. Inherits from: - Flow[AgentReActState]: Provides flow orchestration capabilities @@ -199,11 +197,6 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): else self.stop ) ) - - # Native tool calling support - self._openai_tools: list[dict[str, Any]] = [] - self._available_functions: dict[str, Callable[..., Any]] = {} - self._state = AgentReActState() def _ensure_flow_initialized(self) -> None: @@ -949,6 +942,91 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): finally: self._is_executing = False + async def invoke_async(self, inputs: dict[str, Any]) -> dict[str, Any]: + """Execute agent asynchronously with given inputs. + + This method is designed for use within async contexts, such as when + the agent is called from within an async Flow method. It uses + kickoff_async() directly instead of running in a separate thread. + + Args: + inputs: Input dictionary containing prompt variables. + + Returns: + Dictionary with agent output, or a coroutine if inside an event loop. + """ + # Magic auto-async: if inside event loop, return coroutine for Flow to await + if is_inside_event_loop(): + return self.invoke_async(inputs) + + self._ensure_flow_initialized() + + with self._execution_lock: + if self._is_executing: + raise RuntimeError( + "Executor is already running. " + "Cannot invoke the same executor instance concurrently." + ) + self._is_executing = True + self._has_been_invoked = True + + try: + # Reset state for fresh execution + self.state.messages.clear() + self.state.iterations = 0 + self.state.current_answer = None + self.state.is_finished = False + + if "system" in self.prompt: + prompt = cast("SystemPromptResult", self.prompt) + system_prompt = self._format_prompt(prompt["system"], inputs) + user_prompt = self._format_prompt(prompt["user"], inputs) + self.state.messages.append( + format_message_for_llm(system_prompt, role="system") + ) + self.state.messages.append(format_message_for_llm(user_prompt)) + else: + user_prompt = self._format_prompt(self.prompt["prompt"], inputs) + self.state.messages.append(format_message_for_llm(user_prompt)) + + self.state.ask_for_human_input = bool( + inputs.get("ask_for_human_input", False) + ) + + # Use async kickoff directly since we're already in an async context + await self.kickoff_async() + + formatted_answer = self.state.current_answer + + if not isinstance(formatted_answer, AgentFinish): + raise RuntimeError( + "Agent execution ended without reaching a final answer." + ) + + if self.state.ask_for_human_input: + formatted_answer = self._handle_human_feedback(formatted_answer) + + self._create_short_term_memory(formatted_answer) + self._create_long_term_memory(formatted_answer) + self._create_external_memory(formatted_answer) + + return {"output": formatted_answer.output} + + except AssertionError: + fail_text = Text() + fail_text.append("❌ ", style="red bold") + fail_text.append( + "Agent failed to reach a final answer. This is likely a bug - please report it.", + style="red", + ) + self._console.print(fail_text) + raise + except Exception as e: + handle_unknown_error(self._printer, e) + raise + finally: + self._is_executing = False + async def invoke_async(self, inputs: dict[str, Any]) -> dict[str, Any]: """Execute agent asynchronously with given inputs. diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 9a0ea5a2e..a3e5f69ac 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -12,6 +12,7 @@ from concurrent.futures import Future import copy import inspect import logging +import threading from typing import ( TYPE_CHECKING, Any, @@ -64,6 +65,7 @@ from crewai.flow.persistence.base import FlowPersistence from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey from crewai.flow.utils import ( _extract_all_methods, + _extract_all_methods_recursive, _normalize_condition, get_possible_return_constants, is_flow_condition_dict, @@ -397,6 +399,62 @@ def and_(*conditions: str | FlowCondition | Callable[..., Any]) -> FlowCondition return {"type": AND_CONDITION, "conditions": processed_conditions} +class StateProxy(Generic[T]): + """Proxy that provides thread-safe access to flow state. + + Wraps state objects (dict or BaseModel) and uses a lock for all write + operations to prevent race conditions when parallel listeners modify state. + """ + + __slots__ = ("_proxy_lock", "_proxy_state") + + def __init__(self, state: T, lock: threading.Lock) -> None: + object.__setattr__(self, "_proxy_state", state) + object.__setattr__(self, "_proxy_lock", lock) + + def __getattr__(self, name: str) -> Any: + return getattr(object.__getattribute__(self, "_proxy_state"), name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ("_proxy_state", "_proxy_lock"): + object.__setattr__(self, name, value) + else: + with object.__getattribute__(self, "_proxy_lock"): + setattr(object.__getattribute__(self, "_proxy_state"), name, value) + + def __getitem__(self, key: str) -> Any: + return object.__getattribute__(self, "_proxy_state")[key] + + def __setitem__(self, key: str, value: Any) -> None: + with object.__getattribute__(self, "_proxy_lock"): + object.__getattribute__(self, "_proxy_state")[key] = value + + def __delitem__(self, key: str) -> None: + with object.__getattribute__(self, "_proxy_lock"): + del object.__getattribute__(self, "_proxy_state")[key] + + def __contains__(self, key: str) -> bool: + return key in object.__getattribute__(self, "_proxy_state") + + def __repr__(self) -> str: + return repr(object.__getattribute__(self, "_proxy_state")) + + def _unwrap(self) -> T: + """Return the underlying state object.""" + return cast(T, object.__getattribute__(self, "_proxy_state")) + + def model_dump(self) -> dict[str, Any]: + """Return state as a dictionary. + + Works for both dict and BaseModel underlying states. + """ + state = object.__getattribute__(self, "_proxy_state") + if isinstance(state, dict): + return state + result: dict[str, Any] = state.model_dump() + return result + + class FlowMeta(type): def __new__( mcs, @@ -524,6 +582,8 @@ class Flow(Generic[T], metaclass=FlowMeta): set() ) # Track OR listeners that already fired self._method_outputs: list[Any] = [] # list to store all method outputs + self._state_lock = threading.Lock() + self._or_listeners_lock = threading.Lock() self._completed_methods: set[FlowMethodName] = ( set() ) # Track completed methods for reload @@ -568,6 +628,175 @@ class Flow(Generic[T], metaclass=FlowMeta): method = method.__get__(self, self.__class__) self._methods[method.__name__] = method + def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool: + """Mark an OR listener as fired atomically. + + Args: + listener_name: The name of the OR listener to mark. + + Returns: + True if this call was the first to fire the listener. + False if the listener was already fired. + """ + with self._or_listeners_lock: + if listener_name in self._fired_or_listeners: + return False + self._fired_or_listeners.add(listener_name) + return True + + def _clear_or_listeners(self) -> None: + """Clear fired OR listeners for cyclic flows.""" + with self._or_listeners_lock: + self._fired_or_listeners.clear() + + def _discard_or_listener(self, listener_name: FlowMethodName) -> None: + """Discard a single OR listener from the fired set.""" + with self._or_listeners_lock: + self._fired_or_listeners.discard(listener_name) + + def _build_racing_groups(self) -> dict[frozenset[FlowMethodName], FlowMethodName]: + """Identify groups of methods that race for the same OR listener. + + Analyzes the flow graph to find listeners with OR conditions that have + multiple trigger methods. These trigger methods form a "racing group" + where only the first to complete should trigger the OR listener. + + Only methods that are EXCLUSIVELY sources for the OR listener are included + in the racing group. Methods that are also triggers for other listeners + (e.g., AND conditions) are not cancelled when another racing source wins. + + Returns: + Dictionary mapping frozensets of racing method names to their + shared OR listener name. + + Example: + If we have `@listen(or_(method_a, method_b))` on `handler`, + and method_a/method_b aren't used elsewhere, + this returns: {frozenset({'method_a', 'method_b'}): 'handler'} + """ + racing_groups: dict[frozenset[FlowMethodName], FlowMethodName] = {} + + method_to_listeners: dict[FlowMethodName, set[FlowMethodName]] = {} + for listener_name, condition_data in self._listeners.items(): + if is_simple_flow_condition(condition_data): + _, methods = condition_data + for m in methods: + method_to_listeners.setdefault(m, set()).add(listener_name) + elif is_flow_condition_dict(condition_data): + all_methods = _extract_all_methods_recursive(condition_data) + for m in all_methods: + method_name = FlowMethodName(m) if isinstance(m, str) else m + method_to_listeners.setdefault(method_name, set()).add( + listener_name + ) + + for listener_name, condition_data in self._listeners.items(): + if listener_name in self._routers: + continue + + trigger_methods: set[FlowMethodName] = set() + + if is_simple_flow_condition(condition_data): + condition_type, methods = condition_data + if condition_type == OR_CONDITION and len(methods) > 1: + trigger_methods = set(methods) + + elif is_flow_condition_dict(condition_data): + top_level_type = condition_data.get("type", OR_CONDITION) + if top_level_type == OR_CONDITION: + all_methods = _extract_all_methods_recursive(condition_data) + if len(all_methods) > 1: + trigger_methods = set( + FlowMethodName(m) if isinstance(m, str) else m + for m in all_methods + ) + + if trigger_methods: + exclusive_methods = { + m + for m in trigger_methods + if method_to_listeners.get(m, set()) == {listener_name} + } + if len(exclusive_methods) > 1: + racing_groups[frozenset(exclusive_methods)] = listener_name + + return racing_groups + + def _get_racing_group_for_listeners( + self, + listener_names: list[FlowMethodName], + ) -> tuple[frozenset[FlowMethodName], FlowMethodName] | None: + """Check if the given listeners form a racing group. + + Args: + listener_names: List of listener method names being executed. + + Returns: + Tuple of (racing_members, or_listener_name) if these listeners race, + None otherwise. + """ + if not hasattr(self, "_racing_groups_cache"): + self._racing_groups_cache = self._build_racing_groups() + + listener_set = set(listener_names) + + for racing_members, or_listener in self._racing_groups_cache.items(): + if racing_members & listener_set: + racing_subset = racing_members & listener_set + if len(racing_subset) > 1: + return (frozenset(racing_subset), or_listener) + + return None + + async def _execute_racing_listeners( + self, + racing_listeners: frozenset[FlowMethodName], + other_listeners: list[FlowMethodName], + result: Any, + ) -> None: + """Execute racing listeners with first-wins semantics. + + Racing listeners are executed in parallel, but once the first one + completes, the others are cancelled. Non-racing listeners in the + same batch are executed normally in parallel. + + Args: + racing_listeners: Set of listener names that race for an OR condition. + other_listeners: Other listeners to execute in parallel (not racing). + result: The result from the triggering method. + """ + racing_tasks = [ + asyncio.create_task( + self._execute_single_listener(name, result), + name=str(name), + ) + for name in racing_listeners + ] + + other_tasks = [ + asyncio.create_task( + self._execute_single_listener(name, result), + name=str(name), + ) + for name in other_listeners + ] + + if racing_tasks: + for coro in asyncio.as_completed(racing_tasks): + try: + await coro + except Exception as e: + logger.debug(f"Racing listener failed: {e}") + continue + break + + for task in racing_tasks: + if not task.done(): + task.cancel() + + if other_tasks: + await asyncio.gather(*other_tasks, return_exceptions=True) + @classmethod def from_pending( cls, @@ -745,12 +974,14 @@ class Flow(Generic[T], metaclass=FlowMeta): # No default and no feedback - use first outcome collapsed_outcome = emit[0] elif emit: - # Collapse feedback to outcome using LLM - collapsed_outcome = self._collapse_to_outcome( - feedback=feedback, - outcomes=emit, - llm=llm, - ) + if llm is not None: + collapsed_outcome = self._collapse_to_outcome( + feedback=feedback, + outcomes=emit, + llm=llm, + ) + else: + collapsed_outcome = emit[0] # Create result result = HumanFeedbackResult( @@ -789,21 +1020,16 @@ class Flow(Generic[T], metaclass=FlowMeta): # This allows methods to re-execute in loops (e.g., implement_changes → suggest_changes → implement_changes) self._is_execution_resuming = False - # Determine what to pass to listeners + final_result: Any = result try: if emit and collapsed_outcome: - # Router behavior - the outcome itself triggers listeners - # First, add the outcome to method outputs as a router would self._method_outputs.append(collapsed_outcome) - - # Then trigger listeners for the outcome (e.g., "approved" triggers @listen("approved")) - final_result = await self._execute_listeners( - FlowMethodName(collapsed_outcome), # Use outcome as trigger - result, # Pass HumanFeedbackResult to listeners + await self._execute_listeners( + FlowMethodName(collapsed_outcome), + result, ) else: - # Normal behavior - pass the HumanFeedbackResult - final_result = await self._execute_listeners( + await self._execute_listeners( FlowMethodName(context.method_name), result, ) @@ -899,18 +1125,17 @@ class Flow(Generic[T], metaclass=FlowMeta): # Handle case where initial_state is a type (class) if isinstance(self.initial_state, type): - if issubclass(self.initial_state, FlowState): - return self.initial_state() # Uses model defaults - if issubclass(self.initial_state, BaseModel): - # Validate that the model has an id field - model_fields = getattr(self.initial_state, "model_fields", None) + state_class: type[T] = self.initial_state + if issubclass(state_class, FlowState): + return state_class() + if issubclass(state_class, BaseModel): + model_fields = getattr(state_class, "model_fields", None) if not model_fields or "id" not in model_fields: raise ValueError("Flow state model must have an 'id' field") - instance = self.initial_state() - # Ensure id is set - generate UUID if empty - if not getattr(instance, "id", None): - object.__setattr__(instance, "id", str(uuid4())) - return instance + model_instance = state_class() + if not getattr(model_instance, "id", None): + object.__setattr__(model_instance, "id", str(uuid4())) + return model_instance if self.initial_state is dict: return cast(T, {"id": str(uuid4())}) @@ -975,7 +1200,7 @@ class Flow(Generic[T], metaclass=FlowMeta): @property def state(self) -> T: - return self._state + return StateProxy(self._state, self._state_lock) # type: ignore[return-value] @property def method_outputs(self) -> list[Any]: @@ -1300,7 +1525,7 @@ class Flow(Generic[T], metaclass=FlowMeta): self._completed_methods.clear() self._method_outputs.clear() self._pending_and_listeners.clear() - self._fired_or_listeners.clear() + self._clear_or_listeners() else: # We're restoring from persistence, set the flag self._is_execution_resuming = True @@ -1506,7 +1731,7 @@ class Flow(Generic[T], metaclass=FlowMeta): # For cyclic flows, clear from completed to allow re-execution self._completed_methods.discard(start_method_name) # Also clear fired OR listeners to allow them to fire again in new cycle - self._fired_or_listeners.clear() + self._clear_or_listeners() method = self._methods[start_method_name] enhanced_method = self._inject_trigger_payload_for_start_method(method) @@ -1529,9 +1754,25 @@ class Flow(Generic[T], metaclass=FlowMeta): if self.last_human_feedback is not None else result ) - # Execute listeners sequentially to prevent race conditions on shared state - for listener_name in listeners_for_result: - await self._execute_single_listener(listener_name, listener_result) + racing_group = self._get_racing_group_for_listeners( + listeners_for_result + ) + if racing_group: + racing_members, _ = racing_group + other_listeners = [ + name + for name in listeners_for_result + if name not in racing_members + ] + await self._execute_racing_listeners( + racing_members, other_listeners, listener_result + ) + else: + tasks = [ + self._execute_single_listener(listener_name, listener_result) + for listener_name in listeners_for_result + ] + await asyncio.gather(*tasks) else: await self._execute_listeners(start_method_name, result) @@ -1756,11 +1997,27 @@ class Flow(Generic[T], metaclass=FlowMeta): listener_result = router_result_to_feedback.get( str(current_trigger), result ) - # Execute listeners sequentially to prevent race conditions on shared state - for listener_name in listeners_triggered: - await self._execute_single_listener( - listener_name, listener_result + racing_group = self._get_racing_group_for_listeners( + listeners_triggered + ) + if racing_group: + racing_members, _ = racing_group + other_listeners = [ + name + for name in listeners_triggered + if name not in racing_members + ] + await self._execute_racing_listeners( + racing_members, other_listeners, listener_result ) + else: + tasks = [ + self._execute_single_listener( + listener_name, listener_result + ) + for listener_name in listeners_triggered + ] + await asyncio.gather(*tasks) if current_trigger in router_results: # Find start methods triggered by this router result @@ -1974,7 +2231,7 @@ class Flow(Generic[T], metaclass=FlowMeta): # For cyclic flows, clear from completed to allow re-execution self._completed_methods.discard(listener_name) # Also clear from fired OR listeners for cyclic flows - self._fired_or_listeners.discard(listener_name) + self._discard_or_listener(listener_name) try: method = self._methods[listener_name] @@ -2007,9 +2264,25 @@ class Flow(Generic[T], metaclass=FlowMeta): if self.last_human_feedback is not None else listener_result ) - # Execute listeners sequentially to prevent race conditions on shared state - for name in listeners_for_result: - await self._execute_single_listener(name, feedback_result) + racing_group = self._get_racing_group_for_listeners( + listeners_for_result + ) + if racing_group: + racing_members, _ = racing_group + other_listeners = [ + name + for name in listeners_for_result + if name not in racing_members + ] + await self._execute_racing_listeners( + racing_members, other_listeners, feedback_result + ) + else: + tasks = [ + self._execute_single_listener(name, feedback_result) + for name in listeners_for_result + ] + await asyncio.gather(*tasks) except Exception as e: # Don't log HumanFeedbackPending as an error - it's expected control flow @@ -2123,7 +2396,7 @@ class Flow(Generic[T], metaclass=FlowMeta): from crewai.llms.base_llm import BaseLLM as BaseLLMClass from crewai.utilities.i18n import get_i18n - # Get or create LLM instance + llm_instance: BaseLLMClass if isinstance(llm, str): llm_instance = LLM(model=llm) elif isinstance(llm, BaseLLMClass): @@ -2158,26 +2431,23 @@ class Flow(Generic[T], metaclass=FlowMeta): response_model=FeedbackOutcome, ) - # Parse the response - LLM returns JSON string when using response_model if isinstance(response, str): import json try: parsed = json.loads(response) - return parsed.get("outcome", outcomes[0]) + return str(parsed.get("outcome", outcomes[0])) except json.JSONDecodeError: - # Not valid JSON, might be raw outcome string response_clean = response.strip() for outcome in outcomes: if outcome.lower() == response_clean.lower(): return outcome return outcomes[0] elif isinstance(response, FeedbackOutcome): - return response.outcome + return str(response.outcome) elif hasattr(response, "outcome"): - return response.outcome + return str(response.outcome) else: - # Unexpected type, fall back to first outcome logger.warning(f"Unexpected response type: {type(response)}") return outcomes[0] diff --git a/lib/crewai/src/crewai/flow/persistence/decorators.py b/lib/crewai/src/crewai/flow/persistence/decorators.py index 3f5be17db..dbbeaa16f 100644 --- a/lib/crewai/src/crewai/flow/persistence/decorators.py +++ b/lib/crewai/src/crewai/flow/persistence/decorators.py @@ -61,7 +61,7 @@ class PersistenceDecorator: @classmethod def persist_state( cls, - flow_instance: Flow, + flow_instance: Flow[Any], method_name: str, persistence_instance: FlowPersistence, verbose: bool = False, @@ -90,7 +90,13 @@ class PersistenceDecorator: flow_uuid: str | None = None if isinstance(state, dict): flow_uuid = state.get("id") - elif isinstance(state, BaseModel): + elif hasattr(state, "_unwrap"): + unwrapped = state._unwrap() + if isinstance(unwrapped, dict): + flow_uuid = unwrapped.get("id") + else: + flow_uuid = getattr(unwrapped, "id", None) + elif isinstance(state, BaseModel) or hasattr(state, "id"): flow_uuid = getattr(state, "id", None) if not flow_uuid: @@ -104,10 +110,11 @@ class PersistenceDecorator: logger.info(LOG_MESSAGES["save_state"].format(flow_uuid)) try: + state_data = state._unwrap() if hasattr(state, "_unwrap") else state persistence_instance.save_state( flow_uuid=flow_uuid, method_name=method_name, - state_data=state, + state_data=state_data, ) except Exception as e: error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e)) @@ -126,7 +133,9 @@ class PersistenceDecorator: raise ValueError(error_msg) from e -def persist(persistence: FlowPersistence | None = None, verbose: bool = False): +def persist( + persistence: FlowPersistence | None = None, verbose: bool = False +) -> Callable[[type | Callable[..., T]], type | Callable[..., T]]: """Decorator to persist flow state. This decorator can be applied at either the class level or method level. @@ -189,8 +198,8 @@ def persist(persistence: FlowPersistence | None = None, verbose: bool = False): if asyncio.iscoroutinefunction(method): # Create a closure to capture the current name and method def create_async_wrapper( - method_name: str, original_method: Callable - ): + method_name: str, original_method: Callable[..., Any] + ) -> Callable[..., Any]: @functools.wraps(original_method) async def method_wrapper( self: Any, *args: Any, **kwargs: Any @@ -221,8 +230,8 @@ def persist(persistence: FlowPersistence | None = None, verbose: bool = False): else: # Create a closure to capture the current name and method def create_sync_wrapper( - method_name: str, original_method: Callable - ): + method_name: str, original_method: Callable[..., Any] + ) -> Callable[..., Any]: @functools.wraps(original_method) def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: result = original_method(self, *args, **kwargs) @@ -268,7 +277,7 @@ def persist(persistence: FlowPersistence | None = None, verbose: bool = False): PersistenceDecorator.persist_state( flow_instance, method.__name__, actual_persistence, verbose ) - return result + return cast(T, result) for attr in [ "__is_start_method__", diff --git a/lib/crewai/tests/test_flow.py b/lib/crewai/tests/test_flow.py index 54562d41a..6926e15d5 100644 --- a/lib/crewai/tests/test_flow.py +++ b/lib/crewai/tests/test_flow.py @@ -1202,9 +1202,9 @@ def test_complex_and_or_branching(): ) assert execution_order.index("branch_2b") > min_branch_1_index + # Final should be after both 2a and 2b - # Note: final may not be absolutely last due to independent branches (like branch_1c) - # that don't contribute to the final result path with sequential listener execution + assert execution_order[-1] == "final" assert execution_order.index("final") > execution_order.index("branch_2a") assert execution_order.index("final") > execution_order.index("branch_2b") @@ -1256,10 +1256,11 @@ def test_conditional_router_paths_exclusivity(): def test_state_consistency_across_parallel_branches(): - """Test that state remains consistent when branches execute sequentially. + """Test that state remains consistent when branches execute in parallel. - Note: Branches triggered by the same parent execute sequentially, not in parallel. - This ensures predictable state mutations and prevents race conditions. + Note: Branches triggered by the same parent execute in parallel for efficiency. + Thread-safe state access via StateProxy ensures no race conditions. + We check the execution order to ensure the branches execute in parallel. """ execution_order = [] @@ -1296,12 +1297,14 @@ def test_state_consistency_across_parallel_branches(): flow = StateConsistencyFlow() flow.kickoff() - # Branches execute sequentially, so branch_a runs first, then branch_b - assert flow.state["branch_a_value"] == 10 # Sees initial value - assert flow.state["branch_b_value"] == 11 # Sees value after branch_a increment + assert "branch_a" in execution_order + assert "branch_b" in execution_order + assert "verify_state" in execution_order - # Final counter should reflect both increments sequentially - assert flow.state["counter"] == 16 # 10 + 1 + 5 + assert flow.state["branch_a_value"] is not None + assert flow.state["branch_b_value"] is not None + + assert flow.state["counter"] == 16 def test_deeply_nested_conditions(): diff --git a/lib/crewai/tests/test_flow_persistence.py b/lib/crewai/tests/test_flow_persistence.py index 53e059b52..06bbf7231 100644 --- a/lib/crewai/tests/test_flow_persistence.py +++ b/lib/crewai/tests/test_flow_persistence.py @@ -247,4 +247,4 @@ def test_persistence_with_base_model(tmp_path): assert message.role == "user" assert message.type == "text" assert message.content == "Hello, World!" - assert isinstance(flow.state, State) + assert isinstance(flow.state._unwrap(), State)