mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
refactor: improve code structure and logging in LiteAgent and ConsoleFormatter
- Refactored imports in lite_agent.py for better readability. - Enhanced guardrail property initialization in LiteAgent. - Updated logging functionality to emit AgentLogsExecutionEvent for better tracking. - Modified ConsoleFormatter to include tool arguments and final output in status updates. - Improved output formatting for long text in ConsoleFormatter.
This commit is contained in:
@@ -1,14 +1,33 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
@@ -39,10 +58,10 @@ from crewai.utilities.agent_utils import (
|
||||
parse_tools,
|
||||
process_llm_response,
|
||||
render_text_description_and_args,
|
||||
show_agent_logs,
|
||||
)
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.events.agent_events import (
|
||||
AgentLogsExecutionEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
@@ -153,9 +172,11 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
# Guardrail Properties
|
||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output"
|
||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
|
||||
Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
@@ -181,7 +202,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
@@ -208,17 +228,18 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self._guardrail = self.guardrail
|
||||
elif isinstance(self.guardrail, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
assert isinstance(self.llm, LLM)
|
||||
|
||||
self._guardrail = LLMGuardrail(
|
||||
description=self.guardrail, llm=self.llm
|
||||
)
|
||||
self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm)
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(cls, v: Optional[Union[Callable, str]]) -> Optional[Union[Callable, str]]:
|
||||
def validate_guardrail_function(
|
||||
cls, v: Optional[Union[Callable, str]]
|
||||
) -> Optional[Union[Callable, str]]:
|
||||
"""Validate that the guardrail function has the correct signature.
|
||||
|
||||
If v is a callable, validate that it has the correct signature.
|
||||
@@ -330,9 +351,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
if self.response_format:
|
||||
try:
|
||||
# Cast to BaseModel to ensure type safety
|
||||
result = self.response_format.model_validate_json(
|
||||
agent_finish.output
|
||||
)
|
||||
result = self.response_format.model_validate_json(agent_finish.output)
|
||||
if isinstance(result, BaseModel):
|
||||
formatted_result = result
|
||||
except Exception as e:
|
||||
@@ -357,15 +376,15 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
guardrail_result = process_guardrail(
|
||||
output=output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self._guardrail_retry_count
|
||||
retry_count=self._guardrail_retry_count,
|
||||
)
|
||||
|
||||
if not guardrail_result.success:
|
||||
if self._guardrail_retry_count >= self.guardrail_max_retries:
|
||||
raise Exception(
|
||||
f"Agent's guardrail failed validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Agent's guardrail failed validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
self._guardrail_retry_count += 1
|
||||
if self.verbose:
|
||||
self._printer.print(
|
||||
@@ -373,10 +392,13 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
f"\n{guardrail_result.error}"
|
||||
)
|
||||
|
||||
self._messages.append({
|
||||
"role": "user",
|
||||
"content": guardrail_result.error or "Guardrail validation failed"
|
||||
})
|
||||
self._messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": guardrail_result.error
|
||||
or "Guardrail validation failed",
|
||||
}
|
||||
)
|
||||
|
||||
return self._execute_core(agent_info=agent_info)
|
||||
|
||||
@@ -580,11 +602,13 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
"""Show logs for the agent's execution."""
|
||||
show_agent_logs(
|
||||
printer=self._printer,
|
||||
agent_role=self.role,
|
||||
formatted_answer=formatted_answer,
|
||||
verbose=self.verbose,
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentLogsExecutionEvent(
|
||||
agent_role=self.role,
|
||||
formatted_answer=formatted_answer,
|
||||
verbose=self.verbose,
|
||||
),
|
||||
)
|
||||
|
||||
def _append_message(self, text: str, role: str = "assistant") -> None:
|
||||
|
||||
@@ -110,6 +110,7 @@ class EventListener(BaseEventListener):
|
||||
event.crew_name or "Crew",
|
||||
source.id,
|
||||
"completed",
|
||||
final_string_output,
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||
@@ -288,6 +289,7 @@ class EventListener(BaseEventListener):
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_started(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
)
|
||||
else:
|
||||
self.formatter.handle_tool_usage_started(
|
||||
|
||||
@@ -41,7 +41,12 @@ class ConsoleFormatter:
|
||||
)
|
||||
|
||||
def create_status_content(
|
||||
self, title: str, name: str, status_style: str = "blue", **fields
|
||||
self,
|
||||
title: str,
|
||||
name: str,
|
||||
status_style: str = "blue",
|
||||
tool_args: Dict[str, Any] | str = "",
|
||||
**fields,
|
||||
) -> Text:
|
||||
"""Create standardized status content with consistent formatting."""
|
||||
content = Text()
|
||||
@@ -54,6 +59,8 @@ class ConsoleFormatter:
|
||||
content.append(
|
||||
f"{value}\n", style=fields.get(f"{label}_style", status_style)
|
||||
)
|
||||
content.append("Tool Args: ", style="white")
|
||||
content.append(f"{tool_args}\n", style=status_style)
|
||||
|
||||
return content
|
||||
|
||||
@@ -153,6 +160,7 @@ class ConsoleFormatter:
|
||||
crew_name: str,
|
||||
source_id: str,
|
||||
status: str = "completed",
|
||||
final_string_output: str = "",
|
||||
) -> None:
|
||||
"""Handle crew tree updates with consistent formatting."""
|
||||
if not self.verbose or tree is None:
|
||||
@@ -184,6 +192,7 @@ class ConsoleFormatter:
|
||||
style,
|
||||
ID=source_id,
|
||||
)
|
||||
content.append(f"Final Output: {final_string_output}\n", style="white")
|
||||
|
||||
self.print_panel(content, title, style)
|
||||
|
||||
@@ -456,12 +465,19 @@ class ConsoleFormatter:
|
||||
def handle_llm_tool_usage_started(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: Dict[str, Any] | str,
|
||||
):
|
||||
tree = self.get_llm_tree(tool_name)
|
||||
self.add_tree_node(tree, "🔄 Tool Usage Started", "green")
|
||||
self.print(tree)
|
||||
# Create status content for the tool usage
|
||||
content = self.create_status_content(
|
||||
"Tool Usage Started", tool_name, Status="In Progress", tool_args=tool_args
|
||||
)
|
||||
|
||||
# Create and print the panel
|
||||
self.print_panel(content, "Tool Usage", "green")
|
||||
self.print()
|
||||
return tree
|
||||
|
||||
# Still return the tree for compatibility with existing code
|
||||
return self.get_llm_tree(tool_name)
|
||||
|
||||
def handle_llm_tool_usage_finished(
|
||||
self,
|
||||
@@ -492,6 +508,7 @@ class ConsoleFormatter:
|
||||
agent_branch: Optional[Tree],
|
||||
tool_name: str,
|
||||
crew_tree: Optional[Tree],
|
||||
tool_args: Dict[str, Any] | str = "",
|
||||
) -> Optional[Tree]:
|
||||
"""Handle tool usage started event."""
|
||||
if not self.verbose:
|
||||
@@ -1404,8 +1421,8 @@ class ConsoleFormatter:
|
||||
|
||||
# Create tool output content with better formatting
|
||||
output_text = str(formatted_answer.result)
|
||||
if len(output_text) > 1000:
|
||||
output_text = output_text[:997] + "..."
|
||||
if len(output_text) > 2000:
|
||||
output_text = output_text[:1997] + "..."
|
||||
|
||||
output_panel = Panel(
|
||||
Text(output_text, style="bright_green"),
|
||||
|
||||
Reference in New Issue
Block a user