diff --git a/lib/crewai/src/crewai/a2a/extensions/a2ui/client_extension.py b/lib/crewai/src/crewai/a2a/extensions/a2ui/client_extension.py index 90d214d6b..26332db4f 100644 --- a/lib/crewai/src/crewai/a2a/extensions/a2ui/client_extension.py +++ b/lib/crewai/src/crewai/a2a/extensions/a2ui/client_extension.py @@ -3,9 +3,11 @@ from __future__ import annotations from collections.abc import Sequence -from dataclasses import dataclass, field import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict, cast + +from pydantic import Field +from pydantic.dataclasses import dataclass from crewai.a2a.extensions.a2ui.models import extract_a2ui_json_objects from crewai.a2a.extensions.a2ui.prompt import build_a2ui_system_prompt @@ -25,13 +27,77 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class StylesDict(TypedDict, total=False): + """Serialized surface styling.""" + + font: str + primaryColor: str + + +class ComponentEntryDict(TypedDict, total=False): + """Serialized component entry in a surface update.""" + + id: str + weight: float + component: dict[str, Any] + + +class BeginRenderingDict(TypedDict, total=False): + """Serialized beginRendering payload.""" + + surfaceId: str + root: str + catalogId: str + styles: StylesDict + + +class SurfaceUpdateDict(TypedDict, total=False): + """Serialized surfaceUpdate payload.""" + + surfaceId: str + components: list[ComponentEntryDict] + + +class DataEntryDict(TypedDict, total=False): + """Serialized data model entry.""" + + key: str + valueString: str + valueNumber: float + valueBoolean: bool + valueMap: list[DataEntryDict] + + +class DataModelUpdateDict(TypedDict, total=False): + """Serialized dataModelUpdate payload.""" + + surfaceId: str + path: str + contents: list[DataEntryDict] + + +class DeleteSurfaceDict(TypedDict): + """Serialized deleteSurface payload.""" + + surfaceId: str + + +class A2UIMessageDict(TypedDict, total=False): + """Serialized A2UI server-to-client message with exactly one key set.""" + + beginRendering: BeginRenderingDict + surfaceUpdate: SurfaceUpdateDict + dataModelUpdate: DataModelUpdateDict + deleteSurface: DeleteSurfaceDict + + @dataclass class A2UIConversationState: """Tracks active A2UI surfaces and data models across a conversation.""" - active_surfaces: dict[str, dict[str, Any]] = field(default_factory=dict) - data_models: dict[str, list[dict[str, Any]]] = field(default_factory=dict) - last_a2ui_messages: list[dict[str, Any]] = field(default_factory=list) + active_surfaces: dict[str, dict[str, Any]] = Field(default_factory=dict) + data_models: dict[str, list[dict[str, Any]]] = Field(default_factory=dict) + last_a2ui_messages: list[A2UIMessageDict] = Field(default_factory=list) def is_ready(self) -> bool: """Return True when at least one surface is active.""" @@ -74,12 +140,13 @@ class A2UIClientExtension: def extract_state_from_history( self, conversation_history: Sequence[Message] ) -> A2UIConversationState | None: - """Scan conversation history for A2UI DataParts and track surface state.""" + """Scan conversation history for A2UI DataParts and track surface state. + + When ``catalog_id`` is set, only surfaces matching that catalog are tracked. + """ state = A2UIConversationState() for message in conversation_history: - if not _has_parts(message): - continue for part in message.parts: root = part.root if root.kind != "data": @@ -97,6 +164,11 @@ class A2UIClientExtension: if not surface_id: continue + if self._catalog_id and "beginRendering" in data: + catalog_id = data["beginRendering"].get("catalogId") + if catalog_id and catalog_id != self._catalog_id: + continue + if "deleteSurface" in data: state.active_surfaces.pop(surface_id, None) state.data_models.pop(surface_id, None) @@ -115,7 +187,7 @@ class A2UIClientExtension: def augment_prompt( self, base_prompt: str, - conversation_state: A2UIConversationState | None, + _conversation_state: A2UIConversationState | None, ) -> str: """Append A2UI system prompt instructions to the base prompt.""" a2ui_prompt = build_a2ui_system_prompt( @@ -131,26 +203,25 @@ class A2UIClientExtension: ) -> Any: """Extract and validate A2UI JSON from agent output. - Stores extracted A2UI messages on the conversation state and returns - the original response unchanged to preserve the AgentResponseProtocol - contract. + When ``allowed_components`` is set, components not in the allowlist are + logged and stripped from surface updates. Stores extracted A2UI messages + on the conversation state and returns the original response unchanged. """ text = ( agent_response if isinstance(agent_response, str) else str(agent_response) ) a2ui_messages = _extract_and_validate(text) + if self._allowed_components: + allowed = set(self._allowed_components) + a2ui_messages = [_filter_components(msg, allowed) for msg in a2ui_messages] + if a2ui_messages and conversation_state is not None: conversation_state.last_a2ui_messages = a2ui_messages return agent_response -def _has_parts(message: Any) -> bool: - """Check if a message has a parts attribute.""" - return isinstance(getattr(message, "parts", None), list) - - def _get_surface_id(data: dict[str, Any]) -> str | None: """Extract surfaceId from any A2UI message type.""" for key in ("beginRendering", "surfaceUpdate", "dataModelUpdate", "deleteSurface"): @@ -162,7 +233,36 @@ def _get_surface_id(data: dict[str, Any]) -> str | None: return None -def _extract_and_validate(text: str) -> list[dict[str, Any]]: +def _filter_components(msg: A2UIMessageDict, allowed: set[str]) -> A2UIMessageDict: + """Strip components whose type is not in *allowed* from a surfaceUpdate.""" + surface_update = msg.get("surfaceUpdate") + if not isinstance(surface_update, dict): + return msg + + components = surface_update.get("components") + if not isinstance(components, list): + return msg + + filtered = [] + for entry in components: + component = entry.get("component", {}) + component_types = set(component.keys()) + disallowed = component_types - allowed + if disallowed: + logger.debug( + "Stripping disallowed component type(s) %s from surface update", + disallowed, + ) + continue + filtered.append(entry) + + if len(filtered) == len(components): + return msg + + return {**msg, "surfaceUpdate": {**surface_update, "components": filtered}} + + +def _extract_and_validate(text: str) -> list[A2UIMessageDict]: """Extract A2UI JSON objects from text and validate them.""" return [ dumped @@ -171,7 +271,7 @@ def _extract_and_validate(text: str) -> list[dict[str, Any]]: ] -def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None: +def _try_validate(candidate: dict[str, Any]) -> A2UIMessageDict | None: """Validate a single A2UI candidate, returning None on failure.""" try: msg = validate_a2ui_message(candidate) @@ -181,4 +281,4 @@ def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None: exc_info=True, ) return None - return msg.model_dump(by_alias=True, exclude_none=True) + return cast(A2UIMessageDict, msg.model_dump(by_alias=True, exclude_none=True))