refactor: improve typing and use instance state in A2UI client extension

This commit is contained in:
Greyson Lalonde
2026-03-14 16:38:05 -04:00
parent afb6cbbb6e
commit d2e74fc0be

View File

@@ -3,9 +3,11 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, TypedDict, cast
from pydantic import Field
from pydantic.dataclasses import dataclass
from crewai.a2a.extensions.a2ui.models import extract_a2ui_json_objects from crewai.a2a.extensions.a2ui.models import extract_a2ui_json_objects
from crewai.a2a.extensions.a2ui.prompt import build_a2ui_system_prompt from crewai.a2a.extensions.a2ui.prompt import build_a2ui_system_prompt
@@ -25,13 +27,77 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StylesDict(TypedDict, total=False):
"""Serialized surface styling."""
font: str
primaryColor: str
class ComponentEntryDict(TypedDict, total=False):
"""Serialized component entry in a surface update."""
id: str
weight: float
component: dict[str, Any]
class BeginRenderingDict(TypedDict, total=False):
"""Serialized beginRendering payload."""
surfaceId: str
root: str
catalogId: str
styles: StylesDict
class SurfaceUpdateDict(TypedDict, total=False):
"""Serialized surfaceUpdate payload."""
surfaceId: str
components: list[ComponentEntryDict]
class DataEntryDict(TypedDict, total=False):
"""Serialized data model entry."""
key: str
valueString: str
valueNumber: float
valueBoolean: bool
valueMap: list[DataEntryDict]
class DataModelUpdateDict(TypedDict, total=False):
"""Serialized dataModelUpdate payload."""
surfaceId: str
path: str
contents: list[DataEntryDict]
class DeleteSurfaceDict(TypedDict):
"""Serialized deleteSurface payload."""
surfaceId: str
class A2UIMessageDict(TypedDict, total=False):
"""Serialized A2UI server-to-client message with exactly one key set."""
beginRendering: BeginRenderingDict
surfaceUpdate: SurfaceUpdateDict
dataModelUpdate: DataModelUpdateDict
deleteSurface: DeleteSurfaceDict
@dataclass @dataclass
class A2UIConversationState: class A2UIConversationState:
"""Tracks active A2UI surfaces and data models across a conversation.""" """Tracks active A2UI surfaces and data models across a conversation."""
active_surfaces: dict[str, dict[str, Any]] = field(default_factory=dict) active_surfaces: dict[str, dict[str, Any]] = Field(default_factory=dict)
data_models: dict[str, list[dict[str, Any]]] = field(default_factory=dict) data_models: dict[str, list[dict[str, Any]]] = Field(default_factory=dict)
last_a2ui_messages: list[dict[str, Any]] = field(default_factory=list) last_a2ui_messages: list[A2UIMessageDict] = Field(default_factory=list)
def is_ready(self) -> bool: def is_ready(self) -> bool:
"""Return True when at least one surface is active.""" """Return True when at least one surface is active."""
@@ -74,12 +140,13 @@ class A2UIClientExtension:
def extract_state_from_history( def extract_state_from_history(
self, conversation_history: Sequence[Message] self, conversation_history: Sequence[Message]
) -> A2UIConversationState | None: ) -> A2UIConversationState | None:
"""Scan conversation history for A2UI DataParts and track surface state.""" """Scan conversation history for A2UI DataParts and track surface state.
When ``catalog_id`` is set, only surfaces matching that catalog are tracked.
"""
state = A2UIConversationState() state = A2UIConversationState()
for message in conversation_history: for message in conversation_history:
if not _has_parts(message):
continue
for part in message.parts: for part in message.parts:
root = part.root root = part.root
if root.kind != "data": if root.kind != "data":
@@ -97,6 +164,11 @@ class A2UIClientExtension:
if not surface_id: if not surface_id:
continue continue
if self._catalog_id and "beginRendering" in data:
catalog_id = data["beginRendering"].get("catalogId")
if catalog_id and catalog_id != self._catalog_id:
continue
if "deleteSurface" in data: if "deleteSurface" in data:
state.active_surfaces.pop(surface_id, None) state.active_surfaces.pop(surface_id, None)
state.data_models.pop(surface_id, None) state.data_models.pop(surface_id, None)
@@ -115,7 +187,7 @@ class A2UIClientExtension:
def augment_prompt( def augment_prompt(
self, self,
base_prompt: str, base_prompt: str,
conversation_state: A2UIConversationState | None, _conversation_state: A2UIConversationState | None,
) -> str: ) -> str:
"""Append A2UI system prompt instructions to the base prompt.""" """Append A2UI system prompt instructions to the base prompt."""
a2ui_prompt = build_a2ui_system_prompt( a2ui_prompt = build_a2ui_system_prompt(
@@ -131,26 +203,25 @@ class A2UIClientExtension:
) -> Any: ) -> Any:
"""Extract and validate A2UI JSON from agent output. """Extract and validate A2UI JSON from agent output.
Stores extracted A2UI messages on the conversation state and returns When ``allowed_components`` is set, components not in the allowlist are
the original response unchanged to preserve the AgentResponseProtocol logged and stripped from surface updates. Stores extracted A2UI messages
contract. on the conversation state and returns the original response unchanged.
""" """
text = ( text = (
agent_response if isinstance(agent_response, str) else str(agent_response) agent_response if isinstance(agent_response, str) else str(agent_response)
) )
a2ui_messages = _extract_and_validate(text) a2ui_messages = _extract_and_validate(text)
if self._allowed_components:
allowed = set(self._allowed_components)
a2ui_messages = [_filter_components(msg, allowed) for msg in a2ui_messages]
if a2ui_messages and conversation_state is not None: if a2ui_messages and conversation_state is not None:
conversation_state.last_a2ui_messages = a2ui_messages conversation_state.last_a2ui_messages = a2ui_messages
return agent_response return agent_response
def _has_parts(message: Any) -> bool:
"""Check if a message has a parts attribute."""
return isinstance(getattr(message, "parts", None), list)
def _get_surface_id(data: dict[str, Any]) -> str | None: def _get_surface_id(data: dict[str, Any]) -> str | None:
"""Extract surfaceId from any A2UI message type.""" """Extract surfaceId from any A2UI message type."""
for key in ("beginRendering", "surfaceUpdate", "dataModelUpdate", "deleteSurface"): for key in ("beginRendering", "surfaceUpdate", "dataModelUpdate", "deleteSurface"):
@@ -162,7 +233,36 @@ def _get_surface_id(data: dict[str, Any]) -> str | None:
return None return None
def _extract_and_validate(text: str) -> list[dict[str, Any]]: def _filter_components(msg: A2UIMessageDict, allowed: set[str]) -> A2UIMessageDict:
"""Strip components whose type is not in *allowed* from a surfaceUpdate."""
surface_update = msg.get("surfaceUpdate")
if not isinstance(surface_update, dict):
return msg
components = surface_update.get("components")
if not isinstance(components, list):
return msg
filtered = []
for entry in components:
component = entry.get("component", {})
component_types = set(component.keys())
disallowed = component_types - allowed
if disallowed:
logger.debug(
"Stripping disallowed component type(s) %s from surface update",
disallowed,
)
continue
filtered.append(entry)
if len(filtered) == len(components):
return msg
return {**msg, "surfaceUpdate": {**surface_update, "components": filtered}}
def _extract_and_validate(text: str) -> list[A2UIMessageDict]:
"""Extract A2UI JSON objects from text and validate them.""" """Extract A2UI JSON objects from text and validate them."""
return [ return [
dumped dumped
@@ -171,7 +271,7 @@ def _extract_and_validate(text: str) -> list[dict[str, Any]]:
] ]
def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None: def _try_validate(candidate: dict[str, Any]) -> A2UIMessageDict | None:
"""Validate a single A2UI candidate, returning None on failure.""" """Validate a single A2UI candidate, returning None on failure."""
try: try:
msg = validate_a2ui_message(candidate) msg = validate_a2ui_message(candidate)
@@ -181,4 +281,4 @@ def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None:
exc_info=True, exc_info=True,
) )
return None return None
return msg.model_dump(by_alias=True, exclude_none=True) return cast(A2UIMessageDict, msg.model_dump(by_alias=True, exclude_none=True))