feat: enhance ChatTextArea with @mention autocomplete and improve UI feedback

- Updated ChatTextArea to support @mention autocomplete using Tab for completion.
- Added MentionChanged message to handle autocomplete state changes.
- Improved user experience by displaying a hint for available mentions.
- Enhanced error handling in AgentTUI for agent message timeouts.
- Updated rendering logic to ensure proper display of system messages with Rich markup.
This commit is contained in:
Joao Moura
2026-05-13 17:17:55 -04:00
parent 84568860c3
commit 8f3196e1cf

View File

@@ -71,7 +71,7 @@ except ImportError:
class ChatTextArea(TextArea):
"""Multiline chat input: Enter submits, Shift+Enter inserts a newline."""
"""Multiline chat input: Enter submits, Shift+Enter inserts newline, Tab completes @mentions."""
BINDINGS = [
Binding("enter", "submit", "Send", show=False),
@@ -85,10 +85,72 @@ class ChatTextArea(TextArea):
self.text_area = text_area
self.value = value
class MentionChanged(Message):
"""Posted when @mention autocomplete state changes."""
def __init__(self, prefix: str, matches: list[str]) -> None:
super().__init__()
self.prefix = prefix
self.matches = matches
def __init__(self, agent_names: list[str] | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._agent_names = agent_names or []
self._last_mention_prefix: str | None = None
def _get_mention_context(self) -> tuple[int, int, str] | None:
"""Return (row, at_col, partial) if cursor is inside an @mention."""
row, col = self.cursor_location
lines = self.text.split("\n")
if row >= len(lines):
return None
line_to_cursor = lines[row][:col]
at_idx = line_to_cursor.rfind("@")
if at_idx == -1:
return None
after = line_to_cursor[at_idx + 1 :]
if " " in after:
return None
return row, at_idx, after.lower()
def _get_matches(self, prefix: str) -> list[str]:
if not prefix:
return self._agent_names[:]
return [n for n in self._agent_names if n.lower().startswith(prefix)]
def _emit_mention_state(self) -> None:
ctx = self._get_mention_context()
if ctx is None:
if self._last_mention_prefix is not None:
self._last_mention_prefix = None
self.post_message(self.MentionChanged("", []))
return
_, _, prefix = ctx
if prefix != self._last_mention_prefix:
self._last_mention_prefix = prefix
matches = self._get_matches(prefix)
self.post_message(self.MentionChanged(prefix, matches))
def action_submit(self) -> None:
text = self.text
self.clear()
self._last_mention_prefix = None
self.post_message(self.Submitted(self, text))
self.post_message(self.MentionChanged("", []))
def action_complete(self) -> None:
"""Complete the current @mention with Tab."""
ctx = self._get_mention_context()
if ctx is None:
return
row, at_col, prefix = ctx
matches = self._get_matches(prefix)
if not matches:
return
_, col = self.cursor_location
self.replace(matches[0] + " ", start=(row, at_col + 1), end=(row, col))
self._last_mention_prefix = None
self.post_message(self.MentionChanged("", []))
async def _on_key(self, event: events.Key) -> None:
if event.key == "shift+enter":
@@ -99,7 +161,15 @@ class ChatTextArea(TextArea):
event.prevent_default()
self.action_submit()
return
if event.key == "tab":
event.prevent_default()
self.action_complete()
return
if event.key == "escape":
self._last_mention_prefix = None
self.post_message(self.MentionChanged("", []))
await super()._on_key(event)
self._emit_mention_state()
_CORAL = "#eb6658"
@@ -198,7 +268,7 @@ class ThinkingIndicator(Static):
meta_parts.append(f"{step_elapsed:.1f}s")
meta = " · ".join(meta_parts)
suffix = f" ({meta})" if meta else ""
done_line = f" [{_DIM}]✓ {self._current_status}{suffix}[/]"
done_line = f" [{_DIM}]✓ {_safe_render(self._current_status)}{suffix}[/]"
if not any(self._current_status in s for s in self._steps):
self._steps.append(done_line)
if len(self._steps) > 6:
@@ -233,13 +303,18 @@ class ThinkingIndicator(Static):
lines: list[str] = []
for step in self._steps:
lines.append(step)
status_esc = _safe_render(self._current_status)
current = (
f"[{_CORAL}]{ch}[/] [{_DIM}]{self._agent_name}[/] {self._current_status}"
f"[{_CORAL}]{ch}[/] [{_DIM}]{self._agent_name}[/] {status_esc}"
)
if self._tokens:
current += f" {self._tokens}"
lines.append(current)
self.update("\n".join(lines))
content = "\n".join(lines)
try:
self.update(content)
except Exception:
self.update(_safe_render(content))
class CreateRoomScreen(ModalScreen[dict[str, Any] | None]):
@@ -497,6 +572,15 @@ class AgentTUI(App[None]):
margin: 0 0 0 0;
height: auto;
}}
#completion-hint {{
display: none;
height: auto;
max-height: 2;
padding: 0 2;
margin: 0 1;
background: #333333;
color: {_TEAL};
}}
"""
def __init__(
@@ -543,6 +627,7 @@ class AgentTUI(App[None]):
)
with Vertical(id="chat-area"):
yield VerticalScroll(id="chat-scroll")
yield Static("", id="completion-hint")
with Horizontal(id="input-row"):
yield ChatTextArea(
id="chat-input",
@@ -645,7 +730,9 @@ class AgentTUI(App[None]):
self._update_placeholder()
self._load_history_from_disk()
self._render_chat()
self.query_one("#chat-input", ChatTextArea).focus()
chat_input = self.query_one("#chat-input", ChatTextArea)
chat_input._agent_names = self._agent_names
chat_input.focus()
try:
from crewai.new_agent.scheduler import TaskScheduler
@@ -736,6 +823,22 @@ class AgentTUI(App[None]):
self._update_placeholder()
self._render_chat()
# ── @mention autocomplete hint ──
def on_chat_text_area_mention_changed(
self, event: ChatTextArea.MentionChanged
) -> None:
try:
hint = self.query_one("#completion-hint", Static)
except Exception:
return
if not event.matches:
hint.display = False
return
names = " ".join(f"@{n}" for n in event.matches[:6])
hint.update(f"Tab to complete: {names}")
hint.display = True
# ── Message routing ──
async def on_chat_text_area_submitted(self, event: ChatTextArea.Submitted) -> None:
@@ -1264,8 +1367,17 @@ class AgentTUI(App[None]):
ValueError(f"Could not load '{target}'{detail}"),
)
msg = room_context if room_context else text
resp = await asyncio.to_thread(agent.message, msg)
resp = await asyncio.wait_for(
asyncio.to_thread(agent.message, msg),
timeout=300.0,
)
return target, resp, None
except asyncio.TimeoutError:
return (
target,
None,
TimeoutError(f"Agent '{target}' timed out after 5 minutes"),
)
except Exception as exc:
return target, None, exc
@@ -1275,9 +1387,9 @@ class AgentTUI(App[None]):
await self._safe_remove(indicators.get(target)) # type: ignore[arg-type]
if error or response is None:
msg = (
f"Error from {target}: {error}"
f"Error from {_safe_render(target)}: {_safe_render(str(error))}"
if error
else f"Could not load agent '{target}'."
else f"Could not load agent '{_safe_render(target)}'."
)
self._append_msg(room, "system", msg)
if self._current_room == room:
@@ -1337,14 +1449,17 @@ class AgentTUI(App[None]):
self._mount_sys("No agent available.")
return
agent = await asyncio.to_thread(self._get_or_create_agent, target)
agent = await asyncio.wait_for(
asyncio.to_thread(self._get_or_create_agent, target),
timeout=60.0,
)
if agent is None:
await self._safe_remove(thinking)
error_detail = getattr(self, "_last_agent_error", "")
if error_detail:
msg = f"Could not load agent '{target}': {error_detail}"
msg = f"Could not load agent '{_safe_render(target)}': {_safe_render(error_detail)}"
else:
msg = f"Could not load agent '{target}'."
msg = f"Could not load agent '{_safe_render(target)}'."
self._append_msg(room, "system", msg)
if self._current_room == room:
self._mount_sys(msg)
@@ -1365,7 +1480,6 @@ class AgentTUI(App[None]):
# Stream response token-by-token
scroll = self.query_one("#chat-scroll", VerticalScroll)
follow_tail = self._is_near_bottom(scroll)
bubble: ChatBubble | None = None
accumulated = ""
stream_start = time.monotonic()
@@ -1378,7 +1492,7 @@ class AgentTUI(App[None]):
mk = f"[bold {_CORAL}]{target}[/]\n{rendered}"
if final:
if metadata:
mk += f"\n\n[{_DIM}]{metadata}[/]"
mk += f"\n\n[{_DIM}]{_safe_render(metadata)}[/]"
else:
cursor = f"[{_CORAL}]▎[/]"
elapsed = time.monotonic() - stream_start
@@ -1387,21 +1501,46 @@ class AgentTUI(App[None]):
mk += f"{cursor}\n\n{progress}"
return mk
async for chunk in agent.stream(message_text):
# Timeout-protected streaming: prevents UI freeze if LLM stalls
stream = agent.stream(message_text)
first_chunk = True
while True:
try:
timeout = 180.0 if first_chunk else 120.0
chunk = await asyncio.wait_for(
anext(stream), timeout=timeout # type: ignore[arg-type]
)
first_chunk = False
except StopAsyncIteration:
break
except asyncio.TimeoutError:
accumulated += "\n\n[Response timed out]"
break
except Exception:
break
accumulated += chunk
stream_chars += len(chunk)
if bubble is None and self._current_room == room:
bubble = ChatBubble(
_stream_markup(accumulated), classes="agent-bubble"
)
# Insert bubble before thinking so indicator stays at bottom
scroll.mount(bubble, before=thinking)
if follow_tail:
try:
bubble = ChatBubble(
_stream_markup(accumulated), classes="agent-bubble"
)
scroll.mount(bubble, before=thinking)
except Exception:
bubble = ChatBubble(
_safe_render(accumulated), classes="agent-bubble"
)
scroll.mount(bubble, before=thinking)
if self._is_near_bottom(scroll):
scroll.scroll_end(animate=False)
elif bubble is not None:
bubble.update(_stream_markup(accumulated))
if follow_tail:
try:
bubble.update(_stream_markup(accumulated))
except Exception:
bubble.update(_safe_render(accumulated))
if self._is_near_bottom(scroll):
scroll.scroll_end(animate=False)
# Remove cursor, add final metadata
@@ -1414,24 +1553,26 @@ class AgentTUI(App[None]):
response, "output_tokens", 0
):
meta_parts.append(
f"{response.input_tokens or 0:,} "
f"{response.output_tokens or 0:,} tokens"
f"~{response.output_tokens or 0:,} tokens"
)
if getattr(response, "response_time_ms", 0):
meta_parts.append(f"{response.response_time_ms / 1000:.1f}s")
metadata = " · ".join(meta_parts)
if bubble is not None:
bubble.update(
_stream_markup(accumulated, final=True, metadata=metadata)
)
try:
bubble.update(
_stream_markup(accumulated, final=True, metadata=metadata)
)
except Exception:
bubble.update(_safe_render(accumulated))
content = accumulated or (response.content if response else "")
self._append_msg(room, target, content, metadata)
except Exception as e:
await self._safe_remove(thinking)
msg = f"Error: {e}"
msg = f"Error: {_safe_render(str(e))}"
self._append_msg(room, "system", msg)
if self._current_room == room:
self._mount_sys(msg)
@@ -1486,18 +1627,42 @@ class AgentTUI(App[None]):
scroll.scroll_end(animate=False)
def _mount_sys(self, text: str) -> None:
self._mount_bubble("system", text)
"""Mount a system message. Accepts pre-formatted Rich markup."""
scroll = self.query_one("#chat-scroll", VerticalScroll)
near_bottom = self._is_near_bottom(scroll)
try:
bubble = ChatBubble(f"[{_DIM}]{text}[/]", classes="system-bubble")
except Exception:
bubble = ChatBubble(_safe_render(text), classes="system-bubble")
scroll.mount(bubble)
if near_bottom:
scroll.scroll_end(animate=False)
def _highlight_mentions(self, escaped_text: str) -> str:
"""Highlight @agent_name mentions in pre-escaped text."""
for name in self._agent_names:
escaped_text = re.sub(
r"@" + re.escape(name) + r"\b",
f"[bold {_TEAL}]@{name}[/]",
escaped_text,
flags=re.IGNORECASE,
)
return escaped_text
def _make_bubble(self, sender: str, content: str, metadata: str = "") -> ChatBubble:
if sender == "You":
markup = f"[bold #e8e8e8]You[/]\n{_safe_render(content)}"
rendered = self._highlight_mentions(_safe_render(content))
markup = f"[bold #e8e8e8]You[/]\n{rendered}"
return ChatBubble(markup, classes="user-bubble")
if sender == "system":
markup = f"[dim italic]{_safe_render(content)}[/]"
return ChatBubble(markup, classes="system-bubble")
markup = f"[bold {_CORAL}]{sender}[/]\n{_safe_render(content)}"
try:
return ChatBubble(f"[{_DIM}]{content}[/]", classes="system-bubble")
except Exception:
return ChatBubble(_safe_render(content), classes="system-bubble")
rendered = _safe_render(content)
markup = f"[bold {_CORAL}]{sender}[/]\n{rendered}"
if metadata:
markup += f"\n\n[{_DIM}]{metadata}[/]"
markup += f"\n\n[{_DIM}]{_safe_render(metadata)}[/]"
return ChatBubble(markup, classes="agent-bubble")
def _render_chat(self) -> None: