mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-07 03:28:29 +00:00
refactor: improve typing and use instance state in A2UI client extension
This commit is contained in:
@@ -3,9 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from crewai.a2a.extensions.a2ui.models import extract_a2ui_json_objects
|
||||
from crewai.a2a.extensions.a2ui.prompt import build_a2ui_system_prompt
|
||||
@@ -25,13 +27,77 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StylesDict(TypedDict, total=False):
|
||||
"""Serialized surface styling."""
|
||||
|
||||
font: str
|
||||
primaryColor: str
|
||||
|
||||
|
||||
class ComponentEntryDict(TypedDict, total=False):
|
||||
"""Serialized component entry in a surface update."""
|
||||
|
||||
id: str
|
||||
weight: float
|
||||
component: dict[str, Any]
|
||||
|
||||
|
||||
class BeginRenderingDict(TypedDict, total=False):
|
||||
"""Serialized beginRendering payload."""
|
||||
|
||||
surfaceId: str
|
||||
root: str
|
||||
catalogId: str
|
||||
styles: StylesDict
|
||||
|
||||
|
||||
class SurfaceUpdateDict(TypedDict, total=False):
|
||||
"""Serialized surfaceUpdate payload."""
|
||||
|
||||
surfaceId: str
|
||||
components: list[ComponentEntryDict]
|
||||
|
||||
|
||||
class DataEntryDict(TypedDict, total=False):
|
||||
"""Serialized data model entry."""
|
||||
|
||||
key: str
|
||||
valueString: str
|
||||
valueNumber: float
|
||||
valueBoolean: bool
|
||||
valueMap: list[DataEntryDict]
|
||||
|
||||
|
||||
class DataModelUpdateDict(TypedDict, total=False):
|
||||
"""Serialized dataModelUpdate payload."""
|
||||
|
||||
surfaceId: str
|
||||
path: str
|
||||
contents: list[DataEntryDict]
|
||||
|
||||
|
||||
class DeleteSurfaceDict(TypedDict):
|
||||
"""Serialized deleteSurface payload."""
|
||||
|
||||
surfaceId: str
|
||||
|
||||
|
||||
class A2UIMessageDict(TypedDict, total=False):
|
||||
"""Serialized A2UI server-to-client message with exactly one key set."""
|
||||
|
||||
beginRendering: BeginRenderingDict
|
||||
surfaceUpdate: SurfaceUpdateDict
|
||||
dataModelUpdate: DataModelUpdateDict
|
||||
deleteSurface: DeleteSurfaceDict
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2UIConversationState:
|
||||
"""Tracks active A2UI surfaces and data models across a conversation."""
|
||||
|
||||
active_surfaces: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
data_models: dict[str, list[dict[str, Any]]] = field(default_factory=dict)
|
||||
last_a2ui_messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
active_surfaces: dict[str, dict[str, Any]] = Field(default_factory=dict)
|
||||
data_models: dict[str, list[dict[str, Any]]] = Field(default_factory=dict)
|
||||
last_a2ui_messages: list[A2UIMessageDict] = Field(default_factory=list)
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Return True when at least one surface is active."""
|
||||
@@ -74,12 +140,13 @@ class A2UIClientExtension:
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> A2UIConversationState | None:
|
||||
"""Scan conversation history for A2UI DataParts and track surface state."""
|
||||
"""Scan conversation history for A2UI DataParts and track surface state.
|
||||
|
||||
When ``catalog_id`` is set, only surfaces matching that catalog are tracked.
|
||||
"""
|
||||
state = A2UIConversationState()
|
||||
|
||||
for message in conversation_history:
|
||||
if not _has_parts(message):
|
||||
continue
|
||||
for part in message.parts:
|
||||
root = part.root
|
||||
if root.kind != "data":
|
||||
@@ -97,6 +164,11 @@ class A2UIClientExtension:
|
||||
if not surface_id:
|
||||
continue
|
||||
|
||||
if self._catalog_id and "beginRendering" in data:
|
||||
catalog_id = data["beginRendering"].get("catalogId")
|
||||
if catalog_id and catalog_id != self._catalog_id:
|
||||
continue
|
||||
|
||||
if "deleteSurface" in data:
|
||||
state.active_surfaces.pop(surface_id, None)
|
||||
state.data_models.pop(surface_id, None)
|
||||
@@ -115,7 +187,7 @@ class A2UIClientExtension:
|
||||
def augment_prompt(
|
||||
self,
|
||||
base_prompt: str,
|
||||
conversation_state: A2UIConversationState | None,
|
||||
_conversation_state: A2UIConversationState | None,
|
||||
) -> str:
|
||||
"""Append A2UI system prompt instructions to the base prompt."""
|
||||
a2ui_prompt = build_a2ui_system_prompt(
|
||||
@@ -131,26 +203,25 @@ class A2UIClientExtension:
|
||||
) -> Any:
|
||||
"""Extract and validate A2UI JSON from agent output.
|
||||
|
||||
Stores extracted A2UI messages on the conversation state and returns
|
||||
the original response unchanged to preserve the AgentResponseProtocol
|
||||
contract.
|
||||
When ``allowed_components`` is set, components not in the allowlist are
|
||||
logged and stripped from surface updates. Stores extracted A2UI messages
|
||||
on the conversation state and returns the original response unchanged.
|
||||
"""
|
||||
text = (
|
||||
agent_response if isinstance(agent_response, str) else str(agent_response)
|
||||
)
|
||||
a2ui_messages = _extract_and_validate(text)
|
||||
|
||||
if self._allowed_components:
|
||||
allowed = set(self._allowed_components)
|
||||
a2ui_messages = [_filter_components(msg, allowed) for msg in a2ui_messages]
|
||||
|
||||
if a2ui_messages and conversation_state is not None:
|
||||
conversation_state.last_a2ui_messages = a2ui_messages
|
||||
|
||||
return agent_response
|
||||
|
||||
|
||||
def _has_parts(message: Any) -> bool:
|
||||
"""Check if a message has a parts attribute."""
|
||||
return isinstance(getattr(message, "parts", None), list)
|
||||
|
||||
|
||||
def _get_surface_id(data: dict[str, Any]) -> str | None:
|
||||
"""Extract surfaceId from any A2UI message type."""
|
||||
for key in ("beginRendering", "surfaceUpdate", "dataModelUpdate", "deleteSurface"):
|
||||
@@ -162,7 +233,36 @@ def _get_surface_id(data: dict[str, Any]) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_and_validate(text: str) -> list[dict[str, Any]]:
|
||||
def _filter_components(msg: A2UIMessageDict, allowed: set[str]) -> A2UIMessageDict:
|
||||
"""Strip components whose type is not in *allowed* from a surfaceUpdate."""
|
||||
surface_update = msg.get("surfaceUpdate")
|
||||
if not isinstance(surface_update, dict):
|
||||
return msg
|
||||
|
||||
components = surface_update.get("components")
|
||||
if not isinstance(components, list):
|
||||
return msg
|
||||
|
||||
filtered = []
|
||||
for entry in components:
|
||||
component = entry.get("component", {})
|
||||
component_types = set(component.keys())
|
||||
disallowed = component_types - allowed
|
||||
if disallowed:
|
||||
logger.debug(
|
||||
"Stripping disallowed component type(s) %s from surface update",
|
||||
disallowed,
|
||||
)
|
||||
continue
|
||||
filtered.append(entry)
|
||||
|
||||
if len(filtered) == len(components):
|
||||
return msg
|
||||
|
||||
return {**msg, "surfaceUpdate": {**surface_update, "components": filtered}}
|
||||
|
||||
|
||||
def _extract_and_validate(text: str) -> list[A2UIMessageDict]:
|
||||
"""Extract A2UI JSON objects from text and validate them."""
|
||||
return [
|
||||
dumped
|
||||
@@ -171,7 +271,7 @@ def _extract_and_validate(text: str) -> list[dict[str, Any]]:
|
||||
]
|
||||
|
||||
|
||||
def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None:
|
||||
def _try_validate(candidate: dict[str, Any]) -> A2UIMessageDict | None:
|
||||
"""Validate a single A2UI candidate, returning None on failure."""
|
||||
try:
|
||||
msg = validate_a2ui_message(candidate)
|
||||
@@ -181,4 +281,4 @@ def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None:
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
return msg.model_dump(by_alias=True, exclude_none=True)
|
||||
return cast(A2UIMessageDict, msg.model_dump(by_alias=True, exclude_none=True))
|
||||
|
||||
Reference in New Issue
Block a user