chore: improve typing and consolidate utilities

- add type annotations across utility modules  
- refactor printer system, agent utils, and imports for consistency  
- remove unused modules, constants, and redundant patterns  
- improve runtime type checks, exception handling, and guardrail validation  
- standardize warning suppression and logging utilities  
- fix llm typing, threading/typing edge cases, and test behavior
This commit is contained in:
Greyson LaLonde
2025-09-23 11:33:46 -04:00
committed by GitHub
parent 34bed359a6
commit 3e97393f58
47 changed files with 1939 additions and 1233 deletions

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import json
import re
from collections.abc import Callable, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
from rich.console import Console
@@ -19,18 +21,47 @@ from crewai.tools import BaseTool as CrewAITool
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer
from crewai.utilities.errors import AgentRepositoryError
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
LLMContextLengthExceededError,
)
from crewai.utilities.i18n import I18N
from crewai.utilities.printer import ColoredText, Printer
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
class SummaryContent(TypedDict):
"""Structure for summary content entries.
Attributes:
content: The summarized content.
"""
content: str
console = Console()
_MULTIPLE_NEWLINES: Final[re.Pattern[str]] = re.compile(r"\n+")
def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
"""Parse tools to be used for the task."""
tools_list = []
"""Parse tools to be used for the task.
Args:
tools: List of tools to parse.
Returns:
List of structured tools.
Raises:
ValueError: If a tool is not a CrewStructuredTool or BaseTool.
"""
tools_list: list[CrewStructuredTool] = []
for tool in tools:
if isinstance(tool, CrewAITool):
@@ -42,7 +73,14 @@ def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str:
"""Get the names of the tools."""
"""Get the names of the tools.
Args:
tools: List of tools to get names from.
Returns:
Comma-separated string of tool names.
"""
return ", ".join([t.name for t in tools])
@@ -51,16 +89,30 @@ def render_text_description_and_args(
) -> str:
"""Render the tool name, description, and args in plain text.
search: This tool is used for search, args: {"query": {"type": "string"}}
calculator: This tool is used for math, \
args: {"expression": {"type": "string"}}
search: This tool is used for search, args: {"query": {"type": "string"}}
calculator: This tool is used for math, \
args: {"expression": {"type": "string"}}
Args:
tools: List of tools to render.
Returns:
Plain text description of tools.
"""
tool_strings = [tool.description for tool in tools]
return "\n".join(tool_strings)
def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool:
"""Check if the maximum number of iterations has been reached."""
"""Check if the maximum number of iterations has been reached.
Args:
iterations: Current number of iterations.
max_iterations: Maximum allowed iterations.
Returns:
True if maximum iterations reached, False otherwise.
"""
return iterations >= max_iterations
@@ -68,16 +120,19 @@ def handle_max_iterations_exceeded(
formatted_answer: AgentAction | AgentFinish | None,
printer: Printer,
i18n: I18N,
messages: list[dict[str, str]],
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[Any],
callbacks: list[Callable[..., Any]],
) -> AgentAction | AgentFinish:
"""
Handles the case when the maximum number of iterations is exceeded.
Performs one more LLM call to get the final answer.
"""Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer.
Parameters:
Args:
formatted_answer: The last formatted answer from the agent.
printer: Printer instance for output.
i18n: I18N instance for internationalization.
messages: List of messages to send to the LLM.
llm: The LLM instance to call.
callbacks: List of callbacks for the LLM call.
Returns:
The final formatted answer after exceeding max iterations.
@@ -98,7 +153,7 @@ def handle_max_iterations_exceeded(
# Perform one more LLM call to get the final answer
answer = llm.call(
messages,
messages, # type: ignore[arg-type]
callbacks=callbacks,
)
@@ -110,20 +165,38 @@ def handle_max_iterations_exceeded(
raise ValueError("Invalid response from LLM call - None or empty.")
# Return the formatted answer, regardless of its type
return format_answer(answer)
return format_answer(answer=answer)
def format_message_for_llm(prompt: str, role: str = "user") -> dict[str, str]:
def format_message_for_llm(
prompt: str, role: Literal["user", "assistant", "system"] = "user"
) -> LLMMessage:
"""Format a message for the LLM.
Args:
prompt: The message content.
role: The role of the message sender, either 'user' or 'assistant'.
Returns:
A dictionary with 'role' and 'content' keys.
"""
prompt = prompt.rstrip()
return {"role": role, "content": prompt}
def format_answer(answer: str) -> AgentAction | AgentFinish:
"""Format a response from the LLM into an AgentAction or AgentFinish."""
"""Format a response from the LLM into an AgentAction or AgentFinish.
Args:
answer: The raw response from the LLM
Returns:
Either an AgentAction or AgentFinish
"""
try:
return parse(answer)
except Exception:
# If parsing fails, return a default AgentFinish
return AgentFinish(
thought="Failed to parse LLM response",
output=answer,
@@ -134,23 +207,43 @@ def format_answer(answer: str) -> AgentAction | AgentFinish:
def enforce_rpm_limit(
request_within_rpm_limit: Callable[[], bool] | None = None,
) -> None:
"""Enforce the requests per minute (RPM) limit if applicable."""
"""Enforce the requests per minute (RPM) limit if applicable.
Args:
request_within_rpm_limit: Function to enforce RPM limit.
"""
if request_within_rpm_limit:
request_within_rpm_limit()
def get_llm_response(
llm: LLM | BaseLLM,
messages: list[dict[str, str]],
callbacks: list[Any],
messages: list[LLMMessage],
callbacks: list[Callable[..., Any]],
printer: Printer,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> str:
"""Call the LLM and return the response, handling any invalid responses."""
"""Call the LLM and return the response, handling any invalid responses.
Args:
llm: The LLM instance to call
messages: The messages to send to the LLM
callbacks: List of callbacks for the LLM call
printer: Printer instance for output
from_task: Optional task context for the LLM call
from_agent: Optional agent context for the LLM call
Returns:
The response from the LLM as a string
Raises:
Exception: If an error occurs.
ValueError: If the response is None or empty.
"""
try:
answer = llm.call(
messages,
messages, # type: ignore[arg-type]
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent,
@@ -170,7 +263,15 @@ def get_llm_response(
def process_llm_response(
answer: str, use_stop_words: bool
) -> AgentAction | AgentFinish:
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
"""Process the LLM response and format it into an AgentAction or AgentFinish.
Args:
answer: The raw response from the LLM
use_stop_words: Whether to use stop words in the LLM call
Returns:
Either an AgentAction or AgentFinish
"""
if not use_stop_words:
try:
# Preliminary parsing to check for errors.
@@ -200,6 +301,9 @@ def handle_agent_action_core(
Returns:
Either an AgentAction or AgentFinish
Notes:
- TODO: Remove messages parameter and its usage.
"""
if step_callback:
step_callback(tool_result)
@@ -220,7 +324,7 @@ def handle_agent_action_core(
return formatted_answer
def handle_unknown_error(printer: Any, exception: Exception) -> None:
def handle_unknown_error(printer: Printer, exception: Exception) -> None:
"""Handle unknown errors by informing the user.
Args:
@@ -244,10 +348,10 @@ def handle_unknown_error(printer: Any, exception: Exception) -> None:
def handle_output_parser_exception(
e: OutputParserError,
messages: list[dict[str, str]],
messages: list[LLMMessage],
iterations: int,
log_error_after: int = 3,
printer: Any | None = None,
printer: Printer | None = None,
) -> AgentAction:
"""Handle OutputParserError by updating messages and formatted_answer.
@@ -288,18 +392,18 @@ def is_context_length_exceeded(exception: Exception) -> bool:
Returns:
bool: True if the exception is due to context length exceeding
"""
return LLMContextLengthExceededException(str(exception))._is_context_limit_error(
return LLMContextLengthExceededError(str(exception))._is_context_limit_error(
str(exception)
)
def handle_context_length(
respect_context_window: bool,
printer: Any,
messages: list[dict[str, str]],
llm: Any,
callbacks: list[Any],
i18n: Any,
printer: Printer,
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[Callable[..., Any]],
i18n: I18N,
) -> None:
"""Handle context length exceeded by either summarizing or raising an error.
@@ -310,13 +414,16 @@ def handle_context_length(
llm: LLM instance for summarization
callbacks: List of callbacks for LLM
i18n: I18N instance for messages
Raises:
SystemExit: If context length is exceeded and user opts not to summarize
"""
if respect_context_window:
printer.print(
content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...",
color="yellow",
)
summarize_messages(messages, llm, callbacks, i18n)
summarize_messages(messages=messages, llm=llm, callbacks=callbacks, i18n=i18n)
else:
printer.print(
content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
@@ -328,10 +435,10 @@ def handle_context_length(
def summarize_messages(
messages: list[dict[str, str]],
llm: Any,
callbacks: list[Any],
i18n: Any,
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[Callable[..., Any]],
i18n: I18N,
) -> None:
"""Summarize messages to fit within context window.
@@ -349,7 +456,7 @@ def summarize_messages(
for i in range(0, len(messages_string), cut_size)
]
summarized_contents = []
summarized_contents: list[SummaryContent] = []
total_groups = len(messages_groups)
for idx, group in enumerate(messages_groups, 1):
@@ -357,15 +464,17 @@ def summarize_messages(
content=f"Summarizing {idx}/{total_groups}...",
color="yellow",
)
messages = [
format_message_for_llm(
i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
i18n.slice("summarize_instruction").format(group=group["content"]),
),
]
summary = llm.call(
[
format_message_for_llm(
i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
i18n.slice("summarize_instruction").format(group=group["content"]),
),
],
messages, # type: ignore[arg-type]
callbacks=callbacks,
)
summarized_contents.append({"content": str(summary)})
@@ -404,20 +513,29 @@ def show_agent_logs(
if formatted_answer is None:
# Start logs
printer.print(
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
content=[
ColoredText("# Agent: ", "bold_purple"),
ColoredText(agent_role, "bold_green"),
]
)
if task_description:
printer.print(
content=f"\033[95m## Task:\033[00m \033[92m{task_description}\033[00m"
content=[
ColoredText("## Task: ", "purple"),
ColoredText(task_description, "green"),
]
)
else:
# Execution logs
printer.print(
content=f"\n\n\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
content=[
ColoredText("\n\n# Agent: ", "bold_purple"),
ColoredText(agent_role, "bold_green"),
]
)
if isinstance(formatted_answer, AgentAction):
thought = re.sub(r"\n+", "\n", formatted_answer.thought)
thought = _MULTIPLE_NEWLINES.sub("\n", formatted_answer.thought)
formatted_json = json.dumps(
formatted_answer.tool_input,
indent=2,
@@ -425,24 +543,39 @@ def show_agent_logs(
)
if thought and thought != "":
printer.print(
content=f"\033[95m## Thought:\033[00m \033[92m{thought}\033[00m"
content=[
ColoredText("## Thought: ", "purple"),
ColoredText(thought, "green"),
]
)
printer.print(
content=f"\033[95m## Using tool:\033[00m \033[92m{formatted_answer.tool}\033[00m"
content=[
ColoredText("## Using tool: ", "purple"),
ColoredText(formatted_answer.tool, "green"),
]
)
printer.print(
content=f"\033[95m## Tool Input:\033[00m \033[92m\n{formatted_json}\033[00m"
content=[
ColoredText("## Tool Input: ", "purple"),
ColoredText(f"\n{formatted_json}", "green"),
]
)
printer.print(
content=f"\033[95m## Tool Output:\033[00m \033[92m\n{formatted_answer.result}\033[00m"
content=[
ColoredText("## Tool Output: ", "purple"),
ColoredText(f"\n{formatted_answer.result}", "green"),
]
)
elif isinstance(formatted_answer, AgentFinish):
printer.print(
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
content=[
ColoredText("## Final Answer: ", "purple"),
ColoredText(f"\n{formatted_answer.output}\n\n", "green"),
]
)
def _print_current_organization():
def _print_current_organization() -> None:
settings = Settings()
if settings.org_uuid:
console.print(
@@ -457,6 +590,17 @@ def _print_current_organization():
def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
"""Load an agent from the repository.
Args:
from_repository: The name of the agent to load.
Returns:
A dictionary of attributes to use for the agent.
Raises:
AgentRepositoryError: If the agent cannot be loaded.
"""
attributes: dict[str, Any] = {}
if from_repository:
import importlib