diff --git a/src/crewai/agents/parser.py b/src/crewai/agents/parser.py index 912983b11..2e4bba53d 100644 --- a/src/crewai/agents/parser.py +++ b/src/crewai/agents/parser.py @@ -18,7 +18,7 @@ from crewai.agents.constants import ( MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE, UNABLE_TO_REPAIR_JSON_RESULTS, ) -from crewai.utilities import I18N +from crewai.utilities.i18n import I18N _I18N = I18N() diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 6e9d22edb..733b46c79 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -1,28 +1,26 @@ +import io import json import logging import os import sys import threading -import warnings from collections import defaultdict -from contextlib import contextmanager +from collections.abc import Callable +from datetime import datetime from typing import ( Any, - DefaultDict, - Dict, - List, + Final, Literal, - Optional, - Type, + TextIO, TypedDict, - Union, cast, ) -from datetime import datetime + from dotenv import load_dotenv from litellm.types.utils import ChatCompletionDeltaToolCall from pydantic import BaseModel, Field +from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import ( LLMCallCompletedEvent, LLMCallFailedEvent, @@ -31,15 +29,19 @@ from crewai.events.types.llm_events import ( LLMStreamChunkEvent, ) from crewai.events.types.tool_usage_events import ( - ToolUsageStartedEvent, - ToolUsageFinishedEvent, ToolUsageErrorEvent, + ToolUsageFinishedEvent, + ToolUsageStartedEvent, ) +from crewai.llms.base_llm import BaseLLM +from crewai.utilities.exceptions.context_window_exceeding_exception import ( + LLMContextLengthExceededError, +) +from crewai.utilities.logger_utils import suppress_warnings -with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) +with suppress_warnings(): import litellm - from litellm import Choices + from litellm import Choices, CustomLogger from litellm.exceptions import ContextWindowExceededError from litellm.litellm_core_utils.get_supported_openai_params import ( get_supported_openai_params, @@ -47,16 +49,6 @@ with warnings.catch_warnings(): from litellm.types.utils import ModelResponse from litellm.utils import supports_response_schema - -import io -from typing import TextIO - -from crewai.llms.base_llm import BaseLLM -from crewai.events.event_bus import crewai_event_bus -from crewai.utilities.exceptions.context_window_exceeding_exception import ( - LLMContextLengthExceededException, -) - load_dotenv() litellm.suppress_debug_info = True @@ -126,7 +118,11 @@ if not isinstance(sys.stderr, FilteredStream): sys.stderr = FilteredStream(sys.stderr) -LLM_CONTEXT_WINDOW_SIZES = { +MIN_CONTEXT: Final[int] = 1024 +MAX_CONTEXT: Final[int] = 2097152 # Current max from gemini-1.5-pro +ANTHROPIC_PREFIXES: Final[tuple[str, str, str]] = ("anthropic/", "claude-", "claude/") + +LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = { # openai "gpt-4": 8192, "gpt-4o": 128000, @@ -252,30 +248,19 @@ LLM_CONTEXT_WINDOW_SIZES = { "mistral/mistral-large-2402": 32768, } -DEFAULT_CONTEXT_WINDOW_SIZE = 8192 -CONTEXT_WINDOW_USAGE_RATIO = 0.85 - - -@contextmanager -def suppress_warnings(): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore") - warnings.filterwarnings( - "ignore", message="open_text is deprecated*", category=DeprecationWarning - ) - - yield +DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 8192 +CONTEXT_WINDOW_USAGE_RATIO: Final[float] = 0.85 class Delta(TypedDict): - content: Optional[str] - role: Optional[str] + content: str | None + role: str | None class StreamingChoices(TypedDict): delta: Delta index: int - finish_reason: Optional[str] + finish_reason: str | None class FunctionArgs(BaseModel): @@ -288,31 +273,31 @@ class AccumulatedToolArgs(BaseModel): class LLM(BaseLLM): - completion_cost: Optional[float] = None + completion_cost: float | None = None def __init__( self, model: str, - timeout: Optional[Union[float, int]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - n: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[int, float]] = None, - response_format: Optional[Type[BaseModel]] = None, - seed: Optional[int] = None, - logprobs: Optional[int] = None, - top_logprobs: Optional[int] = None, - base_url: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - callbacks: List[Any] | None = None, - reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, + timeout: float | int | None = None, + temperature: float | None = None, + top_p: float | None = None, + n: int | None = None, + stop: str | list[str] | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + presence_penalty: float | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[int, float] | None = None, + response_format: type[BaseModel] | None = None, + seed: int | None = None, + logprobs: int | None = None, + top_logprobs: int | None = None, + base_url: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_key: str | None = None, + callbacks: list[Any] | None = None, + reasoning_effort: Literal["none", "low", "medium", "high"] | None = None, stream: bool = False, **kwargs, ): @@ -345,7 +330,7 @@ class LLM(BaseLLM): # Normalize self.stop to always be a List[str] if stop is None: - self.stop: List[str] = [] + self.stop: list[str] = [] elif isinstance(stop, str): self.stop = [stop] else: @@ -354,7 +339,8 @@ class LLM(BaseLLM): self.set_callbacks(callbacks or []) self.set_env_callbacks() - def _is_anthropic_model(self, model: str) -> bool: + @staticmethod + def _is_anthropic_model(model: str) -> bool: """Determine if the model is from Anthropic provider. Args: @@ -363,21 +349,18 @@ class LLM(BaseLLM): Returns: bool: True if the model is from Anthropic, False otherwise. """ - ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/") return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) def _prepare_completion_params( self, - messages: Union[str, List[Dict[str, str]]], - tools: Optional[List[dict]] = None, - ) -> Dict[str, Any]: + messages: str | list[dict[str, str]], + tools: list[dict] | None = None, + ) -> dict[str, Any]: """Prepare parameters for the completion call. Args: messages: Input messages for the LLM tools: Optional list of tool schemas - callbacks: Optional list of callback functions - available_functions: Optional dict of available functions Returns: Dict[str, Any]: Parameters for the completion call @@ -419,11 +402,11 @@ class LLM(BaseLLM): def _handle_streaming_response( self, - params: Dict[str, Any], - callbacks: Optional[List[Any]] = None, - available_functions: Optional[Dict[str, Any]] = None, - from_task: Optional[Any] = None, - from_agent: Optional[Any] = None, + params: dict[str, Any], + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, ) -> str: """Handle a streaming response from the LLM. @@ -445,9 +428,8 @@ class LLM(BaseLLM): last_chunk = None chunk_count = 0 usage_info = None - tool_calls = None - accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict( + accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict( AccumulatedToolArgs ) @@ -472,16 +454,16 @@ class LLM(BaseLLM): choices = chunk["choices"] elif hasattr(chunk, "choices"): # Check if choices is not a type but an actual attribute with value - if not isinstance(getattr(chunk, "choices"), type): - choices = getattr(chunk, "choices") + if not isinstance(chunk.choices, type): + choices = chunk.choices # Try to extract usage information if available if isinstance(chunk, dict) and "usage" in chunk: usage_info = chunk["usage"] elif hasattr(chunk, "usage"): # Check if usage is not a type but an actual attribute with value - if not isinstance(getattr(chunk, "usage"), type): - usage_info = getattr(chunk, "usage") + if not isinstance(chunk.usage, type): + usage_info = chunk.usage if choices and len(choices) > 0: choice = choices[0] @@ -491,7 +473,7 @@ class LLM(BaseLLM): if isinstance(choice, dict) and "delta" in choice: delta = choice["delta"] elif hasattr(choice, "delta"): - delta = getattr(choice, "delta") + delta = choice.delta # Extract content from delta if delta: @@ -501,7 +483,7 @@ class LLM(BaseLLM): chunk_content = delta["content"] # Handle object format elif hasattr(delta, "content"): - chunk_content = getattr(delta, "content") + chunk_content = delta.content # Handle case where content might be None or empty if chunk_content is None and isinstance(delta, dict): @@ -533,7 +515,9 @@ class LLM(BaseLLM): full_response += chunk_content # Emit the chunk event - assert hasattr(crewai_event_bus, "emit") + if not hasattr(crewai_event_bus, "emit"): + raise Exception("crewai_event_bus must have an `emit` method") + crewai_event_bus.emit( self, event=LLMStreamChunkEvent( @@ -572,8 +556,8 @@ class LLM(BaseLLM): if isinstance(last_chunk, dict) and "choices" in last_chunk: choices = last_chunk["choices"] elif hasattr(last_chunk, "choices"): - if not isinstance(getattr(last_chunk, "choices"), type): - choices = getattr(last_chunk, "choices") + if not isinstance(last_chunk.choices, type): + choices = last_chunk.choices if choices and len(choices) > 0: choice = choices[0] @@ -583,14 +567,14 @@ class LLM(BaseLLM): if isinstance(choice, dict) and "message" in choice: message = choice["message"] elif hasattr(choice, "message"): - message = getattr(choice, "message") + message = choice.message if message: content = None if isinstance(message, dict) and "content" in message: content = message["content"] elif hasattr(message, "content"): - content = getattr(message, "content") + content = message.content if content: full_response = content @@ -617,8 +601,8 @@ class LLM(BaseLLM): if isinstance(last_chunk, dict) and "choices" in last_chunk: choices = last_chunk["choices"] elif hasattr(last_chunk, "choices"): - if not isinstance(getattr(last_chunk, "choices"), type): - choices = getattr(last_chunk, "choices") + if not isinstance(last_chunk.choices, type): + choices = last_chunk.choices if choices and len(choices) > 0: choice = choices[0] @@ -627,13 +611,13 @@ class LLM(BaseLLM): if isinstance(choice, dict) and "message" in choice: message = choice["message"] elif hasattr(choice, "message"): - message = getattr(choice, "message") + message = choice.message if message: if isinstance(message, dict) and "tool_calls" in message: tool_calls = message["tool_calls"] elif hasattr(message, "tool_calls"): - tool_calls = getattr(message, "tool_calls") + tool_calls = message.tool_calls except Exception as e: logging.debug(f"Error checking for tool calls: {e}") # --- 8) If no tool calls or no available functions, return the text response directly @@ -673,11 +657,11 @@ class LLM(BaseLLM): # Catch context window errors from litellm and convert them to our own exception type. # This exception is handled by CrewAgentExecutor._invoke_loop() which can then # decide whether to summarize the content or abort based on the respect_context_window flag. - raise LLMContextLengthExceededException(str(e)) + raise LLMContextLengthExceededError(str(e)) from e except Exception as e: - logging.error(f"Error in streaming response: {str(e)}") + logging.error(f"Error in streaming response: {e!s}") if full_response.strip(): - logging.warning(f"Returning partial response despite error: {str(e)}") + logging.warning(f"Returning partial response despite error: {e!s}") self._handle_emit_call_events( response=full_response, call_type=LLMCallType.LLM_CALL, @@ -688,22 +672,25 @@ class LLM(BaseLLM): return full_response # Emit failed event and re-raise the exception - assert hasattr(crewai_event_bus, "emit") + if not hasattr(crewai_event_bus, "emit"): + raise AttributeError( + "crewai_event_bus must have an 'emit' method" + ) from e crewai_event_bus.emit( self, event=LLMCallFailedEvent( error=str(e), from_task=from_task, from_agent=from_agent ), ) - raise Exception(f"Failed to get streaming response: {str(e)}") + raise Exception(f"Failed to get streaming response: {e!s}") from e def _handle_streaming_tool_calls( self, - tool_calls: List[ChatCompletionDeltaToolCall], - accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs], - available_functions: Optional[Dict[str, Any]] = None, - from_task: Optional[Any] = None, - from_agent: Optional[Any] = None, + tool_calls: list[ChatCompletionDeltaToolCall], + accumulated_tool_args: defaultdict[int, AccumulatedToolArgs], + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, ) -> None | str: for tool_call in tool_calls: current_tool_accumulator = accumulated_tool_args[tool_call.index] @@ -715,7 +702,8 @@ class LLM(BaseLLM): current_tool_accumulator.function.arguments += ( tool_call.function.arguments ) - assert hasattr(crewai_event_bus, "emit") + if not hasattr(crewai_event_bus, "emit"): + raise AttributeError("crewai_event_bus must have an 'emit' method") crewai_event_bus.emit( self, event=LLMStreamChunkEvent( @@ -742,11 +730,11 @@ class LLM(BaseLLM): continue return None + @staticmethod def _handle_streaming_callbacks( - self, - callbacks: Optional[List[Any]], - usage_info: Optional[Dict[str, Any]], - last_chunk: Optional[Any], + callbacks: list[Any] | None, + usage_info: dict[str, Any] | None, + last_chunk: Any | None, ) -> None: """Handle callbacks with usage info for streaming responses. @@ -769,10 +757,8 @@ class LLM(BaseLLM): ): usage_info = last_chunk["usage"] elif hasattr(last_chunk, "usage"): - if not isinstance( - getattr(last_chunk, "usage"), type - ): - usage_info = getattr(last_chunk, "usage") + if not isinstance(last_chunk.usage, type): + usage_info = last_chunk.usage except Exception as e: logging.debug(f"Error extracting usage info: {e}") @@ -786,11 +772,11 @@ class LLM(BaseLLM): def _handle_non_streaming_response( self, - params: Dict[str, Any], - callbacks: Optional[List[Any]] = None, - available_functions: Optional[Dict[str, Any]] = None, - from_task: Optional[Any] = None, - from_agent: Optional[Any] = None, + params: dict[str, Any], + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, ) -> str | Any: """Handle a non-streaming response from the LLM. @@ -815,7 +801,7 @@ class LLM(BaseLLM): except ContextWindowExceededError as e: # Convert litellm's context window error to our own exception type # for consistent handling in the rest of the codebase - raise LLMContextLengthExceededException(str(e)) + raise LLMContextLengthExceededError(str(e)) from e # --- 2) Extract response message and content response_message = cast(Choices, cast(ModelResponse, response).choices)[ 0 @@ -847,7 +833,7 @@ class LLM(BaseLLM): ) return text_response # --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls - elif tool_calls and not available_functions and not text_response: + if tool_calls and not available_functions and not text_response: return tool_calls # --- 7) Handle tool calls if present @@ -868,19 +854,21 @@ class LLM(BaseLLM): def _handle_tool_call( self, - tool_calls: List[Any], - available_functions: Optional[Dict[str, Any]] = None, - from_task: Optional[Any] = None, - from_agent: Optional[Any] = None, - ) -> Optional[str]: + tool_calls: list[Any], + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + ) -> str | None: """Handle a tool call from the LLM. Args: tool_calls: List of tool calls from the LLM available_functions: Dict of available functions + from_task: Optional Task that invoked the LLM + from_agent: Optional Agent that invoked the LLM Returns: - Optional[str]: The result of the tool call, or None if no tool call was made + The result of the tool call, or None if no tool call was made """ # --- 1) Validate tool calls and available functions if not tool_calls or not available_functions: @@ -899,7 +887,8 @@ class LLM(BaseLLM): fn = available_functions[function_name] # --- 3.2) Execute function - assert hasattr(crewai_event_bus, "emit") + if not hasattr(crewai_event_bus, "emit"): + raise AttributeError("crewai_event_bus must have an 'emit' method") started_at = datetime.now() crewai_event_bus.emit( self, @@ -939,17 +928,20 @@ class LLM(BaseLLM): function_name, lambda: None ) # Ensure fn is always a callable logging.error(f"Error executing function '{function_name}': {e}") - assert hasattr(crewai_event_bus, "emit") + if not hasattr(crewai_event_bus, "emit"): + raise AttributeError( + "crewai_event_bus must have an 'emit' method" + ) from e crewai_event_bus.emit( self, - event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"), + event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"), ) crewai_event_bus.emit( self, event=ToolUsageErrorEvent( tool_name=function_name, tool_args=function_args, - error=f"Tool execution error: {str(e)}", + error=f"Tool execution error: {e!s}", from_task=from_task, from_agent=from_agent, ), @@ -958,13 +950,13 @@ class LLM(BaseLLM): def call( self, - messages: Union[str, List[Dict[str, str]]], - tools: Optional[List[dict]] = None, - callbacks: Optional[List[Any]] = None, - available_functions: Optional[Dict[str, Any]] = None, - from_task: Optional[Any] = None, - from_agent: Optional[Any] = None, - ) -> Union[str, Any]: + messages: str | list[dict[str, str]], + tools: list[dict] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Any | None = None, + from_agent: Any | None = None, + ) -> str | Any: """High-level LLM call method. Args: @@ -988,10 +980,11 @@ class LLM(BaseLLM): Raises: TypeError: If messages format is invalid ValueError: If response format is not supported - LLMContextLengthExceededException: If input exceeds model's context limit + LLMContextLengthExceededError: If input exceeds model's context limit """ # --- 1) Emit call started event - assert hasattr(crewai_event_bus, "emit") + if not hasattr(crewai_event_bus, "emit"): + raise AttributeError("crewai_event_bus must have an 'emit' method") crewai_event_bus.emit( self, event=LLMCallStartedEvent( @@ -1028,13 +1021,12 @@ class LLM(BaseLLM): return self._handle_streaming_response( params, callbacks, available_functions, from_task, from_agent ) - else: - return self._handle_non_streaming_response( - params, callbacks, available_functions, from_task, from_agent - ) + return self._handle_non_streaming_response( + params, callbacks, available_functions, from_task, from_agent + ) - except LLMContextLengthExceededException: - # Re-raise LLMContextLengthExceededException as it should be handled + except LLMContextLengthExceededError: + # Re-raise LLMContextLengthExceededError as it should be handled # by the CrewAgentExecutor._invoke_loop method, which can then decide # whether to summarize the content or abort based on the respect_context_window flag raise @@ -1065,7 +1057,10 @@ class LLM(BaseLLM): from_agent=from_agent, ) - assert hasattr(crewai_event_bus, "emit") + if not hasattr(crewai_event_bus, "emit"): + raise AttributeError( + "crewai_event_bus must have an 'emit' method" + ) from e crewai_event_bus.emit( self, event=LLMCallFailedEvent( @@ -1078,8 +1073,8 @@ class LLM(BaseLLM): self, response: Any, call_type: LLMCallType, - from_task: Optional[Any] = None, - from_agent: Optional[Any] = None, + from_task: Any | None = None, + from_agent: Any | None = None, messages: str | list[dict[str, Any]] | None = None, ): """Handle the events for the LLM call. @@ -1091,7 +1086,8 @@ class LLM(BaseLLM): from_agent: Optional agent object messages: Optional messages object """ - assert hasattr(crewai_event_bus, "emit") + if not hasattr(crewai_event_bus, "emit"): + raise AttributeError("crewai_event_bus must have an 'emit' method") crewai_event_bus.emit( self, event=LLMCallCompletedEvent( @@ -1105,8 +1101,8 @@ class LLM(BaseLLM): ) def _format_messages_for_provider( - self, messages: List[Dict[str, str]] - ) -> List[Dict[str, str]]: + self, messages: list[dict[str, str]] + ) -> list[dict[str, str]]: """Format messages according to provider requirements. Args: @@ -1147,7 +1143,7 @@ class LLM(BaseLLM): if "mistral" in self.model.lower(): # Check if the last message has a role of 'assistant' if messages and messages[-1]["role"] == "assistant": - return messages + [{"role": "user", "content": "Please continue."}] + return [*messages, {"role": "user", "content": "Please continue."}] return messages # TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917 @@ -1157,7 +1153,7 @@ class LLM(BaseLLM): and messages and messages[-1]["role"] == "assistant" ): - return messages + [{"role": "user", "content": ""}] + return [*messages, {"role": "user", "content": ""}] # Handle Anthropic models if not self.is_anthropic: @@ -1170,7 +1166,7 @@ class LLM(BaseLLM): return messages - def _get_custom_llm_provider(self) -> Optional[str]: + def _get_custom_llm_provider(self) -> str | None: """ Derives the custom_llm_provider from the model string. - For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter". @@ -1207,7 +1203,7 @@ class LLM(BaseLLM): self.model, custom_llm_provider=provider ) except Exception as e: - logging.error(f"Failed to check function calling support: {str(e)}") + logging.error(f"Failed to check function calling support: {e!s}") return False def supports_stop_words(self) -> bool: @@ -1215,7 +1211,7 @@ class LLM(BaseLLM): params = get_supported_openai_params(model=self.model) return params is not None and "stop" in params except Exception as e: - logging.error(f"Failed to get supported params: {str(e)}") + logging.error(f"Failed to get supported params: {e!s}") return False def get_context_window_size(self) -> int: @@ -1229,9 +1225,6 @@ class LLM(BaseLLM): if self.context_window_size != 0: return self.context_window_size - MIN_CONTEXT = 1024 - MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro - # Validate all context window sizes for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): if value < MIN_CONTEXT or value > MAX_CONTEXT: @@ -1247,7 +1240,8 @@ class LLM(BaseLLM): self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) return self.context_window_size - def set_callbacks(self, callbacks: List[Any]): + @staticmethod + def set_callbacks(callbacks: list[Any]): """ Attempt to keep a single set of callbacks in litellm by removing old duplicates and adding new ones. @@ -1264,9 +1258,9 @@ class LLM(BaseLLM): litellm.callbacks = callbacks - def set_env_callbacks(self): - """ - Sets the success and failure callbacks for the LiteLLM library from environment variables. + @staticmethod + def set_env_callbacks() -> None: + """Sets the success and failure callbacks for the LiteLLM library from environment variables. This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS` environment variables, which should contain comma-separated lists of callback names. @@ -1276,7 +1270,7 @@ class LLM(BaseLLM): If the environment variables are not set or are empty, the corresponding callback lists will be set to empty lists. - Example: + Examples: LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith" LITELLM_FAILURE_CALLBACKS="langfuse" @@ -1285,16 +1279,15 @@ class LLM(BaseLLM): """ with suppress_warnings(): success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "") - success_callbacks = [] + success_callbacks: list[str | Callable[..., Any] | CustomLogger] = [] if success_callbacks_str: success_callbacks = [ cb.strip() for cb in success_callbacks_str.split(",") if cb.strip() ] failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "") - failure_callbacks = [] if failure_callbacks_str: - failure_callbacks = [ + failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [ cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip() ] diff --git a/src/crewai/utilities/__init__.py b/src/crewai/utilities/__init__.py index 26d35a6cc..8ca82a1f4 100644 --- a/src/crewai/utilities/__init__.py +++ b/src/crewai/utilities/__init__.py @@ -1,26 +1,24 @@ -from .converter import Converter, ConverterError -from .file_handler import FileHandler -from .i18n import I18N -from .internal_instructor import InternalInstructor -from .logger import Logger -from .parser import YamlParser -from .printer import Printer -from .prompts import Prompts -from .rpm_controller import RPMController -from .exceptions.context_window_exceeding_exception import ( - LLMContextLengthExceededException, +from crewai.utilities.converter import Converter, ConverterError +from crewai.utilities.exceptions.context_window_exceeding_exception import ( + LLMContextLengthExceededError, ) +from crewai.utilities.file_handler import FileHandler +from crewai.utilities.i18n import I18N +from crewai.utilities.internal_instructor import InternalInstructor +from crewai.utilities.logger import Logger +from crewai.utilities.printer import Printer +from crewai.utilities.prompts import Prompts +from crewai.utilities.rpm_controller import RPMController __all__ = [ + "I18N", "Converter", "ConverterError", "FileHandler", - "I18N", "InternalInstructor", + "LLMContextLengthExceededError", "Logger", "Printer", "Prompts", "RPMController", - "YamlParser", - "LLMContextLengthExceededException", ] diff --git a/src/crewai/utilities/agent_utils.py b/src/crewai/utilities/agent_utils.py index 9b2d1df15..5bc2bcb7f 100644 --- a/src/crewai/utilities/agent_utils.py +++ b/src/crewai/utilities/agent_utils.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import json import re from collections.abc import Callable, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict from rich.console import Console @@ -19,18 +21,47 @@ from crewai.tools import BaseTool as CrewAITool from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.tool_types import ToolResult -from crewai.utilities import I18N, Printer from crewai.utilities.errors import AgentRepositoryError from crewai.utilities.exceptions.context_window_exceeding_exception import ( - LLMContextLengthExceededException, + LLMContextLengthExceededError, ) +from crewai.utilities.i18n import I18N +from crewai.utilities.printer import ColoredText, Printer +from crewai.utilities.types import LLMMessage + +if TYPE_CHECKING: + from crewai.agent import Agent + from crewai.task import Task + + +class SummaryContent(TypedDict): + """Structure for summary content entries. + + Attributes: + content: The summarized content. + """ + + content: str + console = Console() +_MULTIPLE_NEWLINES: Final[re.Pattern[str]] = re.compile(r"\n+") + def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]: - """Parse tools to be used for the task.""" - tools_list = [] + """Parse tools to be used for the task. + + Args: + tools: List of tools to parse. + + Returns: + List of structured tools. + + Raises: + ValueError: If a tool is not a CrewStructuredTool or BaseTool. + """ + tools_list: list[CrewStructuredTool] = [] for tool in tools: if isinstance(tool, CrewAITool): @@ -42,7 +73,14 @@ def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]: def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str: - """Get the names of the tools.""" + """Get the names of the tools. + + Args: + tools: List of tools to get names from. + + Returns: + Comma-separated string of tool names. + """ return ", ".join([t.name for t in tools]) @@ -51,16 +89,30 @@ def render_text_description_and_args( ) -> str: """Render the tool name, description, and args in plain text. - search: This tool is used for search, args: {"query": {"type": "string"}} - calculator: This tool is used for math, \ - args: {"expression": {"type": "string"}} + search: This tool is used for search, args: {"query": {"type": "string"}} + calculator: This tool is used for math, \ + args: {"expression": {"type": "string"}} + + Args: + tools: List of tools to render. + + Returns: + Plain text description of tools. """ tool_strings = [tool.description for tool in tools] return "\n".join(tool_strings) def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool: - """Check if the maximum number of iterations has been reached.""" + """Check if the maximum number of iterations has been reached. + + Args: + iterations: Current number of iterations. + max_iterations: Maximum allowed iterations. + + Returns: + True if maximum iterations reached, False otherwise. + """ return iterations >= max_iterations @@ -68,16 +120,19 @@ def handle_max_iterations_exceeded( formatted_answer: AgentAction | AgentFinish | None, printer: Printer, i18n: I18N, - messages: list[dict[str, str]], + messages: list[LLMMessage], llm: LLM | BaseLLM, - callbacks: list[Any], + callbacks: list[Callable[..., Any]], ) -> AgentAction | AgentFinish: - """ - Handles the case when the maximum number of iterations is exceeded. - Performs one more LLM call to get the final answer. + """Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer. - Parameters: + Args: formatted_answer: The last formatted answer from the agent. + printer: Printer instance for output. + i18n: I18N instance for internationalization. + messages: List of messages to send to the LLM. + llm: The LLM instance to call. + callbacks: List of callbacks for the LLM call. Returns: The final formatted answer after exceeding max iterations. @@ -98,7 +153,7 @@ def handle_max_iterations_exceeded( # Perform one more LLM call to get the final answer answer = llm.call( - messages, + messages, # type: ignore[arg-type] callbacks=callbacks, ) @@ -110,20 +165,38 @@ def handle_max_iterations_exceeded( raise ValueError("Invalid response from LLM call - None or empty.") # Return the formatted answer, regardless of its type - return format_answer(answer) + return format_answer(answer=answer) -def format_message_for_llm(prompt: str, role: str = "user") -> dict[str, str]: +def format_message_for_llm( + prompt: str, role: Literal["user", "assistant", "system"] = "user" +) -> LLMMessage: + """Format a message for the LLM. + + Args: + prompt: The message content. + role: The role of the message sender, either 'user' or 'assistant'. + + Returns: + A dictionary with 'role' and 'content' keys. + + """ prompt = prompt.rstrip() return {"role": role, "content": prompt} def format_answer(answer: str) -> AgentAction | AgentFinish: - """Format a response from the LLM into an AgentAction or AgentFinish.""" + """Format a response from the LLM into an AgentAction or AgentFinish. + + Args: + answer: The raw response from the LLM + + Returns: + Either an AgentAction or AgentFinish + """ try: return parse(answer) except Exception: - # If parsing fails, return a default AgentFinish return AgentFinish( thought="Failed to parse LLM response", output=answer, @@ -134,23 +207,43 @@ def format_answer(answer: str) -> AgentAction | AgentFinish: def enforce_rpm_limit( request_within_rpm_limit: Callable[[], bool] | None = None, ) -> None: - """Enforce the requests per minute (RPM) limit if applicable.""" + """Enforce the requests per minute (RPM) limit if applicable. + + Args: + request_within_rpm_limit: Function to enforce RPM limit. + """ if request_within_rpm_limit: request_within_rpm_limit() def get_llm_response( llm: LLM | BaseLLM, - messages: list[dict[str, str]], - callbacks: list[Any], + messages: list[LLMMessage], + callbacks: list[Callable[..., Any]], printer: Printer, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, ) -> str: - """Call the LLM and return the response, handling any invalid responses.""" + """Call the LLM and return the response, handling any invalid responses. + + Args: + llm: The LLM instance to call + messages: The messages to send to the LLM + callbacks: List of callbacks for the LLM call + printer: Printer instance for output + from_task: Optional task context for the LLM call + from_agent: Optional agent context for the LLM call + + Returns: + The response from the LLM as a string + + Raises: + Exception: If an error occurs. + ValueError: If the response is None or empty. + """ try: answer = llm.call( - messages, + messages, # type: ignore[arg-type] callbacks=callbacks, from_task=from_task, from_agent=from_agent, @@ -170,7 +263,15 @@ def get_llm_response( def process_llm_response( answer: str, use_stop_words: bool ) -> AgentAction | AgentFinish: - """Process the LLM response and format it into an AgentAction or AgentFinish.""" + """Process the LLM response and format it into an AgentAction or AgentFinish. + + Args: + answer: The raw response from the LLM + use_stop_words: Whether to use stop words in the LLM call + + Returns: + Either an AgentAction or AgentFinish + """ if not use_stop_words: try: # Preliminary parsing to check for errors. @@ -200,6 +301,9 @@ def handle_agent_action_core( Returns: Either an AgentAction or AgentFinish + + Notes: + - TODO: Remove messages parameter and its usage. """ if step_callback: step_callback(tool_result) @@ -220,7 +324,7 @@ def handle_agent_action_core( return formatted_answer -def handle_unknown_error(printer: Any, exception: Exception) -> None: +def handle_unknown_error(printer: Printer, exception: Exception) -> None: """Handle unknown errors by informing the user. Args: @@ -244,10 +348,10 @@ def handle_unknown_error(printer: Any, exception: Exception) -> None: def handle_output_parser_exception( e: OutputParserError, - messages: list[dict[str, str]], + messages: list[LLMMessage], iterations: int, log_error_after: int = 3, - printer: Any | None = None, + printer: Printer | None = None, ) -> AgentAction: """Handle OutputParserError by updating messages and formatted_answer. @@ -288,18 +392,18 @@ def is_context_length_exceeded(exception: Exception) -> bool: Returns: bool: True if the exception is due to context length exceeding """ - return LLMContextLengthExceededException(str(exception))._is_context_limit_error( + return LLMContextLengthExceededError(str(exception))._is_context_limit_error( str(exception) ) def handle_context_length( respect_context_window: bool, - printer: Any, - messages: list[dict[str, str]], - llm: Any, - callbacks: list[Any], - i18n: Any, + printer: Printer, + messages: list[LLMMessage], + llm: LLM | BaseLLM, + callbacks: list[Callable[..., Any]], + i18n: I18N, ) -> None: """Handle context length exceeded by either summarizing or raising an error. @@ -310,13 +414,16 @@ def handle_context_length( llm: LLM instance for summarization callbacks: List of callbacks for LLM i18n: I18N instance for messages + + Raises: + SystemExit: If context length is exceeded and user opts not to summarize """ if respect_context_window: printer.print( content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...", color="yellow", ) - summarize_messages(messages, llm, callbacks, i18n) + summarize_messages(messages=messages, llm=llm, callbacks=callbacks, i18n=i18n) else: printer.print( content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.", @@ -328,10 +435,10 @@ def handle_context_length( def summarize_messages( - messages: list[dict[str, str]], - llm: Any, - callbacks: list[Any], - i18n: Any, + messages: list[LLMMessage], + llm: LLM | BaseLLM, + callbacks: list[Callable[..., Any]], + i18n: I18N, ) -> None: """Summarize messages to fit within context window. @@ -349,7 +456,7 @@ def summarize_messages( for i in range(0, len(messages_string), cut_size) ] - summarized_contents = [] + summarized_contents: list[SummaryContent] = [] total_groups = len(messages_groups) for idx, group in enumerate(messages_groups, 1): @@ -357,15 +464,17 @@ def summarize_messages( content=f"Summarizing {idx}/{total_groups}...", color="yellow", ) + + messages = [ + format_message_for_llm( + i18n.slice("summarizer_system_message"), role="system" + ), + format_message_for_llm( + i18n.slice("summarize_instruction").format(group=group["content"]), + ), + ] summary = llm.call( - [ - format_message_for_llm( - i18n.slice("summarizer_system_message"), role="system" - ), - format_message_for_llm( - i18n.slice("summarize_instruction").format(group=group["content"]), - ), - ], + messages, # type: ignore[arg-type] callbacks=callbacks, ) summarized_contents.append({"content": str(summary)}) @@ -404,20 +513,29 @@ def show_agent_logs( if formatted_answer is None: # Start logs printer.print( - content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m" + content=[ + ColoredText("# Agent: ", "bold_purple"), + ColoredText(agent_role, "bold_green"), + ] ) if task_description: printer.print( - content=f"\033[95m## Task:\033[00m \033[92m{task_description}\033[00m" + content=[ + ColoredText("## Task: ", "purple"), + ColoredText(task_description, "green"), + ] ) else: # Execution logs printer.print( - content=f"\n\n\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m" + content=[ + ColoredText("\n\n# Agent: ", "bold_purple"), + ColoredText(agent_role, "bold_green"), + ] ) if isinstance(formatted_answer, AgentAction): - thought = re.sub(r"\n+", "\n", formatted_answer.thought) + thought = _MULTIPLE_NEWLINES.sub("\n", formatted_answer.thought) formatted_json = json.dumps( formatted_answer.tool_input, indent=2, @@ -425,24 +543,39 @@ def show_agent_logs( ) if thought and thought != "": printer.print( - content=f"\033[95m## Thought:\033[00m \033[92m{thought}\033[00m" + content=[ + ColoredText("## Thought: ", "purple"), + ColoredText(thought, "green"), + ] ) printer.print( - content=f"\033[95m## Using tool:\033[00m \033[92m{formatted_answer.tool}\033[00m" + content=[ + ColoredText("## Using tool: ", "purple"), + ColoredText(formatted_answer.tool, "green"), + ] ) printer.print( - content=f"\033[95m## Tool Input:\033[00m \033[92m\n{formatted_json}\033[00m" + content=[ + ColoredText("## Tool Input: ", "purple"), + ColoredText(f"\n{formatted_json}", "green"), + ] ) printer.print( - content=f"\033[95m## Tool Output:\033[00m \033[92m\n{formatted_answer.result}\033[00m" + content=[ + ColoredText("## Tool Output: ", "purple"), + ColoredText(f"\n{formatted_answer.result}", "green"), + ] ) elif isinstance(formatted_answer, AgentFinish): printer.print( - content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n" + content=[ + ColoredText("## Final Answer: ", "purple"), + ColoredText(f"\n{formatted_answer.output}\n\n", "green"), + ] ) -def _print_current_organization(): +def _print_current_organization() -> None: settings = Settings() if settings.org_uuid: console.print( @@ -457,6 +590,17 @@ def _print_current_organization(): def load_agent_from_repository(from_repository: str) -> dict[str, Any]: + """Load an agent from the repository. + + Args: + from_repository: The name of the agent to load. + + Returns: + A dictionary of attributes to use for the agent. + + Raises: + AgentRepositoryError: If the agent cannot be loaded. + """ attributes: dict[str, Any] = {} if from_repository: import importlib diff --git a/src/crewai/utilities/config.py b/src/crewai/utilities/config.py index 156a3e66b..95a542c5e 100644 --- a/src/crewai/utilities/config.py +++ b/src/crewai/utilities/config.py @@ -1,20 +1,19 @@ -from typing import Any, Dict, Type +from typing import Any from pydantic import BaseModel def process_config( - values: Dict[str, Any], model_class: Type[BaseModel] -) -> Dict[str, Any]: - """ - Process the config dictionary and update the values accordingly. + values: dict[str, Any], model_class: type[BaseModel] +) -> dict[str, Any]: + """Process the config dictionary and update the values accordingly. Args: - values (Dict[str, Any]): The dictionary of values to update. - model_class (Type[BaseModel]): The Pydantic model class to reference for field validation. + values: The dictionary of values to update. + model_class: The Pydantic model class to reference for field validation. Returns: - Dict[str, Any]: The updated values dictionary. + The updated values dictionary. """ config = values.get("config", {}) if not config: diff --git a/src/crewai/utilities/constants.py b/src/crewai/utilities/constants.py index 184ea6e9b..c1d808a32 100644 --- a/src/crewai/utilities/constants.py +++ b/src/crewai/utilities/constants.py @@ -1,19 +1,32 @@ -TRAINING_DATA_FILE = "training_data.pkl" -TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl" -DEFAULT_SCORE_THRESHOLD = 0.35 -KNOWLEDGE_DIRECTORY = "knowledge" -MAX_LLM_RETRY = 3 -MAX_FILE_NAME_LENGTH = 255 -EMITTER_COLOR = "bold_blue" +from typing import Annotated, Final + +from crewai.utilities.printer import PrinterColor + +TRAINING_DATA_FILE: Final[str] = "training_data.pkl" +TRAINED_AGENTS_DATA_FILE: Final[str] = "trained_agents_data.pkl" +KNOWLEDGE_DIRECTORY: Final[str] = "knowledge" +MAX_FILE_NAME_LENGTH: Final[int] = 255 +EMITTER_COLOR: Final[PrinterColor] = "bold_blue" class _NotSpecified: - def __repr__(self): + """Sentinel class to detect when no value has been explicitly provided. + + Notes: + - TODO: Consider moving this class and NOT_SPECIFIED to types.py + as they are more type-related constructs than business constants. + """ + + def __repr__(self) -> str: return "NOT_SPECIFIED" -# Sentinel value used to detect when no value has been explicitly provided. -# Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` allows -# us to distinguish between "not passed at all" and "explicitly passed None" or "[]". -NOT_SPECIFIED = _NotSpecified() -CREWAI_BASE_URL = "https://app.crewai.com" +NOT_SPECIFIED: Final[ + Annotated[ + _NotSpecified, + "Sentinel value used to detect when no value has been explicitly provided. " + "Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` " + "allows us to distinguish between 'not passed at all' and 'explicitly passed None' or '[]'.", + ] +] = _NotSpecified() +CREWAI_BASE_URL: Final[str] = "https://app.crewai.com" diff --git a/src/crewai/utilities/converter.py b/src/crewai/utilities/converter.py index a6144868e..8e8aa94af 100644 --- a/src/crewai/utilities/converter.py +++ b/src/crewai/utilities/converter.py @@ -1,18 +1,35 @@ +from __future__ import annotations + import json import re -from typing import Any, Optional, Type, Union, get_args, get_origin +from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin from pydantic import BaseModel, ValidationError +from typing_extensions import Unpack from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter +from crewai.utilities.internal_instructor import InternalInstructor from crewai.utilities.printer import Printer from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser +if TYPE_CHECKING: + from crewai.agent import Agent + from crewai.llm import LLM + from crewai.llms.base_llm import BaseLLM + +_JSON_PATTERN: Final[re.Pattern[str]] = re.compile(r"({.*})", re.DOTALL) + class ConverterError(Exception): """Error raised when Converter fails to parse the input.""" def __init__(self, message: str, *args: object) -> None: + """Initialize the ConverterError with a message. + + Args: + message: The error message. + *args: Additional arguments for the base Exception class. + """ super().__init__(message, *args) self.message = message @@ -20,8 +37,18 @@ class ConverterError(Exception): class Converter(OutputConverter): """Class that converts text into either pydantic or json.""" - def to_pydantic(self, current_attempt=1) -> BaseModel: - """Convert text to pydantic.""" + def to_pydantic(self, current_attempt: int = 1) -> BaseModel: + """Convert text to pydantic. + + Args: + current_attempt: The current attempt number for conversion retries. + + Returns: + A Pydantic BaseModel instance. + + Raises: + ConverterError: If conversion fails after maximum attempts. + """ try: if self.llm.supports_function_calling(): result = self._create_instructor().to_pydantic() @@ -37,104 +64,124 @@ class Converter(OutputConverter): result = self.model.model_validate_json(response) except ValidationError: # If direct validation fails, attempt to extract valid JSON - result = handle_partial_json(response, self.model, False, None) + result = handle_partial_json( + result=response, + model=self.model, + is_json_output=False, + agent=None, + ) # Ensure result is a BaseModel instance if not isinstance(result, BaseModel): if isinstance(result, dict): - result = self.model.parse_obj(result) + result = self.model.model_validate(result) elif isinstance(result, str): try: parsed = json.loads(result) - result = self.model.parse_obj(parsed) + result = self.model.model_validate(parsed) except Exception as parse_err: raise ConverterError( f"Failed to convert partial JSON result into Pydantic: {parse_err}" - ) + ) from parse_err else: raise ConverterError( "handle_partial_json returned an unexpected type." - ) + ) from None return result except ValidationError as e: if current_attempt < self.max_attempts: return self.to_pydantic(current_attempt + 1) raise ConverterError( f"Failed to convert text into a Pydantic model due to validation error: {e}" - ) + ) from e except Exception as e: if current_attempt < self.max_attempts: return self.to_pydantic(current_attempt + 1) raise ConverterError( f"Failed to convert text into a Pydantic model due to error: {e}" - ) + ) from e - def to_json(self, current_attempt=1): - """Convert text to json.""" + def to_json(self, current_attempt: int = 1) -> str | ConverterError | Any: # type: ignore[override] + """Convert text to json. + + Args: + current_attempt: The current attempt number for conversion retries. + + Returns: + A JSON string or ConverterError if conversion fails. + + Raises: + ConverterError: If conversion fails after maximum attempts. + + """ try: if self.llm.supports_function_calling(): return self._create_instructor().to_json() - else: - return json.dumps( - self.llm.call( - [ - {"role": "system", "content": self.instructions}, - {"role": "user", "content": self.text}, - ] - ) + return json.dumps( + self.llm.call( + [ + {"role": "system", "content": self.instructions}, + {"role": "user", "content": self.text}, + ] ) + ) except Exception as e: if current_attempt < self.max_attempts: return self.to_json(current_attempt + 1) return ConverterError(f"Failed to convert text into JSON, error: {e}.") - def _create_instructor(self): + def _create_instructor(self) -> InternalInstructor: """Create an instructor.""" - from crewai.utilities import InternalInstructor - inst = InternalInstructor( + return InternalInstructor( llm=self.llm, model=self.model, content=self.text, ) - return inst - - def _convert_with_instructions(self): - """Create a chain.""" - from crewai.utilities.crew_pydantic_output_parser import ( - CrewPydanticOutputParser, - ) - - parser = CrewPydanticOutputParser(pydantic_object=self.model) - result = self.llm.call( - [ - {"role": "system", "content": self.instructions}, - {"role": "user", "content": self.text}, - ] - ) - return parser.parse_result(result) def convert_to_model( result: str, - output_pydantic: Optional[Type[BaseModel]], - output_json: Optional[Type[BaseModel]], - agent: Any, - converter_cls: Optional[Type[Converter]] = None, -) -> Union[dict, BaseModel, str]: + output_pydantic: type[BaseModel] | None, + output_json: type[BaseModel] | None, + agent: Agent | None = None, + converter_cls: type[Converter] | None = None, +) -> dict[str, Any] | BaseModel | str: + """Convert a result string to a Pydantic model or JSON. + + Args: + result: The result string to convert. + output_pydantic: The Pydantic model class to convert to. + output_json: The Pydantic model class to convert to JSON. + agent: The agent instance. + converter_cls: The converter class to use. + + Returns: + The converted result as a dict, BaseModel, or original string. + """ model = output_pydantic or output_json if model is None: return result try: escaped_result = json.dumps(json.loads(result, strict=False)) - return validate_model(escaped_result, model, bool(output_json)) + return validate_model( + result=escaped_result, model=model, is_json_output=bool(output_json) + ) except json.JSONDecodeError: return handle_partial_json( - result, model, bool(output_json), agent, converter_cls + result=result, + model=model, + is_json_output=bool(output_json), + agent=agent, + converter_cls=converter_cls, ) except ValidationError: return handle_partial_json( - result, model, bool(output_json), agent, converter_cls + result=result, + model=model, + is_json_output=bool(output_json), + agent=agent, + converter_cls=converter_cls, ) except Exception as e: @@ -146,8 +193,18 @@ def convert_to_model( def validate_model( - result: str, model: Type[BaseModel], is_json_output: bool -) -> Union[dict, BaseModel]: + result: str, model: type[BaseModel], is_json_output: bool +) -> dict[str, Any] | BaseModel: + """Validate and convert a JSON string to a Pydantic model or dict. + + Args: + result: The JSON string to validate and convert. + model: The Pydantic model class to convert to. + is_json_output: Whether to return a dict (True) or Pydantic model (False). + + Returns: + The converted result as a dict or BaseModel. + """ exported_result = model.model_validate_json(result) if is_json_output: return exported_result.model_dump() @@ -156,15 +213,27 @@ def validate_model( def handle_partial_json( result: str, - model: Type[BaseModel], + model: type[BaseModel], is_json_output: bool, - agent: Any, - converter_cls: Optional[Type[Converter]] = None, -) -> Union[dict, BaseModel, str]: - match = re.search(r"({.*})", result, re.DOTALL) + agent: Agent | None, + converter_cls: type[Converter] | None = None, +) -> dict[str, Any] | BaseModel | str: + """Handle partial JSON in a result string and convert to Pydantic model or dict. + + Args: + result: The result string to process. + model: The Pydantic model class to convert to. + is_json_output: Whether to return a dict (True) or Pydantic model (False). + agent: The agent instance. + converter_cls: The converter class to use. + + Returns: + The converted result as a dict, BaseModel, or original string. + """ + match = _JSON_PATTERN.search(result) if match: try: - exported_result = model.model_validate_json(match.group(0)) + exported_result = model.model_validate_json(match.group()) if is_json_output: return exported_result.model_dump() return exported_result @@ -179,19 +248,43 @@ def handle_partial_json( ) return convert_with_instructions( - result, model, is_json_output, agent, converter_cls + result=result, + model=model, + is_json_output=is_json_output, + agent=agent, + converter_cls=converter_cls, ) def convert_with_instructions( result: str, - model: Type[BaseModel], + model: type[BaseModel], is_json_output: bool, - agent: Any, - converter_cls: Optional[Type[Converter]] = None, -) -> Union[dict, BaseModel, str]: + agent: Agent | None, + converter_cls: type[Converter] | None = None, +) -> dict | BaseModel | str: + """Convert a result string to a Pydantic model or JSON using instructions. + + Args: + result: The result string to convert. + model: The Pydantic model class to convert to. + is_json_output: Whether to return a dict (True) or Pydantic model (False). + agent: The agent instance. + converter_cls: The converter class to use. + + Returns: + The converted result as a dict, BaseModel, or original string. + + Raises: + TypeError: If neither agent nor converter_cls is provided. + + Notes: + - TODO: Fix llm typing issues, return llm should not be able to be str or None. + """ + if agent is None: + raise TypeError("Agent must be provided if converter_cls is not specified.") llm = agent.function_calling_llm or agent.llm - instructions = get_conversion_instructions(model, llm) + instructions = get_conversion_instructions(model=model, llm=llm) converter = create_converter( agent=agent, converter_cls=converter_cls, @@ -214,9 +307,25 @@ def convert_with_instructions( return exported_result -def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str: +def get_conversion_instructions( + model: type[BaseModel], llm: BaseLLM | LLM | str +) -> str: + """Generate conversion instructions based on the model and LLM capabilities. + + Args: + model: A Pydantic model class. + llm: The language model instance. + + Returns: + + """ instructions = "Please convert the following text into valid JSON." - if llm and not isinstance(llm, str) and llm.supports_function_calling(): + if ( + llm + and not isinstance(llm, str) + and hasattr(llm, "supports_function_calling") + and llm.supports_function_calling() + ): model_schema = PydanticSchemaParser(model=model).get_schema() instructions += ( f"\n\nOutput ONLY the valid JSON and nothing else.\n\n" @@ -231,12 +340,45 @@ def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str: return instructions +class CreateConverterKwargs(TypedDict, total=False): + """Keyword arguments for creating a converter. + + Attributes: + llm: The language model instance. + text: The text to convert. + model: The Pydantic model class. + instructions: The conversion instructions. + """ + + llm: BaseLLM | LLM | str + text: str + model: type[BaseModel] + instructions: str + + def create_converter( - agent: Optional[Any] = None, - converter_cls: Optional[Type[Converter]] = None, - *args, - **kwargs, + agent: Agent | None = None, + converter_cls: type[Converter] | None = None, + *args: Any, + **kwargs: Unpack[CreateConverterKwargs], ) -> Converter: + """Create a converter instance based on the agent or provided class. + + Args: + agent: The agent instance. + converter_cls: The converter class to instantiate. + *args: The positional arguments to pass to the converter. + **kwargs: The keyword arguments to pass to the converter. + + Returns: + An instance of the specified converter class. + + Raises: + ValueError: If neither agent nor converter_cls is provided. + AttributeError: If the agent does not have a 'get_output_converter' method. + Exception: If no converter instance is created. + + """ if agent and not converter_cls: if hasattr(agent, "get_output_converter"): converter = agent.get_output_converter(*args, **kwargs) @@ -253,17 +395,30 @@ def create_converter( return converter -def generate_model_description(model: Type[BaseModel]) -> str: - """ - Generate a string description of a Pydantic model's fields and their types. +def generate_model_description(model: type[BaseModel]) -> str: + """Generate a string description of a Pydantic model's fields and their types. This function takes a Pydantic model class and returns a string that describes the model's fields and their respective types. The description includes handling of complex types such as `Optional`, `List`, and `Dict`, as well as nested Pydantic models. + + Args: + model: A Pydantic model class. + + Returns: + A string representation of the model's fields and types. """ - def describe_field(field_type): + def describe_field(field_type: Any) -> str: + """Recursively describe a field's type. + + Args: + field_type: The type of the field to describe. + + Returns: + A string representation of the field's type. + """ origin = get_origin(field_type) args = get_args(field_type) @@ -272,20 +427,18 @@ def generate_model_description(model: Type[BaseModel]) -> str: non_none_args = [arg for arg in args if arg is not type(None)] if len(non_none_args) == 1: return f"Optional[{describe_field(non_none_args[0])}]" - else: - return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]" - elif origin is list: + return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]" + if origin is list: return f"List[{describe_field(args[0])}]" - elif origin is dict: + if origin is dict: key_type = describe_field(args[0]) value_type = describe_field(args[1]) return f"Dict[{key_type}, {value_type}]" - elif isinstance(field_type, type) and issubclass(field_type, BaseModel): + if isinstance(field_type, type) and issubclass(field_type, BaseModel): return generate_model_description(field_type) - elif hasattr(field_type, "__name__"): + if hasattr(field_type, "__name__"): return field_type.__name__ - else: - return str(field_type) + return str(field_type) fields = model.model_fields field_descriptions = [ diff --git a/src/crewai/utilities/crew/__init__.py b/src/crewai/utilities/crew/__init__.py index db74f269b..b51db2d0f 100644 --- a/src/crewai/utilities/crew/__init__.py +++ b/src/crewai/utilities/crew/__init__.py @@ -1 +1 @@ -"""Crew-specific utilities.""" \ No newline at end of file +"""Crew-specific utilities.""" diff --git a/src/crewai/utilities/crew/crew_context.py b/src/crewai/utilities/crew/crew_context.py index 3f287b566..bb97194d4 100644 --- a/src/crewai/utilities/crew/crew_context.py +++ b/src/crewai/utilities/crew/crew_context.py @@ -1,16 +1,16 @@ """Context management utilities for tracking crew and task execution context using OpenTelemetry baggage.""" -from typing import Optional +from typing import cast from opentelemetry import baggage from crewai.utilities.crew.models import CrewContext -def get_crew_context() -> Optional[CrewContext]: +def get_crew_context() -> CrewContext | None: """Get the current crew context from OpenTelemetry baggage. Returns: CrewContext instance containing crew context information, or None if no context is set """ - return baggage.get_baggage("crew_context") + return cast(CrewContext | None, baggage.get_baggage("crew_context")) diff --git a/src/crewai/utilities/crew/models.py b/src/crewai/utilities/crew/models.py index 78a1f33a6..d18faee1f 100644 --- a/src/crewai/utilities/crew/models.py +++ b/src/crewai/utilities/crew/models.py @@ -1,16 +1,17 @@ """Models for crew-related data structures.""" -from typing import Optional - from pydantic import BaseModel, Field class CrewContext(BaseModel): - """Model representing crew context information.""" + """Model representing crew context information. - id: Optional[str] = Field( - default=None, description="Unique identifier for the crew" - ) - key: Optional[str] = Field( + Attributes: + id: Unique identifier for the crew. + key: Optional crew key/name for identification. + """ + + id: str | None = Field(default=None, description="Unique identifier for the crew") + key: str | None = Field( default=None, description="Optional crew key/name for identification" ) diff --git a/src/crewai/utilities/crew_json_encoder.py b/src/crewai/utilities/crew_json_encoder.py index 6e667431d..745340a7b 100644 --- a/src/crewai/utilities/crew_json_encoder.py +++ b/src/crewai/utilities/crew_json_encoder.py @@ -4,6 +4,7 @@ import json from datetime import date, datetime from decimal import Decimal from enum import Enum +from typing import Any from uuid import UUID from pydantic import BaseModel @@ -11,18 +12,28 @@ from pydantic import BaseModel class CrewJSONEncoder(json.JSONEncoder): """Custom JSON encoder for CrewAI objects and special types.""" - def default(self, obj): + + def default(self, obj: Any) -> Any: + """Custom serialization for CrewAI specific types. + + Args: + obj: The object to serialize. + + Returns: + A JSON-serializable representation of the object. + """ if isinstance(obj, BaseModel): return self._handle_pydantic_model(obj) - elif isinstance(obj, UUID) or isinstance(obj, Decimal) or isinstance(obj, Enum): + if isinstance(obj, (UUID, Decimal, Enum)): return str(obj) - elif isinstance(obj, datetime) or isinstance(obj, date): + if isinstance(obj, (datetime, date)): return obj.isoformat() return super().default(obj) - def _handle_pydantic_model(self, obj): + @staticmethod + def _handle_pydantic_model(obj: BaseModel) -> str | Any: try: data = obj.model_dump() # Remove circular references diff --git a/src/crewai/utilities/crew_pydantic_output_parser.py b/src/crewai/utilities/crew_pydantic_output_parser.py deleted file mode 100644 index c40bf679b..000000000 --- a/src/crewai/utilities/crew_pydantic_output_parser.py +++ /dev/null @@ -1,48 +0,0 @@ -import json -from typing import Any - -import regex -from pydantic import BaseModel, ValidationError - -from crewai.agents.parser import OutputParserError - -"""Parser for converting text outputs into Pydantic models.""" - - -class CrewPydanticOutputParser: - """Parses text outputs into specified Pydantic models.""" - - pydantic_object: type[BaseModel] - - def parse_result(self, result: str) -> Any: - result = self._transform_in_valid_json(result) - - # Treating edge case of function calling llm returning the name instead of tool_name - json_object = json.loads(result) - if "tool_name" not in json_object: - json_object["tool_name"] = json_object.get("name", "") - result = json.dumps(json_object) - - try: - return self.pydantic_object.model_validate(json_object) - except ValidationError as e: - name = self.pydantic_object.__name__ - msg = f"Failed to parse {name} from completion {json_object}. Got: {e}" - raise OutputParserError(error=msg) from e - - def _transform_in_valid_json(self, text) -> str: - text = text.replace("```", "").replace("json", "") - json_pattern = r"\{(?:[^{}]|(?R))*\}" - matches = regex.finditer(json_pattern, text) - - for match in matches: - try: - # Attempt to parse the matched string as JSON - json_obj = json.loads(match.group()) - # Return the first successfully parsed JSON object - json_obj = json.dumps(json_obj) - return str(json_obj) - except json.JSONDecodeError: # noqa: PERF203 - # If parsing fails, skip to the next match - continue - return text diff --git a/src/crewai/utilities/errors.py b/src/crewai/utilities/errors.py index e9aa40872..6bf6ac1fd 100644 --- a/src/crewai/utilities/errors.py +++ b/src/crewai/utilities/errors.py @@ -8,7 +8,11 @@ from typing import Final class DatabaseOperationError(Exception): - """Base exception class for database operation errors.""" + """Base exception class for database operation errors. + + Attributes: + original_error: The original exception that caused this error, if any. + """ def __init__(self, message: str, original_error: Exception | None = None) -> None: """Initialize the database operation error. diff --git a/src/crewai/utilities/evaluators/crew_evaluator_handler.py b/src/crewai/utilities/evaluators/crew_evaluator_handler.py index c10946494..47c2bb100 100644 --- a/src/crewai/utilities/evaluators/crew_evaluator_handler.py +++ b/src/crewai/utilities/evaluators/crew_evaluator_handler.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from collections import defaultdict +from typing import TYPE_CHECKING from pydantic import BaseModel, Field, InstanceOf from rich.box import HEAVY_EDGE @@ -6,11 +9,14 @@ from rich.console import Console from rich.table import Table from crewai.agent import Agent -from crewai.llm import BaseLLM -from crewai.task import Task -from crewai.tasks.task_output import TaskOutput from crewai.events.event_bus import crewai_event_bus from crewai.events.types.crew_events import CrewTestResultEvent +from crewai.llms.base_llm import BaseLLM +from crewai.task import Task +from crewai.tasks.task_output import TaskOutput + +if TYPE_CHECKING: + from crewai.crew import Crew class TaskEvaluationPydanticOutput(BaseModel): @@ -20,23 +26,21 @@ class TaskEvaluationPydanticOutput(BaseModel): class CrewEvaluator: - """ - A class to evaluate the performance of the agents in the crew based on the tasks they have performed. + """A class to evaluate the performance of the agents in the crew based on the tasks they have performed. Attributes: - crew (Crew): The crew of agents to evaluate. - eval_llm (BaseLLM): Language model instance to use for evaluations - tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task. - iteration (int): The current iteration of the evaluation. + crew: The crew of agents to evaluate. + tasks_scores: A dictionary to store the scores of the agents for each task. + run_execution_times: A dictionary to store execution times for each run. + iteration: The current iteration of the evaluation. """ - tasks_scores: defaultdict = defaultdict(list) - run_execution_times: defaultdict = defaultdict(list) - iteration: int = 0 - - def __init__(self, crew, eval_llm: InstanceOf[BaseLLM]): + def __init__(self, crew: Crew, eval_llm: InstanceOf[BaseLLM]) -> None: self.crew = crew self.llm = eval_llm + self.tasks_scores: defaultdict[int, list[float]] = defaultdict(list) + self.run_execution_times: defaultdict[int, list[float]] = defaultdict(list) + self.iteration: int = 0 self._setup_for_evaluating() def _setup_for_evaluating(self) -> None: @@ -44,7 +48,7 @@ class CrewEvaluator: for task in self.crew.tasks: task.callback = self.evaluate - def _evaluator_agent(self): + def _evaluator_agent(self) -> Agent: return Agent( role="Task Execution Evaluator", goal=( @@ -55,8 +59,9 @@ class CrewEvaluator: llm=self.llm, ) + @staticmethod def _evaluation_task( - self, evaluator_agent: Agent, task_to_evaluate: Task, task_output: str + evaluator_agent: Agent, task_to_evaluate: Task, task_output: str ) -> Task: return Task( description=( @@ -73,6 +78,11 @@ class CrewEvaluator: ) def set_iteration(self, iteration: int) -> None: + """Sets the current iteration of the evaluation. + + Args: + iteration: The current iteration number. + """ self.iteration = iteration def print_crew_evaluation_result(self) -> None: @@ -97,7 +107,8 @@ class CrewEvaluator: └────────────────────┴───────┴───────┴───────┴────────────┴──────────────────────────────┘ """ task_averages = [ - sum(scores) / len(scores) for scores in zip(*self.tasks_scores.values()) + sum(scores) / len(scores) + for scores in zip(*self.tasks_scores.values(), strict=False) ] crew_average = sum(task_averages) / len(task_averages) @@ -158,8 +169,12 @@ class CrewEvaluator: console.print("\n") console.print(table) - def evaluate(self, task_output: TaskOutput): - """Evaluates the performance of the agents in the crew based on the tasks they have performed.""" + def evaluate(self, task_output: TaskOutput) -> None: + """Evaluates the performance of the agents in the crew based on the tasks they have performed. + + Args: + task_output: The output of the task to evaluate. + """ current_task = None for task in self.crew.tasks: if task.description == task_output.description: @@ -179,19 +194,24 @@ class CrewEvaluator: evaluation_result = evaluation_task.execute_sync() if isinstance(evaluation_result.pydantic, TaskEvaluationPydanticOutput): + quality_score = evaluation_result.pydantic.quality + if quality_score is None: + raise ValueError("Evaluation quality score cannot be None") + crewai_event_bus.emit( self.crew, CrewTestResultEvent( - quality=evaluation_result.pydantic.quality, + quality=quality_score, execution_duration=current_task.execution_duration, model=self.llm.model, crew_name=self.crew.name, crew=self.crew, ), ) - self.tasks_scores[self.iteration].append(evaluation_result.pydantic.quality) - self.run_execution_times[self.iteration].append( - current_task.execution_duration - ) + self.tasks_scores[self.iteration].append(quality_score) + if current_task.execution_duration is not None: + self.run_execution_times[self.iteration].append( + current_task.execution_duration + ) else: raise ValueError("Evaluation result is not in the expected format") diff --git a/src/crewai/utilities/evaluators/task_evaluator.py b/src/crewai/utilities/evaluators/task_evaluator.py index 1b1d05b4c..ad1b993cf 100644 --- a/src/crewai/utilities/evaluators/task_evaluator.py +++ b/src/crewai/utilities/evaluators/task_evaluator.py @@ -1,35 +1,42 @@ -from typing import List +from __future__ import annotations + +from typing import TYPE_CHECKING, cast from pydantic import BaseModel, Field -from crewai.utilities import Converter -from crewai.events.types.task_events import TaskEvaluationEvent from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.task_events import TaskEvaluationEvent +from crewai.llm import LLM +from crewai.utilities.converter import Converter from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser from crewai.utilities.training_converter import TrainingConverter +if TYPE_CHECKING: + from crewai.agent import Agent + from crewai.task import Task + class Entity(BaseModel): name: str = Field(description="The name of the entity.") type: str = Field(description="The type of the entity.") description: str = Field(description="Description of the entity.") - relationships: List[str] = Field(description="Relationships of the entity.") + relationships: list[str] = Field(description="Relationships of the entity.") class TaskEvaluation(BaseModel): - suggestions: List[str] = Field( + suggestions: list[str] = Field( description="Suggestions to improve future similar tasks." ) quality: float = Field( description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task." ) - entities: List[Entity] = Field( + entities: list[Entity] = Field( description="Entities extracted from the task output." ) class TrainingTaskEvaluation(BaseModel): - suggestions: List[str] = Field( + suggestions: list[str] = Field( description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions." ) quality: float = Field( @@ -41,11 +48,35 @@ class TrainingTaskEvaluation(BaseModel): class TaskEvaluator: - def __init__(self, original_agent): - self.llm = original_agent.llm + """A class to evaluate the performance of an agent based on the tasks they have performed. + + Attributes: + llm: The LLM to use for evaluation. + original_agent: The agent to evaluate. + """ + + def __init__(self, original_agent: Agent) -> None: + """Initializes the TaskEvaluator with the given LLM and agent. + + Args: + original_agent: The agent to evaluate. + """ + self.llm = cast(LLM, original_agent.llm) self.original_agent = original_agent - def evaluate(self, task, output) -> TaskEvaluation: + def evaluate(self, task: Task, output: str) -> TaskEvaluation: + """ + + Args: + task: The task to be evaluated. + output: The output of the task. + + Returns: + TaskEvaluation: The evaluation of the task. + + Notes: + - Investigate the Converter.to_pydantic signature, returns BaseModel strictly? + """ crewai_event_bus.emit( self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task) ) @@ -73,7 +104,7 @@ class TaskEvaluator: instructions=instructions, ) - return converter.to_pydantic() + return cast(TaskEvaluation, converter.to_pydantic()) def evaluate_training_data( self, training_data: dict, agent_id: str @@ -81,9 +112,12 @@ class TaskEvaluator: """ Evaluate the training data based on the llm output, human feedback, and improved output. - Parameters: - - training_data (dict): The training data to be evaluated. - - agent_id (str): The ID of the agent. + Args: + - training_data: The training data to be evaluated. + - agent_id: The ID of the agent. + + Notes: + - Investigate the Converter.to_pydantic signature, returns BaseModel strictly? """ crewai_event_bus.emit( self, TaskEvaluationEvent(evaluation_type="training_data_evaluation") @@ -142,5 +176,4 @@ class TaskEvaluator: instructions=instructions, ) - pydantic_result = converter.to_pydantic() - return pydantic_result + return cast(TrainingTaskEvaluation, converter.to_pydantic()) diff --git a/src/crewai/utilities/events/__init__.py b/src/crewai/utilities/events/__init__.py index 24184086a..2b484d125 100644 --- a/src/crewai/utilities/events/__init__.py +++ b/src/crewai/utilities/events/__init__.py @@ -3,9 +3,10 @@ import warnings from abc import ABC from collections.abc import Callable -from typing import Any, Type, TypeVar +from typing import Any, TypeVar from typing_extensions import deprecated + import crewai.events as new_events from crewai.events.base_events import BaseEvent from crewai.events.event_types import EventTypes @@ -17,14 +18,14 @@ warnings.warn( "Importing from 'crewai.utilities.events' is deprecated and will be removed in v1.0.0. " "Please use 'crewai.events' instead.", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) @deprecated("Use 'from crewai.events import BaseEventListener' instead") class BaseEventListener(new_events.BaseEventListener, ABC): """Deprecated: Use crewai.events.BaseEventListener instead.""" - pass + @deprecated("Use 'from crewai.events import crewai_event_bus' instead") class crewai_event_bus: # noqa: N801 @@ -32,7 +33,7 @@ class crewai_event_bus: # noqa: N801 @classmethod def on( - cls, event_type: Type[EventT] + cls, event_type: type[EventT] ) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]: """Delegate to the actual event bus instance.""" return new_events.crewai_event_bus.on(event_type) @@ -44,7 +45,7 @@ class crewai_event_bus: # noqa: N801 @classmethod def register_handler( - cls, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None] + cls, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None] ) -> None: """Delegate to the actual event bus instance.""" return new_events.crewai_event_bus.register_handler(event_type, handler) @@ -54,87 +55,88 @@ class crewai_event_bus: # noqa: N801 """Delegate to the actual event bus instance.""" return new_events.crewai_event_bus.scoped_handlers() + @deprecated("Use 'from crewai.events import CrewKickoffStartedEvent' instead") class CrewKickoffStartedEvent(new_events.CrewKickoffStartedEvent): """Deprecated: Use crewai.events.CrewKickoffStartedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import CrewKickoffCompletedEvent' instead") class CrewKickoffCompletedEvent(new_events.CrewKickoffCompletedEvent): """Deprecated: Use crewai.events.CrewKickoffCompletedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import AgentExecutionCompletedEvent' instead") class AgentExecutionCompletedEvent(new_events.AgentExecutionCompletedEvent): """Deprecated: Use crewai.events.AgentExecutionCompletedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import MemoryQueryCompletedEvent' instead") class MemoryQueryCompletedEvent(new_events.MemoryQueryCompletedEvent): """Deprecated: Use crewai.events.MemoryQueryCompletedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import MemorySaveCompletedEvent' instead") class MemorySaveCompletedEvent(new_events.MemorySaveCompletedEvent): """Deprecated: Use crewai.events.MemorySaveCompletedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import MemorySaveStartedEvent' instead") class MemorySaveStartedEvent(new_events.MemorySaveStartedEvent): """Deprecated: Use crewai.events.MemorySaveStartedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import MemoryQueryStartedEvent' instead") class MemoryQueryStartedEvent(new_events.MemoryQueryStartedEvent): """Deprecated: Use crewai.events.MemoryQueryStartedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import MemoryRetrievalCompletedEvent' instead") class MemoryRetrievalCompletedEvent(new_events.MemoryRetrievalCompletedEvent): """Deprecated: Use crewai.events.MemoryRetrievalCompletedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import MemorySaveFailedEvent' instead") class MemorySaveFailedEvent(new_events.MemorySaveFailedEvent): """Deprecated: Use crewai.events.MemorySaveFailedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import MemoryQueryFailedEvent' instead") class MemoryQueryFailedEvent(new_events.MemoryQueryFailedEvent): """Deprecated: Use crewai.events.MemoryQueryFailedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import KnowledgeRetrievalStartedEvent' instead") class KnowledgeRetrievalStartedEvent(new_events.KnowledgeRetrievalStartedEvent): """Deprecated: Use crewai.events.KnowledgeRetrievalStartedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import KnowledgeRetrievalCompletedEvent' instead") class KnowledgeRetrievalCompletedEvent(new_events.KnowledgeRetrievalCompletedEvent): """Deprecated: Use crewai.events.KnowledgeRetrievalCompletedEvent instead.""" - pass + @deprecated("Use 'from crewai.events import LLMStreamChunkEvent' instead") class LLMStreamChunkEvent(new_events.LLMStreamChunkEvent): """Deprecated: Use crewai.events.LLMStreamChunkEvent instead.""" - pass + __all__ = [ - 'BaseEventListener', - 'crewai_event_bus', - 'CrewKickoffStartedEvent', - 'CrewKickoffCompletedEvent', - 'AgentExecutionCompletedEvent', - 'MemoryQueryCompletedEvent', - 'MemorySaveCompletedEvent', - 'MemorySaveStartedEvent', - 'MemoryQueryStartedEvent', - 'MemoryRetrievalCompletedEvent', - 'MemorySaveFailedEvent', - 'MemoryQueryFailedEvent', - 'KnowledgeRetrievalStartedEvent', - 'KnowledgeRetrievalCompletedEvent', - 'LLMStreamChunkEvent', + "AgentExecutionCompletedEvent", + "BaseEventListener", + "CrewKickoffCompletedEvent", + "CrewKickoffStartedEvent", + "KnowledgeRetrievalCompletedEvent", + "KnowledgeRetrievalStartedEvent", + "LLMStreamChunkEvent", + "MemoryQueryCompletedEvent", + "MemoryQueryFailedEvent", + "MemoryQueryStartedEvent", + "MemoryRetrievalCompletedEvent", + "MemorySaveCompletedEvent", + "MemorySaveFailedEvent", + "MemorySaveStartedEvent", + "crewai_event_bus", ] __deprecated__ = "Use 'crewai.events' instead of 'crewai.utilities.events'" diff --git a/src/crewai/utilities/events/base_event_listener.py b/src/crewai/utilities/events/base_event_listener.py index 349295ce3..a4fd8330b 100644 --- a/src/crewai/utilities/events/base_event_listener.py +++ b/src/crewai/utilities/events/base_event_listener.py @@ -1,6 +1,7 @@ """Backwards compatibility stub for crewai.utilities.events.base_event_listener.""" import warnings + from crewai.events import BaseEventListener warnings.warn( diff --git a/src/crewai/utilities/events/crewai_event_bus.py b/src/crewai/utilities/events/crewai_event_bus.py index 337f267ea..959dedb6f 100644 --- a/src/crewai/utilities/events/crewai_event_bus.py +++ b/src/crewai/utilities/events/crewai_event_bus.py @@ -1,6 +1,7 @@ """Backwards compatibility stub for crewai.utilities.events.crewai_event_bus.""" import warnings + from crewai.events import crewai_event_bus warnings.warn( diff --git a/src/crewai/utilities/exceptions/context_window_exceeding_exception.py b/src/crewai/utilities/exceptions/context_window_exceeding_exception.py index 399cf5a00..cbbe3e0a5 100644 --- a/src/crewai/utilities/exceptions/context_window_exceeding_exception.py +++ b/src/crewai/utilities/exceptions/context_window_exceeding_exception.py @@ -1,26 +1,57 @@ -class LLMContextLengthExceededException(Exception): - CONTEXT_LIMIT_ERRORS = [ - "expected a string with maximum length", - "maximum context length", - "context length exceeded", - "context_length_exceeded", - "context window full", - "too many tokens", - "input is too long", - "exceeds token limit", - ] +from typing import Final - def __init__(self, error_message: str): +CONTEXT_LIMIT_ERRORS: Final[list[str]] = [ + "expected a string with maximum length", + "maximum context length", + "context length exceeded", + "context_length_exceeded", + "context window full", + "too many tokens", + "input is too long", + "exceeds token limit", +] + + +class LLMContextLengthExceededError(Exception): + """Exception raised when the context length of a language model is exceeded. + + Attributes: + original_error_message: The original error message from the LLM. + """ + + def __init__(self, error_message: str) -> None: + """Initialize the exception with the original error message. + + Args: + error_message: The original error message from the LLM. + """ self.original_error_message = error_message super().__init__(self._get_error_message(error_message)) - def _is_context_limit_error(self, error_message: str) -> bool: + @staticmethod + def _is_context_limit_error(error_message: str) -> bool: + """Check if the error message indicates a context length limit error. + + Args: + error_message: The error message to check. + + Returns: + True if the error message indicates a context length limit error, False otherwise. + """ return any( - phrase.lower() in error_message.lower() - for phrase in self.CONTEXT_LIMIT_ERRORS + phrase.lower() in error_message.lower() for phrase in CONTEXT_LIMIT_ERRORS ) - def _get_error_message(self, error_message: str): + @staticmethod + def _get_error_message(error_message: str) -> str: + """Generate a user-friendly error message based on the original error message. + + Args: + error_message: The original error message from the LLM. + + Returns: + A user-friendly error message. + """ return ( f"LLM context length exceeded. Original error: {error_message}\n" "Consider using a smaller input or implementing a text splitting strategy." diff --git a/src/crewai/utilities/file_handler.py b/src/crewai/utilities/file_handler.py index 85d9766c5..106cb76b3 100644 --- a/src/crewai/utilities/file_handler.py +++ b/src/crewai/utilities/file_handler.py @@ -2,71 +2,140 @@ import json import os import pickle from datetime import datetime -from typing import Union +from typing import Any, TypedDict + +from typing_extensions import Unpack + + +class LogEntry(TypedDict, total=False): + """TypedDict for log entry kwargs with optional fields for flexibility.""" + + task_name: str + task: str + agent: str + status: str + output: str + input: str + message: str + level: str + crew: str + flow: str + tool: str + error: str + duration: float + metadata: dict[str, Any] class FileHandler: """Handler for file operations supporting both JSON and text-based logging. - - Args: - file_path (Union[bool, str]): Path to the log file or boolean flag + + Attributes: + _path: The path to the log file. """ - def __init__(self, file_path: Union[bool, str]): + def __init__(self, file_path: bool | str) -> None: + """Initialize the FileHandler with the specified file path. + Args: + file_path: Path to the log file or boolean flag. + """ self._initialize_path(file_path) - - def _initialize_path(self, file_path: Union[bool, str]): + + def _initialize_path(self, file_path: bool | str) -> None: + """Initialize the file path based on the input type. + + Args: + file_path: Path to the log file or boolean flag. + + Raises: + ValueError: If file_path is neither a string nor a boolean. + """ if file_path is True: # File path is boolean True self._path = os.path.join(os.curdir, "logs.txt") - + elif isinstance(file_path, str): # File path is a string if file_path.endswith((".json", ".txt")): - self._path = file_path # No modification if the file ends with .json or .txt + self._path = ( + file_path # No modification if the file ends with .json or .txt + ) else: - self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt - + self._path = ( + file_path + ".txt" + ) # Append .txt if the file doesn't end with .json or .txt + else: - raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid - - def log(self, **kwargs): + raise ValueError( + "file_path must be a string or boolean." + ) # Handle the case where file_path isn't valid + + def log(self, **kwargs: Unpack[LogEntry]) -> None: + """Log data with structured fields. + + Keyword Args: + task_name: Name of the task. + task: Description of the task. + agent: Name of the agent. + status: Status of the operation. + output: Output data. + input: Input data. + message: Log message. + level: Log level (e.g., INFO, ERROR). + crew: Name of the crew. + flow: Name of the flow. + tool: Name of the tool used. + error: Error message if any. + duration: Duration of the operation in seconds. + metadata: Additional metadata as a dictionary. + + Raises: + ValueError: If logging fails. + """ try: now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") log_entry = {"timestamp": now, **kwargs} if self._path.endswith(".json"): # Append log in JSON format - with open(self._path, "a", encoding="utf-8") as file: - # If the file is empty, start with a list; else, append to it - try: - # Try reading existing content to avoid overwriting - with open(self._path, "r", encoding="utf-8") as read_file: - existing_data = json.load(read_file) - existing_data.append(log_entry) - except (json.JSONDecodeError, FileNotFoundError): - # If no valid JSON or file doesn't exist, start with an empty list - existing_data = [log_entry] - - with open(self._path, "w", encoding="utf-8") as write_file: - json.dump(existing_data, write_file, indent=4) - write_file.write("\n") - + try: + # Try reading existing content to avoid overwriting + with open(self._path, encoding="utf-8") as read_file: + existing_data = json.load(read_file) + existing_data.append(log_entry) + except (json.JSONDecodeError, FileNotFoundError): + # If no valid JSON or file doesn't exist, start with an empty list + existing_data = [log_entry] + + with open(self._path, "w", encoding="utf-8") as write_file: + json.dump(existing_data, write_file, indent=4) + write_file.write("\n") + else: # Append log in plain text format - message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n" + message = ( + f"{now}: " + + ", ".join([f'{key}="{value}"' for key, value in kwargs.items()]) + + "\n" + ) with open(self._path, "a", encoding="utf-8") as file: file.write(message) except Exception as e: - raise ValueError(f"Failed to log message: {str(e)}") - + raise ValueError(f"Failed to log message: {e!s}") from e + + class PickleHandler: + """Handler for saving and loading data using pickle. + + Attributes: + file_path: The path to the pickle file. + """ + def __init__(self, file_name: str) -> None: - """ - Initialize the PickleHandler with the name of the file where data will be stored. + """Initialize the PickleHandler with the name of the file where data will be stored. + The file will be saved in the current directory. - Parameters: - - file_name (str): The name of the file for saving and loading data. + Args: + file_name: The name of the file for saving and loading data. """ if not file_name.endswith(".pkl"): file_name += ".pkl" @@ -74,34 +143,31 @@ class PickleHandler: self.file_path = os.path.join(os.getcwd(), file_name) def initialize_file(self) -> None: - """ - Initialize the file with an empty dictionary and overwrite any existing data. - """ + """Initialize the file with an empty dictionary and overwrite any existing data.""" self.save({}) - def save(self, data) -> None: + def save(self, data: Any) -> None: """ Save the data to the specified file using pickle. - Parameters: - - data (object): The data to be saved. + Args: + data: The data to be saved to the file. """ - with open(self.file_path, "wb") as file: - pickle.dump(data, file) + with open(self.file_path, "wb") as f: + pickle.dump(obj=data, file=f) - def load(self) -> dict: - """ - Load the data from the specified file using pickle. + def load(self) -> Any: + """Load the data from the specified file using pickle. Returns: - - dict: The data loaded from the file. + The data loaded from the file. """ if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: return {} # Return an empty dictionary if the file does not exist or is empty with open(self.file_path, "rb") as file: try: - return pickle.load(file) # nosec + return pickle.load(file) # noqa: S301 except EOFError: return {} # Return an empty dictionary if the file is empty or corrupted except Exception: diff --git a/src/crewai/utilities/formatter.py b/src/crewai/utilities/formatter.py index f57a7f884..892167d39 100644 --- a/src/crewai/utilities/formatter.py +++ b/src/crewai/utilities/formatter.py @@ -1,4 +1,7 @@ -from typing import TYPE_CHECKING, List, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Final + from crewai.utilities.constants import _NotSpecified if TYPE_CHECKING: @@ -6,17 +9,31 @@ if TYPE_CHECKING: from crewai.tasks.task_output import TaskOutput -def aggregate_raw_outputs_from_task_outputs(task_outputs: List["TaskOutput"]) -> str: - """Generate string context from the task outputs.""" - dividers = "\n\n----------\n\n" - - # Join task outputs with dividers - context = dividers.join(output.raw for output in task_outputs) - return context +DIVIDERS: Final[str] = "\n\n----------\n\n" -def aggregate_raw_outputs_from_tasks(tasks: Union[List["Task"],_NotSpecified]) -> str: - """Generate string context from the tasks.""" +def aggregate_raw_outputs_from_task_outputs(task_outputs: list[TaskOutput]) -> str: + """Generate string context from the task outputs. + + Args: + task_outputs: List of TaskOutput objects. + + Returns: + A string containing the aggregated raw outputs from the task outputs. + """ + + return DIVIDERS.join(output.raw for output in task_outputs) + + +def aggregate_raw_outputs_from_tasks(tasks: list[Task] | _NotSpecified) -> str: + """Generate string context from the tasks. + + Args: + tasks: List of Task objects or _NotSpecified. + + Returns: + A string containing the aggregated raw outputs from the tasks. + """ task_outputs = ( [task.output for task in tasks if task.output is not None] diff --git a/src/crewai/utilities/guardrail.py b/src/crewai/utilities/guardrail.py index dff394fae..837ffad8c 100644 --- a/src/crewai/utilities/guardrail.py +++ b/src/crewai/utilities/guardrail.py @@ -1,7 +1,14 @@ -from collections.abc import Callable -from typing import Any +from __future__ import annotations -from pydantic import BaseModel, field_validator +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, Field, field_validator +from typing_extensions import Self + +if TYPE_CHECKING: + from crewai.lite_agent import LiteAgentOutput + from crewai.tasks.task_output import TaskOutput class GuardrailResult(BaseModel): @@ -12,18 +19,31 @@ class GuardrailResult(BaseModel): be easily handled by the task execution system. Attributes: - success (bool): Whether the guardrail validation passed - result (Any, optional): The validated/transformed result if successful - error (str, optional): Error message if validation failed + success: Whether the guardrail validation passed + result: The validated/transformed result if successful + error: Error message if validation failed """ - success: bool - result: Any | None = None - error: str | None = None + success: bool = Field(description="Whether the guardrail validation passed") + result: Any | None = Field( + default=None, description="The validated/transformed result if successful" + ) + error: str | None = Field( + default=None, description="Error message if validation failed" + ) @field_validator("result", "error") @classmethod def validate_result_error_exclusivity(cls, v: Any, info) -> Any: + """Ensure that result and error are mutually exclusive based on success. + + Args: + v: The value being validated (either result or error) + info: Validation info containing the entire model data + + Returns: + The original value if validation passes + """ values = info.data if "success" in values: if values["success"] and v and "error" in values and values["error"]: @@ -37,15 +57,14 @@ class GuardrailResult(BaseModel): return v @classmethod - def from_tuple(cls, result: tuple[bool, Any | str]) -> "GuardrailResult": + def from_tuple(cls, result: tuple[bool, Any | str]) -> Self: """Create a GuardrailResult from a validation tuple. Args: - result: A tuple of (success, data) where data is either - the validated result or error message. + result: A tuple of (success, data) where data is either the validated result or error message. Returns: - GuardrailResult: A new instance with the tuple data. + A new instance with the tuple data. """ success, data = result return cls( @@ -56,7 +75,10 @@ class GuardrailResult(BaseModel): def process_guardrail( - output: Any, guardrail: Callable, retry_count: int, event_source: Any | None = None + output: TaskOutput | LiteAgentOutput, + guardrail: Callable[[Any], tuple[bool, Any | str]], + retry_count: int, + event_source: Any | None = None, ) -> GuardrailResult: """Process the guardrail for the agent output. @@ -68,7 +90,19 @@ def process_guardrail( Returns: GuardrailResult: The result of the guardrail validation + + Raises: + TypeError: If output is not a TaskOutput or LiteAgentOutput + ValueError: If guardrail is None """ + from crewai.lite_agent import LiteAgentOutput + from crewai.tasks.task_output import TaskOutput + + if not isinstance(output, (TaskOutput, LiteAgentOutput)): + raise TypeError("Output must be a TaskOutput or LiteAgentOutput") + if guardrail is None: + raise ValueError("Guardrail must not be None") + from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_guardrail_events import ( LLMGuardrailCompletedEvent, diff --git a/src/crewai/utilities/i18n.py b/src/crewai/utilities/i18n.py index f2540e455..5bc8c764c 100644 --- a/src/crewai/utilities/i18n.py +++ b/src/crewai/utilities/i18n.py @@ -1,36 +1,51 @@ -import json -import os -from typing import Dict, Optional, Union - -from pydantic import BaseModel, Field, PrivateAttr, model_validator - """Internationalization support for CrewAI prompts and messages.""" +import json +import os +from typing import Literal + +from pydantic import BaseModel, Field, PrivateAttr, model_validator +from typing_extensions import Self + + class I18N(BaseModel): - """Handles loading and retrieving internationalized prompts.""" - _prompts: Dict[str, Dict[str, str]] = PrivateAttr() - prompt_file: Optional[str] = Field( + """Handles loading and retrieving internationalized prompts. + + Attributes: + _prompts: Internal dictionary storing loaded prompts. + prompt_file: Optional path to a custom JSON file containing prompts. + """ + + _prompts: dict[str, dict[str, str]] = PrivateAttr() + prompt_file: str | None = Field( default=None, description="Path to the prompt_file file to load", ) @model_validator(mode="after") - def load_prompts(self) -> "I18N": - """Load prompts from a JSON file.""" + def load_prompts(self) -> Self: + """Load prompts from a JSON file. + + Returns: + The I18N instance with loaded prompts. + + Raises: + Exception: If the prompt file is not found or cannot be decoded. + """ try: if self.prompt_file: - with open(self.prompt_file, "r", encoding="utf-8") as f: + with open(self.prompt_file, encoding="utf-8") as f: self._prompts = json.load(f) else: dir_path = os.path.dirname(os.path.realpath(__file__)) prompts_path = os.path.join(dir_path, "../translations/en.json") - with open(prompts_path, "r", encoding="utf-8") as f: + with open(prompts_path, encoding="utf-8") as f: self._prompts = json.load(f) - except FileNotFoundError: - raise Exception(f"Prompt file '{self.prompt_file}' not found.") - except json.JSONDecodeError: - raise Exception("Error decoding JSON from the prompts file.") + except FileNotFoundError as e: + raise Exception(f"Prompt file '{self.prompt_file}' not found.") from e + except json.JSONDecodeError as e: + raise Exception("Error decoding JSON from the prompts file.") from e if not self._prompts: self._prompts = {} @@ -38,16 +53,58 @@ class I18N(BaseModel): return self def slice(self, slice: str) -> str: + """Retrieve a prompt slice by key. + + Args: + slice: The key of the prompt slice to retrieve. + + Returns: + The prompt slice as a string. + """ return self.retrieve("slices", slice) def errors(self, error: str) -> str: + """Retrieve an error message by key. + + Args: + error: The key of the error message to retrieve. + + Returns: + The error message as a string. + """ return self.retrieve("errors", error) - def tools(self, tool: str) -> Union[str, Dict[str, str]]: + def tools(self, tool: str) -> str | dict[str, str]: + """Retrieve a tool prompt by key. + + Args: + tool: The key of the tool prompt to retrieve. + + Returns: + The tool prompt as a string or dictionary. + """ return self.retrieve("tools", tool) - def retrieve(self, kind, key) -> str: + def retrieve( + self, + kind: Literal[ + "slices", "errors", "tools", "reasoning", "hierarchical_manager_agent" + ], + key: str, + ) -> str: + """Retrieve a prompt by kind and key. + + Args: + kind: The kind of prompt. + key: The key of the specific prompt to retrieve. + + Returns: + The prompt as a string. + + Raises: + Exception: If the prompt for the given kind and key is not found. + """ try: return self._prompts[kind][key] - except Exception as _: - raise Exception(f"Prompt for '{kind}':'{key}' not found.") + except Exception as e: + raise Exception(f"Prompt for '{kind}':'{key}' not found.") from e diff --git a/src/crewai/utilities/import_utils.py b/src/crewai/utilities/import_utils.py index 47e46f4ba..e6d807c36 100644 --- a/src/crewai/utilities/import_utils.py +++ b/src/crewai/utilities/import_utils.py @@ -7,8 +7,6 @@ from types import ModuleType class OptionalDependencyError(ImportError): """Exception raised when an optional dependency is not installed.""" - pass - def require(name: str, *, purpose: str) -> ModuleType: """Import a module, raising a helpful error if it's not installed. diff --git a/src/crewai/utilities/internal_instructor.py b/src/crewai/utilities/internal_instructor.py index e9401c778..aefbcb28b 100644 --- a/src/crewai/utilities/internal_instructor.py +++ b/src/crewai/utilities/internal_instructor.py @@ -1,43 +1,98 @@ -import warnings -from typing import Any, Optional, Type +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeGuard, TypeVar + +from pydantic import BaseModel + +if TYPE_CHECKING: + from crewai.agent import Agent + from crewai.llm import LLM + from crewai.llms.base_llm import BaseLLM + +from crewai.utilities.logger_utils import suppress_warnings +from crewai.utilities.types import LLMMessage + +T = TypeVar("T", bound=BaseModel) -class InternalInstructor: - """Class that wraps an agent llm with instructor.""" +def _is_valid_llm(llm: Any) -> TypeGuard[str | LLM | BaseLLM]: + """Type guard to ensure LLM is valid and not None. + + Args: + llm: The LLM to validate + + Returns: + True if LLM is valid (string or has model attribute), False otherwise + """ + return llm is not None and (isinstance(llm, str) or hasattr(llm, "model")) + + +class InternalInstructor(Generic[T]): + """Class that wraps an agent LLM with instructor for structured output generation. + + Attributes: + content: The content to be processed + model: The Pydantic model class for the response + agent: The agent with LLM + llm: The LLM instance or model name + """ def __init__( self, content: str, - model: Type, - agent: Optional[Any] = None, - llm: Optional[str] = None, - ): + model: type[T], + agent: Agent | None = None, + llm: LLM | BaseLLM | str | None = None, + ) -> None: + """Initialize InternalInstructor. + + Args: + content: The content to be processed + model: The Pydantic model class for the response + agent: The agent with LLM + llm: The LLM instance or model name + """ self.content = content self.agent = agent - self.llm = llm self.model = model - self._client = None - self.set_instructor() + self.llm = llm or (agent.function_calling_llm or agent.llm if agent else None) - def set_instructor(self): - """Set instructor.""" - if self.agent and not self.llm: - self.llm = self.agent.function_calling_llm or self.agent.llm - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) + with suppress_warnings(): import instructor from litellm import completion self._client = instructor.from_litellm(completion) - def to_json(self): - model = self.to_pydantic() - return model.model_dump_json(indent=2) + def to_json(self) -> str: + """Convert the structured output to JSON format. - def to_pydantic(self): - messages = [{"role": "user", "content": self.content}] - model = self._client.chat.completions.create( - model=self.llm.model, response_model=self.model, messages=messages + Returns: + JSON string representation of the structured output + """ + pydantic_model = self.to_pydantic() + return pydantic_model.model_dump_json(indent=2) + + def to_pydantic(self) -> T: + """Generate structured output using the specified Pydantic model. + + Returns: + Instance of the specified Pydantic model with structured data + + Raises: + ValueError: If LLM is not provided or invalid + """ + messages: list[LLMMessage] = [{"role": "user", "content": self.content}] + + if not _is_valid_llm(self.llm): + raise ValueError( + "LLM must be provided and have a model attribute or be a string" + ) + + if isinstance(self.llm, str): + model_name = self.llm + else: + model_name = self.llm.model + + return self._client.chat.completions.create( + model=model_name, response_model=self.model, messages=messages ) - return model diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 3998a9bce..d3b439e5d 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -1,62 +1,55 @@ +import logging import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Final from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS -from crewai.llm import LLM, BaseLLM +from crewai.llm import LLM +from crewai.llms.base_llm import BaseLLM + +logger = logging.getLogger(__name__) def create_llm( - llm_value: Union[str, LLM, Any, None] = None, -) -> Optional[LLM | BaseLLM]: - """ - Creates or returns an LLM instance based on the given llm_value. + llm_value: str | LLM | Any | None = None, +) -> LLM | BaseLLM | None: + """Creates or returns an LLM instance based on the given llm_value. Args: - llm_value (str | BaseLLM | Any | None): - - str: The model name (e.g., "gpt-4"). - - BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is. - - Any: Attempt to extract known attributes like model_name, temperature, etc. - - None: Use environment-based or fallback default model. + llm_value: LLM instance, model name string, None, or an object with LLM attributes. Returns: A BaseLLM instance if successful, or None if something fails. """ - # 1) If llm_value is already a BaseLLM or LLM object, return it directly - if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM): + if isinstance(llm_value, (LLM, BaseLLM)): return llm_value - # 2) If llm_value is a string (model name) if isinstance(llm_value, str): try: - created_llm = LLM(model=llm_value) - return created_llm + return LLM(model=llm_value) except Exception as e: - print(f"Failed to instantiate LLM with model='{llm_value}': {e}") + logger.debug(f"Failed to instantiate LLM with model='{llm_value}': {e}") return None - # 3) If llm_value is None, parse environment variables or use default if llm_value is None: return _llm_via_environment_or_fallback() - # 4) Otherwise, attempt to extract relevant attributes from an unknown object try: - # Extract attributes with explicit types model = ( getattr(llm_value, "model", None) or getattr(llm_value, "model_name", None) or getattr(llm_value, "deployment_name", None) or str(llm_value) ) - temperature: Optional[float] = getattr(llm_value, "temperature", None) - max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None) - logprobs: Optional[int] = getattr(llm_value, "logprobs", None) - timeout: Optional[float] = getattr(llm_value, "timeout", None) - api_key: Optional[str] = getattr(llm_value, "api_key", None) - base_url: Optional[str] = getattr(llm_value, "base_url", None) - api_base: Optional[str] = getattr(llm_value, "api_base", None) + temperature: float | None = getattr(llm_value, "temperature", None) + max_tokens: int | None = getattr(llm_value, "max_tokens", None) + logprobs: int | None = getattr(llm_value, "logprobs", None) + timeout: float | None = getattr(llm_value, "timeout", None) + api_key: str | None = getattr(llm_value, "api_key", None) + base_url: str | None = getattr(llm_value, "base_url", None) + api_base: str | None = getattr(llm_value, "api_base", None) - created_llm = LLM( + return LLM( model=model, temperature=temperature, max_tokens=max_tokens, @@ -66,15 +59,23 @@ def create_llm( base_url=base_url, api_base=api_base, ) - return created_llm except Exception as e: - print(f"Error instantiating LLM from unknown object type: {e}") + logger.debug(f"Error instantiating LLM from unknown object type: {e}") return None -def _llm_via_environment_or_fallback() -> Optional[LLM]: - """ - Helper function: if llm_value is None, we load environment variables or fallback default model. +UNACCEPTED_ATTRIBUTES: Final[list[str]] = [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION_NAME", +] + + +def _llm_via_environment_or_fallback() -> LLM | None: + """Creates an LLM instance based on environment variables or defaults. + + Returns: + A BaseLLM instance if successful, or None if something fails. """ model_name = ( os.environ.get("MODEL") @@ -83,28 +84,25 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: or DEFAULT_LLM_MODEL ) - # Initialize parameters with correct types model: str = model_name - temperature: Optional[float] = None - max_tokens: Optional[int] = None - max_completion_tokens: Optional[int] = None - logprobs: Optional[int] = None - timeout: Optional[float] = None - api_key: Optional[str] = None - base_url: Optional[str] = None - api_version: Optional[str] = None - presence_penalty: Optional[float] = None - frequency_penalty: Optional[float] = None - top_p: Optional[float] = None - n: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None - logit_bias: Optional[Dict[int, float]] = None - response_format: Optional[Dict[str, Any]] = None - seed: Optional[int] = None - top_logprobs: Optional[int] = None - callbacks: List[Any] = [] + temperature: float | None = None + max_tokens: int | None = None + max_completion_tokens: int | None = None + logprobs: int | None = None + timeout: float | None = None + api_key: str | None = None + api_version: str | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + top_p: float | None = None + n: int | None = None + stop: str | list[str] | None = None + logit_bias: dict[int, float] | None = None + response_format: dict[str, Any] | None = None + seed: int | None = None + top_logprobs: int | None = None + callbacks: list[Any] = [] - # Optional base URL from env base_url = ( os.environ.get("BASE_URL") or os.environ.get("OPENAI_API_BASE") @@ -119,8 +117,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: elif api_base and not base_url: base_url = api_base - # Initialize llm_params dictionary - llm_params: Dict[str, Any] = { + llm_params: dict[str, Any] = { "model": model, "temperature": temperature, "max_tokens": max_tokens, @@ -143,11 +140,6 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: "callbacks": callbacks, } - UNACCEPTED_ATTRIBUTES = [ - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_REGION_NAME", - ] set_provider = model_name.partition("/")[0] if "/" in model_name else "openai" if set_provider in ENV_VARS: @@ -167,28 +159,26 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: if key not in ["prompt", "key_name", "default"]: llm_params[key.lower()] = value else: - print( + logger.debug( f"Expected env_var to be a dictionary, but got {type(env_var)}" ) - # Remove None values llm_params = {k: v for k, v in llm_params.items() if v is not None} - # Try creating the LLM try: - new_llm = LLM(**llm_params) - return new_llm + return LLM(**llm_params) except Exception as e: - print( + logger.debug( f"Error instantiating LLM from environment/fallback: {type(e).__name__}: {e}" ) return None def _normalize_key_name(key_name: str) -> str: - """ - Maps environment variable names to recognized litellm parameter keys, - using patterns from LITELLM_PARAMS. + """Maps environment variable names to recognized litellm parameter keys. + + Args: + key_name: The environment variable name to normalize. """ for pattern in LITELLM_PARAMS: if pattern in key_name: diff --git a/src/crewai/utilities/logger.py b/src/crewai/utilities/logger.py index 2f69e7abc..6796f26e0 100644 --- a/src/crewai/utilities/logger.py +++ b/src/crewai/utilities/logger.py @@ -2,19 +2,34 @@ from datetime import datetime from pydantic import BaseModel, Field, PrivateAttr -from crewai.utilities.printer import Printer +from crewai.utilities.printer import ColoredText, Printer, PrinterColor class Logger(BaseModel): - verbose: bool = Field(default=False) + verbose: bool = Field( + default=False, + description="Enables verbose logging with timestamps", + ) + default_color: PrinterColor = Field( + default="bold_yellow", + description="Default color for log messages", + ) _printer: Printer = PrivateAttr(default_factory=Printer) - default_color: str = Field(default="bold_yellow") - def log(self, level, message, color=None): - if color is None: - color = self.default_color + def log(self, level: str, message: str, color: PrinterColor | None = None) -> None: + """Log a message with timestamp if verbose mode is enabled. + + Args: + level: The log level (e.g., 'info', 'warning', 'error'). + message: The message to log. + color: Optional color for the message. Defaults to default_color. + """ if self.verbose: - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self._printer.print( - f"\n[{timestamp}][{level.upper()}]: {message}", color=color + [ + ColoredText(f"\n[{timestamp}]", "cyan"), + ColoredText(f"[{level.upper()}]: ", "yellow"), + ColoredText(message, color or self.default_color), + ] ) diff --git a/src/crewai/utilities/logger_utils.py b/src/crewai/utilities/logger_utils.py index 7d0e806be..f0ad21f18 100644 --- a/src/crewai/utilities/logger_utils.py +++ b/src/crewai/utilities/logger_utils.py @@ -47,10 +47,12 @@ def suppress_warnings() -> Generator[None, None, None]: None during the context execution. Note: - There is a similar implementation in src/crewai/llm.py that also - suppresses a specific deprecation warning. That version may be - consolidated here in the future. + This implementation consolidates warning suppression used throughout + the codebase, including specific deprecation warnings from dependencies. """ with warnings.catch_warnings(): warnings.filterwarnings("ignore") + warnings.filterwarnings( + "ignore", message="open_text is deprecated*", category=DeprecationWarning + ) yield diff --git a/src/crewai/utilities/parser.py b/src/crewai/utilities/parser.py deleted file mode 100644 index c19cc1133..000000000 --- a/src/crewai/utilities/parser.py +++ /dev/null @@ -1,31 +0,0 @@ -import re - - -class YamlParser: - @staticmethod - def parse(file): - """ - Parses a YAML file, modifies specific patterns, and checks for unsupported 'context' usage. - Args: - file (file object): The YAML file to parse. - Returns: - str: The modified content of the YAML file. - Raises: - ValueError: If 'context:' is used incorrectly. - """ - content = file.read() - - # Replace single { and } with doubled ones, while leaving already doubled ones intact and the other special characters {# and {% - modified_content = re.sub(r"(? str: """Returns the path for SQLite database storage. @@ -19,13 +20,6 @@ def db_storage_path() -> str: return str(data_dir) -def get_project_directory_name(): +def get_project_directory_name() -> str: """Returns the current project directory name.""" - project_directory_name = os.environ.get("CREWAI_STORAGE_DIR") - - if project_directory_name: - return project_directory_name - else: - cwd = Path.cwd() - project_directory_name = cwd.name - return project_directory_name \ No newline at end of file + return os.environ.get("CREWAI_STORAGE_DIR", Path.cwd().name) diff --git a/src/crewai/utilities/planning_handler.py b/src/crewai/utilities/planning_handler.py index 1bd14a0c8..c1470d77f 100644 --- a/src/crewai/utilities/planning_handler.py +++ b/src/crewai/utilities/planning_handler.py @@ -1,16 +1,19 @@ +"""Handles planning and coordination of crew tasks.""" + import logging -from typing import Any, List, Optional from pydantic import BaseModel, Field from crewai.agent import Agent +from crewai.llms.base_llm import BaseLLM from crewai.task import Task -"""Handles planning and coordination of crew tasks.""" logger = logging.getLogger(__name__) + class PlanPerTask(BaseModel): """Represents a plan for a specific task.""" + task: str = Field(..., description="The task for which the plan is created") plan: str = Field( ..., @@ -20,28 +23,48 @@ class PlanPerTask(BaseModel): class PlannerTaskPydanticOutput(BaseModel): """Output format for task planning results.""" - list_of_plans_per_task: List[PlanPerTask] = Field( + + list_of_plans_per_task: list[PlanPerTask] = Field( ..., description="Step by step plan on how the agents can execute their tasks using the available tools with mastery", ) class CrewPlanner: - """Plans and coordinates the execution of crew tasks.""" - def __init__(self, tasks: List[Task], planning_agent_llm: Optional[Any] = None): - self.tasks = tasks + """Plans and coordinates the execution of crew tasks. - if planning_agent_llm is None: - self.planning_agent_llm = "gpt-4o-mini" - else: - self.planning_agent_llm = planning_agent_llm + Attributes: + tasks: List of tasks to be planned. + planning_agent_llm: Optional LLM model for the planning agent. + """ + + def __init__( + self, tasks: list[Task], planning_agent_llm: str | BaseLLM | None = None + ) -> None: + """Initialize CrewPlanner with tasks and optional planning agent LLM. + + Args: + tasks: List of tasks to be planned. + planning_agent_llm: Optional LLM model for the planning agent. Defaults to None. + """ + self.tasks = tasks + self.planning_agent_llm = planning_agent_llm or "gpt-4o-mini" def _handle_crew_planning(self) -> PlannerTaskPydanticOutput: - """Handles the Crew planning by creating detailed step-by-step plans for each task.""" + """Handles the Crew planning by creating detailed step-by-step plans for each task. + + Returns: + A PlannerTaskPydanticOutput containing the detailed plans for each task. + + Raises: + ValueError: If the planning output cannot be obtained. + """ planning_agent = self._create_planning_agent() tasks_summary = self._create_tasks_summary() - planner_task = self._create_planner_task(planning_agent, tasks_summary) + planner_task = self._create_planner_task( + planning_agent=planning_agent, tasks_summary=tasks_summary + ) result = planner_task.execute_sync() @@ -51,7 +74,11 @@ class CrewPlanner: raise ValueError("Failed to get the Planning output") def _create_planning_agent(self) -> Agent: - """Creates the planning agent for the crew planning.""" + """Creates the planning agent for the crew planning. + + Returns: + An Agent instance configured for planning tasks. + """ return Agent( role="Task Execution Planner", goal=( @@ -62,8 +89,17 @@ class CrewPlanner: llm=self.planning_agent_llm, ) - def _create_planner_task(self, planning_agent: Agent, tasks_summary: str) -> Task: - """Creates the planner task using the given agent and tasks summary.""" + @staticmethod + def _create_planner_task(planning_agent: Agent, tasks_summary: str) -> Task: + """Creates the planner task using the given agent and tasks summary. + + Args: + planning_agent: The agent responsible for planning. + tasks_summary: A summary of all tasks to be included in the planning. + + Returns: + A Task instance configured for planning. + """ return Task( description=( f"Based on these tasks summary: {tasks_summary} \n Create the most descriptive plan based on the tasks " @@ -74,31 +110,42 @@ class CrewPlanner: output_pydantic=PlannerTaskPydanticOutput, ) - def _get_agent_knowledge(self, task: Task) -> List[str]: - """ - Safely retrieve knowledge source content from the task's agent. + @staticmethod + def _get_agent_knowledge(task: Task) -> list[str]: + """Safely retrieve knowledge source content from the task's agent. Args: task: The task containing an agent with potential knowledge sources Returns: - List[str]: A list of knowledge source strings + A list of knowledge source strings """ try: if task.agent and task.agent.knowledge_sources: - return [source.content for source in task.agent.knowledge_sources] + return [ + getattr(source, "content", str(source)) + for source in task.agent.knowledge_sources + ] except AttributeError: logger.warning("Error accessing agent knowledge sources") return [] def _create_tasks_summary(self) -> str: - """Creates a summary of all tasks.""" + """Creates a summary of all tasks. + + Returns: + A string summarizing all tasks with their details. + """ tasks_summary = [] for idx, task in enumerate(self.tasks): knowledge_list = self._get_agent_knowledge(task) agent_tools = ( - f"[{', '.join(str(tool) for tool in task.agent.tools)}]" if task.agent and task.agent.tools else '"agent has no tools"', - f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' if knowledge_list and str(knowledge_list) != "None" else "" + f"[{', '.join(str(tool) for tool in task.agent.tools)}]" + if task.agent and task.agent.tools + else '"agent has no tools"', + f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' + if knowledge_list and str(knowledge_list) != "None" + else "", ) task_summary = f""" Task Number {idx + 1} - {task.description} diff --git a/src/crewai/utilities/printer.py b/src/crewai/utilities/printer.py index 74ad9a30b..cce14aba7 100644 --- a/src/crewai/utilities/printer.py +++ b/src/crewai/utilities/printer.py @@ -1,71 +1,72 @@ """Utility for colored console output.""" -from typing import Optional +from typing import Final, Literal, NamedTuple + +PrinterColor = Literal[ + "purple", + "bold_purple", + "green", + "bold_green", + "cyan", + "bold_cyan", + "magenta", + "bold_magenta", + "yellow", + "bold_yellow", + "red", + "blue", + "bold_blue", +] + +_COLOR_CODES: Final[dict[PrinterColor, str]] = { + "purple": "\033[95m", + "bold_purple": "\033[1m\033[95m", + "red": "\033[91m", + "bold_green": "\033[1m\033[92m", + "green": "\033[32m", + "blue": "\033[94m", + "bold_blue": "\033[1m\033[94m", + "yellow": "\033[93m", + "bold_yellow": "\033[1m\033[93m", + "cyan": "\033[96m", + "bold_cyan": "\033[1m\033[96m", + "magenta": "\033[35m", + "bold_magenta": "\033[1m\033[35m", +} + +RESET: Final[str] = "\033[0m" + + +class ColoredText(NamedTuple): + """Represents text with an optional color for console output. + + Attributes: + text: The text content to be printed. + color: Optional color for the text, specified as a PrinterColor. + """ + + text: str + color: PrinterColor | None class Printer: """Handles colored console output formatting.""" - def print(self, content: str, color: Optional[str] = None): - if color == "purple": - self._print_purple(content) - elif color == "red": - self._print_red(content) - elif color == "bold_green": - self._print_bold_green(content) - elif color == "bold_purple": - self._print_bold_purple(content) - elif color == "bold_blue": - self._print_bold_blue(content) - elif color == "yellow": - self._print_yellow(content) - elif color == "bold_yellow": - self._print_bold_yellow(content) - elif color == "cyan": - self._print_cyan(content) - elif color == "bold_cyan": - self._print_bold_cyan(content) - elif color == "magenta": - self._print_magenta(content) - elif color == "bold_magenta": - self._print_bold_magenta(content) - elif color == "green": - self._print_green(content) - else: - print(content) + @staticmethod + def print( + content: str | list[ColoredText], color: PrinterColor | None = None + ) -> None: + """Prints content to the console with optional color formatting. - def _print_bold_purple(self, content): - print("\033[1m\033[95m {}\033[00m".format(content)) - - def _print_bold_green(self, content): - print("\033[1m\033[92m {}\033[00m".format(content)) - - def _print_purple(self, content): - print("\033[95m {}\033[00m".format(content)) - - def _print_red(self, content): - print("\033[91m {}\033[00m".format(content)) - - def _print_bold_blue(self, content): - print("\033[1m\033[94m {}\033[00m".format(content)) - - def _print_yellow(self, content): - print("\033[93m {}\033[00m".format(content)) - - def _print_bold_yellow(self, content): - print("\033[1m\033[93m {}\033[00m".format(content)) - - def _print_cyan(self, content): - print("\033[96m {}\033[00m".format(content)) - - def _print_bold_cyan(self, content): - print("\033[1m\033[96m {}\033[00m".format(content)) - - def _print_magenta(self, content): - print("\033[35m {}\033[00m".format(content)) - - def _print_bold_magenta(self, content): - print("\033[1m\033[35m {}\033[00m".format(content)) - - def _print_green(self, content): - print("\033[32m {}\033[00m".format(content)) + Args: + content: Either a string or a list of ColoredText objects for multicolor output. + color: Optional color for the text when content is a string. Ignored when content is a list. + """ + if isinstance(content, str): + content = [ColoredText(content, color)] + print( + "".join( + f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}" + for c in content + ) + ) diff --git a/src/crewai/utilities/prompts.py b/src/crewai/utilities/prompts.py index cd3577874..4d3e168e1 100644 --- a/src/crewai/utilities/prompts.py +++ b/src/crewai/utilities/prompts.py @@ -1,29 +1,59 @@ -from typing import Any, Optional +from __future__ import annotations + +from typing import Any, TypedDict from pydantic import BaseModel, Field -from crewai.utilities import I18N +from crewai.utilities.i18n import I18N + + +class StandardPromptResult(TypedDict): + """Result with only prompt field for standard mode.""" + + prompt: str + + +class SystemPromptResult(StandardPromptResult): + """Result with system, user, and prompt fields for system prompt mode.""" + + system: str + user: str class Prompts(BaseModel): """Manages and generates prompts for a generic agent.""" - i18n: I18N = Field(default=I18N()) - has_tools: bool = False - system_template: Optional[str] = None - prompt_template: Optional[str] = None - response_template: Optional[str] = None - use_system_prompt: Optional[bool] = False - agent: Any + i18n: I18N = Field(default_factory=I18N) + has_tools: bool = Field( + default=False, description="Indicates if the agent has access to tools" + ) + system_template: str | None = Field( + default=None, description="Custom system prompt template" + ) + prompt_template: str | None = Field( + default=None, description="Custom user prompt template" + ) + response_template: str | None = Field( + default=None, description="Custom response prompt template" + ) + use_system_prompt: bool | None = Field( + default=False, + description="Whether to use the system prompt when no custom templates are provided", + ) + agent: Any = Field(description="Reference to the agent using these prompts") - def task_execution(self) -> dict[str, str]: - """Generate a standard prompt for task execution.""" - slices = ["role_playing"] + def task_execution(self) -> SystemPromptResult | StandardPromptResult: + """Generate a standard prompt for task execution. + + Returns: + A dictionary containing the constructed prompt(s). + """ + slices: list[str] = ["role_playing"] if self.has_tools: slices.append("tools") else: slices.append("no_tools") - system = self._build_prompt(slices) + system: str = self._build_prompt(slices) slices.append("task") if ( @@ -31,54 +61,67 @@ class Prompts(BaseModel): and not self.prompt_template and self.use_system_prompt ): - return { - "system": system, - "user": self._build_prompt(["task"]), - "prompt": self._build_prompt(slices), - } - else: - return { - "prompt": self._build_prompt( - slices, - self.system_template, - self.prompt_template, - self.response_template, - ) - } + return SystemPromptResult( + system=system, + user=self._build_prompt(["task"]), + prompt=self._build_prompt(slices), + ) + return StandardPromptResult( + prompt=self._build_prompt( + slices, + self.system_template, + self.prompt_template, + self.response_template, + ) + ) def _build_prompt( self, components: list[str], - system_template=None, - prompt_template=None, - response_template=None, + system_template: str | None = None, + prompt_template: str | None = None, + response_template: str | None = None, ) -> str: - """Constructs a prompt string from specified components.""" + """Constructs a prompt string from specified components. + + Args: + components: List of component names to include in the prompt. + system_template: Optional custom template for the system prompt. + prompt_template: Optional custom template for the user prompt. + response_template: Optional custom template for the response prompt. + + Returns: + The constructed prompt string. + """ + prompt: str if not system_template or not prompt_template: # If any of the required templates are missing, fall back to the default format - prompt_parts = [self.i18n.slice(component) for component in components] + prompt_parts: list[str] = [ + self.i18n.slice(component) for component in components + ] prompt = "".join(prompt_parts) else: # All templates are provided, use them - prompt_parts = [ + template_parts: list[str] = [ self.i18n.slice(component) for component in components if component != "task" ] - system = system_template.replace("{{ .System }}", "".join(prompt_parts)) + system: str = system_template.replace( + "{{ .System }}", "".join(template_parts) + ) prompt = prompt_template.replace( "{{ .Prompt }}", "".join(self.i18n.slice("task")) ) # Handle missing response_template if response_template: - response = response_template.split("{{ .Response }}")[0] + response: str = response_template.split("{{ .Response }}")[0] prompt = f"{system}\n{prompt}\n{response}" else: prompt = f"{system}\n{prompt}" - prompt = ( + return ( prompt.replace("{goal}", self.agent.goal) .replace("{role}", self.agent.role) .replace("{backstory}", self.agent.backstory) ) - return prompt diff --git a/src/crewai/utilities/pydantic_schema_parser.py b/src/crewai/utilities/pydantic_schema_parser.py index 2827d70aa..a5bbb5088 100644 --- a/src/crewai/utilities/pydantic_schema_parser.py +++ b/src/crewai/utilities/pydantic_schema_parser.py @@ -1,57 +1,62 @@ -from typing import Dict, List, Type, Union, get_args, get_origin +from typing import Any, Union, get_args, get_origin -from pydantic import BaseModel +from pydantic import BaseModel, Field class PydanticSchemaParser(BaseModel): - model: Type[BaseModel] + model: type[BaseModel] = Field(..., description="The Pydantic model to parse.") def get_schema(self) -> str: - """ - Public method to get the schema of a Pydantic model. + """Public method to get the schema of a Pydantic model. - :return: String representation of the model schema. + Returns: + String representation of the model schema. """ return "{\n" + self._get_model_schema(self.model) + "\n}" - def _get_model_schema(self, model: Type[BaseModel], depth: int = 0) -> str: - indent = " " * 4 * depth - lines = [ - f"{indent} {field_name}: {self._get_field_type(field, depth + 1)}" + def _get_model_schema(self, model: type[BaseModel], depth: int = 0) -> str: + """Recursively get the schema of a Pydantic model, handling nested models and lists. + + Args: + model: The Pydantic model to process. + depth: The current depth of recursion for indentation purposes. + + Returns: + A string representation of the model schema. + """ + indent: str = " " * 4 * depth + lines: list[str] = [ + f"{indent} {field_name}: {self._get_field_type_for_annotation(field.annotation, depth + 1)}" for field_name, field in model.model_fields.items() ] return ",\n".join(lines) - def _get_field_type(self, field, depth: int) -> str: - field_type = field.annotation - origin = get_origin(field_type) + def _format_list_type(self, list_item_type: Any, depth: int) -> str: + """Format a List type, handling nested models if necessary. - if origin in {list, List}: - list_item_type = get_args(field_type)[0] - return self._format_list_type(list_item_type, depth) + Args: + list_item_type: The type of items in the list. + depth: The current depth of recursion for indentation purposes. - if origin in {dict, Dict}: - key_type, value_type = get_args(field_type) - return f"Dict[{key_type.__name__}, {value_type.__name__}]" - - if origin is Union: - return self._format_union_type(field_type, depth) - - if isinstance(field_type, type) and issubclass(field_type, BaseModel): - nested_schema = self._get_model_schema(field_type, depth) - nested_indent = " " * 4 * depth - return f"{field_type.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}" - - return field_type.__name__ - - def _format_list_type(self, list_item_type, depth: int) -> str: + Returns: + A string representation of the List type. + """ if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel): nested_schema = self._get_model_schema(list_item_type, depth + 1) - nested_indent = " " * 4 * (depth) + nested_indent = " " * 4 * depth return f"List[\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}\n{nested_indent}]" return f"List[{list_item_type.__name__}]" - def _format_union_type(self, field_type, depth: int) -> str: + def _format_union_type(self, field_type: Any, depth: int) -> str: + """Format a Union type, handling Optional and nested types. + + Args: + field_type: The Union type to format. + depth: The current depth of recursion for indentation purposes. + + Returns: + A string representation of the Union type. + """ args = get_args(field_type) if type(None) in args: # It's an Optional type @@ -61,26 +66,32 @@ class PydanticSchemaParser(BaseModel): non_none_args[0], depth ) return f"Optional[{inner_type}]" - else: - # Union with None and multiple other types - inner_types = ", ".join( - self._get_field_type_for_annotation(arg, depth) - for arg in non_none_args - ) - return f"Optional[Union[{inner_types}]]" - else: - # General Union type + # Union with None and multiple other types inner_types = ", ".join( - self._get_field_type_for_annotation(arg, depth) for arg in args + self._get_field_type_for_annotation(arg, depth) for arg in non_none_args ) - return f"Union[{inner_types}]" + return f"Optional[Union[{inner_types}]]" + # General Union type + inner_types = ", ".join( + self._get_field_type_for_annotation(arg, depth) for arg in args + ) + return f"Union[{inner_types}]" - def _get_field_type_for_annotation(self, annotation, depth: int) -> str: - origin = get_origin(annotation) - if origin in {list, List}: + def _get_field_type_for_annotation(self, annotation: Any, depth: int) -> str: + """Recursively get the string representation of a field's type annotation. + + Args: + annotation: The type annotation to process. + depth: The current depth of recursion for indentation purposes. + + Returns: + A string representation of the type annotation. + """ + origin: Any = get_origin(annotation) + if origin is list: list_item_type = get_args(annotation)[0] return self._format_list_type(list_item_type, depth) - if origin in {dict, Dict}: + if origin is dict: key_type, value_type = get_args(annotation) return f"Dict[{key_type.__name__}, {value_type.__name__}]" if origin is Union: diff --git a/src/crewai/utilities/reasoning_handler.py b/src/crewai/utilities/reasoning_handler.py index f5c4636ae..56ac8c1a0 100644 --- a/src/crewai/utilities/reasoning_handler.py +++ b/src/crewai/utilities/reasoning_handler.py @@ -1,19 +1,19 @@ -import logging import json -from typing import Tuple, cast +import logging +from typing import Any, Final, Literal, cast from pydantic import BaseModel, Field from crewai.agent import Agent -from crewai.task import Task -from crewai.utilities import I18N -from crewai.llm import LLM from crewai.events.event_bus import crewai_event_bus from crewai.events.types.reasoning_events import ( - AgentReasoningStartedEvent, AgentReasoningCompletedEvent, AgentReasoningFailedEvent, + AgentReasoningStartedEvent, ) +from crewai.llm import LLM +from crewai.task import Task +from crewai.utilities.i18n import I18N class ReasoningPlan(BaseModel): @@ -29,22 +29,49 @@ class AgentReasoningOutput(BaseModel): plan: ReasoningPlan = Field(description="The reasoning plan for the task.") -class ReasoningFunction(BaseModel): - """Model for function calling with reasoning.""" - - plan: str = Field(description="The detailed reasoning plan for the task.") - ready: bool = Field(description="Whether the agent is ready to execute the task.") +FUNCTION_SCHEMA: Final[dict[str, Any]] = { + "type": "function", + "function": { + "name": "create_reasoning_plan", + "description": "Create or refine a reasoning plan for a task", + "parameters": { + "type": "object", + "properties": { + "plan": { + "type": "string", + "description": "The detailed reasoning plan for the task.", + }, + "ready": { + "type": "boolean", + "description": "Whether the agent is ready to execute the task.", + }, + }, + "required": ["plan", "ready"], + }, + }, +} class AgentReasoning: """ Handles the agent reasoning process, enabling an agent to reflect and create a plan before executing a task. + + Attributes: + task: The task for which the agent is reasoning. + agent: The agent performing the reasoning. + llm: The language model used for reasoning. + logger: Logger for logging events and errors. + i18n: Internationalization utility for retrieving prompts. """ - def __init__(self, task: Task, agent: Agent): - if not task or not agent: - raise ValueError("Both task and agent must be provided.") + def __init__(self, task: Task, agent: Agent) -> None: + """Initialize the AgentReasoning with a task and an agent. + + Args: + task: The task for which the agent is reasoning. + agent: The agent performing the reasoning. + """ self.task = task self.agent = agent self.llm = cast(LLM, agent.llm) @@ -52,9 +79,7 @@ class AgentReasoning: self.i18n = I18N() def handle_agent_reasoning(self) -> AgentReasoningOutput: - """ - Public method for the reasoning process that creates and refines a plan - for the task until the agent is ready to execute it. + """Public method for the reasoning process that creates and refines a plan for the task until the agent is ready to execute it. Returns: AgentReasoningOutput: The output of the agent reasoning process. @@ -70,7 +95,7 @@ class AgentReasoning: from_task=self.task, ), ) - except Exception: + except Exception: # noqa: S110 # Ignore event bus errors to avoid breaking execution pass @@ -90,7 +115,7 @@ class AgentReasoning: from_task=self.task, ), ) - except Exception: + except Exception: # noqa: S110 pass return output @@ -107,17 +132,16 @@ class AgentReasoning: from_task=self.task, ), ) - except Exception: + except Exception: # noqa: S110 pass raise def __handle_agent_reasoning(self) -> AgentReasoningOutput: - """ - Private method that handles the agent reasoning process. + """Private method that handles the agent reasoning process. Returns: - AgentReasoningOutput: The output of the agent reasoning process. + The output of the agent reasoning process. """ plan, ready = self.__create_initial_plan() @@ -126,46 +150,38 @@ class AgentReasoning: reasoning_plan = ReasoningPlan(plan=plan, ready=ready) return AgentReasoningOutput(plan=reasoning_plan) - def __create_initial_plan(self) -> Tuple[str, bool]: - """ - Creates the initial reasoning plan for the task. + def __create_initial_plan(self) -> tuple[str, bool]: + """Creates the initial reasoning plan for the task. Returns: - Tuple[str, bool]: The initial plan and whether the agent is ready to execute the task. + The initial plan and whether the agent is ready to execute the task. """ reasoning_prompt = self.__create_reasoning_prompt() if self.llm.supports_function_calling(): plan, ready = self.__call_with_function(reasoning_prompt, "initial_plan") return plan, ready - else: - system_prompt = self.i18n.retrieve("reasoning", "initial_plan").format( - role=self.agent.role, - goal=self.agent.goal, - backstory=self.__get_agent_backstory(), - ) + response = _call_llm_with_reasoning_prompt( + llm=self.llm, + prompt=reasoning_prompt, + task=self.task, + agent=self.agent, + i18n=self.i18n, + backstory=self.__get_agent_backstory(), + plan_type="initial_plan", + ) - response = self.llm.call( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": reasoning_prompt}, - ], - from_task=self.task, - from_agent=self.agent, - ) + return self.__parse_reasoning_response(str(response)) - return self.__parse_reasoning_response(str(response)) - - def __refine_plan_if_needed(self, plan: str, ready: bool) -> Tuple[str, bool]: - """ - Refines the reasoning plan if the agent is not ready to execute the task. + def __refine_plan_if_needed(self, plan: str, ready: bool) -> tuple[str, bool]: + """Refines the reasoning plan if the agent is not ready to execute the task. Args: plan: The current reasoning plan. ready: Whether the agent is ready to execute the task. Returns: - Tuple[str, bool]: The refined plan and whether the agent is ready to execute the task. + The refined plan and whether the agent is ready to execute the task. """ attempt = 1 max_attempts = self.agent.max_reasoning_attempts @@ -182,7 +198,7 @@ class AgentReasoning: from_task=self.task, ), ) - except Exception: + except Exception: # noqa: S110 pass refine_prompt = self.__create_refine_prompt(plan) @@ -190,19 +206,14 @@ class AgentReasoning: if self.llm.supports_function_calling(): plan, ready = self.__call_with_function(refine_prompt, "refine_plan") else: - system_prompt = self.i18n.retrieve("reasoning", "refine_plan").format( - role=self.agent.role, - goal=self.agent.goal, + response = _call_llm_with_reasoning_prompt( + llm=self.llm, + prompt=refine_prompt, + task=self.task, + agent=self.agent, + i18n=self.i18n, backstory=self.__get_agent_backstory(), - ) - - response = self.llm.call( - [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": refine_prompt}, - ], - from_task=self.task, - from_agent=self.agent, + plan_type="refine_plan", ) plan, ready = self.__parse_reasoning_response(str(response)) @@ -216,41 +227,18 @@ class AgentReasoning: return plan, ready - def __call_with_function(self, prompt: str, prompt_type: str) -> Tuple[str, bool]: - """ - Calls the LLM with function calling to get a reasoning plan. + def __call_with_function(self, prompt: str, prompt_type: str) -> tuple[str, bool]: + """Calls the LLM with function calling to get a reasoning plan. Args: prompt: The prompt to send to the LLM. prompt_type: The type of prompt (initial_plan or refine_plan). Returns: - Tuple[str, bool]: A tuple containing the plan and whether the agent is ready. + A tuple containing the plan and whether the agent is ready. """ self.logger.debug(f"Using function calling for {prompt_type} reasoning") - function_schema = { - "type": "function", - "function": { - "name": "create_reasoning_plan", - "description": "Create or refine a reasoning plan for a task", - "parameters": { - "type": "object", - "properties": { - "plan": { - "type": "string", - "description": "The detailed reasoning plan for the task.", - }, - "ready": { - "type": "boolean", - "description": "Whether the agent is ready to execute the task.", - }, - }, - "required": ["plan", "ready"], - }, - }, - } - try: system_prompt = self.i18n.retrieve("reasoning", prompt_type).format( role=self.agent.role, @@ -259,7 +247,7 @@ class AgentReasoning: ) # Prepare a simple callable that just returns the tool arguments as JSON - def _create_reasoning_plan(plan: str, ready: bool = True): # noqa: N802 + def _create_reasoning_plan(plan: str, ready: bool = True): """Return the reasoning plan result in JSON string form.""" return json.dumps({"plan": plan, "ready": ready}) @@ -268,7 +256,7 @@ class AgentReasoning: {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ], - tools=[function_schema], + tools=[FUNCTION_SCHEMA], available_functions={"create_reasoning_plan": _create_reasoning_plan}, from_task=self.task, from_agent=self.agent, @@ -291,7 +279,7 @@ class AgentReasoning: except Exception as e: self.logger.warning( - f"Error during function calling: {str(e)}. Falling back to text parsing." + f"Error during function calling: {e!s}. Falling back to text parsing." ) try: @@ -316,7 +304,7 @@ class AgentReasoning: "READY: I am ready to execute the task." in fallback_str, ) except Exception as inner_e: - self.logger.error(f"Error during fallback text parsing: {str(inner_e)}") + self.logger.error(f"Error during fallback text parsing: {inner_e!s}") return ( "Failed to generate a plan due to an error.", True, @@ -378,7 +366,8 @@ class AgentReasoning: current_plan=current_plan, ) - def __parse_reasoning_response(self, response: str) -> Tuple[str, bool]: + @staticmethod + def __parse_reasoning_response(response: str) -> tuple[str, bool]: """ Parses the reasoning response to extract the plan and whether the agent is ready to execute the task. @@ -387,7 +376,7 @@ class AgentReasoning: response: The LLM response. Returns: - Tuple[str, bool]: The plan and whether the agent is ready to execute the task. + The plan and whether the agent is ready to execute the task. """ if not response: return "No plan was generated.", False @@ -412,3 +401,43 @@ class AgentReasoning: "The _handle_agent_reasoning method is deprecated. Use handle_agent_reasoning instead." ) return self.handle_agent_reasoning() + + +def _call_llm_with_reasoning_prompt( + llm: LLM, + prompt: str, + task: Task, + agent: Agent, + i18n: I18N, + backstory: str, + plan_type: Literal["initial_plan", "refine_plan"], +) -> str: + """Calls the LLM with the reasoning prompt. + + Args: + llm: The language model to use. + prompt: The prompt to send to the LLM. + task: The task for which the agent is reasoning. + agent: The agent performing the reasoning. + i18n: Internationalization utility for retrieving prompts. + backstory: The agent's backstory. + plan_type: The type of plan being created ("initial_plan" or "refine_plan"). + + Returns: + The LLM response. + """ + system_prompt = i18n.retrieve("reasoning", plan_type).format( + role=agent.role, + goal=agent.goal, + backstory=backstory, + ) + + response = llm.call( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + from_task=task, + from_agent=agent, + ) + return str(response) diff --git a/src/crewai/utilities/rpm_controller.py b/src/crewai/utilities/rpm_controller.py index ec59b8304..4704c3e5b 100644 --- a/src/crewai/utilities/rpm_controller.py +++ b/src/crewai/utilities/rpm_controller.py @@ -1,41 +1,54 @@ +"""Controls request rate limiting for API calls.""" + import threading import time -from typing import Optional from pydantic import BaseModel, Field, PrivateAttr, model_validator +from typing_extensions import Self from crewai.utilities.logger import Logger -"""Controls request rate limiting for API calls.""" - class RPMController(BaseModel): """Manages requests per minute limiting.""" - max_rpm: Optional[int] = Field(default=None) + max_rpm: int | None = Field( + default=None, + description="Maximum requests per minute. If None, no limit is applied.", + ) logger: Logger = Field(default_factory=lambda: Logger(verbose=False)) _current_rpm: int = PrivateAttr(default=0) - _timer: Optional[threading.Timer] = PrivateAttr(default=None) - _lock: Optional[threading.Lock] = PrivateAttr(default=None) + _timer: "threading.Timer | None" = PrivateAttr(default=None) + _lock: "threading.Lock | None" = PrivateAttr(default=None) _shutdown_flag: bool = PrivateAttr(default=False) @model_validator(mode="after") - def reset_counter(self): + def reset_counter(self) -> Self: + """Resets the RPM counter and starts the timer if max_rpm is set. + + Returns: + The instance of the RPMController. + """ if self.max_rpm is not None: if not self._shutdown_flag: self._lock = threading.Lock() self._reset_request_count() return self - def check_or_wait(self): + def check_or_wait(self) -> bool: + """Checks if a new request can be made based on the RPM limit. + + Returns: + True if a new request can be made, False otherwise. + """ if self.max_rpm is None: return True - def _check_and_increment(): + def _check_and_increment() -> bool: if self.max_rpm is not None and self._current_rpm < self.max_rpm: self._current_rpm += 1 return True - elif self.max_rpm is not None: + if self.max_rpm is not None: self.logger.log( "info", "Max RPM reached, waiting for next minute to start." ) @@ -50,16 +63,18 @@ class RPMController(BaseModel): else: return _check_and_increment() - def stop_rpm_counter(self): + def stop_rpm_counter(self) -> None: + """Stops the RPM counter and cancels any active timers.""" + self._shutdown_flag = True if self._timer: self._timer.cancel() self._timer = None - def _wait_for_next_minute(self): + def _wait_for_next_minute(self) -> None: time.sleep(60) self._current_rpm = 0 - def _reset_request_count(self): + def _reset_request_count(self) -> None: def _reset(): self._current_rpm = 0 if not self._shutdown_flag: @@ -71,7 +86,3 @@ class RPMController(BaseModel): _reset() else: _reset() - - if self._timer: - self._shutdown_flag = True - self._timer.cancel() diff --git a/src/crewai/utilities/serialization.py b/src/crewai/utilities/serialization.py index c3c0c3d47..0267d0c83 100644 --- a/src/crewai/utilities/serialization.py +++ b/src/crewai/utilities/serialization.py @@ -1,14 +1,16 @@ +from __future__ import annotations + import json import uuid from datetime import date, datetime -from typing import Any, Dict, List, Union +from typing import Any, TypeAlias from pydantic import BaseModel -SerializablePrimitive = Union[str, int, float, bool, None] -Serializable = Union[ - SerializablePrimitive, List["Serializable"], Dict[str, "Serializable"] -] +SerializablePrimitive: TypeAlias = str | int | float | bool | None +Serializable: TypeAlias = ( + SerializablePrimitive | list["Serializable"] | dict[str, "Serializable"] +) def to_serializable( @@ -24,9 +26,10 @@ def to_serializable( Non-convertible objects default to their string representations. Args: - obj (Any): Object to transform. - exclude (set[str], optional): Set of keys to exclude from the result. - max_depth (int, optional): Maximum recursion depth. Defaults to 5. + obj: Object to transform. + exclude: Set of keys to exclude from the result. + max_depth: Maximum recursion depth. Defaults to 5. + _current_depth: Current recursion depth (for internal use). Returns: Serializable: A JSON-compatible structure. @@ -39,18 +42,18 @@ def to_serializable( if isinstance(obj, (str, int, float, bool, type(None))): return obj - elif isinstance(obj, uuid.UUID): + if isinstance(obj, uuid.UUID): return str(obj) - elif isinstance(obj, (date, datetime)): + if isinstance(obj, (date, datetime)): return obj.isoformat() - elif isinstance(obj, (list, tuple, set)): + if isinstance(obj, (list, tuple, set)): return [ to_serializable( item, max_depth=max_depth, _current_depth=_current_depth + 1 ) for item in obj ] - elif isinstance(obj, dict): + if isinstance(obj, dict): return { _to_serializable_key(key): to_serializable( obj=value, @@ -61,33 +64,31 @@ def to_serializable( for key, value in obj.items() if key not in exclude } - elif isinstance(obj, BaseModel): + if isinstance(obj, BaseModel): return to_serializable( obj=obj.model_dump(exclude=exclude), max_depth=max_depth, _current_depth=_current_depth + 1, ) - else: - return repr(obj) + return repr(obj) def _to_serializable_key(key: Any) -> str: if isinstance(key, (str, int)): return str(key) - return f"key_{id(key)}_{repr(key)}" + return f"key_{id(key)}_{key!r}" def to_string(obj: Any) -> str | None: """Serializes an object into a JSON string. Args: - obj (Any): Object to serialize. + obj: Object to serialize. Returns: - str | None: A JSON-formatted string or `None` if empty. + A JSON-formatted string or `None` if empty. """ serializable = to_serializable(obj) if serializable is None: return None - else: - return json.dumps(serializable) + return json.dumps(serializable) diff --git a/src/crewai/utilities/string_utils.py b/src/crewai/utilities/string_utils.py index 255e66a0b..40181459f 100644 --- a/src/crewai/utilities/string_utils.py +++ b/src/crewai/utilities/string_utils.py @@ -1,10 +1,12 @@ import re -from typing import Any, Dict, List, Optional, Union +from typing import Any, Final + +_VARIABLE_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}") def interpolate_only( - input_string: Optional[str], - inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]], + input_string: str | None, + inputs: dict[str, str | int | float | dict[str, Any] | list[Any]], ) -> str: """Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched. Only interpolates placeholders that follow the pattern {variable_name} where @@ -26,26 +28,30 @@ def interpolate_only( """ # Validation function for recursive type checking - def validate_type(value: Any) -> None: - if value is None: + def _validate_type(validate_value: Any) -> None: + if validate_value is None: return - if isinstance(value, (str, int, float, bool)): + if isinstance(validate_value, (str, int, float, bool)): return - if isinstance(value, (dict, list)): - for item in value.values() if isinstance(value, dict) else value: - validate_type(item) + if isinstance(validate_value, (dict, list)): + for item in ( + validate_value.values() + if isinstance(validate_value, dict) + else validate_value + ): + _validate_type(item) return raise ValueError( - f"Unsupported type {type(value).__name__} in inputs. " + f"Unsupported type {type(validate_value).__name__} in inputs. " "Only str, int, float, bool, dict, and list are allowed." ) # Validate all input values for key, value in inputs.items(): try: - validate_type(value) - except ValueError as e: - raise ValueError(f"Invalid value for key '{key}': {str(e)}") from e + _validate_type(value) + except ValueError as e: # noqa: PERF203 + raise ValueError(f"Invalid value for key '{key}': {e!s}") from e if input_string is None or not input_string: return "" @@ -56,13 +62,7 @@ def interpolate_only( "Inputs dictionary cannot be empty when interpolating variables" ) - # The regex pattern to find valid variable placeholders - # Matches {variable_name} where variable_name starts with a letter/underscore - # and contains only letters, numbers, and underscores - pattern = r"\{([A-Za-z_][A-Za-z0-9_\-]*)\}" - - # Find all matching variables in the input string - variables = re.findall(pattern, input_string) + variables = _VARIABLE_PATTERN.findall(input_string) result = input_string # Check if all variables exist in inputs diff --git a/src/crewai/utilities/task_output_storage_handler.py b/src/crewai/utilities/task_output_storage_handler.py index 95d366bcb..2259bb833 100644 --- a/src/crewai/utilities/task_output_storage_handler.py +++ b/src/crewai/utilities/task_output_storage_handler.py @@ -4,50 +4,14 @@ This module provides functionality for storing and retrieving task outputs from persistent storage, supporting replay and audit capabilities. """ -from datetime import datetime from typing import Any -from pydantic import BaseModel, Field - from crewai.memory.storage.kickoff_task_outputs_storage import ( KickoffTaskOutputsSQLiteStorage, ) from crewai.task import Task -class ExecutionLog(BaseModel): - """Represents a log entry for task execution. - - Attributes: - task_id: Unique identifier for the task. - expected_output: The expected output description for the task. - output: The actual output produced by the task. - timestamp: When the task was executed. - task_index: The position of the task in the execution sequence. - inputs: Input parameters provided to the task. - was_replayed: Whether this output was replayed from a previous run. - """ - - task_id: str - expected_output: str | None = None - output: dict[str, Any] - timestamp: datetime = Field(default_factory=datetime.now) - task_index: int - inputs: dict[str, Any] = Field(default_factory=dict) - was_replayed: bool = False - - def __getitem__(self, key: str) -> Any: - """Enable dictionary-style access to execution log attributes. - - Args: - key: The attribute name to access. - - Returns: - The value of the requested attribute. - """ - return getattr(self, key) - - class TaskOutputStorageHandler: """Manages storage and retrieval of task outputs. diff --git a/src/crewai/utilities/token_counter_callback.py b/src/crewai/utilities/token_counter_callback.py index 4f61d7557..96124f226 100644 --- a/src/crewai/utilities/token_counter_callback.py +++ b/src/crewai/utilities/token_counter_callback.py @@ -4,13 +4,13 @@ This module provides a callback handler that tracks token usage for LLM API calls through the litellm library. """ -import warnings from typing import Any from litellm.integrations.custom_logger import CustomLogger from litellm.types.utils import Usage from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess +from crewai.utilities.logger_utils import suppress_warnings class TokenCalcHandler(CustomLogger): @@ -23,12 +23,13 @@ class TokenCalcHandler(CustomLogger): token_cost_process: The token process tracker to accumulate usage metrics. """ - def __init__(self, token_cost_process: TokenProcess | None) -> None: + def __init__(self, token_cost_process: TokenProcess | None, **kwargs: Any) -> None: """Initialize the token calculation handler. Args: token_cost_process: Optional token process tracker for accumulating metrics. """ + super().__init__(**kwargs) self.token_cost_process = token_cost_process def log_success_event( @@ -49,8 +50,7 @@ class TokenCalcHandler(CustomLogger): if self.token_cost_process is None: return - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) + with suppress_warnings(): if isinstance(response_obj, dict) and "usage" in response_obj: usage: Usage = response_obj["usage"] if usage: diff --git a/src/crewai/utilities/tool_utils.py b/src/crewai/utilities/tool_utils.py index c1c20bc66..851492a04 100644 --- a/src/crewai/utilities/tool_utils.py +++ b/src/crewai/utilities/tool_utils.py @@ -1,12 +1,21 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING from crewai.agents.parser import AgentAction -from crewai.security import Fingerprint +from crewai.agents.tools_handler import ToolsHandler +from crewai.security.fingerprint import Fingerprint from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.tool_types import ToolResult from crewai.tools.tool_usage import ToolUsage, ToolUsageError from crewai.utilities.i18n import I18N +if TYPE_CHECKING: + from crewai.agent import Agent + from crewai.llm import LLM + from crewai.llms.base_llm import BaseLLM + from crewai.task import Task + def execute_tool_and_check_finality( agent_action: AgentAction, @@ -14,10 +23,10 @@ def execute_tool_and_check_finality( i18n: I18N, agent_key: str | None = None, agent_role: str | None = None, - tools_handler: Any | None = None, - task: Any | None = None, - agent: Any | None = None, - function_calling_llm: Any | None = None, + tools_handler: ToolsHandler | None = None, + task: Task | None = None, + agent: Agent | None = None, + function_calling_llm: BaseLLM | LLM | None = None, fingerprint_context: dict[str, str] | None = None, ) -> ToolResult: """Execute a tool and check if the result should be treated as a final answer. @@ -32,59 +41,54 @@ def execute_tool_and_check_finality( task: Optional task for tool execution agent: Optional agent instance for tool execution function_calling_llm: Optional LLM for function calling + fingerprint_context: Optional context for fingerprinting Returns: ToolResult containing the execution result and whether it should be treated as a final answer """ - try: - tool_name_to_tool_map = {tool.name: tool for tool in tools} + tool_name_to_tool_map = {tool.name: tool for tool in tools} - if agent_key and agent_role and agent: - fingerprint_context = fingerprint_context or {} - if agent: - if hasattr(agent, "set_fingerprint") and callable( - agent.set_fingerprint - ): - if isinstance(fingerprint_context, dict): - try: - fingerprint_obj = Fingerprint.from_dict(fingerprint_context) - agent.set_fingerprint(fingerprint_obj) - except Exception as e: - raise ValueError(f"Failed to set fingerprint: {e}") from e + if agent_key and agent_role and agent: + fingerprint_context = fingerprint_context or {} + if agent: + if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint): + if isinstance(fingerprint_context, dict): + try: + fingerprint_obj = Fingerprint.from_dict(fingerprint_context) + agent.set_fingerprint(fingerprint=fingerprint_obj) + except Exception as e: + raise ValueError(f"Failed to set fingerprint: {e}") from e - # Create tool usage instance - tool_usage = ToolUsage( - tools_handler=tools_handler, - tools=tools, - function_calling_llm=function_calling_llm, - task=task, - agent=agent, - action=agent_action, - ) + # Create tool usage instance + tool_usage = ToolUsage( + tools_handler=tools_handler, + tools=tools, + function_calling_llm=function_calling_llm, + task=task, + agent=agent, + action=agent_action, + ) - # Parse tool calling - tool_calling = tool_usage.parse_tool_calling(agent_action.text) + # Parse tool calling + tool_calling = tool_usage.parse_tool_calling(agent_action.text) - if isinstance(tool_calling, ToolUsageError): - return ToolResult(tool_calling.message, False) + if isinstance(tool_calling, ToolUsageError): + return ToolResult(tool_calling.message, False) - # Check if tool name matches - if tool_calling.tool_name.casefold().strip() in [ - name.casefold().strip() for name in tool_name_to_tool_map - ] or tool_calling.tool_name.casefold().replace("_", " ") in [ - name.casefold().strip() for name in tool_name_to_tool_map - ]: - tool_result = tool_usage.use(tool_calling, agent_action.text) - tool = tool_name_to_tool_map.get(tool_calling.tool_name) - if tool: - return ToolResult(tool_result, tool.result_as_answer) + # Check if tool name matches + if tool_calling.tool_name.casefold().strip() in [ + name.casefold().strip() for name in tool_name_to_tool_map + ] or tool_calling.tool_name.casefold().replace("_", " ") in [ + name.casefold().strip() for name in tool_name_to_tool_map + ]: + tool_result = tool_usage.use(tool_calling, agent_action.text) + tool = tool_name_to_tool_map.get(tool_calling.tool_name) + if tool: + return ToolResult(tool_result, tool.result_as_answer) - # Handle invalid tool name - tool_result = i18n.errors("wrong_tool_name").format( - tool=tool_calling.tool_name, - tools=", ".join([tool.name.casefold() for tool in tools]), - ) - return ToolResult(tool_result, False) - - except Exception as e: - raise e + # Handle invalid tool name + tool_result = i18n.errors("wrong_tool_name").format( + tool=tool_calling.tool_name, + tools=", ".join([tool.name.casefold() for tool in tools]), + ) + return ToolResult(result=tool_result, result_as_answer=False) diff --git a/src/crewai/utilities/training_converter.py b/src/crewai/utilities/training_converter.py index 4aef94cd5..05f74fa53 100644 --- a/src/crewai/utilities/training_converter.py +++ b/src/crewai/utilities/training_converter.py @@ -1,42 +1,72 @@ import json import re -from typing import Any, get_origin +from typing import Any, Final, get_origin from pydantic import BaseModel, ValidationError from crewai.utilities.converter import Converter, ConverterError +_FLOAT_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\d+(?:\.\d+)?)") + class TrainingConverter(Converter): - """ - A specialized converter for smaller LLMs (up to 7B parameters) that handles validation errors + """A specialized converter for smaller LLMs (up to 7B parameters) that handles validation errors by breaking down the model into individual fields and querying the LLM for each field separately. """ - def to_pydantic(self, current_attempt=1) -> BaseModel: + def to_pydantic(self, current_attempt: int = 1) -> BaseModel: + """Convert the text to a Pydantic model, with fallback to field-by-field extraction on failure. + + Args: + current_attempt: The current attempt number for conversion. + + Returns: + An instance of the Pydantic model. + + Raises: + ConverterError: If conversion fails after field-by-field extraction. + """ try: return super().to_pydantic(current_attempt) except ConverterError: return self._convert_field_by_field() def _convert_field_by_field(self) -> BaseModel: - field_values = {} + field_values: dict[str, Any] = {} for field_name, field_info in self.model.model_fields.items(): - field_description = field_info.description - field_type = field_info.annotation + field_description: str | None = field_info.description + field_type: type | None = field_info.annotation - response = self._ask_llm_for_field(field_name, field_description) - value = self._process_field_value(response, field_type) + if field_description is None: + raise ValueError(f"Field '{field_name}' has no description") + response: str = self._ask_llm_for_field( + field_name=field_name, field_description=field_description + ) + value: Any = self._process_field_value( + response=response, field_type=field_type + ) field_values[field_name] = value try: return self.model(**field_values) except ValidationError as e: - raise ConverterError(f"Failed to create model from individually collected fields: {e}") + raise ConverterError( + f"Failed to create model from individually collected fields: {e}" + ) from e def _ask_llm_for_field(self, field_name: str, field_description: str) -> str: - prompt = f""" + """Query the LLM for a specific field value based on its description. + + Args: + field_name: The name of the field to extract. + field_description: The description of the field to guide extraction. + + Returns: + The LLM's response containing the field value. + """ + + prompt: str = f""" Based on the following information: {self.text} @@ -45,14 +75,19 @@ Please provide ONLY the {field_name} field value as described: Respond with ONLY the requested information, nothing else. """ - return self.llm.call([ - {"role": "system", "content": f"Extract the {field_name} from the previous information."}, - {"role": "user", "content": prompt} - ]) + return self.llm.call( + [ + { + "role": "system", + "content": f"Extract the {field_name} from the previous information.", + }, + {"role": "user", "content": prompt}, + ] + ) - def _process_field_value(self, response: str, field_type: Any) -> Any: + def _process_field_value(self, response: str, field_type: type | None) -> Any: response = response.strip() - origin = get_origin(field_type) + origin: type[Any] | None = get_origin(field_type) if origin is list: return self._parse_list(response) @@ -65,25 +100,45 @@ Respond with ONLY the requested information, nothing else. return response - def _parse_list(self, response: str) -> list: + def _parse_list(self, response: str) -> list[Any]: try: - if response.startswith('['): + if response.startswith("["): return json.loads(response) - items = [item.strip() for item in response.split('\n') if item.strip()] + items: list[str] = [ + item.strip() for item in response.split("\n") if item.strip() + ] return [self._strip_bullet(item) for item in items] except json.JSONDecodeError: return [response] - def _parse_float(self, response: str) -> float: + @staticmethod + def _parse_float(response: str) -> float: + """Parse a float from the response, extracting the first numeric value found. + + Args: + response: The response string from which to extract the float. + + Returns: + The extracted float value, or 0.0 if no valid float is found. + """ try: - match = re.search(r'(\d+(\.\d+)?)', response) + match = _FLOAT_PATTERN.search(response) return float(match.group(1)) if match else 0.0 - except Exception: + except (ValueError, AttributeError): return 0.0 - def _strip_bullet(self, item: str) -> str: - if item.startswith(('- ', '* ')): + @staticmethod + def _strip_bullet(item: str) -> str: + """Remove common bullet point characters from the start of a string. + + Args: + item: The string item to process. + + Returns: + The string without leading bullet characters. + """ + if item.startswith(("- ", "* ")): return item[2:].strip() - return item.strip() \ No newline at end of file + return item.strip() diff --git a/src/crewai/utilities/training_handler.py b/src/crewai/utilities/training_handler.py index 2d34f3261..4bc87d237 100644 --- a/src/crewai/utilities/training_handler.py +++ b/src/crewai/utilities/training_handler.py @@ -1,35 +1,33 @@ import os +from typing import Any from crewai.utilities.file_handler import PickleHandler class CrewTrainingHandler(PickleHandler): - def save_trained_data(self, agent_id: str, trained_data: dict) -> None: - """ - Save the trained data for a specific agent. + def save_trained_data(self, agent_id: str, trained_data: dict[int, Any]) -> None: + """Save the trained data for a specific agent. - Parameters: - - agent_id (str): The ID of the agent. - - trained_data (dict): The trained data to be saved. + Args: + agent_id: The ID of the agent. + trained_data: The trained data to be saved. """ data = self.load() data[agent_id] = trained_data self.save(data) - def append(self, train_iteration: int, agent_id: str, new_data) -> None: - """ - Append new data to the existing pickle file. + def append(self, train_iteration: int, agent_id: str, new_data: Any) -> None: + """Append new training data for a specific agent and iteration. - Parameters: - - new_data (object): The new data to be appended. + Args: + train_iteration: The training iteration number. + agent_id: The ID of the agent. + new_data: The new training data to append. """ data = self.load() - - if agent_id in data: - data[agent_id][train_iteration] = new_data - else: - data[agent_id] = {train_iteration: new_data} - + if agent_id not in data: + data[agent_id] = {} + data[agent_id][train_iteration] = new_data self.save(data) def clear(self) -> None: diff --git a/src/crewai/utilities/types.py b/src/crewai/utilities/types.py new file mode 100644 index 000000000..0cdaa1878 --- /dev/null +++ b/src/crewai/utilities/types.py @@ -0,0 +1,15 @@ +"""Types for CrewAI utilities.""" + +from typing import Literal, TypedDict + + +class LLMMessage(TypedDict): + """Type for formatted LLM messages. + + Notes: + - TODO: Update the LLM.call & BaseLLM.call signatures to use this type + instead of str | list[dict[str, str]] + """ + + role: Literal["user", "assistant", "system"] + content: str diff --git a/tests/agents/test_agent_reasoning.py b/tests/agents/test_agent_reasoning.py index e6c6b77a4..62e6e9f89 100644 --- a/tests/agents/test_agent_reasoning.py +++ b/tests/agents/test_agent_reasoning.py @@ -1,11 +1,11 @@ """Tests for reasoning in agents.""" import json + import pytest from crewai import Agent, Task from crewai.llm import LLM -from crewai.utilities.reasoning_handler import AgentReasoning @pytest.fixture @@ -79,10 +79,8 @@ def test_agent_with_reasoning_not_ready_initially(mock_llm_responses): call_count[0] += 1 if call_count[0] == 1: return mock_llm_responses["not_ready"] - else: - return mock_llm_responses["ready_after_refine"] - else: - return "2x" + return mock_llm_responses["ready_after_refine"] + return "2x" agent.llm.call = mock_llm_call @@ -121,8 +119,7 @@ def test_agent_with_reasoning_max_attempts_reached(): ) or any("refine your plan" in msg.get("content", "") for msg in messages): call_count[0] += 1 return f"Attempt {call_count[0]}: I need more time to think.\n\nNOT READY: I need to refine my plan further." - else: - return "This is an unsolved problem in mathematics." + return "This is an unsolved problem in mathematics." agent.llm.call = mock_llm_call @@ -135,26 +132,6 @@ def test_agent_with_reasoning_max_attempts_reached(): assert "Reasoning Plan:" in task.description -def test_agent_reasoning_input_validation(): - """Test input validation in AgentReasoning.""" - llm = LLM("gpt-3.5-turbo") - - agent = Agent( - role="Test Agent", - goal="To test the reasoning feature", - backstory="I am a test agent created to verify the reasoning feature works correctly.", - llm=llm, - reasoning=True, - ) - - with pytest.raises(ValueError, match="Both task and agent must be provided"): - AgentReasoning(task=None, agent=agent) - - task = Task(description="Simple task", expected_output="Simple output") - with pytest.raises(ValueError, match="Both task and agent must be provided"): - AgentReasoning(task=task, agent=None) - - def test_agent_reasoning_error_handling(): """Test error handling during the reasoning process.""" llm = LLM("gpt-3.5-turbo") @@ -215,8 +192,7 @@ def test_agent_with_function_calling(): return json.dumps( {"plan": "I'll solve this simple math problem: 2+2=4.", "ready": True} ) - else: - return "4" + return "4" agent.llm.call = mock_function_call @@ -251,8 +227,7 @@ def test_agent_with_function_calling_fallback(): def mock_function_call(messages, *args, **kwargs): if "tools" in kwargs: return "Invalid JSON that will trigger fallback. READY: I am ready to execute the task." - else: - return "4" + return "4" agent.llm.call = mock_function_call diff --git a/tests/memory/test_short_term_memory.py b/tests/memory/test_short_term_memory.py index 18dc28fa8..b50f6d2fe 100644 --- a/tests/memory/test_short_term_memory.py +++ b/tests/memory/test_short_term_memory.py @@ -39,7 +39,7 @@ def short_term_memory(): def test_short_term_memory_search_events(short_term_memory): events = defaultdict(list) - with patch("crewai.rag.chromadb.client.ChromaDBClient.search", return_value=[]): + with patch.object(short_term_memory.storage, "search", return_value=[]): with crewai_event_bus.scoped_handlers(): @crewai_event_bus.on(MemoryQueryStartedEvent) diff --git a/tests/test_llm.py b/tests/test_llm.py index 0c333af06..065687565 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -7,15 +7,14 @@ import pytest from pydantic import BaseModel from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess -from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM from crewai.events.event_types import ( LLMCallCompletedEvent, LLMStreamChunkEvent, - ToolUsageStartedEvent, - ToolUsageFinishedEvent, ToolUsageErrorEvent, + ToolUsageFinishedEvent, + ToolUsageStartedEvent, ) - +from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM from crewai.utilities.token_counter_callback import TokenCalcHandler @@ -376,11 +375,11 @@ def get_weather_tool_schema(): def test_context_window_exceeded_error_handling(): - """Test that litellm.ContextWindowExceededError is converted to LLMContextLengthExceededException.""" + """Test that litellm.ContextWindowExceededError is converted to LLMContextLengthExceededError.""" from litellm.exceptions import ContextWindowExceededError from crewai.utilities.exceptions.context_window_exceeding_exception import ( - LLMContextLengthExceededException, + LLMContextLengthExceededError, ) llm = LLM(model="gpt-4") @@ -393,7 +392,7 @@ def test_context_window_exceeded_error_handling(): llm_provider="openai", ) - with pytest.raises(LLMContextLengthExceededException) as excinfo: + with pytest.raises(LLMContextLengthExceededError) as excinfo: llm.call("This is a test message") assert "context length exceeded" in str(excinfo.value).lower() @@ -408,7 +407,7 @@ def test_context_window_exceeded_error_handling(): llm_provider="openai", ) - with pytest.raises(LLMContextLengthExceededException) as excinfo: + with pytest.raises(LLMContextLengthExceededError) as excinfo: llm.call("This is a test message") assert "context length exceeded" in str(excinfo.value).lower()