fix: ensure context isolation

This commit is contained in:
Greyson LaLonde
2026-01-20 15:27:16 -05:00
parent f9ee229ea1
commit f3cacdcf1d
9 changed files with 199 additions and 98 deletions

View File

@@ -14,15 +14,25 @@ from typing import (
from pydantic import BeforeValidator, HttpUrl, TypeAdapter from pydantic import BeforeValidator, HttpUrl, TypeAdapter
from typing_extensions import NotRequired from typing_extensions import NotRequired
from crewai.a2a.updates import (
PollingConfig, try:
PollingHandler, from crewai.a2a.updates import (
PushNotificationConfig, PollingConfig,
PushNotificationHandler, PollingHandler,
StreamingConfig, PushNotificationConfig,
StreamingHandler, PushNotificationHandler,
UpdateConfig, StreamingConfig,
) StreamingHandler,
UpdateConfig,
)
except ImportError:
PollingConfig = Any # type: ignore[misc,assignment]
PollingHandler = Any # type: ignore[misc,assignment]
PushNotificationConfig = Any # type: ignore[misc,assignment]
PushNotificationHandler = Any # type: ignore[misc,assignment]
StreamingConfig = Any # type: ignore[misc,assignment]
StreamingHandler = Any # type: ignore[misc,assignment]
UpdateConfig = Any # type: ignore[misc,assignment]
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"] TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]

View File

@@ -251,30 +251,48 @@ async def aexecute_a2a_delegation(
if turn_number is None: if turn_number is None:
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1 turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
result = await _aexecute_a2a_delegation_impl( try:
endpoint=endpoint, result = await _aexecute_a2a_delegation_impl(
auth=auth, endpoint=endpoint,
timeout=timeout, auth=auth,
task_description=task_description, timeout=timeout,
context=context, task_description=task_description,
context_id=context_id, context=context,
task_id=task_id, context_id=context_id,
reference_task_ids=reference_task_ids, task_id=task_id,
metadata=metadata, reference_task_ids=reference_task_ids,
extensions=extensions, metadata=metadata,
conversation_history=conversation_history, extensions=extensions,
is_multiturn=is_multiturn, conversation_history=conversation_history,
turn_number=turn_number, is_multiturn=is_multiturn,
agent_branch=agent_branch, turn_number=turn_number,
agent_id=agent_id, agent_branch=agent_branch,
agent_role=agent_role, agent_id=agent_id,
response_model=response_model, agent_role=agent_role,
updates=updates, response_model=response_model,
transport_protocol=transport_protocol, updates=updates,
from_task=from_task, transport_protocol=transport_protocol,
from_agent=from_agent, from_task=from_task,
skill_id=skill_id, from_agent=from_agent,
) skill_id=skill_id,
)
except Exception as e:
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status="failed",
result=None,
error=str(e),
context_id=context_id,
is_multiturn=is_multiturn,
endpoint=endpoint,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
raise
agent_card_data: dict[str, Any] = result.get("agent_card") or {} agent_card_data: dict[str, Any] = result.get("agent_card") or {}
crewai_event_bus.emit( crewai_event_bus.emit(

View File

@@ -14,7 +14,14 @@ from typing import (
) )
from urllib.parse import urlparse from urllib.parse import urlparse
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator from pydantic import (
BaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
model_validator,
)
from typing_extensions import Self from typing_extensions import Self
from crewai.agent.utils import ( from crewai.agent.utils import (
@@ -41,6 +48,7 @@ from crewai.events.types.knowledge_events import (
) )
from crewai.events.types.memory_events import ( from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent, MemoryRetrievalCompletedEvent,
MemoryRetrievalFailedEvent,
MemoryRetrievalStartedEvent, MemoryRetrievalStartedEvent,
) )
from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow
@@ -77,17 +85,10 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
try:
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
except ImportError:
A2AClientConfig = Any
A2AConfig = Any
A2AServerConfig = Any
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai_tools import CodeInterpreterTool from crewai_tools import CodeInterpreterTool
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.agents.agent_builder.base_agent import PlatformAppOrAction from crewai.agents.agent_builder.base_agent import PlatformAppOrAction
from crewai.lite_agent_output import LiteAgentOutput from crewai.lite_agent_output import LiteAgentOutput
from crewai.task import Task from crewai.task import Task
@@ -133,6 +134,8 @@ class Agent(BaseAgent):
mcps: List of MCP server references for tool integration. mcps: List of MCP server references for tool integration.
""" """
model_config = ConfigDict()
_times_executed: int = PrivateAttr(default=0) _times_executed: int = PrivateAttr(default=0)
_mcp_clients: list[Any] = PrivateAttr(default_factory=list) _mcp_clients: list[Any] = PrivateAttr(default_factory=list)
_last_messages: list[LLMMessage] = PrivateAttr(default_factory=list) _last_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
@@ -346,30 +349,43 @@ class Agent(BaseAgent):
) )
start_time = time.time() start_time = time.time()
memory = ""
contextual_memory = ContextualMemory( try:
self.crew._short_term_memory, contextual_memory = ContextualMemory(
self.crew._long_term_memory, self.crew._short_term_memory,
self.crew._entity_memory, self.crew._long_term_memory,
self.crew._external_memory, self.crew._entity_memory,
agent=self, self.crew._external_memory,
task=task, agent=self,
) task=task,
memory = contextual_memory.build_context_for_task(task, context or "") )
if memory.strip() != "": memory = contextual_memory.build_context_for_task(task, context or "")
task_prompt += self.i18n.slice("memory").format(memory=memory) if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemoryRetrievalCompletedEvent( event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None, task_id=str(task.id) if task else None,
memory_content=memory, memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000, retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent", source_type="agent",
from_agent=self, from_agent=self,
from_task=task, from_task=task,
), ),
) )
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryRetrievalFailedEvent(
task_id=str(task.id) if task else None,
source_type="agent",
from_agent=self,
from_task=task,
error=str(e),
),
)
knowledge_config = get_knowledge_config(self) knowledge_config = get_knowledge_config(self)
task_prompt = handle_knowledge_retrieval( task_prompt = handle_knowledge_retrieval(
@@ -555,32 +571,45 @@ class Agent(BaseAgent):
) )
start_time = time.time() start_time = time.time()
memory = ""
contextual_memory = ContextualMemory( try:
self.crew._short_term_memory, contextual_memory = ContextualMemory(
self.crew._long_term_memory, self.crew._short_term_memory,
self.crew._entity_memory, self.crew._long_term_memory,
self.crew._external_memory, self.crew._entity_memory,
agent=self, self.crew._external_memory,
task=task, agent=self,
) task=task,
memory = await contextual_memory.abuild_context_for_task( )
task, context or "" memory = await contextual_memory.abuild_context_for_task(
) task, context or ""
if memory.strip() != "": )
task_prompt += self.i18n.slice("memory").format(memory=memory) if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=MemoryRetrievalCompletedEvent( event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None, task_id=str(task.id) if task else None,
memory_content=memory, memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000, retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent", source_type="agent",
from_agent=self, from_agent=self,
from_task=task, from_task=task,
), ),
) )
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryRetrievalFailedEvent(
task_id=str(task.id) if task else None,
source_type="agent",
from_agent=self,
from_task=task,
error=str(e),
),
)
knowledge_config = get_knowledge_config(self) knowledge_config = get_knowledge_config(self)
task_prompt = await ahandle_knowledge_retrieval( task_prompt = await ahandle_knowledge_retrieval(
@@ -1669,3 +1698,22 @@ class Agent(BaseAgent):
) )
return await lite_agent.kickoff_async(messages) return await lite_agent.kickoff_async(messages)
# Rebuild Agent model to resolve A2A type forward references
try:
from crewai.a2a.config import (
A2AClientConfig as _A2AClientConfig,
A2AConfig as _A2AConfig,
A2AServerConfig as _A2AServerConfig,
)
Agent.model_rebuild(
_types_namespace={
"A2AConfig": _A2AConfig,
"A2AClientConfig": _A2AClientConfig,
"A2AServerConfig": _A2AServerConfig,
}
)
except ImportError:
pass

View File

@@ -75,6 +75,7 @@ from crewai.events.types.memory_events import (
MemoryQueryFailedEvent, MemoryQueryFailedEvent,
MemoryQueryStartedEvent, MemoryQueryStartedEvent,
MemoryRetrievalCompletedEvent, MemoryRetrievalCompletedEvent,
MemoryRetrievalFailedEvent,
MemoryRetrievalStartedEvent, MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent, MemorySaveCompletedEvent,
MemorySaveFailedEvent, MemorySaveFailedEvent,
@@ -174,6 +175,7 @@ __all__ = [
"MemoryQueryFailedEvent", "MemoryQueryFailedEvent",
"MemoryQueryStartedEvent", "MemoryQueryStartedEvent",
"MemoryRetrievalCompletedEvent", "MemoryRetrievalCompletedEvent",
"MemoryRetrievalFailedEvent",
"MemoryRetrievalStartedEvent", "MemoryRetrievalStartedEvent",
"MemorySaveCompletedEvent", "MemorySaveCompletedEvent",
"MemorySaveFailedEvent", "MemorySaveFailedEvent",

View File

@@ -1,4 +1,5 @@
from collections.abc import Iterator from collections.abc import Iterator
import contextvars
from datetime import datetime, timezone from datetime import datetime, timezone
import itertools import itertools
from typing import Any from typing import Any
@@ -9,27 +10,37 @@ from pydantic import BaseModel, Field
from crewai.utilities.serialization import Serializable, to_serializable from crewai.utilities.serialization import Serializable, to_serializable
_emission_counter: Iterator[int] = itertools.count(start=1) _emission_counter: contextvars.ContextVar[Iterator[int]] = contextvars.ContextVar(
"_emission_counter"
)
def _get_or_create_counter() -> Iterator[int]:
"""Get the emission counter for the current context, creating if needed."""
try:
return _emission_counter.get()
except LookupError:
counter: Iterator[int] = itertools.count(start=1)
_emission_counter.set(counter)
return counter
def get_next_emission_sequence() -> int: def get_next_emission_sequence() -> int:
"""Get the next emission sequence number. """Get the next emission sequence number.
Thread-safe due to atomic next() on itertools.count under the GIL.
Returns: Returns:
The next sequence number. The next sequence number.
""" """
return next(_emission_counter) return next(_get_or_create_counter())
def reset_emission_counter() -> None: def reset_emission_counter() -> None:
"""Reset the emission sequence counter to 1. """Reset the emission sequence counter to 1.
Useful for test isolation. Resets for the current context only.
""" """
global _emission_counter counter: Iterator[int] = itertools.count(start=1)
_emission_counter = itertools.count(start=1) _emission_counter.set(counter)
class BaseEvent(BaseModel): class BaseEvent(BaseModel):

View File

@@ -420,7 +420,7 @@ class CrewAIEventsBus:
return None return None
def flush(self, timeout: float | None = None) -> bool: def flush(self, timeout: float | None = 30.0) -> bool:
"""Block until all pending event handlers complete. """Block until all pending event handlers complete.
This method waits for all futures from previously emitted events to This method waits for all futures from previously emitted events to
@@ -429,7 +429,7 @@ class CrewAIEventsBus:
Args: Args:
timeout: Maximum time in seconds to wait for handlers to complete. timeout: Maximum time in seconds to wait for handlers to complete.
If None, waits indefinitely. Defaults to 30 seconds. Pass None to wait indefinitely.
Returns: Returns:
True if all handlers completed, False if timeout occurred. True if all handlers completed, False if timeout occurred.

View File

@@ -192,6 +192,7 @@ SCOPE_ENDING_EVENTS: frozenset[str] = frozenset(
"mcp_tool_execution_completed", "mcp_tool_execution_completed",
"mcp_tool_execution_failed", "mcp_tool_execution_failed",
"memory_retrieval_completed", "memory_retrieval_completed",
"memory_retrieval_failed",
"memory_save_completed", "memory_save_completed",
"memory_save_failed", "memory_save_failed",
"memory_query_completed", "memory_query_completed",
@@ -241,6 +242,7 @@ VALID_EVENT_PAIRS: dict[str, str] = {
"mcp_tool_execution_completed": "mcp_tool_execution_started", "mcp_tool_execution_completed": "mcp_tool_execution_started",
"mcp_tool_execution_failed": "mcp_tool_execution_started", "mcp_tool_execution_failed": "mcp_tool_execution_started",
"memory_retrieval_completed": "memory_retrieval_started", "memory_retrieval_completed": "memory_retrieval_started",
"memory_retrieval_failed": "memory_retrieval_started",
"memory_save_completed": "memory_save_started", "memory_save_completed": "memory_save_started",
"memory_save_failed": "memory_save_started", "memory_save_failed": "memory_save_started",
"memory_query_completed": "memory_query_started", "memory_query_completed": "memory_query_started",

View File

@@ -79,6 +79,7 @@ from crewai.events.types.memory_events import (
MemoryQueryFailedEvent, MemoryQueryFailedEvent,
MemoryQueryStartedEvent, MemoryQueryStartedEvent,
MemoryRetrievalCompletedEvent, MemoryRetrievalCompletedEvent,
MemoryRetrievalFailedEvent,
MemoryRetrievalStartedEvent, MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent, MemorySaveCompletedEvent,
MemorySaveFailedEvent, MemorySaveFailedEvent,
@@ -173,6 +174,7 @@ EventTypes = (
| MemoryQueryFailedEvent | MemoryQueryFailedEvent
| MemoryRetrievalStartedEvent | MemoryRetrievalStartedEvent
| MemoryRetrievalCompletedEvent | MemoryRetrievalCompletedEvent
| MemoryRetrievalFailedEvent
| MCPConnectionStartedEvent | MCPConnectionStartedEvent
| MCPConnectionCompletedEvent | MCPConnectionCompletedEvent
| MCPConnectionFailedEvent | MCPConnectionFailedEvent

View File

@@ -14,7 +14,7 @@ class MemoryBaseEvent(BaseEvent):
agent_role: str | None = None agent_role: str | None = None
agent_id: str | None = None agent_id: str | None = None
def __init__(self, **data): def __init__(self, **data: Any) -> None:
super().__init__(**data) super().__init__(**data)
self._set_agent_params(data) self._set_agent_params(data)
self._set_task_params(data) self._set_task_params(data)
@@ -93,3 +93,11 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
task_id: str | None = None task_id: str | None = None
memory_content: str memory_content: str
retrieval_time_ms: float retrieval_time_ms: float
class MemoryRetrievalFailedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt fails."""
type: str = "memory_retrieval_failed"
task_id: str | None = None
error: str