mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
refactor: improve typing and use instance state in A2UI client extension
This commit is contained in:
@@ -3,9 +3,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from crewai.a2a.extensions.a2ui.models import extract_a2ui_json_objects
|
from crewai.a2a.extensions.a2ui.models import extract_a2ui_json_objects
|
||||||
from crewai.a2a.extensions.a2ui.prompt import build_a2ui_system_prompt
|
from crewai.a2a.extensions.a2ui.prompt import build_a2ui_system_prompt
|
||||||
@@ -25,13 +27,77 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StylesDict(TypedDict, total=False):
|
||||||
|
"""Serialized surface styling."""
|
||||||
|
|
||||||
|
font: str
|
||||||
|
primaryColor: str
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentEntryDict(TypedDict, total=False):
|
||||||
|
"""Serialized component entry in a surface update."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
weight: float
|
||||||
|
component: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class BeginRenderingDict(TypedDict, total=False):
|
||||||
|
"""Serialized beginRendering payload."""
|
||||||
|
|
||||||
|
surfaceId: str
|
||||||
|
root: str
|
||||||
|
catalogId: str
|
||||||
|
styles: StylesDict
|
||||||
|
|
||||||
|
|
||||||
|
class SurfaceUpdateDict(TypedDict, total=False):
|
||||||
|
"""Serialized surfaceUpdate payload."""
|
||||||
|
|
||||||
|
surfaceId: str
|
||||||
|
components: list[ComponentEntryDict]
|
||||||
|
|
||||||
|
|
||||||
|
class DataEntryDict(TypedDict, total=False):
|
||||||
|
"""Serialized data model entry."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
valueString: str
|
||||||
|
valueNumber: float
|
||||||
|
valueBoolean: bool
|
||||||
|
valueMap: list[DataEntryDict]
|
||||||
|
|
||||||
|
|
||||||
|
class DataModelUpdateDict(TypedDict, total=False):
|
||||||
|
"""Serialized dataModelUpdate payload."""
|
||||||
|
|
||||||
|
surfaceId: str
|
||||||
|
path: str
|
||||||
|
contents: list[DataEntryDict]
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteSurfaceDict(TypedDict):
|
||||||
|
"""Serialized deleteSurface payload."""
|
||||||
|
|
||||||
|
surfaceId: str
|
||||||
|
|
||||||
|
|
||||||
|
class A2UIMessageDict(TypedDict, total=False):
|
||||||
|
"""Serialized A2UI server-to-client message with exactly one key set."""
|
||||||
|
|
||||||
|
beginRendering: BeginRenderingDict
|
||||||
|
surfaceUpdate: SurfaceUpdateDict
|
||||||
|
dataModelUpdate: DataModelUpdateDict
|
||||||
|
deleteSurface: DeleteSurfaceDict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class A2UIConversationState:
|
class A2UIConversationState:
|
||||||
"""Tracks active A2UI surfaces and data models across a conversation."""
|
"""Tracks active A2UI surfaces and data models across a conversation."""
|
||||||
|
|
||||||
active_surfaces: dict[str, dict[str, Any]] = field(default_factory=dict)
|
active_surfaces: dict[str, dict[str, Any]] = Field(default_factory=dict)
|
||||||
data_models: dict[str, list[dict[str, Any]]] = field(default_factory=dict)
|
data_models: dict[str, list[dict[str, Any]]] = Field(default_factory=dict)
|
||||||
last_a2ui_messages: list[dict[str, Any]] = field(default_factory=list)
|
last_a2ui_messages: list[A2UIMessageDict] = Field(default_factory=list)
|
||||||
|
|
||||||
def is_ready(self) -> bool:
|
def is_ready(self) -> bool:
|
||||||
"""Return True when at least one surface is active."""
|
"""Return True when at least one surface is active."""
|
||||||
@@ -74,12 +140,13 @@ class A2UIClientExtension:
|
|||||||
def extract_state_from_history(
|
def extract_state_from_history(
|
||||||
self, conversation_history: Sequence[Message]
|
self, conversation_history: Sequence[Message]
|
||||||
) -> A2UIConversationState | None:
|
) -> A2UIConversationState | None:
|
||||||
"""Scan conversation history for A2UI DataParts and track surface state."""
|
"""Scan conversation history for A2UI DataParts and track surface state.
|
||||||
|
|
||||||
|
When ``catalog_id`` is set, only surfaces matching that catalog are tracked.
|
||||||
|
"""
|
||||||
state = A2UIConversationState()
|
state = A2UIConversationState()
|
||||||
|
|
||||||
for message in conversation_history:
|
for message in conversation_history:
|
||||||
if not _has_parts(message):
|
|
||||||
continue
|
|
||||||
for part in message.parts:
|
for part in message.parts:
|
||||||
root = part.root
|
root = part.root
|
||||||
if root.kind != "data":
|
if root.kind != "data":
|
||||||
@@ -97,6 +164,11 @@ class A2UIClientExtension:
|
|||||||
if not surface_id:
|
if not surface_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if self._catalog_id and "beginRendering" in data:
|
||||||
|
catalog_id = data["beginRendering"].get("catalogId")
|
||||||
|
if catalog_id and catalog_id != self._catalog_id:
|
||||||
|
continue
|
||||||
|
|
||||||
if "deleteSurface" in data:
|
if "deleteSurface" in data:
|
||||||
state.active_surfaces.pop(surface_id, None)
|
state.active_surfaces.pop(surface_id, None)
|
||||||
state.data_models.pop(surface_id, None)
|
state.data_models.pop(surface_id, None)
|
||||||
@@ -115,7 +187,7 @@ class A2UIClientExtension:
|
|||||||
def augment_prompt(
|
def augment_prompt(
|
||||||
self,
|
self,
|
||||||
base_prompt: str,
|
base_prompt: str,
|
||||||
conversation_state: A2UIConversationState | None,
|
_conversation_state: A2UIConversationState | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Append A2UI system prompt instructions to the base prompt."""
|
"""Append A2UI system prompt instructions to the base prompt."""
|
||||||
a2ui_prompt = build_a2ui_system_prompt(
|
a2ui_prompt = build_a2ui_system_prompt(
|
||||||
@@ -131,26 +203,25 @@ class A2UIClientExtension:
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Extract and validate A2UI JSON from agent output.
|
"""Extract and validate A2UI JSON from agent output.
|
||||||
|
|
||||||
Stores extracted A2UI messages on the conversation state and returns
|
When ``allowed_components`` is set, components not in the allowlist are
|
||||||
the original response unchanged to preserve the AgentResponseProtocol
|
logged and stripped from surface updates. Stores extracted A2UI messages
|
||||||
contract.
|
on the conversation state and returns the original response unchanged.
|
||||||
"""
|
"""
|
||||||
text = (
|
text = (
|
||||||
agent_response if isinstance(agent_response, str) else str(agent_response)
|
agent_response if isinstance(agent_response, str) else str(agent_response)
|
||||||
)
|
)
|
||||||
a2ui_messages = _extract_and_validate(text)
|
a2ui_messages = _extract_and_validate(text)
|
||||||
|
|
||||||
|
if self._allowed_components:
|
||||||
|
allowed = set(self._allowed_components)
|
||||||
|
a2ui_messages = [_filter_components(msg, allowed) for msg in a2ui_messages]
|
||||||
|
|
||||||
if a2ui_messages and conversation_state is not None:
|
if a2ui_messages and conversation_state is not None:
|
||||||
conversation_state.last_a2ui_messages = a2ui_messages
|
conversation_state.last_a2ui_messages = a2ui_messages
|
||||||
|
|
||||||
return agent_response
|
return agent_response
|
||||||
|
|
||||||
|
|
||||||
def _has_parts(message: Any) -> bool:
|
|
||||||
"""Check if a message has a parts attribute."""
|
|
||||||
return isinstance(getattr(message, "parts", None), list)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_surface_id(data: dict[str, Any]) -> str | None:
|
def _get_surface_id(data: dict[str, Any]) -> str | None:
|
||||||
"""Extract surfaceId from any A2UI message type."""
|
"""Extract surfaceId from any A2UI message type."""
|
||||||
for key in ("beginRendering", "surfaceUpdate", "dataModelUpdate", "deleteSurface"):
|
for key in ("beginRendering", "surfaceUpdate", "dataModelUpdate", "deleteSurface"):
|
||||||
@@ -162,7 +233,36 @@ def _get_surface_id(data: dict[str, Any]) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_and_validate(text: str) -> list[dict[str, Any]]:
|
def _filter_components(msg: A2UIMessageDict, allowed: set[str]) -> A2UIMessageDict:
|
||||||
|
"""Strip components whose type is not in *allowed* from a surfaceUpdate."""
|
||||||
|
surface_update = msg.get("surfaceUpdate")
|
||||||
|
if not isinstance(surface_update, dict):
|
||||||
|
return msg
|
||||||
|
|
||||||
|
components = surface_update.get("components")
|
||||||
|
if not isinstance(components, list):
|
||||||
|
return msg
|
||||||
|
|
||||||
|
filtered = []
|
||||||
|
for entry in components:
|
||||||
|
component = entry.get("component", {})
|
||||||
|
component_types = set(component.keys())
|
||||||
|
disallowed = component_types - allowed
|
||||||
|
if disallowed:
|
||||||
|
logger.debug(
|
||||||
|
"Stripping disallowed component type(s) %s from surface update",
|
||||||
|
disallowed,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
filtered.append(entry)
|
||||||
|
|
||||||
|
if len(filtered) == len(components):
|
||||||
|
return msg
|
||||||
|
|
||||||
|
return {**msg, "surfaceUpdate": {**surface_update, "components": filtered}}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_and_validate(text: str) -> list[A2UIMessageDict]:
|
||||||
"""Extract A2UI JSON objects from text and validate them."""
|
"""Extract A2UI JSON objects from text and validate them."""
|
||||||
return [
|
return [
|
||||||
dumped
|
dumped
|
||||||
@@ -171,7 +271,7 @@ def _extract_and_validate(text: str) -> list[dict[str, Any]]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None:
|
def _try_validate(candidate: dict[str, Any]) -> A2UIMessageDict | None:
|
||||||
"""Validate a single A2UI candidate, returning None on failure."""
|
"""Validate a single A2UI candidate, returning None on failure."""
|
||||||
try:
|
try:
|
||||||
msg = validate_a2ui_message(candidate)
|
msg = validate_a2ui_message(candidate)
|
||||||
@@ -181,4 +281,4 @@ def _try_validate(candidate: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
return msg.model_dump(by_alias=True, exclude_none=True)
|
return cast(A2UIMessageDict, msg.model_dump(by_alias=True, exclude_none=True))
|
||||||
|
|||||||
Reference in New Issue
Block a user