mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 21:38:14 +00:00
fix: ensure context isolation
This commit is contained in:
@@ -14,15 +14,25 @@ from typing import (
|
||||
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from crewai.a2a.updates import (
|
||||
PollingConfig,
|
||||
PollingHandler,
|
||||
PushNotificationConfig,
|
||||
PushNotificationHandler,
|
||||
StreamingConfig,
|
||||
StreamingHandler,
|
||||
UpdateConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from crewai.a2a.updates import (
|
||||
PollingConfig,
|
||||
PollingHandler,
|
||||
PushNotificationConfig,
|
||||
PushNotificationHandler,
|
||||
StreamingConfig,
|
||||
StreamingHandler,
|
||||
UpdateConfig,
|
||||
)
|
||||
except ImportError:
|
||||
PollingConfig = Any # type: ignore[misc,assignment]
|
||||
PollingHandler = Any # type: ignore[misc,assignment]
|
||||
PushNotificationConfig = Any # type: ignore[misc,assignment]
|
||||
PushNotificationHandler = Any # type: ignore[misc,assignment]
|
||||
StreamingConfig = Any # type: ignore[misc,assignment]
|
||||
StreamingHandler = Any # type: ignore[misc,assignment]
|
||||
UpdateConfig = Any # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
||||
|
||||
@@ -251,30 +251,48 @@ async def aexecute_a2a_delegation(
|
||||
if turn_number is None:
|
||||
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
||||
|
||||
result = await _aexecute_a2a_delegation_impl(
|
||||
endpoint=endpoint,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
conversation_history=conversation_history,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
agent_branch=agent_branch,
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
response_model=response_model,
|
||||
updates=updates,
|
||||
transport_protocol=transport_protocol,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
)
|
||||
try:
|
||||
result = await _aexecute_a2a_delegation_impl(
|
||||
endpoint=endpoint,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
task_description=task_description,
|
||||
context=context,
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
conversation_history=conversation_history,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
agent_branch=agent_branch,
|
||||
agent_id=agent_id,
|
||||
agent_role=agent_role,
|
||||
response_model=response_model,
|
||||
updates=updates,
|
||||
transport_protocol=transport_protocol,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
status="failed",
|
||||
result=None,
|
||||
error=str(e),
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
endpoint=endpoint,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
|
||||
crewai_event_bus.emit(
|
||||
|
||||
@@ -14,7 +14,14 @@ from typing import (
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agent.utils import (
|
||||
@@ -41,6 +48,7 @@ from crewai.events.types.knowledge_events import (
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalFailedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow
|
||||
@@ -77,17 +85,10 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
try:
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
except ImportError:
|
||||
A2AClientConfig = Any
|
||||
A2AConfig = Any
|
||||
A2AServerConfig = Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
from crewai.agents.agent_builder.base_agent import PlatformAppOrAction
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.task import Task
|
||||
@@ -133,6 +134,8 @@ class Agent(BaseAgent):
|
||||
mcps: List of MCP server references for tool integration.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
_mcp_clients: list[Any] = PrivateAttr(default_factory=list)
|
||||
_last_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
|
||||
@@ -346,30 +349,43 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
memory = ""
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context or "")
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
try:
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context or "")
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalCompletedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalCompletedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalFailedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = handle_knowledge_retrieval(
|
||||
@@ -555,32 +571,45 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
memory = ""
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = await contextual_memory.abuild_context_for_task(
|
||||
task, context or ""
|
||||
)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
try:
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = await contextual_memory.abuild_context_for_task(
|
||||
task, context or ""
|
||||
)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalCompletedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalCompletedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalFailedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = await ahandle_knowledge_retrieval(
|
||||
@@ -1669,3 +1698,22 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
return await lite_agent.kickoff_async(messages)
|
||||
|
||||
|
||||
# Rebuild Agent model to resolve A2A type forward references
|
||||
try:
|
||||
from crewai.a2a.config import (
|
||||
A2AClientConfig as _A2AClientConfig,
|
||||
A2AConfig as _A2AConfig,
|
||||
A2AServerConfig as _A2AServerConfig,
|
||||
)
|
||||
|
||||
Agent.model_rebuild(
|
||||
_types_namespace={
|
||||
"A2AConfig": _A2AConfig,
|
||||
"A2AClientConfig": _A2AClientConfig,
|
||||
"A2AServerConfig": _A2AServerConfig,
|
||||
}
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -75,6 +75,7 @@ from crewai.events.types.memory_events import (
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalFailedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
@@ -174,6 +175,7 @@ __all__ = [
|
||||
"MemoryQueryFailedEvent",
|
||||
"MemoryQueryStartedEvent",
|
||||
"MemoryRetrievalCompletedEvent",
|
||||
"MemoryRetrievalFailedEvent",
|
||||
"MemoryRetrievalStartedEvent",
|
||||
"MemorySaveCompletedEvent",
|
||||
"MemorySaveFailedEvent",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Iterator
|
||||
import contextvars
|
||||
from datetime import datetime, timezone
|
||||
import itertools
|
||||
from typing import Any
|
||||
@@ -9,27 +10,37 @@ from pydantic import BaseModel, Field
|
||||
from crewai.utilities.serialization import Serializable, to_serializable
|
||||
|
||||
|
||||
_emission_counter: Iterator[int] = itertools.count(start=1)
|
||||
_emission_counter: contextvars.ContextVar[Iterator[int]] = contextvars.ContextVar(
|
||||
"_emission_counter"
|
||||
)
|
||||
|
||||
|
||||
def _get_or_create_counter() -> Iterator[int]:
|
||||
"""Get the emission counter for the current context, creating if needed."""
|
||||
try:
|
||||
return _emission_counter.get()
|
||||
except LookupError:
|
||||
counter: Iterator[int] = itertools.count(start=1)
|
||||
_emission_counter.set(counter)
|
||||
return counter
|
||||
|
||||
|
||||
def get_next_emission_sequence() -> int:
|
||||
"""Get the next emission sequence number.
|
||||
|
||||
Thread-safe due to atomic next() on itertools.count under the GIL.
|
||||
|
||||
Returns:
|
||||
The next sequence number.
|
||||
"""
|
||||
return next(_emission_counter)
|
||||
return next(_get_or_create_counter())
|
||||
|
||||
|
||||
def reset_emission_counter() -> None:
|
||||
"""Reset the emission sequence counter to 1.
|
||||
|
||||
Useful for test isolation.
|
||||
Resets for the current context only.
|
||||
"""
|
||||
global _emission_counter
|
||||
_emission_counter = itertools.count(start=1)
|
||||
counter: Iterator[int] = itertools.count(start=1)
|
||||
_emission_counter.set(counter)
|
||||
|
||||
|
||||
class BaseEvent(BaseModel):
|
||||
|
||||
@@ -420,7 +420,7 @@ class CrewAIEventsBus:
|
||||
|
||||
return None
|
||||
|
||||
def flush(self, timeout: float | None = None) -> bool:
|
||||
def flush(self, timeout: float | None = 30.0) -> bool:
|
||||
"""Block until all pending event handlers complete.
|
||||
|
||||
This method waits for all futures from previously emitted events to
|
||||
@@ -429,7 +429,7 @@ class CrewAIEventsBus:
|
||||
|
||||
Args:
|
||||
timeout: Maximum time in seconds to wait for handlers to complete.
|
||||
If None, waits indefinitely.
|
||||
Defaults to 30 seconds. Pass None to wait indefinitely.
|
||||
|
||||
Returns:
|
||||
True if all handlers completed, False if timeout occurred.
|
||||
|
||||
@@ -192,6 +192,7 @@ SCOPE_ENDING_EVENTS: frozenset[str] = frozenset(
|
||||
"mcp_tool_execution_completed",
|
||||
"mcp_tool_execution_failed",
|
||||
"memory_retrieval_completed",
|
||||
"memory_retrieval_failed",
|
||||
"memory_save_completed",
|
||||
"memory_save_failed",
|
||||
"memory_query_completed",
|
||||
@@ -241,6 +242,7 @@ VALID_EVENT_PAIRS: dict[str, str] = {
|
||||
"mcp_tool_execution_completed": "mcp_tool_execution_started",
|
||||
"mcp_tool_execution_failed": "mcp_tool_execution_started",
|
||||
"memory_retrieval_completed": "memory_retrieval_started",
|
||||
"memory_retrieval_failed": "memory_retrieval_started",
|
||||
"memory_save_completed": "memory_save_started",
|
||||
"memory_save_failed": "memory_save_started",
|
||||
"memory_query_completed": "memory_query_started",
|
||||
|
||||
@@ -79,6 +79,7 @@ from crewai.events.types.memory_events import (
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalFailedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
@@ -173,6 +174,7 @@ EventTypes = (
|
||||
| MemoryQueryFailedEvent
|
||||
| MemoryRetrievalStartedEvent
|
||||
| MemoryRetrievalCompletedEvent
|
||||
| MemoryRetrievalFailedEvent
|
||||
| MCPConnectionStartedEvent
|
||||
| MCPConnectionCompletedEvent
|
||||
| MCPConnectionFailedEvent
|
||||
|
||||
@@ -14,7 +14,7 @@ class MemoryBaseEvent(BaseEvent):
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
@@ -93,3 +93,11 @@ class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
||||
task_id: str | None = None
|
||||
memory_content: str
|
||||
retrieval_time_ms: float
|
||||
|
||||
|
||||
class MemoryRetrievalFailedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when memory retrieval for a task prompt fails."""
|
||||
|
||||
type: str = "memory_retrieval_failed"
|
||||
task_id: str | None = None
|
||||
error: str
|
||||
|
||||
Reference in New Issue
Block a user