From f3cacdcf1d9387b96d262cbec394525348e07cd8 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 20 Jan 2026 15:27:16 -0500 Subject: [PATCH] fix: ensure context isolation --- lib/crewai/src/crewai/a2a/types.py | 28 +++- lib/crewai/src/crewai/a2a/utils/delegation.py | 66 +++++--- lib/crewai/src/crewai/agent/core.py | 158 ++++++++++++------ lib/crewai/src/crewai/events/__init__.py | 2 + lib/crewai/src/crewai/events/base_events.py | 25 ++- lib/crewai/src/crewai/events/event_bus.py | 4 +- lib/crewai/src/crewai/events/event_context.py | 2 + lib/crewai/src/crewai/events/event_types.py | 2 + .../src/crewai/events/types/memory_events.py | 10 +- 9 files changed, 199 insertions(+), 98 deletions(-) diff --git a/lib/crewai/src/crewai/a2a/types.py b/lib/crewai/src/crewai/a2a/types.py index 90473b669..ea15abd80 100644 --- a/lib/crewai/src/crewai/a2a/types.py +++ b/lib/crewai/src/crewai/a2a/types.py @@ -14,15 +14,25 @@ from typing import ( from pydantic import BeforeValidator, HttpUrl, TypeAdapter from typing_extensions import NotRequired -from crewai.a2a.updates import ( - PollingConfig, - PollingHandler, - PushNotificationConfig, - PushNotificationHandler, - StreamingConfig, - StreamingHandler, - UpdateConfig, -) + +try: + from crewai.a2a.updates import ( + PollingConfig, + PollingHandler, + PushNotificationConfig, + PushNotificationHandler, + StreamingConfig, + StreamingHandler, + UpdateConfig, + ) +except ImportError: + PollingConfig = Any # type: ignore[misc,assignment] + PollingHandler = Any # type: ignore[misc,assignment] + PushNotificationConfig = Any # type: ignore[misc,assignment] + PushNotificationHandler = Any # type: ignore[misc,assignment] + StreamingConfig = Any # type: ignore[misc,assignment] + StreamingHandler = Any # type: ignore[misc,assignment] + UpdateConfig = Any # type: ignore[misc,assignment] TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"] diff --git a/lib/crewai/src/crewai/a2a/utils/delegation.py b/lib/crewai/src/crewai/a2a/utils/delegation.py index 0fc9eaec5..f322bbf74 100644 --- a/lib/crewai/src/crewai/a2a/utils/delegation.py +++ b/lib/crewai/src/crewai/a2a/utils/delegation.py @@ -251,30 +251,48 @@ async def aexecute_a2a_delegation( if turn_number is None: turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1 - result = await _aexecute_a2a_delegation_impl( - endpoint=endpoint, - auth=auth, - timeout=timeout, - task_description=task_description, - context=context, - context_id=context_id, - task_id=task_id, - reference_task_ids=reference_task_ids, - metadata=metadata, - extensions=extensions, - conversation_history=conversation_history, - is_multiturn=is_multiturn, - turn_number=turn_number, - agent_branch=agent_branch, - agent_id=agent_id, - agent_role=agent_role, - response_model=response_model, - updates=updates, - transport_protocol=transport_protocol, - from_task=from_task, - from_agent=from_agent, - skill_id=skill_id, - ) + try: + result = await _aexecute_a2a_delegation_impl( + endpoint=endpoint, + auth=auth, + timeout=timeout, + task_description=task_description, + context=context, + context_id=context_id, + task_id=task_id, + reference_task_ids=reference_task_ids, + metadata=metadata, + extensions=extensions, + conversation_history=conversation_history, + is_multiturn=is_multiturn, + turn_number=turn_number, + agent_branch=agent_branch, + agent_id=agent_id, + agent_role=agent_role, + response_model=response_model, + updates=updates, + transport_protocol=transport_protocol, + from_task=from_task, + from_agent=from_agent, + skill_id=skill_id, + ) + except Exception as e: + crewai_event_bus.emit( + agent_branch, + A2ADelegationCompletedEvent( + status="failed", + result=None, + error=str(e), + context_id=context_id, + is_multiturn=is_multiturn, + endpoint=endpoint, + metadata=metadata, + extensions=list(extensions.keys()) if extensions else None, + from_task=from_task, + from_agent=from_agent, + ), + ) + raise agent_card_data: dict[str, Any] = result.get("agent_card") or {} crewai_event_bus.emit( diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index bc964754c..ca1effb14 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -14,7 +14,14 @@ from typing import ( ) from urllib.parse import urlparse -from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + InstanceOf, + PrivateAttr, + model_validator, +) from typing_extensions import Self from crewai.agent.utils import ( @@ -41,6 +48,7 @@ from crewai.events.types.knowledge_events import ( ) from crewai.events.types.memory_events import ( MemoryRetrievalCompletedEvent, + MemoryRetrievalFailedEvent, MemoryRetrievalStartedEvent, ) from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow @@ -77,17 +85,10 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.training_handler import CrewTrainingHandler -try: - from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig -except ImportError: - A2AClientConfig = Any - A2AConfig = Any - A2AServerConfig = Any - - if TYPE_CHECKING: from crewai_tools import CodeInterpreterTool + from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig from crewai.agents.agent_builder.base_agent import PlatformAppOrAction from crewai.lite_agent_output import LiteAgentOutput from crewai.task import Task @@ -133,6 +134,8 @@ class Agent(BaseAgent): mcps: List of MCP server references for tool integration. """ + model_config = ConfigDict() + _times_executed: int = PrivateAttr(default=0) _mcp_clients: list[Any] = PrivateAttr(default_factory=list) _last_messages: list[LLMMessage] = PrivateAttr(default_factory=list) @@ -346,30 +349,43 @@ class Agent(BaseAgent): ) start_time = time.time() + memory = "" - contextual_memory = ContextualMemory( - self.crew._short_term_memory, - self.crew._long_term_memory, - self.crew._entity_memory, - self.crew._external_memory, - agent=self, - task=task, - ) - memory = contextual_memory.build_context_for_task(task, context or "") - if memory.strip() != "": - task_prompt += self.i18n.slice("memory").format(memory=memory) + try: + contextual_memory = ContextualMemory( + self.crew._short_term_memory, + self.crew._long_term_memory, + self.crew._entity_memory, + self.crew._external_memory, + agent=self, + task=task, + ) + memory = contextual_memory.build_context_for_task(task, context or "") + if memory.strip() != "": + task_prompt += self.i18n.slice("memory").format(memory=memory) - crewai_event_bus.emit( - self, - event=MemoryRetrievalCompletedEvent( - task_id=str(task.id) if task else None, - memory_content=memory, - retrieval_time_ms=(time.time() - start_time) * 1000, - source_type="agent", - from_agent=self, - from_task=task, - ), - ) + crewai_event_bus.emit( + self, + event=MemoryRetrievalCompletedEvent( + task_id=str(task.id) if task else None, + memory_content=memory, + retrieval_time_ms=(time.time() - start_time) * 1000, + source_type="agent", + from_agent=self, + from_task=task, + ), + ) + except Exception as e: + crewai_event_bus.emit( + self, + event=MemoryRetrievalFailedEvent( + task_id=str(task.id) if task else None, + source_type="agent", + from_agent=self, + from_task=task, + error=str(e), + ), + ) knowledge_config = get_knowledge_config(self) task_prompt = handle_knowledge_retrieval( @@ -555,32 +571,45 @@ class Agent(BaseAgent): ) start_time = time.time() + memory = "" - contextual_memory = ContextualMemory( - self.crew._short_term_memory, - self.crew._long_term_memory, - self.crew._entity_memory, - self.crew._external_memory, - agent=self, - task=task, - ) - memory = await contextual_memory.abuild_context_for_task( - task, context or "" - ) - if memory.strip() != "": - task_prompt += self.i18n.slice("memory").format(memory=memory) + try: + contextual_memory = ContextualMemory( + self.crew._short_term_memory, + self.crew._long_term_memory, + self.crew._entity_memory, + self.crew._external_memory, + agent=self, + task=task, + ) + memory = await contextual_memory.abuild_context_for_task( + task, context or "" + ) + if memory.strip() != "": + task_prompt += self.i18n.slice("memory").format(memory=memory) - crewai_event_bus.emit( - self, - event=MemoryRetrievalCompletedEvent( - task_id=str(task.id) if task else None, - memory_content=memory, - retrieval_time_ms=(time.time() - start_time) * 1000, - source_type="agent", - from_agent=self, - from_task=task, - ), - ) + crewai_event_bus.emit( + self, + event=MemoryRetrievalCompletedEvent( + task_id=str(task.id) if task else None, + memory_content=memory, + retrieval_time_ms=(time.time() - start_time) * 1000, + source_type="agent", + from_agent=self, + from_task=task, + ), + ) + except Exception as e: + crewai_event_bus.emit( + self, + event=MemoryRetrievalFailedEvent( + task_id=str(task.id) if task else None, + source_type="agent", + from_agent=self, + from_task=task, + error=str(e), + ), + ) knowledge_config = get_knowledge_config(self) task_prompt = await ahandle_knowledge_retrieval( @@ -1669,3 +1698,22 @@ class Agent(BaseAgent): ) return await lite_agent.kickoff_async(messages) + + +# Rebuild Agent model to resolve A2A type forward references +try: + from crewai.a2a.config import ( + A2AClientConfig as _A2AClientConfig, + A2AConfig as _A2AConfig, + A2AServerConfig as _A2AServerConfig, + ) + + Agent.model_rebuild( + _types_namespace={ + "A2AConfig": _A2AConfig, + "A2AClientConfig": _A2AClientConfig, + "A2AServerConfig": _A2AServerConfig, + } + ) +except ImportError: + pass diff --git a/lib/crewai/src/crewai/events/__init__.py b/lib/crewai/src/crewai/events/__init__.py index efbb479cd..61c0ec380 100644 --- a/lib/crewai/src/crewai/events/__init__.py +++ b/lib/crewai/src/crewai/events/__init__.py @@ -75,6 +75,7 @@ from crewai.events.types.memory_events import ( MemoryQueryFailedEvent, MemoryQueryStartedEvent, MemoryRetrievalCompletedEvent, + MemoryRetrievalFailedEvent, MemoryRetrievalStartedEvent, MemorySaveCompletedEvent, MemorySaveFailedEvent, @@ -174,6 +175,7 @@ __all__ = [ "MemoryQueryFailedEvent", "MemoryQueryStartedEvent", "MemoryRetrievalCompletedEvent", + "MemoryRetrievalFailedEvent", "MemoryRetrievalStartedEvent", "MemorySaveCompletedEvent", "MemorySaveFailedEvent", diff --git a/lib/crewai/src/crewai/events/base_events.py b/lib/crewai/src/crewai/events/base_events.py index 43a1ef797..cfeb1b32e 100644 --- a/lib/crewai/src/crewai/events/base_events.py +++ b/lib/crewai/src/crewai/events/base_events.py @@ -1,4 +1,5 @@ from collections.abc import Iterator +import contextvars from datetime import datetime, timezone import itertools from typing import Any @@ -9,27 +10,37 @@ from pydantic import BaseModel, Field from crewai.utilities.serialization import Serializable, to_serializable -_emission_counter: Iterator[int] = itertools.count(start=1) +_emission_counter: contextvars.ContextVar[Iterator[int]] = contextvars.ContextVar( + "_emission_counter" +) + + +def _get_or_create_counter() -> Iterator[int]: + """Get the emission counter for the current context, creating if needed.""" + try: + return _emission_counter.get() + except LookupError: + counter: Iterator[int] = itertools.count(start=1) + _emission_counter.set(counter) + return counter def get_next_emission_sequence() -> int: """Get the next emission sequence number. - Thread-safe due to atomic next() on itertools.count under the GIL. - Returns: The next sequence number. """ - return next(_emission_counter) + return next(_get_or_create_counter()) def reset_emission_counter() -> None: """Reset the emission sequence counter to 1. - Useful for test isolation. + Resets for the current context only. """ - global _emission_counter - _emission_counter = itertools.count(start=1) + counter: Iterator[int] = itertools.count(start=1) + _emission_counter.set(counter) class BaseEvent(BaseModel): diff --git a/lib/crewai/src/crewai/events/event_bus.py b/lib/crewai/src/crewai/events/event_bus.py index 4287efb58..11a564cca 100644 --- a/lib/crewai/src/crewai/events/event_bus.py +++ b/lib/crewai/src/crewai/events/event_bus.py @@ -420,7 +420,7 @@ class CrewAIEventsBus: return None - def flush(self, timeout: float | None = None) -> bool: + def flush(self, timeout: float | None = 30.0) -> bool: """Block until all pending event handlers complete. This method waits for all futures from previously emitted events to @@ -429,7 +429,7 @@ class CrewAIEventsBus: Args: timeout: Maximum time in seconds to wait for handlers to complete. - If None, waits indefinitely. + Defaults to 30 seconds. Pass None to wait indefinitely. Returns: True if all handlers completed, False if timeout occurred. diff --git a/lib/crewai/src/crewai/events/event_context.py b/lib/crewai/src/crewai/events/event_context.py index b097e0141..bc691e05b 100644 --- a/lib/crewai/src/crewai/events/event_context.py +++ b/lib/crewai/src/crewai/events/event_context.py @@ -192,6 +192,7 @@ SCOPE_ENDING_EVENTS: frozenset[str] = frozenset( "mcp_tool_execution_completed", "mcp_tool_execution_failed", "memory_retrieval_completed", + "memory_retrieval_failed", "memory_save_completed", "memory_save_failed", "memory_query_completed", @@ -241,6 +242,7 @@ VALID_EVENT_PAIRS: dict[str, str] = { "mcp_tool_execution_completed": "mcp_tool_execution_started", "mcp_tool_execution_failed": "mcp_tool_execution_started", "memory_retrieval_completed": "memory_retrieval_started", + "memory_retrieval_failed": "memory_retrieval_started", "memory_save_completed": "memory_save_started", "memory_save_failed": "memory_save_started", "memory_query_completed": "memory_query_started", diff --git a/lib/crewai/src/crewai/events/event_types.py b/lib/crewai/src/crewai/events/event_types.py index 78aa11fe0..5fca4bd7d 100644 --- a/lib/crewai/src/crewai/events/event_types.py +++ b/lib/crewai/src/crewai/events/event_types.py @@ -79,6 +79,7 @@ from crewai.events.types.memory_events import ( MemoryQueryFailedEvent, MemoryQueryStartedEvent, MemoryRetrievalCompletedEvent, + MemoryRetrievalFailedEvent, MemoryRetrievalStartedEvent, MemorySaveCompletedEvent, MemorySaveFailedEvent, @@ -173,6 +174,7 @@ EventTypes = ( | MemoryQueryFailedEvent | MemoryRetrievalStartedEvent | MemoryRetrievalCompletedEvent + | MemoryRetrievalFailedEvent | MCPConnectionStartedEvent | MCPConnectionCompletedEvent | MCPConnectionFailedEvent diff --git a/lib/crewai/src/crewai/events/types/memory_events.py b/lib/crewai/src/crewai/events/types/memory_events.py index 7e954427a..0fd57a352 100644 --- a/lib/crewai/src/crewai/events/types/memory_events.py +++ b/lib/crewai/src/crewai/events/types/memory_events.py @@ -14,7 +14,7 @@ class MemoryBaseEvent(BaseEvent): agent_role: str | None = None agent_id: str | None = None - def __init__(self, **data): + def __init__(self, **data: Any) -> None: super().__init__(**data) self._set_agent_params(data) self._set_task_params(data) @@ -93,3 +93,11 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent): task_id: str | None = None memory_content: str retrieval_time_ms: float + + +class MemoryRetrievalFailedEvent(MemoryBaseEvent): + """Event emitted when memory retrieval for a task prompt fails.""" + + type: str = "memory_retrieval_failed" + task_id: str | None = None + error: str