mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
refactor: improve code structure and logging in LiteAgent and ConsoleFormatter
- Refactored imports in lite_agent.py for better readability. - Enhanced guardrail property initialization in LiteAgent. - Updated logging functionality to emit AgentLogsExecutionEvent for better tracking. - Modified ConsoleFormatter to include tool arguments and final output in status updates. - Improved output formatting for long text in ConsoleFormatter.
This commit is contained in:
@@ -1,14 +1,33 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Self
|
from typing import Self
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator, field_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
InstanceOf,
|
||||||
|
PrivateAttr,
|
||||||
|
model_validator,
|
||||||
|
field_validator,
|
||||||
|
)
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
@@ -39,10 +58,10 @@ from crewai.utilities.agent_utils import (
|
|||||||
parse_tools,
|
parse_tools,
|
||||||
process_llm_response,
|
process_llm_response,
|
||||||
render_text_description_and_args,
|
render_text_description_and_args,
|
||||||
show_agent_logs,
|
|
||||||
)
|
)
|
||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import generate_model_description
|
||||||
from crewai.utilities.events.agent_events import (
|
from crewai.utilities.events.agent_events import (
|
||||||
|
AgentLogsExecutionEvent,
|
||||||
LiteAgentExecutionCompletedEvent,
|
LiteAgentExecutionCompletedEvent,
|
||||||
LiteAgentExecutionErrorEvent,
|
LiteAgentExecutionErrorEvent,
|
||||||
LiteAgentExecutionStartedEvent,
|
LiteAgentExecutionStartedEvent,
|
||||||
@@ -153,9 +172,11 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Guardrail Properties
|
# Guardrail Properties
|
||||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = Field(
|
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
|
||||||
|
Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Function or string description of a guardrail to validate agent output"
|
description="Function or string description of a guardrail to validate agent output",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
guardrail_max_retries: int = Field(
|
guardrail_max_retries: int = Field(
|
||||||
default=3, description="Maximum number of retries when guardrail fails"
|
default=3, description="Maximum number of retries when guardrail fails"
|
||||||
@@ -181,7 +202,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||||
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def setup_llm(self):
|
def setup_llm(self):
|
||||||
"""Set up the LLM and other components after initialization."""
|
"""Set up the LLM and other components after initialization."""
|
||||||
@@ -208,17 +228,18 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
self._guardrail = self.guardrail
|
self._guardrail = self.guardrail
|
||||||
elif isinstance(self.guardrail, str):
|
elif isinstance(self.guardrail, str):
|
||||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||||
|
|
||||||
assert isinstance(self.llm, LLM)
|
assert isinstance(self.llm, LLM)
|
||||||
|
|
||||||
self._guardrail = LLMGuardrail(
|
self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm)
|
||||||
description=self.guardrail, llm=self.llm
|
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@field_validator("guardrail", mode="before")
|
@field_validator("guardrail", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_guardrail_function(cls, v: Optional[Union[Callable, str]]) -> Optional[Union[Callable, str]]:
|
def validate_guardrail_function(
|
||||||
|
cls, v: Optional[Union[Callable, str]]
|
||||||
|
) -> Optional[Union[Callable, str]]:
|
||||||
"""Validate that the guardrail function has the correct signature.
|
"""Validate that the guardrail function has the correct signature.
|
||||||
|
|
||||||
If v is a callable, validate that it has the correct signature.
|
If v is a callable, validate that it has the correct signature.
|
||||||
@@ -330,9 +351,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
if self.response_format:
|
if self.response_format:
|
||||||
try:
|
try:
|
||||||
# Cast to BaseModel to ensure type safety
|
# Cast to BaseModel to ensure type safety
|
||||||
result = self.response_format.model_validate_json(
|
result = self.response_format.model_validate_json(agent_finish.output)
|
||||||
agent_finish.output
|
|
||||||
)
|
|
||||||
if isinstance(result, BaseModel):
|
if isinstance(result, BaseModel):
|
||||||
formatted_result = result
|
formatted_result = result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -357,7 +376,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
guardrail_result = process_guardrail(
|
guardrail_result = process_guardrail(
|
||||||
output=output,
|
output=output,
|
||||||
guardrail=self._guardrail,
|
guardrail=self._guardrail,
|
||||||
retry_count=self._guardrail_retry_count
|
retry_count=self._guardrail_retry_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not guardrail_result.success:
|
if not guardrail_result.success:
|
||||||
@@ -373,10 +392,13 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
f"\n{guardrail_result.error}"
|
f"\n{guardrail_result.error}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._messages.append({
|
self._messages.append(
|
||||||
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": guardrail_result.error or "Guardrail validation failed"
|
"content": guardrail_result.error
|
||||||
})
|
or "Guardrail validation failed",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return self._execute_core(agent_info=agent_info)
|
return self._execute_core(agent_info=agent_info)
|
||||||
|
|
||||||
@@ -580,11 +602,13 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
|
|
||||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||||
"""Show logs for the agent's execution."""
|
"""Show logs for the agent's execution."""
|
||||||
show_agent_logs(
|
crewai_event_bus.emit(
|
||||||
printer=self._printer,
|
self,
|
||||||
|
AgentLogsExecutionEvent(
|
||||||
agent_role=self.role,
|
agent_role=self.role,
|
||||||
formatted_answer=formatted_answer,
|
formatted_answer=formatted_answer,
|
||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _append_message(self, text: str, role: str = "assistant") -> None:
|
def _append_message(self, text: str, role: str = "assistant") -> None:
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ class EventListener(BaseEventListener):
|
|||||||
event.crew_name or "Crew",
|
event.crew_name or "Crew",
|
||||||
source.id,
|
source.id,
|
||||||
"completed",
|
"completed",
|
||||||
|
final_string_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||||
@@ -288,6 +289,7 @@ class EventListener(BaseEventListener):
|
|||||||
if isinstance(source, LLM):
|
if isinstance(source, LLM):
|
||||||
self.formatter.handle_llm_tool_usage_started(
|
self.formatter.handle_llm_tool_usage_started(
|
||||||
event.tool_name,
|
event.tool_name,
|
||||||
|
event.tool_args,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.formatter.handle_tool_usage_started(
|
self.formatter.handle_tool_usage_started(
|
||||||
|
|||||||
@@ -41,7 +41,12 @@ class ConsoleFormatter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_status_content(
|
def create_status_content(
|
||||||
self, title: str, name: str, status_style: str = "blue", **fields
|
self,
|
||||||
|
title: str,
|
||||||
|
name: str,
|
||||||
|
status_style: str = "blue",
|
||||||
|
tool_args: Dict[str, Any] | str = "",
|
||||||
|
**fields,
|
||||||
) -> Text:
|
) -> Text:
|
||||||
"""Create standardized status content with consistent formatting."""
|
"""Create standardized status content with consistent formatting."""
|
||||||
content = Text()
|
content = Text()
|
||||||
@@ -54,6 +59,8 @@ class ConsoleFormatter:
|
|||||||
content.append(
|
content.append(
|
||||||
f"{value}\n", style=fields.get(f"{label}_style", status_style)
|
f"{value}\n", style=fields.get(f"{label}_style", status_style)
|
||||||
)
|
)
|
||||||
|
content.append("Tool Args: ", style="white")
|
||||||
|
content.append(f"{tool_args}\n", style=status_style)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@@ -153,6 +160,7 @@ class ConsoleFormatter:
|
|||||||
crew_name: str,
|
crew_name: str,
|
||||||
source_id: str,
|
source_id: str,
|
||||||
status: str = "completed",
|
status: str = "completed",
|
||||||
|
final_string_output: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle crew tree updates with consistent formatting."""
|
"""Handle crew tree updates with consistent formatting."""
|
||||||
if not self.verbose or tree is None:
|
if not self.verbose or tree is None:
|
||||||
@@ -184,6 +192,7 @@ class ConsoleFormatter:
|
|||||||
style,
|
style,
|
||||||
ID=source_id,
|
ID=source_id,
|
||||||
)
|
)
|
||||||
|
content.append(f"Final Output: {final_string_output}\n", style="white")
|
||||||
|
|
||||||
self.print_panel(content, title, style)
|
self.print_panel(content, title, style)
|
||||||
|
|
||||||
@@ -456,12 +465,19 @@ class ConsoleFormatter:
|
|||||||
def handle_llm_tool_usage_started(
|
def handle_llm_tool_usage_started(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
|
tool_args: Dict[str, Any] | str,
|
||||||
):
|
):
|
||||||
tree = self.get_llm_tree(tool_name)
|
# Create status content for the tool usage
|
||||||
self.add_tree_node(tree, "🔄 Tool Usage Started", "green")
|
content = self.create_status_content(
|
||||||
self.print(tree)
|
"Tool Usage Started", tool_name, Status="In Progress", tool_args=tool_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and print the panel
|
||||||
|
self.print_panel(content, "Tool Usage", "green")
|
||||||
self.print()
|
self.print()
|
||||||
return tree
|
|
||||||
|
# Still return the tree for compatibility with existing code
|
||||||
|
return self.get_llm_tree(tool_name)
|
||||||
|
|
||||||
def handle_llm_tool_usage_finished(
|
def handle_llm_tool_usage_finished(
|
||||||
self,
|
self,
|
||||||
@@ -492,6 +508,7 @@ class ConsoleFormatter:
|
|||||||
agent_branch: Optional[Tree],
|
agent_branch: Optional[Tree],
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
crew_tree: Optional[Tree],
|
crew_tree: Optional[Tree],
|
||||||
|
tool_args: Dict[str, Any] | str = "",
|
||||||
) -> Optional[Tree]:
|
) -> Optional[Tree]:
|
||||||
"""Handle tool usage started event."""
|
"""Handle tool usage started event."""
|
||||||
if not self.verbose:
|
if not self.verbose:
|
||||||
@@ -1404,8 +1421,8 @@ class ConsoleFormatter:
|
|||||||
|
|
||||||
# Create tool output content with better formatting
|
# Create tool output content with better formatting
|
||||||
output_text = str(formatted_answer.result)
|
output_text = str(formatted_answer.result)
|
||||||
if len(output_text) > 1000:
|
if len(output_text) > 2000:
|
||||||
output_text = output_text[:997] + "..."
|
output_text = output_text[:1997] + "..."
|
||||||
|
|
||||||
output_panel = Panel(
|
output_panel = Panel(
|
||||||
Text(output_text, style="bright_green"),
|
Text(output_text, style="bright_green"),
|
||||||
|
|||||||
Reference in New Issue
Block a user