feat: add pass additional data to a2a events

This commit is contained in:
Greyson LaLonde
2026-01-16 03:37:09 -05:00
parent b7434af0ce
commit a675327ffa
6 changed files with 62 additions and 6 deletions

View File

@@ -164,12 +164,11 @@ def process_task_state(
Returns:
Result dictionary if terminal/actionable state, None otherwise.
"""
should_extract = result_parts is None
if result_parts is None:
result_parts = []
if a2a_task.status.state == TaskState.completed:
if should_extract:
if not result_parts:
extracted_parts = extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
if a2a_task.history:
@@ -287,6 +286,7 @@ async def send_message_and_get_task_id(
from_agent: Any | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
context_id: str | None = None,
) -> str | TaskStateResult:
"""Send message and process initial response.
@@ -305,6 +305,7 @@ async def send_message_and_get_task_id(
from_agent: Optional CrewAI Agent object for event metadata.
endpoint: Optional A2A endpoint URL.
a2a_agent_name: Optional A2A agent name.
context_id: Optional A2A context ID for correlation.
Returns:
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
@@ -377,6 +378,7 @@ async def send_message_and_get_task_id(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
@@ -389,6 +391,7 @@ async def send_message_and_get_task_id(
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
@@ -398,6 +401,7 @@ async def send_message_and_get_task_id(
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
@@ -421,6 +425,7 @@ async def send_message_and_get_task_id(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
@@ -432,6 +437,7 @@ async def send_message_and_get_task_id(
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
@@ -441,6 +447,7 @@ async def send_message_and_get_task_id(
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,

View File

@@ -88,11 +88,12 @@ async def _poll_task_until_complete(
)
elapsed = time.monotonic() - start_time
effective_context_id = task.context_id or context_id
crewai_event_bus.emit(
agent_branch,
A2APollingStatusEvent(
task_id=task_id,
context_id=context_id,
context_id=effective_context_id,
state=str(task.status.state.value) if task.status.state else "unknown",
elapsed_seconds=elapsed,
poll_count=poll_count,
@@ -169,6 +170,7 @@ class PollingHandler:
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
)
if not isinstance(result_or_task_id, str):

View File

@@ -196,6 +196,7 @@ class PushNotificationHandler:
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
)
if not isinstance(result_or_task_id, str):

View File

@@ -93,6 +93,7 @@ class StreamingHandler:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
message_context_id = event.context_id or context_id
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
@@ -100,8 +101,8 @@ class StreamingHandler:
crewai_event_bus.emit(
agent_branch,
A2AStreamingChunkEvent(
task_id=task_id,
context_id=context_id,
task_id=event.task_id or task_id,
context_id=message_context_id,
chunk=text,
chunk_index=chunk_index,
endpoint=endpoint,
@@ -132,6 +133,7 @@ class StreamingHandler:
else len(getattr(p.root, "data", b""))
for p in artifact.parts
)
effective_context_id = a2a_task.context_id or context_id
crewai_event_bus.emit(
agent_branch,
A2AArtifactReceivedEvent(
@@ -147,7 +149,7 @@ class StreamingHandler:
last_chunk=update.last_chunk or False,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
context_id=effective_context_id,
turn_number=turn_number,
is_multiturn=is_multiturn,
from_task=from_task,
@@ -158,6 +160,16 @@ class StreamingHandler:
is_final_update = False
if isinstance(update, TaskStatusUpdateEvent):
is_final_update = update.final
if (
update.status
and update.status.message
and update.status.message.parts
):
result_parts.extend(
part.root.text
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
)
if (
not is_final_update

View File

@@ -8,6 +8,8 @@ from crewai.events.types.a2a_events import (
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
A2AParallelDelegationCompletedEvent,
A2AParallelDelegationStartedEvent,
A2APollingStartedEvent,
A2APollingStatusEvent,
A2APushNotificationReceivedEvent,
@@ -122,6 +124,8 @@ EventTypes = (
| A2AServerTaskStartedEvent
| A2AStreamingChunkEvent
| A2AStreamingStartedEvent
| A2AParallelDelegationStartedEvent
| A2AParallelDelegationCompletedEvent
| CrewKickoffStartedEvent
| CrewKickoffCompletedEvent
| CrewKickoffFailedEvent

View File

@@ -624,3 +624,33 @@ class A2AServerTaskFailedEvent(A2AEventBase):
context_id: str
error: str
metadata: dict[str, Any] | None = None
class A2AParallelDelegationStartedEvent(A2AEventBase):
"""Event emitted when parallel delegation to multiple A2A agents begins.
Attributes:
endpoints: List of A2A agent endpoints being delegated to.
task_description: Description of the task being delegated.
"""
type: str = "a2a_parallel_delegation_started"
endpoints: list[str]
task_description: str
class A2AParallelDelegationCompletedEvent(A2AEventBase):
"""Event emitted when parallel delegation to multiple A2A agents completes.
Attributes:
endpoints: List of A2A agent endpoints that were delegated to.
success_count: Number of successful delegations.
failure_count: Number of failed delegations.
results: Summary of results from each agent.
"""
type: str = "a2a_parallel_delegation_completed"
endpoints: list[str]
success_count: int
failure_count: int
results: dict[str, str] | None = None