mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 17:18:13 +00:00
fix: ensure context isolation
This commit is contained in:
@@ -14,15 +14,25 @@ from typing import (
|
|||||||
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
|
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from crewai.a2a.updates import (
|
|
||||||
PollingConfig,
|
try:
|
||||||
PollingHandler,
|
from crewai.a2a.updates import (
|
||||||
PushNotificationConfig,
|
PollingConfig,
|
||||||
PushNotificationHandler,
|
PollingHandler,
|
||||||
StreamingConfig,
|
PushNotificationConfig,
|
||||||
StreamingHandler,
|
PushNotificationHandler,
|
||||||
UpdateConfig,
|
StreamingConfig,
|
||||||
)
|
StreamingHandler,
|
||||||
|
UpdateConfig,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
PollingConfig = Any # type: ignore[misc,assignment]
|
||||||
|
PollingHandler = Any # type: ignore[misc,assignment]
|
||||||
|
PushNotificationConfig = Any # type: ignore[misc,assignment]
|
||||||
|
PushNotificationHandler = Any # type: ignore[misc,assignment]
|
||||||
|
StreamingConfig = Any # type: ignore[misc,assignment]
|
||||||
|
StreamingHandler = Any # type: ignore[misc,assignment]
|
||||||
|
UpdateConfig = Any # type: ignore[misc,assignment]
|
||||||
|
|
||||||
|
|
||||||
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
||||||
|
|||||||
@@ -251,30 +251,48 @@ async def aexecute_a2a_delegation(
|
|||||||
if turn_number is None:
|
if turn_number is None:
|
||||||
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
||||||
|
|
||||||
result = await _aexecute_a2a_delegation_impl(
|
try:
|
||||||
endpoint=endpoint,
|
result = await _aexecute_a2a_delegation_impl(
|
||||||
auth=auth,
|
endpoint=endpoint,
|
||||||
timeout=timeout,
|
auth=auth,
|
||||||
task_description=task_description,
|
timeout=timeout,
|
||||||
context=context,
|
task_description=task_description,
|
||||||
context_id=context_id,
|
context=context,
|
||||||
task_id=task_id,
|
context_id=context_id,
|
||||||
reference_task_ids=reference_task_ids,
|
task_id=task_id,
|
||||||
metadata=metadata,
|
reference_task_ids=reference_task_ids,
|
||||||
extensions=extensions,
|
metadata=metadata,
|
||||||
conversation_history=conversation_history,
|
extensions=extensions,
|
||||||
is_multiturn=is_multiturn,
|
conversation_history=conversation_history,
|
||||||
turn_number=turn_number,
|
is_multiturn=is_multiturn,
|
||||||
agent_branch=agent_branch,
|
turn_number=turn_number,
|
||||||
agent_id=agent_id,
|
agent_branch=agent_branch,
|
||||||
agent_role=agent_role,
|
agent_id=agent_id,
|
||||||
response_model=response_model,
|
agent_role=agent_role,
|
||||||
updates=updates,
|
response_model=response_model,
|
||||||
transport_protocol=transport_protocol,
|
updates=updates,
|
||||||
from_task=from_task,
|
transport_protocol=transport_protocol,
|
||||||
from_agent=from_agent,
|
from_task=from_task,
|
||||||
skill_id=skill_id,
|
from_agent=from_agent,
|
||||||
)
|
skill_id=skill_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
agent_branch,
|
||||||
|
A2ADelegationCompletedEvent(
|
||||||
|
status="failed",
|
||||||
|
result=None,
|
||||||
|
error=str(e),
|
||||||
|
context_id=context_id,
|
||||||
|
is_multiturn=is_multiturn,
|
||||||
|
endpoint=endpoint,
|
||||||
|
metadata=metadata,
|
||||||
|
extensions=list(extensions.keys()) if extensions else None,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
|
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
|
|||||||
@@ -14,7 +14,14 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
InstanceOf,
|
||||||
|
PrivateAttr,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.agent.utils import (
|
from crewai.agent.utils import (
|
||||||
@@ -41,6 +48,7 @@ from crewai.events.types.knowledge_events import (
|
|||||||
)
|
)
|
||||||
from crewai.events.types.memory_events import (
|
from crewai.events.types.memory_events import (
|
||||||
MemoryRetrievalCompletedEvent,
|
MemoryRetrievalCompletedEvent,
|
||||||
|
MemoryRetrievalFailedEvent,
|
||||||
MemoryRetrievalStartedEvent,
|
MemoryRetrievalStartedEvent,
|
||||||
)
|
)
|
||||||
from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow
|
from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow
|
||||||
@@ -77,17 +85,10 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
|
|||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
|
||||||
except ImportError:
|
|
||||||
A2AClientConfig = Any
|
|
||||||
A2AConfig = Any
|
|
||||||
A2AServerConfig = Any
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai_tools import CodeInterpreterTool
|
from crewai_tools import CodeInterpreterTool
|
||||||
|
|
||||||
|
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||||
from crewai.agents.agent_builder.base_agent import PlatformAppOrAction
|
from crewai.agents.agent_builder.base_agent import PlatformAppOrAction
|
||||||
from crewai.lite_agent_output import LiteAgentOutput
|
from crewai.lite_agent_output import LiteAgentOutput
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
@@ -133,6 +134,8 @@ class Agent(BaseAgent):
|
|||||||
mcps: List of MCP server references for tool integration.
|
mcps: List of MCP server references for tool integration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict()
|
||||||
|
|
||||||
_times_executed: int = PrivateAttr(default=0)
|
_times_executed: int = PrivateAttr(default=0)
|
||||||
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
|
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
|
||||||
_last_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
|
_last_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
|
||||||
@@ -346,30 +349,43 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
memory = ""
|
||||||
|
|
||||||
contextual_memory = ContextualMemory(
|
try:
|
||||||
self.crew._short_term_memory,
|
contextual_memory = ContextualMemory(
|
||||||
self.crew._long_term_memory,
|
self.crew._short_term_memory,
|
||||||
self.crew._entity_memory,
|
self.crew._long_term_memory,
|
||||||
self.crew._external_memory,
|
self.crew._entity_memory,
|
||||||
agent=self,
|
self.crew._external_memory,
|
||||||
task=task,
|
agent=self,
|
||||||
)
|
task=task,
|
||||||
memory = contextual_memory.build_context_for_task(task, context or "")
|
)
|
||||||
if memory.strip() != "":
|
memory = contextual_memory.build_context_for_task(task, context or "")
|
||||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
if memory.strip() != "":
|
||||||
|
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryRetrievalCompletedEvent(
|
event=MemoryRetrievalCompletedEvent(
|
||||||
task_id=str(task.id) if task else None,
|
task_id=str(task.id) if task else None,
|
||||||
memory_content=memory,
|
memory_content=memory,
|
||||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||||
source_type="agent",
|
source_type="agent",
|
||||||
from_agent=self,
|
from_agent=self,
|
||||||
from_task=task,
|
from_task=task,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryRetrievalFailedEvent(
|
||||||
|
task_id=str(task.id) if task else None,
|
||||||
|
source_type="agent",
|
||||||
|
from_agent=self,
|
||||||
|
from_task=task,
|
||||||
|
error=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
knowledge_config = get_knowledge_config(self)
|
knowledge_config = get_knowledge_config(self)
|
||||||
task_prompt = handle_knowledge_retrieval(
|
task_prompt = handle_knowledge_retrieval(
|
||||||
@@ -555,32 +571,45 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
memory = ""
|
||||||
|
|
||||||
contextual_memory = ContextualMemory(
|
try:
|
||||||
self.crew._short_term_memory,
|
contextual_memory = ContextualMemory(
|
||||||
self.crew._long_term_memory,
|
self.crew._short_term_memory,
|
||||||
self.crew._entity_memory,
|
self.crew._long_term_memory,
|
||||||
self.crew._external_memory,
|
self.crew._entity_memory,
|
||||||
agent=self,
|
self.crew._external_memory,
|
||||||
task=task,
|
agent=self,
|
||||||
)
|
task=task,
|
||||||
memory = await contextual_memory.abuild_context_for_task(
|
)
|
||||||
task, context or ""
|
memory = await contextual_memory.abuild_context_for_task(
|
||||||
)
|
task, context or ""
|
||||||
if memory.strip() != "":
|
)
|
||||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
if memory.strip() != "":
|
||||||
|
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryRetrievalCompletedEvent(
|
event=MemoryRetrievalCompletedEvent(
|
||||||
task_id=str(task.id) if task else None,
|
task_id=str(task.id) if task else None,
|
||||||
memory_content=memory,
|
memory_content=memory,
|
||||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||||
source_type="agent",
|
source_type="agent",
|
||||||
from_agent=self,
|
from_agent=self,
|
||||||
from_task=task,
|
from_task=task,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryRetrievalFailedEvent(
|
||||||
|
task_id=str(task.id) if task else None,
|
||||||
|
source_type="agent",
|
||||||
|
from_agent=self,
|
||||||
|
from_task=task,
|
||||||
|
error=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
knowledge_config = get_knowledge_config(self)
|
knowledge_config = get_knowledge_config(self)
|
||||||
task_prompt = await ahandle_knowledge_retrieval(
|
task_prompt = await ahandle_knowledge_retrieval(
|
||||||
@@ -1669,3 +1698,22 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return await lite_agent.kickoff_async(messages)
|
return await lite_agent.kickoff_async(messages)
|
||||||
|
|
||||||
|
|
||||||
|
# Rebuild Agent model to resolve A2A type forward references
|
||||||
|
try:
|
||||||
|
from crewai.a2a.config import (
|
||||||
|
A2AClientConfig as _A2AClientConfig,
|
||||||
|
A2AConfig as _A2AConfig,
|
||||||
|
A2AServerConfig as _A2AServerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
Agent.model_rebuild(
|
||||||
|
_types_namespace={
|
||||||
|
"A2AConfig": _A2AConfig,
|
||||||
|
"A2AClientConfig": _A2AClientConfig,
|
||||||
|
"A2AServerConfig": _A2AServerConfig,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ from crewai.events.types.memory_events import (
|
|||||||
MemoryQueryFailedEvent,
|
MemoryQueryFailedEvent,
|
||||||
MemoryQueryStartedEvent,
|
MemoryQueryStartedEvent,
|
||||||
MemoryRetrievalCompletedEvent,
|
MemoryRetrievalCompletedEvent,
|
||||||
|
MemoryRetrievalFailedEvent,
|
||||||
MemoryRetrievalStartedEvent,
|
MemoryRetrievalStartedEvent,
|
||||||
MemorySaveCompletedEvent,
|
MemorySaveCompletedEvent,
|
||||||
MemorySaveFailedEvent,
|
MemorySaveFailedEvent,
|
||||||
@@ -174,6 +175,7 @@ __all__ = [
|
|||||||
"MemoryQueryFailedEvent",
|
"MemoryQueryFailedEvent",
|
||||||
"MemoryQueryStartedEvent",
|
"MemoryQueryStartedEvent",
|
||||||
"MemoryRetrievalCompletedEvent",
|
"MemoryRetrievalCompletedEvent",
|
||||||
|
"MemoryRetrievalFailedEvent",
|
||||||
"MemoryRetrievalStartedEvent",
|
"MemoryRetrievalStartedEvent",
|
||||||
"MemorySaveCompletedEvent",
|
"MemorySaveCompletedEvent",
|
||||||
"MemorySaveFailedEvent",
|
"MemorySaveFailedEvent",
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
import contextvars
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -9,27 +10,37 @@ from pydantic import BaseModel, Field
|
|||||||
from crewai.utilities.serialization import Serializable, to_serializable
|
from crewai.utilities.serialization import Serializable, to_serializable
|
||||||
|
|
||||||
|
|
||||||
_emission_counter: Iterator[int] = itertools.count(start=1)
|
_emission_counter: contextvars.ContextVar[Iterator[int]] = contextvars.ContextVar(
|
||||||
|
"_emission_counter"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_create_counter() -> Iterator[int]:
|
||||||
|
"""Get the emission counter for the current context, creating if needed."""
|
||||||
|
try:
|
||||||
|
return _emission_counter.get()
|
||||||
|
except LookupError:
|
||||||
|
counter: Iterator[int] = itertools.count(start=1)
|
||||||
|
_emission_counter.set(counter)
|
||||||
|
return counter
|
||||||
|
|
||||||
|
|
||||||
def get_next_emission_sequence() -> int:
|
def get_next_emission_sequence() -> int:
|
||||||
"""Get the next emission sequence number.
|
"""Get the next emission sequence number.
|
||||||
|
|
||||||
Thread-safe due to atomic next() on itertools.count under the GIL.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The next sequence number.
|
The next sequence number.
|
||||||
"""
|
"""
|
||||||
return next(_emission_counter)
|
return next(_get_or_create_counter())
|
||||||
|
|
||||||
|
|
||||||
def reset_emission_counter() -> None:
|
def reset_emission_counter() -> None:
|
||||||
"""Reset the emission sequence counter to 1.
|
"""Reset the emission sequence counter to 1.
|
||||||
|
|
||||||
Useful for test isolation.
|
Resets for the current context only.
|
||||||
"""
|
"""
|
||||||
global _emission_counter
|
counter: Iterator[int] = itertools.count(start=1)
|
||||||
_emission_counter = itertools.count(start=1)
|
_emission_counter.set(counter)
|
||||||
|
|
||||||
|
|
||||||
class BaseEvent(BaseModel):
|
class BaseEvent(BaseModel):
|
||||||
|
|||||||
@@ -420,7 +420,7 @@ class CrewAIEventsBus:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def flush(self, timeout: float | None = None) -> bool:
|
def flush(self, timeout: float | None = 30.0) -> bool:
|
||||||
"""Block until all pending event handlers complete.
|
"""Block until all pending event handlers complete.
|
||||||
|
|
||||||
This method waits for all futures from previously emitted events to
|
This method waits for all futures from previously emitted events to
|
||||||
@@ -429,7 +429,7 @@ class CrewAIEventsBus:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
timeout: Maximum time in seconds to wait for handlers to complete.
|
timeout: Maximum time in seconds to wait for handlers to complete.
|
||||||
If None, waits indefinitely.
|
Defaults to 30 seconds. Pass None to wait indefinitely.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if all handlers completed, False if timeout occurred.
|
True if all handlers completed, False if timeout occurred.
|
||||||
|
|||||||
@@ -192,6 +192,7 @@ SCOPE_ENDING_EVENTS: frozenset[str] = frozenset(
|
|||||||
"mcp_tool_execution_completed",
|
"mcp_tool_execution_completed",
|
||||||
"mcp_tool_execution_failed",
|
"mcp_tool_execution_failed",
|
||||||
"memory_retrieval_completed",
|
"memory_retrieval_completed",
|
||||||
|
"memory_retrieval_failed",
|
||||||
"memory_save_completed",
|
"memory_save_completed",
|
||||||
"memory_save_failed",
|
"memory_save_failed",
|
||||||
"memory_query_completed",
|
"memory_query_completed",
|
||||||
@@ -241,6 +242,7 @@ VALID_EVENT_PAIRS: dict[str, str] = {
|
|||||||
"mcp_tool_execution_completed": "mcp_tool_execution_started",
|
"mcp_tool_execution_completed": "mcp_tool_execution_started",
|
||||||
"mcp_tool_execution_failed": "mcp_tool_execution_started",
|
"mcp_tool_execution_failed": "mcp_tool_execution_started",
|
||||||
"memory_retrieval_completed": "memory_retrieval_started",
|
"memory_retrieval_completed": "memory_retrieval_started",
|
||||||
|
"memory_retrieval_failed": "memory_retrieval_started",
|
||||||
"memory_save_completed": "memory_save_started",
|
"memory_save_completed": "memory_save_started",
|
||||||
"memory_save_failed": "memory_save_started",
|
"memory_save_failed": "memory_save_started",
|
||||||
"memory_query_completed": "memory_query_started",
|
"memory_query_completed": "memory_query_started",
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ from crewai.events.types.memory_events import (
|
|||||||
MemoryQueryFailedEvent,
|
MemoryQueryFailedEvent,
|
||||||
MemoryQueryStartedEvent,
|
MemoryQueryStartedEvent,
|
||||||
MemoryRetrievalCompletedEvent,
|
MemoryRetrievalCompletedEvent,
|
||||||
|
MemoryRetrievalFailedEvent,
|
||||||
MemoryRetrievalStartedEvent,
|
MemoryRetrievalStartedEvent,
|
||||||
MemorySaveCompletedEvent,
|
MemorySaveCompletedEvent,
|
||||||
MemorySaveFailedEvent,
|
MemorySaveFailedEvent,
|
||||||
@@ -173,6 +174,7 @@ EventTypes = (
|
|||||||
| MemoryQueryFailedEvent
|
| MemoryQueryFailedEvent
|
||||||
| MemoryRetrievalStartedEvent
|
| MemoryRetrievalStartedEvent
|
||||||
| MemoryRetrievalCompletedEvent
|
| MemoryRetrievalCompletedEvent
|
||||||
|
| MemoryRetrievalFailedEvent
|
||||||
| MCPConnectionStartedEvent
|
| MCPConnectionStartedEvent
|
||||||
| MCPConnectionCompletedEvent
|
| MCPConnectionCompletedEvent
|
||||||
| MCPConnectionFailedEvent
|
| MCPConnectionFailedEvent
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class MemoryBaseEvent(BaseEvent):
|
|||||||
agent_role: str | None = None
|
agent_role: str | None = None
|
||||||
agent_id: str | None = None
|
agent_id: str | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
self._set_agent_params(data)
|
self._set_agent_params(data)
|
||||||
self._set_task_params(data)
|
self._set_task_params(data)
|
||||||
@@ -93,3 +93,11 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
|||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
memory_content: str
|
memory_content: str
|
||||||
retrieval_time_ms: float
|
retrieval_time_ms: float
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRetrievalFailedEvent(MemoryBaseEvent):
|
||||||
|
"""Event emitted when memory retrieval for a task prompt fails."""
|
||||||
|
|
||||||
|
type: str = "memory_retrieval_failed"
|
||||||
|
task_id: str | None = None
|
||||||
|
error: str
|
||||||
|
|||||||
Reference in New Issue
Block a user