fix: ensure context isolation

This commit is contained in:
Greyson LaLonde
2026-01-20 15:27:16 -05:00
parent f9ee229ea1
commit f3cacdcf1d
9 changed files with 199 additions and 98 deletions

View File

@@ -14,15 +14,25 @@ from typing import (
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
from typing_extensions import NotRequired
from crewai.a2a.updates import (
PollingConfig,
PollingHandler,
PushNotificationConfig,
PushNotificationHandler,
StreamingConfig,
StreamingHandler,
UpdateConfig,
)
try:
from crewai.a2a.updates import (
PollingConfig,
PollingHandler,
PushNotificationConfig,
PushNotificationHandler,
StreamingConfig,
StreamingHandler,
UpdateConfig,
)
except ImportError:
PollingConfig = Any # type: ignore[misc,assignment]
PollingHandler = Any # type: ignore[misc,assignment]
PushNotificationConfig = Any # type: ignore[misc,assignment]
PushNotificationHandler = Any # type: ignore[misc,assignment]
StreamingConfig = Any # type: ignore[misc,assignment]
StreamingHandler = Any # type: ignore[misc,assignment]
UpdateConfig = Any # type: ignore[misc,assignment]
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]

View File

@@ -251,30 +251,48 @@ async def aexecute_a2a_delegation(
if turn_number is None:
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
result = await _aexecute_a2a_delegation_impl(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
is_multiturn=is_multiturn,
turn_number=turn_number,
agent_branch=agent_branch,
agent_id=agent_id,
agent_role=agent_role,
response_model=response_model,
updates=updates,
transport_protocol=transport_protocol,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
)
try:
result = await _aexecute_a2a_delegation_impl(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
is_multiturn=is_multiturn,
turn_number=turn_number,
agent_branch=agent_branch,
agent_id=agent_id,
agent_role=agent_role,
response_model=response_model,
updates=updates,
transport_protocol=transport_protocol,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
)
except Exception as e:
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status="failed",
result=None,
error=str(e),
context_id=context_id,
is_multiturn=is_multiturn,
endpoint=endpoint,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
raise
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
crewai_event_bus.emit(

View File

@@ -14,7 +14,14 @@ from typing import (
)
from urllib.parse import urlparse
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
model_validator,
)
from typing_extensions import Self
from crewai.agent.utils import (
@@ -41,6 +48,7 @@ from crewai.events.types.knowledge_events import (
)
from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
MemoryRetrievalFailedEvent,
MemoryRetrievalStartedEvent,
)
from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow
@@ -77,17 +85,10 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
try:
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
except ImportError:
A2AClientConfig = Any
A2AConfig = Any
A2AServerConfig = Any
if TYPE_CHECKING:
from crewai_tools import CodeInterpreterTool
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
from crewai.agents.agent_builder.base_agent import PlatformAppOrAction
from crewai.lite_agent_output import LiteAgentOutput
from crewai.task import Task
@@ -133,6 +134,8 @@ class Agent(BaseAgent):
mcps: List of MCP server references for tool integration.
"""
model_config = ConfigDict()
_times_executed: int = PrivateAttr(default=0)
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
_last_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
@@ -346,30 +349,43 @@ class Agent(BaseAgent):
)
start_time = time.time()
memory = ""
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._external_memory,
agent=self,
task=task,
)
memory = contextual_memory.build_context_for_task(task, context or "")
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
try:
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._external_memory,
agent=self,
task=task,
)
memory = contextual_memory.build_context_for_task(task, context or "")
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
crewai_event_bus.emit(
self,
event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None,
memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent",
from_agent=self,
from_task=task,
),
)
crewai_event_bus.emit(
self,
event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None,
memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent",
from_agent=self,
from_task=task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryRetrievalFailedEvent(
task_id=str(task.id) if task else None,
source_type="agent",
from_agent=self,
from_task=task,
error=str(e),
),
)
knowledge_config = get_knowledge_config(self)
task_prompt = handle_knowledge_retrieval(
@@ -555,32 +571,45 @@ class Agent(BaseAgent):
)
start_time = time.time()
memory = ""
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._external_memory,
agent=self,
task=task,
)
memory = await contextual_memory.abuild_context_for_task(
task, context or ""
)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
try:
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._external_memory,
agent=self,
task=task,
)
memory = await contextual_memory.abuild_context_for_task(
task, context or ""
)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
crewai_event_bus.emit(
self,
event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None,
memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent",
from_agent=self,
from_task=task,
),
)
crewai_event_bus.emit(
self,
event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None,
memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent",
from_agent=self,
from_task=task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryRetrievalFailedEvent(
task_id=str(task.id) if task else None,
source_type="agent",
from_agent=self,
from_task=task,
error=str(e),
),
)
knowledge_config = get_knowledge_config(self)
task_prompt = await ahandle_knowledge_retrieval(
@@ -1669,3 +1698,22 @@ class Agent(BaseAgent):
)
return await lite_agent.kickoff_async(messages)
# Rebuild Agent model to resolve A2A type forward references
try:
from crewai.a2a.config import (
A2AClientConfig as _A2AClientConfig,
A2AConfig as _A2AConfig,
A2AServerConfig as _A2AServerConfig,
)
Agent.model_rebuild(
_types_namespace={
"A2AConfig": _A2AConfig,
"A2AClientConfig": _A2AClientConfig,
"A2AServerConfig": _A2AServerConfig,
}
)
except ImportError:
pass

View File

@@ -75,6 +75,7 @@ from crewai.events.types.memory_events import (
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalFailedEvent,
MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
@@ -174,6 +175,7 @@ __all__ = [
"MemoryQueryFailedEvent",
"MemoryQueryStartedEvent",
"MemoryRetrievalCompletedEvent",
"MemoryRetrievalFailedEvent",
"MemoryRetrievalStartedEvent",
"MemorySaveCompletedEvent",
"MemorySaveFailedEvent",

View File

@@ -1,4 +1,5 @@
from collections.abc import Iterator
import contextvars
from datetime import datetime, timezone
import itertools
from typing import Any
@@ -9,27 +10,37 @@ from pydantic import BaseModel, Field
from crewai.utilities.serialization import Serializable, to_serializable
_emission_counter: Iterator[int] = itertools.count(start=1)
_emission_counter: contextvars.ContextVar[Iterator[int]] = contextvars.ContextVar(
"_emission_counter"
)
def _get_or_create_counter() -> Iterator[int]:
"""Get the emission counter for the current context, creating if needed."""
try:
return _emission_counter.get()
except LookupError:
counter: Iterator[int] = itertools.count(start=1)
_emission_counter.set(counter)
return counter
def get_next_emission_sequence() -> int:
"""Get the next emission sequence number.
Thread-safe due to atomic next() on itertools.count under the GIL.
Returns:
The next sequence number.
"""
return next(_emission_counter)
return next(_get_or_create_counter())
def reset_emission_counter() -> None:
"""Reset the emission sequence counter to 1.
Useful for test isolation.
Resets for the current context only.
"""
global _emission_counter
_emission_counter = itertools.count(start=1)
counter: Iterator[int] = itertools.count(start=1)
_emission_counter.set(counter)
class BaseEvent(BaseModel):

View File

@@ -420,7 +420,7 @@ class CrewAIEventsBus:
return None
def flush(self, timeout: float | None = None) -> bool:
def flush(self, timeout: float | None = 30.0) -> bool:
"""Block until all pending event handlers complete.
This method waits for all futures from previously emitted events to
@@ -429,7 +429,7 @@ class CrewAIEventsBus:
Args:
timeout: Maximum time in seconds to wait for handlers to complete.
If None, waits indefinitely.
Defaults to 30 seconds. Pass None to wait indefinitely.
Returns:
True if all handlers completed, False if timeout occurred.

View File

@@ -192,6 +192,7 @@ SCOPE_ENDING_EVENTS: frozenset[str] = frozenset(
"mcp_tool_execution_completed",
"mcp_tool_execution_failed",
"memory_retrieval_completed",
"memory_retrieval_failed",
"memory_save_completed",
"memory_save_failed",
"memory_query_completed",
@@ -241,6 +242,7 @@ VALID_EVENT_PAIRS: dict[str, str] = {
"mcp_tool_execution_completed": "mcp_tool_execution_started",
"mcp_tool_execution_failed": "mcp_tool_execution_started",
"memory_retrieval_completed": "memory_retrieval_started",
"memory_retrieval_failed": "memory_retrieval_started",
"memory_save_completed": "memory_save_started",
"memory_save_failed": "memory_save_started",
"memory_query_completed": "memory_query_started",

View File

@@ -79,6 +79,7 @@ from crewai.events.types.memory_events import (
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalFailedEvent,
MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
@@ -173,6 +174,7 @@ EventTypes = (
| MemoryQueryFailedEvent
| MemoryRetrievalStartedEvent
| MemoryRetrievalCompletedEvent
| MemoryRetrievalFailedEvent
| MCPConnectionStartedEvent
| MCPConnectionCompletedEvent
| MCPConnectionFailedEvent

View File

@@ -14,7 +14,7 @@ class MemoryBaseEvent(BaseEvent):
agent_role: str | None = None
agent_id: str | None = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
self._set_agent_params(data)
self._set_task_params(data)
@@ -93,3 +93,11 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
task_id: str | None = None
memory_content: str
retrieval_time_ms: float
class MemoryRetrievalFailedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt fails."""
type: str = "memory_retrieval_failed"
task_id: str | None = None
error: str