refactor: improve typing and use instance state in A2UI client extension

This commit is contained in:
Greyson Lalonde
2026-03-14 16:38:05 -04:00
parent afb6cbbb6e
commit d2e74fc0be

View File

@@ -3,9 +3,11 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass, field
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypedDict, cast
from pydantic import Field
from pydantic.dataclasses import dataclass
from crewai.a2a.extensions.a2ui.models import extract_a2ui_json_objects
from crewai.a2a.extensions.a2ui.prompt import build_a2ui_system_prompt
@@ -25,13 +27,77 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class StylesDict(TypedDict, total=False):
"""Serialized surface styling."""
font: str
primaryColor: str
class ComponentEntryDict(TypedDict, total=False):
"""Serialized component entry in a surface update."""
id: str
weight: float
component: dict[str, Any]
class BeginRenderingDict(TypedDict, total=False):
"""Serialized beginRendering payload."""
surfaceId: str
root: str
catalogId: str
styles: StylesDict
class SurfaceUpdateDict(TypedDict, total=False):
"""Serialized surfaceUpdate payload."""
surfaceId: str
components: list[ComponentEntryDict]
class DataEntryDict(TypedDict, total=False):
"""Serialized data model entry."""
key: str
valueString: str
valueNumber: float
valueBoolean: bool
valueMap: list[DataEntryDict]
class DataModelUpdateDict(TypedDict, total=False):
"""Serialized dataModelUpdate payload."""
surfaceId: str
path: str
contents: list[DataEntryDict]
class DeleteSurfaceDict(TypedDict):
"""Serialized deleteSurface payload."""
surfaceId: str
class A2UIMessageDict(TypedDict, total=False):
"""Serialized A2UI server-to-client message with exactly one key set."""
beginRendering: BeginRenderingDict
surfaceUpdate: SurfaceUpdateDict
dataModelUpdate: DataModelUpdateDict
deleteSurface: DeleteSurfaceDict
@dataclass
class A2UIConversationState:
"""Tracks active A2UI surfaces and data models across a conversation."""
active_surfaces: dict[str, dict[str, Any]] = field(default_factory=dict)
data_models: dict[str, list[dict[str, Any]]] = field(default_factory=dict)
last_a2ui_messages: list[dict[str, Any]] = field(default_factory=list)
active_surfaces: dict[str, dict[str, Any]] = Field(default_factory=dict)
data_models: dict[str, list[dict[str, Any]]] = Field(default_factory=dict)
last_a2ui_messages: list[A2UIMessageDict] = Field(default_factory=list)
def is_ready(self) -> bool:
"""Return True when at least one surface is active."""
@@ -74,12 +140,13 @@ class A2UIClientExtension:
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> A2UIConversationState | None:
"""Scan conversation history for A2UI DataParts and track surface state."""
"""Scan conversation history for A2UI DataParts and track surface state.
When ``catalog_id`` is set, only surfaces matching that catalog are tracked.
"""
state = A2UIConversationState()
for message in conversation_history:
if not _has_parts(message):
continue
for part in message.parts:
root = part.root
if root.kind != "data":
@@ -97,6 +164,11 @@ class A2UIClientExtension:
if not surface_id:
continue
if self._catalog_id and "beginRendering" in data:
catalog_id = data["beginRendering"].get("catalogId")
if catalog_id and catalog_id != self._catalog_id:
continue
if "deleteSurface" in data:
state.active_surfaces.pop(surface_id, None)
state.data_models.pop(surface_id, None)
@@ -115,7 +187,7 @@ class A2UIClientExtension:
def augment_prompt(
self,
base_prompt: str,
conversation_state: A2UIConversationState | None,
_conversation_state: A2UIConversationState | None,
) -> str:
"""Append A2UI system prompt instructions to the base prompt."""
a2ui_prompt = build_a2ui_system_prompt(
@@ -131,26 +203,25 @@ class A2UIClientExtension:
) -> Any:
"""Extract and validate A2UI JSON from agent output.
Stores extracted A2UI messages on the conversation state and returns
the original response unchanged to preserve the AgentResponseProtocol
contract.
When ``allowed_components`` is set, components not in the allowlist are
logged and stripped from surface updates. Stores extracted A2UI messages
on the conversation state and returns the original response unchanged.
"""
text = (
agent_response if isinstance(agent_response, str) else str(agent_response)
)
a2ui_messages = _extract_and_validate(text)
if self._allowed_components:
allowed = set(self._allowed_components)
a2ui_messages = [_filter_components(msg, allowed) for msg in a2ui_messages]
if a2ui_messages and conversation_state is not None:
conversation_state.last_a2ui_messages = a2ui_messages
return agent_response
def _has_parts(message: Any) -> bool:
"""Check if a message has a parts attribute."""
return isinstance(getattr(message, "parts", None), list)
def _get_surface_id(data: dict[str, Any]) -> str | None:
"""Extract surfaceId from any A2UI message type."""
for key in ("beginRendering", "surfaceUpdate", "dataModelUpdate", "deleteSurface"):
@@ -162,7 +233,36 @@ def _get_surface_id(data: dict[str, Any]) -> str | None:
return None
def _extract_and_validate(text: str) -> list[dict[str, Any]]:
def _filter_components(msg: A2UIMessageDict, allowed: set[str]) -> A2UIMessageDict:
"""Strip components whose type is not in *allowed* from a surfaceUpdate."""
surface_update = msg.get("surfaceUpdate")
if not isinstance(surface_update, dict):
return msg
components = surface_update.get("components")
if not isinstance(components, list):
return msg
filtered = []
for entry in components:
component = entry.get("component", {})
component_types = set(component.keys())
disallowed = component_types - allowed
if disallowed:
logger.debug(
"Stripping disallowed component type(s) %s from surface update",
disallowed,
)
continue
filtered.append(entry)
if len(filtered) == len(components):
return msg
return {**msg, "surfaceUpdate": {**surface_update, "components": filtered}}
def _extract_and_validate(text: str) -> list[A2UIMessageDict]:
"""Extract A2UI JSON objects from text and validate them."""
return [
dumped
@@ -171,7 +271,7 @@ def _extract_and_validate(text: str) -> list[dict[str, Any]]:
]
def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None:
def _try_validate(candidate: dict[str, Any]) -> A2UIMessageDict | None:
"""Validate a single A2UI candidate, returning None on failure."""
try:
msg = validate_a2ui_message(candidate)
@@ -181,4 +281,4 @@ def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None:
exc_info=True,
)
return None
return msg.model_dump(by_alias=True, exclude_none=True)
return cast(A2UIMessageDict, msg.model_dump(by_alias=True, exclude_none=True))