chore: improve typing and consolidate utilities

- add type annotations across utility modules  
- refactor printer system, agent utils, and imports for consistency  
- remove unused modules, constants, and redundant patterns  
- improve runtime type checks, exception handling, and guardrail validation  
- standardize warning suppression and logging utilities  
- fix llm typing, threading/typing edge cases, and test behavior
This commit is contained in:
Greyson LaLonde
2025-09-23 11:33:46 -04:00
committed by GitHub
parent 34bed359a6
commit 3e97393f58
47 changed files with 1939 additions and 1233 deletions

View File

@@ -18,7 +18,7 @@ from crewai.agents.constants import (
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE, MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
UNABLE_TO_REPAIR_JSON_RESULTS, UNABLE_TO_REPAIR_JSON_RESULTS,
) )
from crewai.utilities import I18N from crewai.utilities.i18n import I18N
_I18N = I18N() _I18N = I18N()

View File

@@ -1,28 +1,26 @@
import io
import json import json
import logging import logging
import os import os
import sys import sys
import threading import threading
import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from collections.abc import Callable
from datetime import datetime
from typing import ( from typing import (
Any, Any,
DefaultDict, Final,
Dict,
List,
Literal, Literal,
Optional, TextIO,
Type,
TypedDict, TypedDict,
Union,
cast, cast,
) )
from datetime import datetime
from dotenv import load_dotenv from dotenv import load_dotenv
from litellm.types.utils import ChatCompletionDeltaToolCall from litellm.types.utils import ChatCompletionDeltaToolCall
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import ( from crewai.events.types.llm_events import (
LLMCallCompletedEvent, LLMCallCompletedEvent,
LLMCallFailedEvent, LLMCallFailedEvent,
@@ -31,15 +29,19 @@ from crewai.events.types.llm_events import (
LLMStreamChunkEvent, LLMStreamChunkEvent,
) )
from crewai.events.types.tool_usage_events import ( from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent, ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
) )
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.logger_utils import suppress_warnings
with warnings.catch_warnings(): with suppress_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm import litellm
from litellm import Choices from litellm import Choices, CustomLogger
from litellm.exceptions import ContextWindowExceededError from litellm.exceptions import ContextWindowExceededError
from litellm.litellm_core_utils.get_supported_openai_params import ( from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params, get_supported_openai_params,
@@ -47,16 +49,6 @@ with warnings.catch_warnings():
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema from litellm.utils import supports_response_schema
import io
from typing import TextIO
from crewai.llms.base_llm import BaseLLM
from crewai.events.event_bus import crewai_event_bus
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
load_dotenv() load_dotenv()
litellm.suppress_debug_info = True litellm.suppress_debug_info = True
@@ -126,7 +118,11 @@ if not isinstance(sys.stderr, FilteredStream):
sys.stderr = FilteredStream(sys.stderr) sys.stderr = FilteredStream(sys.stderr)
LLM_CONTEXT_WINDOW_SIZES = { MIN_CONTEXT: Final[int] = 1024
MAX_CONTEXT: Final[int] = 2097152 # Current max from gemini-1.5-pro
ANTHROPIC_PREFIXES: Final[tuple[str, str, str]] = ("anthropic/", "claude-", "claude/")
LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
# openai # openai
"gpt-4": 8192, "gpt-4": 8192,
"gpt-4o": 128000, "gpt-4o": 128000,
@@ -252,30 +248,19 @@ LLM_CONTEXT_WINDOW_SIZES = {
"mistral/mistral-large-2402": 32768, "mistral/mistral-large-2402": 32768,
} }
DEFAULT_CONTEXT_WINDOW_SIZE = 8192 DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 8192
CONTEXT_WINDOW_USAGE_RATIO = 0.85 CONTEXT_WINDOW_USAGE_RATIO: Final[float] = 0.85
@contextmanager
def suppress_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
warnings.filterwarnings(
"ignore", message="open_text is deprecated*", category=DeprecationWarning
)
yield
class Delta(TypedDict): class Delta(TypedDict):
content: Optional[str] content: str | None
role: Optional[str] role: str | None
class StreamingChoices(TypedDict): class StreamingChoices(TypedDict):
delta: Delta delta: Delta
index: int index: int
finish_reason: Optional[str] finish_reason: str | None
class FunctionArgs(BaseModel): class FunctionArgs(BaseModel):
@@ -288,31 +273,31 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM): class LLM(BaseLLM):
completion_cost: Optional[float] = None completion_cost: float | None = None
def __init__( def __init__(
self, self,
model: str, model: str,
timeout: Optional[Union[float, int]] = None, timeout: float | int | None = None,
temperature: Optional[float] = None, temperature: float | None = None,
top_p: Optional[float] = None, top_p: float | None = None,
n: Optional[int] = None, n: int | None = None,
stop: Optional[Union[str, List[str]]] = None, stop: str | list[str] | None = None,
max_completion_tokens: Optional[int] = None, max_completion_tokens: int | None = None,
max_tokens: Optional[int] = None, max_tokens: int | None = None,
presence_penalty: Optional[float] = None, presence_penalty: float | None = None,
frequency_penalty: Optional[float] = None, frequency_penalty: float | None = None,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: dict[int, float] | None = None,
response_format: Optional[Type[BaseModel]] = None, response_format: type[BaseModel] | None = None,
seed: Optional[int] = None, seed: int | None = None,
logprobs: Optional[int] = None, logprobs: int | None = None,
top_logprobs: Optional[int] = None, top_logprobs: int | None = None,
base_url: Optional[str] = None, base_url: str | None = None,
api_base: Optional[str] = None, api_base: str | None = None,
api_version: Optional[str] = None, api_version: str | None = None,
api_key: Optional[str] = None, api_key: str | None = None,
callbacks: List[Any] | None = None, callbacks: list[Any] | None = None,
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False, stream: bool = False,
**kwargs, **kwargs,
): ):
@@ -345,7 +330,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str] # Normalize self.stop to always be a List[str]
if stop is None: if stop is None:
self.stop: List[str] = [] self.stop: list[str] = []
elif isinstance(stop, str): elif isinstance(stop, str):
self.stop = [stop] self.stop = [stop]
else: else:
@@ -354,7 +339,8 @@ class LLM(BaseLLM):
self.set_callbacks(callbacks or []) self.set_callbacks(callbacks or [])
self.set_env_callbacks() self.set_env_callbacks()
def _is_anthropic_model(self, model: str) -> bool: @staticmethod
def _is_anthropic_model(model: str) -> bool:
"""Determine if the model is from Anthropic provider. """Determine if the model is from Anthropic provider.
Args: Args:
@@ -363,21 +349,18 @@ class LLM(BaseLLM):
Returns: Returns:
bool: True if the model is from Anthropic, False otherwise. bool: True if the model is from Anthropic, False otherwise.
""" """
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _prepare_completion_params( def _prepare_completion_params(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: str | list[dict[str, str]],
tools: Optional[List[dict]] = None, tools: list[dict] | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Prepare parameters for the completion call. """Prepare parameters for the completion call.
Args: Args:
messages: Input messages for the LLM messages: Input messages for the LLM
tools: Optional list of tool schemas tools: Optional list of tool schemas
callbacks: Optional list of callback functions
available_functions: Optional dict of available functions
Returns: Returns:
Dict[str, Any]: Parameters for the completion call Dict[str, Any]: Parameters for the completion call
@@ -419,11 +402,11 @@ class LLM(BaseLLM):
def _handle_streaming_response( def _handle_streaming_response(
self, self,
params: Dict[str, Any], params: dict[str, Any],
callbacks: Optional[List[Any]] = None, callbacks: list[Any] | None = None,
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
from_task: Optional[Any] = None, from_task: Any | None = None,
from_agent: Optional[Any] = None, from_agent: Any | None = None,
) -> str: ) -> str:
"""Handle a streaming response from the LLM. """Handle a streaming response from the LLM.
@@ -445,9 +428,8 @@ class LLM(BaseLLM):
last_chunk = None last_chunk = None
chunk_count = 0 chunk_count = 0
usage_info = None usage_info = None
tool_calls = None
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict( accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs AccumulatedToolArgs
) )
@@ -472,16 +454,16 @@ class LLM(BaseLLM):
choices = chunk["choices"] choices = chunk["choices"]
elif hasattr(chunk, "choices"): elif hasattr(chunk, "choices"):
# Check if choices is not a type but an actual attribute with value # Check if choices is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "choices"), type): if not isinstance(chunk.choices, type):
choices = getattr(chunk, "choices") choices = chunk.choices
# Try to extract usage information if available # Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk: if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"] usage_info = chunk["usage"]
elif hasattr(chunk, "usage"): elif hasattr(chunk, "usage"):
# Check if usage is not a type but an actual attribute with value # Check if usage is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "usage"), type): if not isinstance(chunk.usage, type):
usage_info = getattr(chunk, "usage") usage_info = chunk.usage
if choices and len(choices) > 0: if choices and len(choices) > 0:
choice = choices[0] choice = choices[0]
@@ -491,7 +473,7 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "delta" in choice: if isinstance(choice, dict) and "delta" in choice:
delta = choice["delta"] delta = choice["delta"]
elif hasattr(choice, "delta"): elif hasattr(choice, "delta"):
delta = getattr(choice, "delta") delta = choice.delta
# Extract content from delta # Extract content from delta
if delta: if delta:
@@ -501,7 +483,7 @@ class LLM(BaseLLM):
chunk_content = delta["content"] chunk_content = delta["content"]
# Handle object format # Handle object format
elif hasattr(delta, "content"): elif hasattr(delta, "content"):
chunk_content = getattr(delta, "content") chunk_content = delta.content
# Handle case where content might be None or empty # Handle case where content might be None or empty
if chunk_content is None and isinstance(delta, dict): if chunk_content is None and isinstance(delta, dict):
@@ -533,7 +515,9 @@ class LLM(BaseLLM):
full_response += chunk_content full_response += chunk_content
# Emit the chunk event # Emit the chunk event
assert hasattr(crewai_event_bus, "emit") if not hasattr(crewai_event_bus, "emit"):
raise Exception("crewai_event_bus must have an `emit` method")
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMStreamChunkEvent( event=LLMStreamChunkEvent(
@@ -572,8 +556,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk: if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"] choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"): elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type): if not isinstance(last_chunk.choices, type):
choices = getattr(last_chunk, "choices") choices = last_chunk.choices
if choices and len(choices) > 0: if choices and len(choices) > 0:
choice = choices[0] choice = choices[0]
@@ -583,14 +567,14 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice: if isinstance(choice, dict) and "message" in choice:
message = choice["message"] message = choice["message"]
elif hasattr(choice, "message"): elif hasattr(choice, "message"):
message = getattr(choice, "message") message = choice.message
if message: if message:
content = None content = None
if isinstance(message, dict) and "content" in message: if isinstance(message, dict) and "content" in message:
content = message["content"] content = message["content"]
elif hasattr(message, "content"): elif hasattr(message, "content"):
content = getattr(message, "content") content = message.content
if content: if content:
full_response = content full_response = content
@@ -617,8 +601,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk: if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"] choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"): elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type): if not isinstance(last_chunk.choices, type):
choices = getattr(last_chunk, "choices") choices = last_chunk.choices
if choices and len(choices) > 0: if choices and len(choices) > 0:
choice = choices[0] choice = choices[0]
@@ -627,13 +611,13 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice: if isinstance(choice, dict) and "message" in choice:
message = choice["message"] message = choice["message"]
elif hasattr(choice, "message"): elif hasattr(choice, "message"):
message = getattr(choice, "message") message = choice.message
if message: if message:
if isinstance(message, dict) and "tool_calls" in message: if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message["tool_calls"] tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"): elif hasattr(message, "tool_calls"):
tool_calls = getattr(message, "tool_calls") tool_calls = message.tool_calls
except Exception as e: except Exception as e:
logging.debug(f"Error checking for tool calls: {e}") logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly # --- 8) If no tool calls or no available functions, return the text response directly
@@ -673,11 +657,11 @@ class LLM(BaseLLM):
# Catch context window errors from litellm and convert them to our own exception type. # Catch context window errors from litellm and convert them to our own exception type.
# This exception is handled by CrewAgentExecutor._invoke_loop() which can then # This exception is handled by CrewAgentExecutor._invoke_loop() which can then
# decide whether to summarize the content or abort based on the respect_context_window flag. # decide whether to summarize the content or abort based on the respect_context_window flag.
raise LLMContextLengthExceededException(str(e)) raise LLMContextLengthExceededError(str(e)) from e
except Exception as e: except Exception as e:
logging.error(f"Error in streaming response: {str(e)}") logging.error(f"Error in streaming response: {e!s}")
if full_response.strip(): if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}") logging.warning(f"Returning partial response despite error: {e!s}")
self._handle_emit_call_events( self._handle_emit_call_events(
response=full_response, response=full_response,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
@@ -688,22 +672,25 @@ class LLM(BaseLLM):
return full_response return full_response
# Emit failed event and re-raise the exception # Emit failed event and re-raise the exception
assert hasattr(crewai_event_bus, "emit") if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMCallFailedEvent( event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent error=str(e), from_task=from_task, from_agent=from_agent
), ),
) )
raise Exception(f"Failed to get streaming response: {str(e)}") raise Exception(f"Failed to get streaming response: {e!s}") from e
def _handle_streaming_tool_calls( def _handle_streaming_tool_calls(
self, self,
tool_calls: List[ChatCompletionDeltaToolCall], tool_calls: list[ChatCompletionDeltaToolCall],
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs], accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
from_task: Optional[Any] = None, from_task: Any | None = None,
from_agent: Optional[Any] = None, from_agent: Any | None = None,
) -> None | str: ) -> None | str:
for tool_call in tool_calls: for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index] current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -715,7 +702,8 @@ class LLM(BaseLLM):
current_tool_accumulator.function.arguments += ( current_tool_accumulator.function.arguments += (
tool_call.function.arguments tool_call.function.arguments
) )
assert hasattr(crewai_event_bus, "emit") if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMStreamChunkEvent( event=LLMStreamChunkEvent(
@@ -742,11 +730,11 @@ class LLM(BaseLLM):
continue continue
return None return None
@staticmethod
def _handle_streaming_callbacks( def _handle_streaming_callbacks(
self, callbacks: list[Any] | None,
callbacks: Optional[List[Any]], usage_info: dict[str, Any] | None,
usage_info: Optional[Dict[str, Any]], last_chunk: Any | None,
last_chunk: Optional[Any],
) -> None: ) -> None:
"""Handle callbacks with usage info for streaming responses. """Handle callbacks with usage info for streaming responses.
@@ -769,10 +757,8 @@ class LLM(BaseLLM):
): ):
usage_info = last_chunk["usage"] usage_info = last_chunk["usage"]
elif hasattr(last_chunk, "usage"): elif hasattr(last_chunk, "usage"):
if not isinstance( if not isinstance(last_chunk.usage, type):
getattr(last_chunk, "usage"), type usage_info = last_chunk.usage
):
usage_info = getattr(last_chunk, "usage")
except Exception as e: except Exception as e:
logging.debug(f"Error extracting usage info: {e}") logging.debug(f"Error extracting usage info: {e}")
@@ -786,11 +772,11 @@ class LLM(BaseLLM):
def _handle_non_streaming_response( def _handle_non_streaming_response(
self, self,
params: Dict[str, Any], params: dict[str, Any],
callbacks: Optional[List[Any]] = None, callbacks: list[Any] | None = None,
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
from_task: Optional[Any] = None, from_task: Any | None = None,
from_agent: Optional[Any] = None, from_agent: Any | None = None,
) -> str | Any: ) -> str | Any:
"""Handle a non-streaming response from the LLM. """Handle a non-streaming response from the LLM.
@@ -815,7 +801,7 @@ class LLM(BaseLLM):
except ContextWindowExceededError as e: except ContextWindowExceededError as e:
# Convert litellm's context window error to our own exception type # Convert litellm's context window error to our own exception type
# for consistent handling in the rest of the codebase # for consistent handling in the rest of the codebase
raise LLMContextLengthExceededException(str(e)) raise LLMContextLengthExceededError(str(e)) from e
# --- 2) Extract response message and content # --- 2) Extract response message and content
response_message = cast(Choices, cast(ModelResponse, response).choices)[ response_message = cast(Choices, cast(ModelResponse, response).choices)[
0 0
@@ -847,7 +833,7 @@ class LLM(BaseLLM):
) )
return text_response return text_response
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls # --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
elif tool_calls and not available_functions and not text_response: if tool_calls and not available_functions and not text_response:
return tool_calls return tool_calls
# --- 7) Handle tool calls if present # --- 7) Handle tool calls if present
@@ -868,19 +854,21 @@ class LLM(BaseLLM):
def _handle_tool_call( def _handle_tool_call(
self, self,
tool_calls: List[Any], tool_calls: list[Any],
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
from_task: Optional[Any] = None, from_task: Any | None = None,
from_agent: Optional[Any] = None, from_agent: Any | None = None,
) -> Optional[str]: ) -> str | None:
"""Handle a tool call from the LLM. """Handle a tool call from the LLM.
Args: Args:
tool_calls: List of tool calls from the LLM tool_calls: List of tool calls from the LLM
available_functions: Dict of available functions available_functions: Dict of available functions
from_task: Optional Task that invoked the LLM
from_agent: Optional Agent that invoked the LLM
Returns: Returns:
Optional[str]: The result of the tool call, or None if no tool call was made The result of the tool call, or None if no tool call was made
""" """
# --- 1) Validate tool calls and available functions # --- 1) Validate tool calls and available functions
if not tool_calls or not available_functions: if not tool_calls or not available_functions:
@@ -899,7 +887,8 @@ class LLM(BaseLLM):
fn = available_functions[function_name] fn = available_functions[function_name]
# --- 3.2) Execute function # --- 3.2) Execute function
assert hasattr(crewai_event_bus, "emit") if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
started_at = datetime.now() started_at = datetime.now()
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
@@ -939,17 +928,20 @@ class LLM(BaseLLM):
function_name, lambda: None function_name, lambda: None
) # Ensure fn is always a callable ) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}") logging.error(f"Error executing function '{function_name}': {e}")
assert hasattr(crewai_event_bus, "emit") if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"), event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
) )
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=ToolUsageErrorEvent( event=ToolUsageErrorEvent(
tool_name=function_name, tool_name=function_name,
tool_args=function_args, tool_args=function_args,
error=f"Tool execution error: {str(e)}", error=f"Tool execution error: {e!s}",
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
), ),
@@ -958,13 +950,13 @@ class LLM(BaseLLM):
def call( def call(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: str | list[dict[str, str]],
tools: Optional[List[dict]] = None, tools: list[dict] | None = None,
callbacks: Optional[List[Any]] = None, callbacks: list[Any] | None = None,
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
from_task: Optional[Any] = None, from_task: Any | None = None,
from_agent: Optional[Any] = None, from_agent: Any | None = None,
) -> Union[str, Any]: ) -> str | Any:
"""High-level LLM call method. """High-level LLM call method.
Args: Args:
@@ -988,10 +980,11 @@ class LLM(BaseLLM):
Raises: Raises:
TypeError: If messages format is invalid TypeError: If messages format is invalid
ValueError: If response format is not supported ValueError: If response format is not supported
LLMContextLengthExceededException: If input exceeds model's context limit LLMContextLengthExceededError: If input exceeds model's context limit
""" """
# --- 1) Emit call started event # --- 1) Emit call started event
assert hasattr(crewai_event_bus, "emit") if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMCallStartedEvent( event=LLMCallStartedEvent(
@@ -1028,13 +1021,12 @@ class LLM(BaseLLM):
return self._handle_streaming_response( return self._handle_streaming_response(
params, callbacks, available_functions, from_task, from_agent params, callbacks, available_functions, from_task, from_agent
) )
else: return self._handle_non_streaming_response(
return self._handle_non_streaming_response( params, callbacks, available_functions, from_task, from_agent
params, callbacks, available_functions, from_task, from_agent )
)
except LLMContextLengthExceededException: except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededException as it should be handled # Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide # by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag # whether to summarize the content or abort based on the respect_context_window flag
raise raise
@@ -1065,7 +1057,10 @@ class LLM(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
) )
assert hasattr(crewai_event_bus, "emit") if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMCallFailedEvent( event=LLMCallFailedEvent(
@@ -1078,8 +1073,8 @@ class LLM(BaseLLM):
self, self,
response: Any, response: Any,
call_type: LLMCallType, call_type: LLMCallType,
from_task: Optional[Any] = None, from_task: Any | None = None,
from_agent: Optional[Any] = None, from_agent: Any | None = None,
messages: str | list[dict[str, Any]] | None = None, messages: str | list[dict[str, Any]] | None = None,
): ):
"""Handle the events for the LLM call. """Handle the events for the LLM call.
@@ -1091,7 +1086,8 @@ class LLM(BaseLLM):
from_agent: Optional agent object from_agent: Optional agent object
messages: Optional messages object messages: Optional messages object
""" """
assert hasattr(crewai_event_bus, "emit") if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMCallCompletedEvent( event=LLMCallCompletedEvent(
@@ -1105,8 +1101,8 @@ class LLM(BaseLLM):
) )
def _format_messages_for_provider( def _format_messages_for_provider(
self, messages: List[Dict[str, str]] self, messages: list[dict[str, str]]
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
"""Format messages according to provider requirements. """Format messages according to provider requirements.
Args: Args:
@@ -1147,7 +1143,7 @@ class LLM(BaseLLM):
if "mistral" in self.model.lower(): if "mistral" in self.model.lower():
# Check if the last message has a role of 'assistant' # Check if the last message has a role of 'assistant'
if messages and messages[-1]["role"] == "assistant": if messages and messages[-1]["role"] == "assistant":
return messages + [{"role": "user", "content": "Please continue."}] return [*messages, {"role": "user", "content": "Please continue."}]
return messages return messages
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917 # TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
@@ -1157,7 +1153,7 @@ class LLM(BaseLLM):
and messages and messages
and messages[-1]["role"] == "assistant" and messages[-1]["role"] == "assistant"
): ):
return messages + [{"role": "user", "content": ""}] return [*messages, {"role": "user", "content": ""}]
# Handle Anthropic models # Handle Anthropic models
if not self.is_anthropic: if not self.is_anthropic:
@@ -1170,7 +1166,7 @@ class LLM(BaseLLM):
return messages return messages
def _get_custom_llm_provider(self) -> Optional[str]: def _get_custom_llm_provider(self) -> str | None:
""" """
Derives the custom_llm_provider from the model string. Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter". - For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
@@ -1207,7 +1203,7 @@ class LLM(BaseLLM):
self.model, custom_llm_provider=provider self.model, custom_llm_provider=provider
) )
except Exception as e: except Exception as e:
logging.error(f"Failed to check function calling support: {str(e)}") logging.error(f"Failed to check function calling support: {e!s}")
return False return False
def supports_stop_words(self) -> bool: def supports_stop_words(self) -> bool:
@@ -1215,7 +1211,7 @@ class LLM(BaseLLM):
params = get_supported_openai_params(model=self.model) params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params return params is not None and "stop" in params
except Exception as e: except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}") logging.error(f"Failed to get supported params: {e!s}")
return False return False
def get_context_window_size(self) -> int: def get_context_window_size(self) -> int:
@@ -1229,9 +1225,6 @@ class LLM(BaseLLM):
if self.context_window_size != 0: if self.context_window_size != 0:
return self.context_window_size return self.context_window_size
MIN_CONTEXT = 1024
MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro
# Validate all context window sizes # Validate all context window sizes
for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < MIN_CONTEXT or value > MAX_CONTEXT: if value < MIN_CONTEXT or value > MAX_CONTEXT:
@@ -1247,7 +1240,8 @@ class LLM(BaseLLM):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size return self.context_window_size
def set_callbacks(self, callbacks: List[Any]): @staticmethod
def set_callbacks(callbacks: list[Any]):
""" """
Attempt to keep a single set of callbacks in litellm by removing old Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones. duplicates and adding new ones.
@@ -1264,9 +1258,9 @@ class LLM(BaseLLM):
litellm.callbacks = callbacks litellm.callbacks = callbacks
def set_env_callbacks(self): @staticmethod
""" def set_env_callbacks() -> None:
Sets the success and failure callbacks for the LiteLLM library from environment variables. """Sets the success and failure callbacks for the LiteLLM library from environment variables.
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS` This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
environment variables, which should contain comma-separated lists of callback names. environment variables, which should contain comma-separated lists of callback names.
@@ -1276,7 +1270,7 @@ class LLM(BaseLLM):
If the environment variables are not set or are empty, the corresponding callback lists If the environment variables are not set or are empty, the corresponding callback lists
will be set to empty lists. will be set to empty lists.
Example: Examples:
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith" LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
LITELLM_FAILURE_CALLBACKS="langfuse" LITELLM_FAILURE_CALLBACKS="langfuse"
@@ -1285,16 +1279,15 @@ class LLM(BaseLLM):
""" """
with suppress_warnings(): with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "") success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
success_callbacks = [] success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
if success_callbacks_str: if success_callbacks_str:
success_callbacks = [ success_callbacks = [
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip() cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
] ]
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "") failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
failure_callbacks = []
if failure_callbacks_str: if failure_callbacks_str:
failure_callbacks = [ failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip() cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
] ]

View File

@@ -1,26 +1,24 @@
from .converter import Converter, ConverterError from crewai.utilities.converter import Converter, ConverterError
from .file_handler import FileHandler from crewai.utilities.exceptions.context_window_exceeding_exception import (
from .i18n import I18N LLMContextLengthExceededError,
from .internal_instructor import InternalInstructor
from .logger import Logger
from .parser import YamlParser
from .printer import Printer
from .prompts import Prompts
from .rpm_controller import RPMController
from .exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
) )
from crewai.utilities.file_handler import FileHandler
from crewai.utilities.i18n import I18N
from crewai.utilities.internal_instructor import InternalInstructor
from crewai.utilities.logger import Logger
from crewai.utilities.printer import Printer
from crewai.utilities.prompts import Prompts
from crewai.utilities.rpm_controller import RPMController
__all__ = [ __all__ = [
"I18N",
"Converter", "Converter",
"ConverterError", "ConverterError",
"FileHandler", "FileHandler",
"I18N",
"InternalInstructor", "InternalInstructor",
"LLMContextLengthExceededError",
"Logger", "Logger",
"Printer", "Printer",
"Prompts", "Prompts",
"RPMController", "RPMController",
"YamlParser",
"LLMContextLengthExceededException",
] ]

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import json import json
import re import re
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
from rich.console import Console from rich.console import Console
@@ -19,18 +21,47 @@ from crewai.tools import BaseTool as CrewAITool
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer
from crewai.utilities.errors import AgentRepositoryError from crewai.utilities.errors import AgentRepositoryError
from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException, LLMContextLengthExceededError,
) )
from crewai.utilities.i18n import I18N
from crewai.utilities.printer import ColoredText, Printer
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
class SummaryContent(TypedDict):
"""Structure for summary content entries.
Attributes:
content: The summarized content.
"""
content: str
console = Console() console = Console()
_MULTIPLE_NEWLINES: Final[re.Pattern[str]] = re.compile(r"\n+")
def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]: def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
"""Parse tools to be used for the task.""" """Parse tools to be used for the task.
tools_list = []
Args:
tools: List of tools to parse.
Returns:
List of structured tools.
Raises:
ValueError: If a tool is not a CrewStructuredTool or BaseTool.
"""
tools_list: list[CrewStructuredTool] = []
for tool in tools: for tool in tools:
if isinstance(tool, CrewAITool): if isinstance(tool, CrewAITool):
@@ -42,7 +73,14 @@ def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str: def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str:
"""Get the names of the tools.""" """Get the names of the tools.
Args:
tools: List of tools to get names from.
Returns:
Comma-separated string of tool names.
"""
return ", ".join([t.name for t in tools]) return ", ".join([t.name for t in tools])
@@ -51,16 +89,30 @@ def render_text_description_and_args(
) -> str: ) -> str:
"""Render the tool name, description, and args in plain text. """Render the tool name, description, and args in plain text.
search: This tool is used for search, args: {"query": {"type": "string"}} search: This tool is used for search, args: {"query": {"type": "string"}}
calculator: This tool is used for math, \ calculator: This tool is used for math, \
args: {"expression": {"type": "string"}} args: {"expression": {"type": "string"}}
Args:
tools: List of tools to render.
Returns:
Plain text description of tools.
""" """
tool_strings = [tool.description for tool in tools] tool_strings = [tool.description for tool in tools]
return "\n".join(tool_strings) return "\n".join(tool_strings)
def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool: def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool:
"""Check if the maximum number of iterations has been reached.""" """Check if the maximum number of iterations has been reached.
Args:
iterations: Current number of iterations.
max_iterations: Maximum allowed iterations.
Returns:
True if maximum iterations reached, False otherwise.
"""
return iterations >= max_iterations return iterations >= max_iterations
@@ -68,16 +120,19 @@ def handle_max_iterations_exceeded(
formatted_answer: AgentAction | AgentFinish | None, formatted_answer: AgentAction | AgentFinish | None,
printer: Printer, printer: Printer,
i18n: I18N, i18n: I18N,
messages: list[dict[str, str]], messages: list[LLMMessage],
llm: LLM | BaseLLM, llm: LLM | BaseLLM,
callbacks: list[Any], callbacks: list[Callable[..., Any]],
) -> AgentAction | AgentFinish: ) -> AgentAction | AgentFinish:
""" """Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer.
Handles the case when the maximum number of iterations is exceeded.
Performs one more LLM call to get the final answer.
Parameters: Args:
formatted_answer: The last formatted answer from the agent. formatted_answer: The last formatted answer from the agent.
printer: Printer instance for output.
i18n: I18N instance for internationalization.
messages: List of messages to send to the LLM.
llm: The LLM instance to call.
callbacks: List of callbacks for the LLM call.
Returns: Returns:
The final formatted answer after exceeding max iterations. The final formatted answer after exceeding max iterations.
@@ -98,7 +153,7 @@ def handle_max_iterations_exceeded(
# Perform one more LLM call to get the final answer # Perform one more LLM call to get the final answer
answer = llm.call( answer = llm.call(
messages, messages, # type: ignore[arg-type]
callbacks=callbacks, callbacks=callbacks,
) )
@@ -110,20 +165,38 @@ def handle_max_iterations_exceeded(
raise ValueError("Invalid response from LLM call - None or empty.") raise ValueError("Invalid response from LLM call - None or empty.")
# Return the formatted answer, regardless of its type # Return the formatted answer, regardless of its type
return format_answer(answer) return format_answer(answer=answer)
def format_message_for_llm(prompt: str, role: str = "user") -> dict[str, str]: def format_message_for_llm(
prompt: str, role: Literal["user", "assistant", "system"] = "user"
) -> LLMMessage:
"""Format a message for the LLM.
Args:
prompt: The message content.
role: The role of the message sender, either 'user' or 'assistant'.
Returns:
A dictionary with 'role' and 'content' keys.
"""
prompt = prompt.rstrip() prompt = prompt.rstrip()
return {"role": role, "content": prompt} return {"role": role, "content": prompt}
def format_answer(answer: str) -> AgentAction | AgentFinish: def format_answer(answer: str) -> AgentAction | AgentFinish:
"""Format a response from the LLM into an AgentAction or AgentFinish.""" """Format a response from the LLM into an AgentAction or AgentFinish.
Args:
answer: The raw response from the LLM
Returns:
Either an AgentAction or AgentFinish
"""
try: try:
return parse(answer) return parse(answer)
except Exception: except Exception:
# If parsing fails, return a default AgentFinish
return AgentFinish( return AgentFinish(
thought="Failed to parse LLM response", thought="Failed to parse LLM response",
output=answer, output=answer,
@@ -134,23 +207,43 @@ def format_answer(answer: str) -> AgentAction | AgentFinish:
def enforce_rpm_limit( def enforce_rpm_limit(
request_within_rpm_limit: Callable[[], bool] | None = None, request_within_rpm_limit: Callable[[], bool] | None = None,
) -> None: ) -> None:
"""Enforce the requests per minute (RPM) limit if applicable.""" """Enforce the requests per minute (RPM) limit if applicable.
Args:
request_within_rpm_limit: Function to enforce RPM limit.
"""
if request_within_rpm_limit: if request_within_rpm_limit:
request_within_rpm_limit() request_within_rpm_limit()
def get_llm_response( def get_llm_response(
llm: LLM | BaseLLM, llm: LLM | BaseLLM,
messages: list[dict[str, str]], messages: list[LLMMessage],
callbacks: list[Any], callbacks: list[Callable[..., Any]],
printer: Printer, printer: Printer,
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
) -> str: ) -> str:
"""Call the LLM and return the response, handling any invalid responses.""" """Call the LLM and return the response, handling any invalid responses.
Args:
llm: The LLM instance to call
messages: The messages to send to the LLM
callbacks: List of callbacks for the LLM call
printer: Printer instance for output
from_task: Optional task context for the LLM call
from_agent: Optional agent context for the LLM call
Returns:
The response from the LLM as a string
Raises:
Exception: If an error occurs.
ValueError: If the response is None or empty.
"""
try: try:
answer = llm.call( answer = llm.call(
messages, messages, # type: ignore[arg-type]
callbacks=callbacks, callbacks=callbacks,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
@@ -170,7 +263,15 @@ def get_llm_response(
def process_llm_response( def process_llm_response(
answer: str, use_stop_words: bool answer: str, use_stop_words: bool
) -> AgentAction | AgentFinish: ) -> AgentAction | AgentFinish:
"""Process the LLM response and format it into an AgentAction or AgentFinish.""" """Process the LLM response and format it into an AgentAction or AgentFinish.
Args:
answer: The raw response from the LLM
use_stop_words: Whether to use stop words in the LLM call
Returns:
Either an AgentAction or AgentFinish
"""
if not use_stop_words: if not use_stop_words:
try: try:
# Preliminary parsing to check for errors. # Preliminary parsing to check for errors.
@@ -200,6 +301,9 @@ def handle_agent_action_core(
Returns: Returns:
Either an AgentAction or AgentFinish Either an AgentAction or AgentFinish
Notes:
- TODO: Remove messages parameter and its usage.
""" """
if step_callback: if step_callback:
step_callback(tool_result) step_callback(tool_result)
@@ -220,7 +324,7 @@ def handle_agent_action_core(
return formatted_answer return formatted_answer
def handle_unknown_error(printer: Any, exception: Exception) -> None: def handle_unknown_error(printer: Printer, exception: Exception) -> None:
"""Handle unknown errors by informing the user. """Handle unknown errors by informing the user.
Args: Args:
@@ -244,10 +348,10 @@ def handle_unknown_error(printer: Any, exception: Exception) -> None:
def handle_output_parser_exception( def handle_output_parser_exception(
e: OutputParserError, e: OutputParserError,
messages: list[dict[str, str]], messages: list[LLMMessage],
iterations: int, iterations: int,
log_error_after: int = 3, log_error_after: int = 3,
printer: Any | None = None, printer: Printer | None = None,
) -> AgentAction: ) -> AgentAction:
"""Handle OutputParserError by updating messages and formatted_answer. """Handle OutputParserError by updating messages and formatted_answer.
@@ -288,18 +392,18 @@ def is_context_length_exceeded(exception: Exception) -> bool:
Returns: Returns:
bool: True if the exception is due to context length exceeding bool: True if the exception is due to context length exceeding
""" """
return LLMContextLengthExceededException(str(exception))._is_context_limit_error( return LLMContextLengthExceededError(str(exception))._is_context_limit_error(
str(exception) str(exception)
) )
def handle_context_length( def handle_context_length(
respect_context_window: bool, respect_context_window: bool,
printer: Any, printer: Printer,
messages: list[dict[str, str]], messages: list[LLMMessage],
llm: Any, llm: LLM | BaseLLM,
callbacks: list[Any], callbacks: list[Callable[..., Any]],
i18n: Any, i18n: I18N,
) -> None: ) -> None:
"""Handle context length exceeded by either summarizing or raising an error. """Handle context length exceeded by either summarizing or raising an error.
@@ -310,13 +414,16 @@ def handle_context_length(
llm: LLM instance for summarization llm: LLM instance for summarization
callbacks: List of callbacks for LLM callbacks: List of callbacks for LLM
i18n: I18N instance for messages i18n: I18N instance for messages
Raises:
SystemExit: If context length is exceeded and user opts not to summarize
""" """
if respect_context_window: if respect_context_window:
printer.print( printer.print(
content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...", content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...",
color="yellow", color="yellow",
) )
summarize_messages(messages, llm, callbacks, i18n) summarize_messages(messages=messages, llm=llm, callbacks=callbacks, i18n=i18n)
else: else:
printer.print( printer.print(
content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.", content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
@@ -328,10 +435,10 @@ def handle_context_length(
def summarize_messages( def summarize_messages(
messages: list[dict[str, str]], messages: list[LLMMessage],
llm: Any, llm: LLM | BaseLLM,
callbacks: list[Any], callbacks: list[Callable[..., Any]],
i18n: Any, i18n: I18N,
) -> None: ) -> None:
"""Summarize messages to fit within context window. """Summarize messages to fit within context window.
@@ -349,7 +456,7 @@ def summarize_messages(
for i in range(0, len(messages_string), cut_size) for i in range(0, len(messages_string), cut_size)
] ]
summarized_contents = [] summarized_contents: list[SummaryContent] = []
total_groups = len(messages_groups) total_groups = len(messages_groups)
for idx, group in enumerate(messages_groups, 1): for idx, group in enumerate(messages_groups, 1):
@@ -357,15 +464,17 @@ def summarize_messages(
content=f"Summarizing {idx}/{total_groups}...", content=f"Summarizing {idx}/{total_groups}...",
color="yellow", color="yellow",
) )
messages = [
format_message_for_llm(
i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
i18n.slice("summarize_instruction").format(group=group["content"]),
),
]
summary = llm.call( summary = llm.call(
[ messages, # type: ignore[arg-type]
format_message_for_llm(
i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
i18n.slice("summarize_instruction").format(group=group["content"]),
),
],
callbacks=callbacks, callbacks=callbacks,
) )
summarized_contents.append({"content": str(summary)}) summarized_contents.append({"content": str(summary)})
@@ -404,20 +513,29 @@ def show_agent_logs(
if formatted_answer is None: if formatted_answer is None:
# Start logs # Start logs
printer.print( printer.print(
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m" content=[
ColoredText("# Agent: ", "bold_purple"),
ColoredText(agent_role, "bold_green"),
]
) )
if task_description: if task_description:
printer.print( printer.print(
content=f"\033[95m## Task:\033[00m \033[92m{task_description}\033[00m" content=[
ColoredText("## Task: ", "purple"),
ColoredText(task_description, "green"),
]
) )
else: else:
# Execution logs # Execution logs
printer.print( printer.print(
content=f"\n\n\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m" content=[
ColoredText("\n\n# Agent: ", "bold_purple"),
ColoredText(agent_role, "bold_green"),
]
) )
if isinstance(formatted_answer, AgentAction): if isinstance(formatted_answer, AgentAction):
thought = re.sub(r"\n+", "\n", formatted_answer.thought) thought = _MULTIPLE_NEWLINES.sub("\n", formatted_answer.thought)
formatted_json = json.dumps( formatted_json = json.dumps(
formatted_answer.tool_input, formatted_answer.tool_input,
indent=2, indent=2,
@@ -425,24 +543,39 @@ def show_agent_logs(
) )
if thought and thought != "": if thought and thought != "":
printer.print( printer.print(
content=f"\033[95m## Thought:\033[00m \033[92m{thought}\033[00m" content=[
ColoredText("## Thought: ", "purple"),
ColoredText(thought, "green"),
]
) )
printer.print( printer.print(
content=f"\033[95m## Using tool:\033[00m \033[92m{formatted_answer.tool}\033[00m" content=[
ColoredText("## Using tool: ", "purple"),
ColoredText(formatted_answer.tool, "green"),
]
) )
printer.print( printer.print(
content=f"\033[95m## Tool Input:\033[00m \033[92m\n{formatted_json}\033[00m" content=[
ColoredText("## Tool Input: ", "purple"),
ColoredText(f"\n{formatted_json}", "green"),
]
) )
printer.print( printer.print(
content=f"\033[95m## Tool Output:\033[00m \033[92m\n{formatted_answer.result}\033[00m" content=[
ColoredText("## Tool Output: ", "purple"),
ColoredText(f"\n{formatted_answer.result}", "green"),
]
) )
elif isinstance(formatted_answer, AgentFinish): elif isinstance(formatted_answer, AgentFinish):
printer.print( printer.print(
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n" content=[
ColoredText("## Final Answer: ", "purple"),
ColoredText(f"\n{formatted_answer.output}\n\n", "green"),
]
) )
def _print_current_organization(): def _print_current_organization() -> None:
settings = Settings() settings = Settings()
if settings.org_uuid: if settings.org_uuid:
console.print( console.print(
@@ -457,6 +590,17 @@ def _print_current_organization():
def load_agent_from_repository(from_repository: str) -> dict[str, Any]: def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
"""Load an agent from the repository.
Args:
from_repository: The name of the agent to load.
Returns:
A dictionary of attributes to use for the agent.
Raises:
AgentRepositoryError: If the agent cannot be loaded.
"""
attributes: dict[str, Any] = {} attributes: dict[str, Any] = {}
if from_repository: if from_repository:
import importlib import importlib

View File

@@ -1,20 +1,19 @@
from typing import Any, Dict, Type from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
def process_config( def process_config(
values: Dict[str, Any], model_class: Type[BaseModel] values: dict[str, Any], model_class: type[BaseModel]
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """Process the config dictionary and update the values accordingly.
Process the config dictionary and update the values accordingly.
Args: Args:
values (Dict[str, Any]): The dictionary of values to update. values: The dictionary of values to update.
model_class (Type[BaseModel]): The Pydantic model class to reference for field validation. model_class: The Pydantic model class to reference for field validation.
Returns: Returns:
Dict[str, Any]: The updated values dictionary. The updated values dictionary.
""" """
config = values.get("config", {}) config = values.get("config", {})
if not config: if not config:

View File

@@ -1,19 +1,32 @@
TRAINING_DATA_FILE = "training_data.pkl" from typing import Annotated, Final
TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl"
DEFAULT_SCORE_THRESHOLD = 0.35 from crewai.utilities.printer import PrinterColor
KNOWLEDGE_DIRECTORY = "knowledge"
MAX_LLM_RETRY = 3 TRAINING_DATA_FILE: Final[str] = "training_data.pkl"
MAX_FILE_NAME_LENGTH = 255 TRAINED_AGENTS_DATA_FILE: Final[str] = "trained_agents_data.pkl"
EMITTER_COLOR = "bold_blue" KNOWLEDGE_DIRECTORY: Final[str] = "knowledge"
MAX_FILE_NAME_LENGTH: Final[int] = 255
EMITTER_COLOR: Final[PrinterColor] = "bold_blue"
class _NotSpecified: class _NotSpecified:
def __repr__(self): """Sentinel class to detect when no value has been explicitly provided.
Notes:
- TODO: Consider moving this class and NOT_SPECIFIED to types.py
as they are more type-related constructs than business constants.
"""
def __repr__(self) -> str:
return "NOT_SPECIFIED" return "NOT_SPECIFIED"
# Sentinel value used to detect when no value has been explicitly provided. NOT_SPECIFIED: Final[
# Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` allows Annotated[
# us to distinguish between "not passed at all" and "explicitly passed None" or "[]". _NotSpecified,
NOT_SPECIFIED = _NotSpecified() "Sentinel value used to detect when no value has been explicitly provided. "
CREWAI_BASE_URL = "https://app.crewai.com" "Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` "
"allows us to distinguish between 'not passed at all' and 'explicitly passed None' or '[]'.",
]
] = _NotSpecified()
CREWAI_BASE_URL: Final[str] = "https://app.crewai.com"

View File

@@ -1,18 +1,35 @@
from __future__ import annotations
import json import json
import re import re
from typing import Any, Optional, Type, Union, get_args, get_origin from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from typing_extensions import Unpack
from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter
from crewai.utilities.internal_instructor import InternalInstructor
from crewai.utilities.printer import Printer from crewai.utilities.printer import Printer
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
_JSON_PATTERN: Final[re.Pattern[str]] = re.compile(r"({.*})", re.DOTALL)
class ConverterError(Exception): class ConverterError(Exception):
"""Error raised when Converter fails to parse the input.""" """Error raised when Converter fails to parse the input."""
def __init__(self, message: str, *args: object) -> None: def __init__(self, message: str, *args: object) -> None:
"""Initialize the ConverterError with a message.
Args:
message: The error message.
*args: Additional arguments for the base Exception class.
"""
super().__init__(message, *args) super().__init__(message, *args)
self.message = message self.message = message
@@ -20,8 +37,18 @@ class ConverterError(Exception):
class Converter(OutputConverter): class Converter(OutputConverter):
"""Class that converts text into either pydantic or json.""" """Class that converts text into either pydantic or json."""
def to_pydantic(self, current_attempt=1) -> BaseModel: def to_pydantic(self, current_attempt: int = 1) -> BaseModel:
"""Convert text to pydantic.""" """Convert text to pydantic.
Args:
current_attempt: The current attempt number for conversion retries.
Returns:
A Pydantic BaseModel instance.
Raises:
ConverterError: If conversion fails after maximum attempts.
"""
try: try:
if self.llm.supports_function_calling(): if self.llm.supports_function_calling():
result = self._create_instructor().to_pydantic() result = self._create_instructor().to_pydantic()
@@ -37,104 +64,124 @@ class Converter(OutputConverter):
result = self.model.model_validate_json(response) result = self.model.model_validate_json(response)
except ValidationError: except ValidationError:
# If direct validation fails, attempt to extract valid JSON # If direct validation fails, attempt to extract valid JSON
result = handle_partial_json(response, self.model, False, None) result = handle_partial_json(
result=response,
model=self.model,
is_json_output=False,
agent=None,
)
# Ensure result is a BaseModel instance # Ensure result is a BaseModel instance
if not isinstance(result, BaseModel): if not isinstance(result, BaseModel):
if isinstance(result, dict): if isinstance(result, dict):
result = self.model.parse_obj(result) result = self.model.model_validate(result)
elif isinstance(result, str): elif isinstance(result, str):
try: try:
parsed = json.loads(result) parsed = json.loads(result)
result = self.model.parse_obj(parsed) result = self.model.model_validate(parsed)
except Exception as parse_err: except Exception as parse_err:
raise ConverterError( raise ConverterError(
f"Failed to convert partial JSON result into Pydantic: {parse_err}" f"Failed to convert partial JSON result into Pydantic: {parse_err}"
) ) from parse_err
else: else:
raise ConverterError( raise ConverterError(
"handle_partial_json returned an unexpected type." "handle_partial_json returned an unexpected type."
) ) from None
return result return result
except ValidationError as e: except ValidationError as e:
if current_attempt < self.max_attempts: if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1) return self.to_pydantic(current_attempt + 1)
raise ConverterError( raise ConverterError(
f"Failed to convert text into a Pydantic model due to validation error: {e}" f"Failed to convert text into a Pydantic model due to validation error: {e}"
) ) from e
except Exception as e: except Exception as e:
if current_attempt < self.max_attempts: if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1) return self.to_pydantic(current_attempt + 1)
raise ConverterError( raise ConverterError(
f"Failed to convert text into a Pydantic model due to error: {e}" f"Failed to convert text into a Pydantic model due to error: {e}"
) ) from e
def to_json(self, current_attempt=1): def to_json(self, current_attempt: int = 1) -> str | ConverterError | Any: # type: ignore[override]
"""Convert text to json.""" """Convert text to json.
Args:
current_attempt: The current attempt number for conversion retries.
Returns:
A JSON string or ConverterError if conversion fails.
Raises:
ConverterError: If conversion fails after maximum attempts.
"""
try: try:
if self.llm.supports_function_calling(): if self.llm.supports_function_calling():
return self._create_instructor().to_json() return self._create_instructor().to_json()
else: return json.dumps(
return json.dumps( self.llm.call(
self.llm.call( [
[ {"role": "system", "content": self.instructions},
{"role": "system", "content": self.instructions}, {"role": "user", "content": self.text},
{"role": "user", "content": self.text}, ]
]
)
) )
)
except Exception as e: except Exception as e:
if current_attempt < self.max_attempts: if current_attempt < self.max_attempts:
return self.to_json(current_attempt + 1) return self.to_json(current_attempt + 1)
return ConverterError(f"Failed to convert text into JSON, error: {e}.") return ConverterError(f"Failed to convert text into JSON, error: {e}.")
def _create_instructor(self): def _create_instructor(self) -> InternalInstructor:
"""Create an instructor.""" """Create an instructor."""
from crewai.utilities import InternalInstructor
inst = InternalInstructor( return InternalInstructor(
llm=self.llm, llm=self.llm,
model=self.model, model=self.model,
content=self.text, content=self.text,
) )
return inst
def _convert_with_instructions(self):
"""Create a chain."""
from crewai.utilities.crew_pydantic_output_parser import (
CrewPydanticOutputParser,
)
parser = CrewPydanticOutputParser(pydantic_object=self.model)
result = self.llm.call(
[
{"role": "system", "content": self.instructions},
{"role": "user", "content": self.text},
]
)
return parser.parse_result(result)
def convert_to_model( def convert_to_model(
result: str, result: str,
output_pydantic: Optional[Type[BaseModel]], output_pydantic: type[BaseModel] | None,
output_json: Optional[Type[BaseModel]], output_json: type[BaseModel] | None,
agent: Any, agent: Agent | None = None,
converter_cls: Optional[Type[Converter]] = None, converter_cls: type[Converter] | None = None,
) -> Union[dict, BaseModel, str]: ) -> dict[str, Any] | BaseModel | str:
"""Convert a result string to a Pydantic model or JSON.
Args:
result: The result string to convert.
output_pydantic: The Pydantic model class to convert to.
output_json: The Pydantic model class to convert to JSON.
agent: The agent instance.
converter_cls: The converter class to use.
Returns:
The converted result as a dict, BaseModel, or original string.
"""
model = output_pydantic or output_json model = output_pydantic or output_json
if model is None: if model is None:
return result return result
try: try:
escaped_result = json.dumps(json.loads(result, strict=False)) escaped_result = json.dumps(json.loads(result, strict=False))
return validate_model(escaped_result, model, bool(output_json)) return validate_model(
result=escaped_result, model=model, is_json_output=bool(output_json)
)
except json.JSONDecodeError: except json.JSONDecodeError:
return handle_partial_json( return handle_partial_json(
result, model, bool(output_json), agent, converter_cls result=result,
model=model,
is_json_output=bool(output_json),
agent=agent,
converter_cls=converter_cls,
) )
except ValidationError: except ValidationError:
return handle_partial_json( return handle_partial_json(
result, model, bool(output_json), agent, converter_cls result=result,
model=model,
is_json_output=bool(output_json),
agent=agent,
converter_cls=converter_cls,
) )
except Exception as e: except Exception as e:
@@ -146,8 +193,18 @@ def convert_to_model(
def validate_model( def validate_model(
result: str, model: Type[BaseModel], is_json_output: bool result: str, model: type[BaseModel], is_json_output: bool
) -> Union[dict, BaseModel]: ) -> dict[str, Any] | BaseModel:
"""Validate and convert a JSON string to a Pydantic model or dict.
Args:
result: The JSON string to validate and convert.
model: The Pydantic model class to convert to.
is_json_output: Whether to return a dict (True) or Pydantic model (False).
Returns:
The converted result as a dict or BaseModel.
"""
exported_result = model.model_validate_json(result) exported_result = model.model_validate_json(result)
if is_json_output: if is_json_output:
return exported_result.model_dump() return exported_result.model_dump()
@@ -156,15 +213,27 @@ def validate_model(
def handle_partial_json( def handle_partial_json(
result: str, result: str,
model: Type[BaseModel], model: type[BaseModel],
is_json_output: bool, is_json_output: bool,
agent: Any, agent: Agent | None,
converter_cls: Optional[Type[Converter]] = None, converter_cls: type[Converter] | None = None,
) -> Union[dict, BaseModel, str]: ) -> dict[str, Any] | BaseModel | str:
match = re.search(r"({.*})", result, re.DOTALL) """Handle partial JSON in a result string and convert to Pydantic model or dict.
Args:
result: The result string to process.
model: The Pydantic model class to convert to.
is_json_output: Whether to return a dict (True) or Pydantic model (False).
agent: The agent instance.
converter_cls: The converter class to use.
Returns:
The converted result as a dict, BaseModel, or original string.
"""
match = _JSON_PATTERN.search(result)
if match: if match:
try: try:
exported_result = model.model_validate_json(match.group(0)) exported_result = model.model_validate_json(match.group())
if is_json_output: if is_json_output:
return exported_result.model_dump() return exported_result.model_dump()
return exported_result return exported_result
@@ -179,19 +248,43 @@ def handle_partial_json(
) )
return convert_with_instructions( return convert_with_instructions(
result, model, is_json_output, agent, converter_cls result=result,
model=model,
is_json_output=is_json_output,
agent=agent,
converter_cls=converter_cls,
) )
def convert_with_instructions( def convert_with_instructions(
result: str, result: str,
model: Type[BaseModel], model: type[BaseModel],
is_json_output: bool, is_json_output: bool,
agent: Any, agent: Agent | None,
converter_cls: Optional[Type[Converter]] = None, converter_cls: type[Converter] | None = None,
) -> Union[dict, BaseModel, str]: ) -> dict | BaseModel | str:
"""Convert a result string to a Pydantic model or JSON using instructions.
Args:
result: The result string to convert.
model: The Pydantic model class to convert to.
is_json_output: Whether to return a dict (True) or Pydantic model (False).
agent: The agent instance.
converter_cls: The converter class to use.
Returns:
The converted result as a dict, BaseModel, or original string.
Raises:
TypeError: If neither agent nor converter_cls is provided.
Notes:
- TODO: Fix llm typing issues, return llm should not be able to be str or None.
"""
if agent is None:
raise TypeError("Agent must be provided if converter_cls is not specified.")
llm = agent.function_calling_llm or agent.llm llm = agent.function_calling_llm or agent.llm
instructions = get_conversion_instructions(model, llm) instructions = get_conversion_instructions(model=model, llm=llm)
converter = create_converter( converter = create_converter(
agent=agent, agent=agent,
converter_cls=converter_cls, converter_cls=converter_cls,
@@ -214,9 +307,25 @@ def convert_with_instructions(
return exported_result return exported_result
def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str: def get_conversion_instructions(
model: type[BaseModel], llm: BaseLLM | LLM | str
) -> str:
"""Generate conversion instructions based on the model and LLM capabilities.
Args:
model: A Pydantic model class.
llm: The language model instance.
Returns:
"""
instructions = "Please convert the following text into valid JSON." instructions = "Please convert the following text into valid JSON."
if llm and not isinstance(llm, str) and llm.supports_function_calling(): if (
llm
and not isinstance(llm, str)
and hasattr(llm, "supports_function_calling")
and llm.supports_function_calling()
):
model_schema = PydanticSchemaParser(model=model).get_schema() model_schema = PydanticSchemaParser(model=model).get_schema()
instructions += ( instructions += (
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n" f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
@@ -231,12 +340,45 @@ def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str:
return instructions return instructions
class CreateConverterKwargs(TypedDict, total=False):
"""Keyword arguments for creating a converter.
Attributes:
llm: The language model instance.
text: The text to convert.
model: The Pydantic model class.
instructions: The conversion instructions.
"""
llm: BaseLLM | LLM | str
text: str
model: type[BaseModel]
instructions: str
def create_converter( def create_converter(
agent: Optional[Any] = None, agent: Agent | None = None,
converter_cls: Optional[Type[Converter]] = None, converter_cls: type[Converter] | None = None,
*args, *args: Any,
**kwargs, **kwargs: Unpack[CreateConverterKwargs],
) -> Converter: ) -> Converter:
"""Create a converter instance based on the agent or provided class.
Args:
agent: The agent instance.
converter_cls: The converter class to instantiate.
*args: The positional arguments to pass to the converter.
**kwargs: The keyword arguments to pass to the converter.
Returns:
An instance of the specified converter class.
Raises:
ValueError: If neither agent nor converter_cls is provided.
AttributeError: If the agent does not have a 'get_output_converter' method.
Exception: If no converter instance is created.
"""
if agent and not converter_cls: if agent and not converter_cls:
if hasattr(agent, "get_output_converter"): if hasattr(agent, "get_output_converter"):
converter = agent.get_output_converter(*args, **kwargs) converter = agent.get_output_converter(*args, **kwargs)
@@ -253,17 +395,30 @@ def create_converter(
return converter return converter
def generate_model_description(model: Type[BaseModel]) -> str: def generate_model_description(model: type[BaseModel]) -> str:
""" """Generate a string description of a Pydantic model's fields and their types.
Generate a string description of a Pydantic model's fields and their types.
This function takes a Pydantic model class and returns a string that describes This function takes a Pydantic model class and returns a string that describes
the model's fields and their respective types. The description includes handling the model's fields and their respective types. The description includes handling
of complex types such as `Optional`, `List`, and `Dict`, as well as nested Pydantic of complex types such as `Optional`, `List`, and `Dict`, as well as nested Pydantic
models. models.
Args:
model: A Pydantic model class.
Returns:
A string representation of the model's fields and types.
""" """
def describe_field(field_type): def describe_field(field_type: Any) -> str:
"""Recursively describe a field's type.
Args:
field_type: The type of the field to describe.
Returns:
A string representation of the field's type.
"""
origin = get_origin(field_type) origin = get_origin(field_type)
args = get_args(field_type) args = get_args(field_type)
@@ -272,20 +427,18 @@ def generate_model_description(model: Type[BaseModel]) -> str:
non_none_args = [arg for arg in args if arg is not type(None)] non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1: if len(non_none_args) == 1:
return f"Optional[{describe_field(non_none_args[0])}]" return f"Optional[{describe_field(non_none_args[0])}]"
else: return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]" if origin is list:
elif origin is list:
return f"List[{describe_field(args[0])}]" return f"List[{describe_field(args[0])}]"
elif origin is dict: if origin is dict:
key_type = describe_field(args[0]) key_type = describe_field(args[0])
value_type = describe_field(args[1]) value_type = describe_field(args[1])
return f"Dict[{key_type}, {value_type}]" return f"Dict[{key_type}, {value_type}]"
elif isinstance(field_type, type) and issubclass(field_type, BaseModel): if isinstance(field_type, type) and issubclass(field_type, BaseModel):
return generate_model_description(field_type) return generate_model_description(field_type)
elif hasattr(field_type, "__name__"): if hasattr(field_type, "__name__"):
return field_type.__name__ return field_type.__name__
else: return str(field_type)
return str(field_type)
fields = model.model_fields fields = model.model_fields
field_descriptions = [ field_descriptions = [

View File

@@ -1,16 +1,16 @@
"""Context management utilities for tracking crew and task execution context using OpenTelemetry baggage.""" """Context management utilities for tracking crew and task execution context using OpenTelemetry baggage."""
from typing import Optional from typing import cast
from opentelemetry import baggage from opentelemetry import baggage
from crewai.utilities.crew.models import CrewContext from crewai.utilities.crew.models import CrewContext
def get_crew_context() -> Optional[CrewContext]: def get_crew_context() -> CrewContext | None:
"""Get the current crew context from OpenTelemetry baggage. """Get the current crew context from OpenTelemetry baggage.
Returns: Returns:
CrewContext instance containing crew context information, or None if no context is set CrewContext instance containing crew context information, or None if no context is set
""" """
return baggage.get_baggage("crew_context") return cast(CrewContext | None, baggage.get_baggage("crew_context"))

View File

@@ -1,16 +1,17 @@
"""Models for crew-related data structures.""" """Models for crew-related data structures."""
from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class CrewContext(BaseModel): class CrewContext(BaseModel):
"""Model representing crew context information.""" """Model representing crew context information.
id: Optional[str] = Field( Attributes:
default=None, description="Unique identifier for the crew" id: Unique identifier for the crew.
) key: Optional crew key/name for identification.
key: Optional[str] = Field( """
id: str | None = Field(default=None, description="Unique identifier for the crew")
key: str | None = Field(
default=None, description="Optional crew key/name for identification" default=None, description="Optional crew key/name for identification"
) )

View File

@@ -4,6 +4,7 @@ import json
from datetime import date, datetime from datetime import date, datetime
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum
from typing import Any
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
@@ -11,18 +12,28 @@ from pydantic import BaseModel
class CrewJSONEncoder(json.JSONEncoder): class CrewJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder for CrewAI objects and special types.""" """Custom JSON encoder for CrewAI objects and special types."""
def default(self, obj):
def default(self, obj: Any) -> Any:
"""Custom serialization for CrewAI specific types.
Args:
obj: The object to serialize.
Returns:
A JSON-serializable representation of the object.
"""
if isinstance(obj, BaseModel): if isinstance(obj, BaseModel):
return self._handle_pydantic_model(obj) return self._handle_pydantic_model(obj)
elif isinstance(obj, UUID) or isinstance(obj, Decimal) or isinstance(obj, Enum): if isinstance(obj, (UUID, Decimal, Enum)):
return str(obj) return str(obj)
elif isinstance(obj, datetime) or isinstance(obj, date): if isinstance(obj, (datetime, date)):
return obj.isoformat() return obj.isoformat()
return super().default(obj) return super().default(obj)
def _handle_pydantic_model(self, obj): @staticmethod
def _handle_pydantic_model(obj: BaseModel) -> str | Any:
try: try:
data = obj.model_dump() data = obj.model_dump()
# Remove circular references # Remove circular references

View File

@@ -1,48 +0,0 @@
import json
from typing import Any
import regex
from pydantic import BaseModel, ValidationError
from crewai.agents.parser import OutputParserError
"""Parser for converting text outputs into Pydantic models."""
class CrewPydanticOutputParser:
"""Parses text outputs into specified Pydantic models."""
pydantic_object: type[BaseModel]
def parse_result(self, result: str) -> Any:
result = self._transform_in_valid_json(result)
# Treating edge case of function calling llm returning the name instead of tool_name
json_object = json.loads(result)
if "tool_name" not in json_object:
json_object["tool_name"] = json_object.get("name", "")
result = json.dumps(json_object)
try:
return self.pydantic_object.model_validate(json_object)
except ValidationError as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
raise OutputParserError(error=msg) from e
def _transform_in_valid_json(self, text) -> str:
text = text.replace("```", "").replace("json", "")
json_pattern = r"\{(?:[^{}]|(?R))*\}"
matches = regex.finditer(json_pattern, text)
for match in matches:
try:
# Attempt to parse the matched string as JSON
json_obj = json.loads(match.group())
# Return the first successfully parsed JSON object
json_obj = json.dumps(json_obj)
return str(json_obj)
except json.JSONDecodeError: # noqa: PERF203
# If parsing fails, skip to the next match
continue
return text

View File

@@ -8,7 +8,11 @@ from typing import Final
class DatabaseOperationError(Exception): class DatabaseOperationError(Exception):
"""Base exception class for database operation errors.""" """Base exception class for database operation errors.
Attributes:
original_error: The original exception that caused this error, if any.
"""
def __init__(self, message: str, original_error: Exception | None = None) -> None: def __init__(self, message: str, original_error: Exception | None = None) -> None:
"""Initialize the database operation error. """Initialize the database operation error.

View File

@@ -1,4 +1,7 @@
from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field, InstanceOf from pydantic import BaseModel, Field, InstanceOf
from rich.box import HEAVY_EDGE from rich.box import HEAVY_EDGE
@@ -6,11 +9,14 @@ from rich.console import Console
from rich.table import Table from rich.table import Table
from crewai.agent import Agent from crewai.agent import Agent
from crewai.llm import BaseLLM
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.crew_events import CrewTestResultEvent from crewai.events.types.crew_events import CrewTestResultEvent
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
if TYPE_CHECKING:
from crewai.crew import Crew
class TaskEvaluationPydanticOutput(BaseModel): class TaskEvaluationPydanticOutput(BaseModel):
@@ -20,23 +26,21 @@ class TaskEvaluationPydanticOutput(BaseModel):
class CrewEvaluator: class CrewEvaluator:
""" """A class to evaluate the performance of the agents in the crew based on the tasks they have performed.
A class to evaluate the performance of the agents in the crew based on the tasks they have performed.
Attributes: Attributes:
crew (Crew): The crew of agents to evaluate. crew: The crew of agents to evaluate.
eval_llm (BaseLLM): Language model instance to use for evaluations tasks_scores: A dictionary to store the scores of the agents for each task.
tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task. run_execution_times: A dictionary to store execution times for each run.
iteration (int): The current iteration of the evaluation. iteration: The current iteration of the evaluation.
""" """
tasks_scores: defaultdict = defaultdict(list) def __init__(self, crew: Crew, eval_llm: InstanceOf[BaseLLM]) -> None:
run_execution_times: defaultdict = defaultdict(list)
iteration: int = 0
def __init__(self, crew, eval_llm: InstanceOf[BaseLLM]):
self.crew = crew self.crew = crew
self.llm = eval_llm self.llm = eval_llm
self.tasks_scores: defaultdict[int, list[float]] = defaultdict(list)
self.run_execution_times: defaultdict[int, list[float]] = defaultdict(list)
self.iteration: int = 0
self._setup_for_evaluating() self._setup_for_evaluating()
def _setup_for_evaluating(self) -> None: def _setup_for_evaluating(self) -> None:
@@ -44,7 +48,7 @@ class CrewEvaluator:
for task in self.crew.tasks: for task in self.crew.tasks:
task.callback = self.evaluate task.callback = self.evaluate
def _evaluator_agent(self): def _evaluator_agent(self) -> Agent:
return Agent( return Agent(
role="Task Execution Evaluator", role="Task Execution Evaluator",
goal=( goal=(
@@ -55,8 +59,9 @@ class CrewEvaluator:
llm=self.llm, llm=self.llm,
) )
@staticmethod
def _evaluation_task( def _evaluation_task(
self, evaluator_agent: Agent, task_to_evaluate: Task, task_output: str evaluator_agent: Agent, task_to_evaluate: Task, task_output: str
) -> Task: ) -> Task:
return Task( return Task(
description=( description=(
@@ -73,6 +78,11 @@ class CrewEvaluator:
) )
def set_iteration(self, iteration: int) -> None: def set_iteration(self, iteration: int) -> None:
"""Sets the current iteration of the evaluation.
Args:
iteration: The current iteration number.
"""
self.iteration = iteration self.iteration = iteration
def print_crew_evaluation_result(self) -> None: def print_crew_evaluation_result(self) -> None:
@@ -97,7 +107,8 @@ class CrewEvaluator:
└────────────────────┴───────┴───────┴───────┴────────────┴──────────────────────────────┘ └────────────────────┴───────┴───────┴───────┴────────────┴──────────────────────────────┘
""" """
task_averages = [ task_averages = [
sum(scores) / len(scores) for scores in zip(*self.tasks_scores.values()) sum(scores) / len(scores)
for scores in zip(*self.tasks_scores.values(), strict=False)
] ]
crew_average = sum(task_averages) / len(task_averages) crew_average = sum(task_averages) / len(task_averages)
@@ -158,8 +169,12 @@ class CrewEvaluator:
console.print("\n") console.print("\n")
console.print(table) console.print(table)
def evaluate(self, task_output: TaskOutput): def evaluate(self, task_output: TaskOutput) -> None:
"""Evaluates the performance of the agents in the crew based on the tasks they have performed.""" """Evaluates the performance of the agents in the crew based on the tasks they have performed.
Args:
task_output: The output of the task to evaluate.
"""
current_task = None current_task = None
for task in self.crew.tasks: for task in self.crew.tasks:
if task.description == task_output.description: if task.description == task_output.description:
@@ -179,19 +194,24 @@ class CrewEvaluator:
evaluation_result = evaluation_task.execute_sync() evaluation_result = evaluation_task.execute_sync()
if isinstance(evaluation_result.pydantic, TaskEvaluationPydanticOutput): if isinstance(evaluation_result.pydantic, TaskEvaluationPydanticOutput):
quality_score = evaluation_result.pydantic.quality
if quality_score is None:
raise ValueError("Evaluation quality score cannot be None")
crewai_event_bus.emit( crewai_event_bus.emit(
self.crew, self.crew,
CrewTestResultEvent( CrewTestResultEvent(
quality=evaluation_result.pydantic.quality, quality=quality_score,
execution_duration=current_task.execution_duration, execution_duration=current_task.execution_duration,
model=self.llm.model, model=self.llm.model,
crew_name=self.crew.name, crew_name=self.crew.name,
crew=self.crew, crew=self.crew,
), ),
) )
self.tasks_scores[self.iteration].append(evaluation_result.pydantic.quality) self.tasks_scores[self.iteration].append(quality_score)
self.run_execution_times[self.iteration].append( if current_task.execution_duration is not None:
current_task.execution_duration self.run_execution_times[self.iteration].append(
) current_task.execution_duration
)
else: else:
raise ValueError("Evaluation result is not in the expected format") raise ValueError("Evaluation result is not in the expected format")

View File

@@ -1,35 +1,42 @@
from typing import List from __future__ import annotations
from typing import TYPE_CHECKING, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.utilities import Converter
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.llm import LLM
from crewai.utilities.converter import Converter
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
from crewai.utilities.training_converter import TrainingConverter from crewai.utilities.training_converter import TrainingConverter
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
class Entity(BaseModel): class Entity(BaseModel):
name: str = Field(description="The name of the entity.") name: str = Field(description="The name of the entity.")
type: str = Field(description="The type of the entity.") type: str = Field(description="The type of the entity.")
description: str = Field(description="Description of the entity.") description: str = Field(description="Description of the entity.")
relationships: List[str] = Field(description="Relationships of the entity.") relationships: list[str] = Field(description="Relationships of the entity.")
class TaskEvaluation(BaseModel): class TaskEvaluation(BaseModel):
suggestions: List[str] = Field( suggestions: list[str] = Field(
description="Suggestions to improve future similar tasks." description="Suggestions to improve future similar tasks."
) )
quality: float = Field( quality: float = Field(
description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task." description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task."
) )
entities: List[Entity] = Field( entities: list[Entity] = Field(
description="Entities extracted from the task output." description="Entities extracted from the task output."
) )
class TrainingTaskEvaluation(BaseModel): class TrainingTaskEvaluation(BaseModel):
suggestions: List[str] = Field( suggestions: list[str] = Field(
description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions." description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions."
) )
quality: float = Field( quality: float = Field(
@@ -41,11 +48,35 @@ class TrainingTaskEvaluation(BaseModel):
class TaskEvaluator: class TaskEvaluator:
def __init__(self, original_agent): """A class to evaluate the performance of an agent based on the tasks they have performed.
self.llm = original_agent.llm
Attributes:
llm: The LLM to use for evaluation.
original_agent: The agent to evaluate.
"""
def __init__(self, original_agent: Agent) -> None:
"""Initializes the TaskEvaluator with the given LLM and agent.
Args:
original_agent: The agent to evaluate.
"""
self.llm = cast(LLM, original_agent.llm)
self.original_agent = original_agent self.original_agent = original_agent
def evaluate(self, task, output) -> TaskEvaluation: def evaluate(self, task: Task, output: str) -> TaskEvaluation:
"""
Args:
task: The task to be evaluated.
output: The output of the task.
Returns:
TaskEvaluation: The evaluation of the task.
Notes:
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
"""
crewai_event_bus.emit( crewai_event_bus.emit(
self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task) self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task)
) )
@@ -73,7 +104,7 @@ class TaskEvaluator:
instructions=instructions, instructions=instructions,
) )
return converter.to_pydantic() return cast(TaskEvaluation, converter.to_pydantic())
def evaluate_training_data( def evaluate_training_data(
self, training_data: dict, agent_id: str self, training_data: dict, agent_id: str
@@ -81,9 +112,12 @@ class TaskEvaluator:
""" """
Evaluate the training data based on the llm output, human feedback, and improved output. Evaluate the training data based on the llm output, human feedback, and improved output.
Parameters: Args:
- training_data (dict): The training data to be evaluated. - training_data: The training data to be evaluated.
- agent_id (str): The ID of the agent. - agent_id: The ID of the agent.
Notes:
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
""" """
crewai_event_bus.emit( crewai_event_bus.emit(
self, TaskEvaluationEvent(evaluation_type="training_data_evaluation") self, TaskEvaluationEvent(evaluation_type="training_data_evaluation")
@@ -142,5 +176,4 @@ class TaskEvaluator:
instructions=instructions, instructions=instructions,
) )
pydantic_result = converter.to_pydantic() return cast(TrainingTaskEvaluation, converter.to_pydantic())
return pydantic_result

View File

@@ -3,9 +3,10 @@
import warnings import warnings
from abc import ABC from abc import ABC
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Type, TypeVar from typing import Any, TypeVar
from typing_extensions import deprecated from typing_extensions import deprecated
import crewai.events as new_events import crewai.events as new_events
from crewai.events.base_events import BaseEvent from crewai.events.base_events import BaseEvent
from crewai.events.event_types import EventTypes from crewai.events.event_types import EventTypes
@@ -17,14 +18,14 @@ warnings.warn(
"Importing from 'crewai.utilities.events' is deprecated and will be removed in v1.0.0. " "Importing from 'crewai.utilities.events' is deprecated and will be removed in v1.0.0. "
"Please use 'crewai.events' instead.", "Please use 'crewai.events' instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2 stacklevel=2,
) )
@deprecated("Use 'from crewai.events import BaseEventListener' instead") @deprecated("Use 'from crewai.events import BaseEventListener' instead")
class BaseEventListener(new_events.BaseEventListener, ABC): class BaseEventListener(new_events.BaseEventListener, ABC):
"""Deprecated: Use crewai.events.BaseEventListener instead.""" """Deprecated: Use crewai.events.BaseEventListener instead."""
pass
@deprecated("Use 'from crewai.events import crewai_event_bus' instead") @deprecated("Use 'from crewai.events import crewai_event_bus' instead")
class crewai_event_bus: # noqa: N801 class crewai_event_bus: # noqa: N801
@@ -32,7 +33,7 @@ class crewai_event_bus: # noqa: N801
@classmethod @classmethod
def on( def on(
cls, event_type: Type[EventT] cls, event_type: type[EventT]
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]: ) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
"""Delegate to the actual event bus instance.""" """Delegate to the actual event bus instance."""
return new_events.crewai_event_bus.on(event_type) return new_events.crewai_event_bus.on(event_type)
@@ -44,7 +45,7 @@ class crewai_event_bus: # noqa: N801
@classmethod @classmethod
def register_handler( def register_handler(
cls, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None] cls, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None]
) -> None: ) -> None:
"""Delegate to the actual event bus instance.""" """Delegate to the actual event bus instance."""
return new_events.crewai_event_bus.register_handler(event_type, handler) return new_events.crewai_event_bus.register_handler(event_type, handler)
@@ -54,87 +55,88 @@ class crewai_event_bus: # noqa: N801
"""Delegate to the actual event bus instance.""" """Delegate to the actual event bus instance."""
return new_events.crewai_event_bus.scoped_handlers() return new_events.crewai_event_bus.scoped_handlers()
@deprecated("Use 'from crewai.events import CrewKickoffStartedEvent' instead") @deprecated("Use 'from crewai.events import CrewKickoffStartedEvent' instead")
class CrewKickoffStartedEvent(new_events.CrewKickoffStartedEvent): class CrewKickoffStartedEvent(new_events.CrewKickoffStartedEvent):
"""Deprecated: Use crewai.events.CrewKickoffStartedEvent instead.""" """Deprecated: Use crewai.events.CrewKickoffStartedEvent instead."""
pass
@deprecated("Use 'from crewai.events import CrewKickoffCompletedEvent' instead") @deprecated("Use 'from crewai.events import CrewKickoffCompletedEvent' instead")
class CrewKickoffCompletedEvent(new_events.CrewKickoffCompletedEvent): class CrewKickoffCompletedEvent(new_events.CrewKickoffCompletedEvent):
"""Deprecated: Use crewai.events.CrewKickoffCompletedEvent instead.""" """Deprecated: Use crewai.events.CrewKickoffCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import AgentExecutionCompletedEvent' instead") @deprecated("Use 'from crewai.events import AgentExecutionCompletedEvent' instead")
class AgentExecutionCompletedEvent(new_events.AgentExecutionCompletedEvent): class AgentExecutionCompletedEvent(new_events.AgentExecutionCompletedEvent):
"""Deprecated: Use crewai.events.AgentExecutionCompletedEvent instead.""" """Deprecated: Use crewai.events.AgentExecutionCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemoryQueryCompletedEvent' instead") @deprecated("Use 'from crewai.events import MemoryQueryCompletedEvent' instead")
class MemoryQueryCompletedEvent(new_events.MemoryQueryCompletedEvent): class MemoryQueryCompletedEvent(new_events.MemoryQueryCompletedEvent):
"""Deprecated: Use crewai.events.MemoryQueryCompletedEvent instead.""" """Deprecated: Use crewai.events.MemoryQueryCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemorySaveCompletedEvent' instead") @deprecated("Use 'from crewai.events import MemorySaveCompletedEvent' instead")
class MemorySaveCompletedEvent(new_events.MemorySaveCompletedEvent): class MemorySaveCompletedEvent(new_events.MemorySaveCompletedEvent):
"""Deprecated: Use crewai.events.MemorySaveCompletedEvent instead.""" """Deprecated: Use crewai.events.MemorySaveCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemorySaveStartedEvent' instead") @deprecated("Use 'from crewai.events import MemorySaveStartedEvent' instead")
class MemorySaveStartedEvent(new_events.MemorySaveStartedEvent): class MemorySaveStartedEvent(new_events.MemorySaveStartedEvent):
"""Deprecated: Use crewai.events.MemorySaveStartedEvent instead.""" """Deprecated: Use crewai.events.MemorySaveStartedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemoryQueryStartedEvent' instead") @deprecated("Use 'from crewai.events import MemoryQueryStartedEvent' instead")
class MemoryQueryStartedEvent(new_events.MemoryQueryStartedEvent): class MemoryQueryStartedEvent(new_events.MemoryQueryStartedEvent):
"""Deprecated: Use crewai.events.MemoryQueryStartedEvent instead.""" """Deprecated: Use crewai.events.MemoryQueryStartedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemoryRetrievalCompletedEvent' instead") @deprecated("Use 'from crewai.events import MemoryRetrievalCompletedEvent' instead")
class MemoryRetrievalCompletedEvent(new_events.MemoryRetrievalCompletedEvent): class MemoryRetrievalCompletedEvent(new_events.MemoryRetrievalCompletedEvent):
"""Deprecated: Use crewai.events.MemoryRetrievalCompletedEvent instead.""" """Deprecated: Use crewai.events.MemoryRetrievalCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemorySaveFailedEvent' instead") @deprecated("Use 'from crewai.events import MemorySaveFailedEvent' instead")
class MemorySaveFailedEvent(new_events.MemorySaveFailedEvent): class MemorySaveFailedEvent(new_events.MemorySaveFailedEvent):
"""Deprecated: Use crewai.events.MemorySaveFailedEvent instead.""" """Deprecated: Use crewai.events.MemorySaveFailedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemoryQueryFailedEvent' instead") @deprecated("Use 'from crewai.events import MemoryQueryFailedEvent' instead")
class MemoryQueryFailedEvent(new_events.MemoryQueryFailedEvent): class MemoryQueryFailedEvent(new_events.MemoryQueryFailedEvent):
"""Deprecated: Use crewai.events.MemoryQueryFailedEvent instead.""" """Deprecated: Use crewai.events.MemoryQueryFailedEvent instead."""
pass
@deprecated("Use 'from crewai.events import KnowledgeRetrievalStartedEvent' instead") @deprecated("Use 'from crewai.events import KnowledgeRetrievalStartedEvent' instead")
class KnowledgeRetrievalStartedEvent(new_events.KnowledgeRetrievalStartedEvent): class KnowledgeRetrievalStartedEvent(new_events.KnowledgeRetrievalStartedEvent):
"""Deprecated: Use crewai.events.KnowledgeRetrievalStartedEvent instead.""" """Deprecated: Use crewai.events.KnowledgeRetrievalStartedEvent instead."""
pass
@deprecated("Use 'from crewai.events import KnowledgeRetrievalCompletedEvent' instead") @deprecated("Use 'from crewai.events import KnowledgeRetrievalCompletedEvent' instead")
class KnowledgeRetrievalCompletedEvent(new_events.KnowledgeRetrievalCompletedEvent): class KnowledgeRetrievalCompletedEvent(new_events.KnowledgeRetrievalCompletedEvent):
"""Deprecated: Use crewai.events.KnowledgeRetrievalCompletedEvent instead.""" """Deprecated: Use crewai.events.KnowledgeRetrievalCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import LLMStreamChunkEvent' instead") @deprecated("Use 'from crewai.events import LLMStreamChunkEvent' instead")
class LLMStreamChunkEvent(new_events.LLMStreamChunkEvent): class LLMStreamChunkEvent(new_events.LLMStreamChunkEvent):
"""Deprecated: Use crewai.events.LLMStreamChunkEvent instead.""" """Deprecated: Use crewai.events.LLMStreamChunkEvent instead."""
pass
__all__ = [ __all__ = [
'BaseEventListener', "AgentExecutionCompletedEvent",
'crewai_event_bus', "BaseEventListener",
'CrewKickoffStartedEvent', "CrewKickoffCompletedEvent",
'CrewKickoffCompletedEvent', "CrewKickoffStartedEvent",
'AgentExecutionCompletedEvent', "KnowledgeRetrievalCompletedEvent",
'MemoryQueryCompletedEvent', "KnowledgeRetrievalStartedEvent",
'MemorySaveCompletedEvent', "LLMStreamChunkEvent",
'MemorySaveStartedEvent', "MemoryQueryCompletedEvent",
'MemoryQueryStartedEvent', "MemoryQueryFailedEvent",
'MemoryRetrievalCompletedEvent', "MemoryQueryStartedEvent",
'MemorySaveFailedEvent', "MemoryRetrievalCompletedEvent",
'MemoryQueryFailedEvent', "MemorySaveCompletedEvent",
'KnowledgeRetrievalStartedEvent', "MemorySaveFailedEvent",
'KnowledgeRetrievalCompletedEvent', "MemorySaveStartedEvent",
'LLMStreamChunkEvent', "crewai_event_bus",
] ]
__deprecated__ = "Use 'crewai.events' instead of 'crewai.utilities.events'" __deprecated__ = "Use 'crewai.events' instead of 'crewai.utilities.events'"

View File

@@ -1,6 +1,7 @@
"""Backwards compatibility stub for crewai.utilities.events.base_event_listener.""" """Backwards compatibility stub for crewai.utilities.events.base_event_listener."""
import warnings import warnings
from crewai.events import BaseEventListener from crewai.events import BaseEventListener
warnings.warn( warnings.warn(

View File

@@ -1,6 +1,7 @@
"""Backwards compatibility stub for crewai.utilities.events.crewai_event_bus.""" """Backwards compatibility stub for crewai.utilities.events.crewai_event_bus."""
import warnings import warnings
from crewai.events import crewai_event_bus from crewai.events import crewai_event_bus
warnings.warn( warnings.warn(

View File

@@ -1,26 +1,57 @@
class LLMContextLengthExceededException(Exception): from typing import Final
CONTEXT_LIMIT_ERRORS = [
"expected a string with maximum length",
"maximum context length",
"context length exceeded",
"context_length_exceeded",
"context window full",
"too many tokens",
"input is too long",
"exceeds token limit",
]
def __init__(self, error_message: str): CONTEXT_LIMIT_ERRORS: Final[list[str]] = [
"expected a string with maximum length",
"maximum context length",
"context length exceeded",
"context_length_exceeded",
"context window full",
"too many tokens",
"input is too long",
"exceeds token limit",
]
class LLMContextLengthExceededError(Exception):
"""Exception raised when the context length of a language model is exceeded.
Attributes:
original_error_message: The original error message from the LLM.
"""
def __init__(self, error_message: str) -> None:
"""Initialize the exception with the original error message.
Args:
error_message: The original error message from the LLM.
"""
self.original_error_message = error_message self.original_error_message = error_message
super().__init__(self._get_error_message(error_message)) super().__init__(self._get_error_message(error_message))
def _is_context_limit_error(self, error_message: str) -> bool: @staticmethod
def _is_context_limit_error(error_message: str) -> bool:
"""Check if the error message indicates a context length limit error.
Args:
error_message: The error message to check.
Returns:
True if the error message indicates a context length limit error, False otherwise.
"""
return any( return any(
phrase.lower() in error_message.lower() phrase.lower() in error_message.lower() for phrase in CONTEXT_LIMIT_ERRORS
for phrase in self.CONTEXT_LIMIT_ERRORS
) )
def _get_error_message(self, error_message: str): @staticmethod
def _get_error_message(error_message: str) -> str:
"""Generate a user-friendly error message based on the original error message.
Args:
error_message: The original error message from the LLM.
Returns:
A user-friendly error message.
"""
return ( return (
f"LLM context length exceeded. Original error: {error_message}\n" f"LLM context length exceeded. Original error: {error_message}\n"
"Consider using a smaller input or implementing a text splitting strategy." "Consider using a smaller input or implementing a text splitting strategy."

View File

@@ -2,71 +2,140 @@ import json
import os import os
import pickle import pickle
from datetime import datetime from datetime import datetime
from typing import Union from typing import Any, TypedDict
from typing_extensions import Unpack
class LogEntry(TypedDict, total=False):
"""TypedDict for log entry kwargs with optional fields for flexibility."""
task_name: str
task: str
agent: str
status: str
output: str
input: str
message: str
level: str
crew: str
flow: str
tool: str
error: str
duration: float
metadata: dict[str, Any]
class FileHandler: class FileHandler:
"""Handler for file operations supporting both JSON and text-based logging. """Handler for file operations supporting both JSON and text-based logging.
Args: Attributes:
file_path (Union[bool, str]): Path to the log file or boolean flag _path: The path to the log file.
""" """
def __init__(self, file_path: Union[bool, str]): def __init__(self, file_path: bool | str) -> None:
"""Initialize the FileHandler with the specified file path.
Args:
file_path: Path to the log file or boolean flag.
"""
self._initialize_path(file_path) self._initialize_path(file_path)
def _initialize_path(self, file_path: Union[bool, str]): def _initialize_path(self, file_path: bool | str) -> None:
"""Initialize the file path based on the input type.
Args:
file_path: Path to the log file or boolean flag.
Raises:
ValueError: If file_path is neither a string nor a boolean.
"""
if file_path is True: # File path is boolean True if file_path is True: # File path is boolean True
self._path = os.path.join(os.curdir, "logs.txt") self._path = os.path.join(os.curdir, "logs.txt")
elif isinstance(file_path, str): # File path is a string elif isinstance(file_path, str): # File path is a string
if file_path.endswith((".json", ".txt")): if file_path.endswith((".json", ".txt")):
self._path = file_path # No modification if the file ends with .json or .txt self._path = (
file_path # No modification if the file ends with .json or .txt
)
else: else:
self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt self._path = (
file_path + ".txt"
) # Append .txt if the file doesn't end with .json or .txt
else: else:
raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid raise ValueError(
"file_path must be a string or boolean."
) # Handle the case where file_path isn't valid
def log(self, **kwargs): def log(self, **kwargs: Unpack[LogEntry]) -> None:
"""Log data with structured fields.
Keyword Args:
task_name: Name of the task.
task: Description of the task.
agent: Name of the agent.
status: Status of the operation.
output: Output data.
input: Input data.
message: Log message.
level: Log level (e.g., INFO, ERROR).
crew: Name of the crew.
flow: Name of the flow.
tool: Name of the tool used.
error: Error message if any.
duration: Duration of the operation in seconds.
metadata: Additional metadata as a dictionary.
Raises:
ValueError: If logging fails.
"""
try: try:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = {"timestamp": now, **kwargs} log_entry = {"timestamp": now, **kwargs}
if self._path.endswith(".json"): if self._path.endswith(".json"):
# Append log in JSON format # Append log in JSON format
with open(self._path, "a", encoding="utf-8") as file: try:
# If the file is empty, start with a list; else, append to it # Try reading existing content to avoid overwriting
try: with open(self._path, encoding="utf-8") as read_file:
# Try reading existing content to avoid overwriting existing_data = json.load(read_file)
with open(self._path, "r", encoding="utf-8") as read_file: existing_data.append(log_entry)
existing_data = json.load(read_file) except (json.JSONDecodeError, FileNotFoundError):
existing_data.append(log_entry) # If no valid JSON or file doesn't exist, start with an empty list
except (json.JSONDecodeError, FileNotFoundError): existing_data = [log_entry]
# If no valid JSON or file doesn't exist, start with an empty list
existing_data = [log_entry]
with open(self._path, "w", encoding="utf-8") as write_file: with open(self._path, "w", encoding="utf-8") as write_file:
json.dump(existing_data, write_file, indent=4) json.dump(existing_data, write_file, indent=4)
write_file.write("\n") write_file.write("\n")
else: else:
# Append log in plain text format # Append log in plain text format
message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n" message = (
f"{now}: "
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
+ "\n"
)
with open(self._path, "a", encoding="utf-8") as file: with open(self._path, "a", encoding="utf-8") as file:
file.write(message) file.write(message)
except Exception as e: except Exception as e:
raise ValueError(f"Failed to log message: {str(e)}") raise ValueError(f"Failed to log message: {e!s}") from e
class PickleHandler: class PickleHandler:
"""Handler for saving and loading data using pickle.
Attributes:
file_path: The path to the pickle file.
"""
def __init__(self, file_name: str) -> None: def __init__(self, file_name: str) -> None:
""" """Initialize the PickleHandler with the name of the file where data will be stored.
Initialize the PickleHandler with the name of the file where data will be stored.
The file will be saved in the current directory. The file will be saved in the current directory.
Parameters: Args:
- file_name (str): The name of the file for saving and loading data. file_name: The name of the file for saving and loading data.
""" """
if not file_name.endswith(".pkl"): if not file_name.endswith(".pkl"):
file_name += ".pkl" file_name += ".pkl"
@@ -74,34 +143,31 @@ class PickleHandler:
self.file_path = os.path.join(os.getcwd(), file_name) self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None: def initialize_file(self) -> None:
""" """Initialize the file with an empty dictionary and overwrite any existing data."""
Initialize the file with an empty dictionary and overwrite any existing data.
"""
self.save({}) self.save({})
def save(self, data) -> None: def save(self, data: Any) -> None:
""" """
Save the data to the specified file using pickle. Save the data to the specified file using pickle.
Parameters: Args:
- data (object): The data to be saved. data: The data to be saved to the file.
""" """
with open(self.file_path, "wb") as file: with open(self.file_path, "wb") as f:
pickle.dump(data, file) pickle.dump(obj=data, file=f)
def load(self) -> dict: def load(self) -> Any:
""" """Load the data from the specified file using pickle.
Load the data from the specified file using pickle.
Returns: Returns:
- dict: The data loaded from the file. The data loaded from the file.
""" """
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return {} # Return an empty dictionary if the file does not exist or is empty return {} # Return an empty dictionary if the file does not exist or is empty
with open(self.file_path, "rb") as file: with open(self.file_path, "rb") as file:
try: try:
return pickle.load(file) # nosec return pickle.load(file) # noqa: S301
except EOFError: except EOFError:
return {} # Return an empty dictionary if the file is empty or corrupted return {} # Return an empty dictionary if the file is empty or corrupted
except Exception: except Exception:

View File

@@ -1,4 +1,7 @@
from typing import TYPE_CHECKING, List, Union from __future__ import annotations
from typing import TYPE_CHECKING, Final
from crewai.utilities.constants import _NotSpecified from crewai.utilities.constants import _NotSpecified
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -6,17 +9,31 @@ if TYPE_CHECKING:
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
def aggregate_raw_outputs_from_task_outputs(task_outputs: List["TaskOutput"]) -> str: DIVIDERS: Final[str] = "\n\n----------\n\n"
"""Generate string context from the task outputs."""
dividers = "\n\n----------\n\n"
# Join task outputs with dividers
context = dividers.join(output.raw for output in task_outputs)
return context
def aggregate_raw_outputs_from_tasks(tasks: Union[List["Task"],_NotSpecified]) -> str: def aggregate_raw_outputs_from_task_outputs(task_outputs: list[TaskOutput]) -> str:
"""Generate string context from the tasks.""" """Generate string context from the task outputs.
Args:
task_outputs: List of TaskOutput objects.
Returns:
A string containing the aggregated raw outputs from the task outputs.
"""
return DIVIDERS.join(output.raw for output in task_outputs)
def aggregate_raw_outputs_from_tasks(tasks: list[Task] | _NotSpecified) -> str:
"""Generate string context from the tasks.
Args:
tasks: List of Task objects or _NotSpecified.
Returns:
A string containing the aggregated raw outputs from the tasks.
"""
task_outputs = ( task_outputs = (
[task.output for task in tasks if task.output is not None] [task.output for task in tasks if task.output is not None]

View File

@@ -1,7 +1,14 @@
from collections.abc import Callable from __future__ import annotations
from typing import Any
from pydantic import BaseModel, field_validator from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Self
if TYPE_CHECKING:
from crewai.lite_agent import LiteAgentOutput
from crewai.tasks.task_output import TaskOutput
class GuardrailResult(BaseModel): class GuardrailResult(BaseModel):
@@ -12,18 +19,31 @@ class GuardrailResult(BaseModel):
be easily handled by the task execution system. be easily handled by the task execution system.
Attributes: Attributes:
success (bool): Whether the guardrail validation passed success: Whether the guardrail validation passed
result (Any, optional): The validated/transformed result if successful result: The validated/transformed result if successful
error (str, optional): Error message if validation failed error: Error message if validation failed
""" """
success: bool success: bool = Field(description="Whether the guardrail validation passed")
result: Any | None = None result: Any | None = Field(
error: str | None = None default=None, description="The validated/transformed result if successful"
)
error: str | None = Field(
default=None, description="Error message if validation failed"
)
@field_validator("result", "error") @field_validator("result", "error")
@classmethod @classmethod
def validate_result_error_exclusivity(cls, v: Any, info) -> Any: def validate_result_error_exclusivity(cls, v: Any, info) -> Any:
"""Ensure that result and error are mutually exclusive based on success.
Args:
v: The value being validated (either result or error)
info: Validation info containing the entire model data
Returns:
The original value if validation passes
"""
values = info.data values = info.data
if "success" in values: if "success" in values:
if values["success"] and v and "error" in values and values["error"]: if values["success"] and v and "error" in values and values["error"]:
@@ -37,15 +57,14 @@ class GuardrailResult(BaseModel):
return v return v
@classmethod @classmethod
def from_tuple(cls, result: tuple[bool, Any | str]) -> "GuardrailResult": def from_tuple(cls, result: tuple[bool, Any | str]) -> Self:
"""Create a GuardrailResult from a validation tuple. """Create a GuardrailResult from a validation tuple.
Args: Args:
result: A tuple of (success, data) where data is either result: A tuple of (success, data) where data is either the validated result or error message.
the validated result or error message.
Returns: Returns:
GuardrailResult: A new instance with the tuple data. A new instance with the tuple data.
""" """
success, data = result success, data = result
return cls( return cls(
@@ -56,7 +75,10 @@ class GuardrailResult(BaseModel):
def process_guardrail( def process_guardrail(
output: Any, guardrail: Callable, retry_count: int, event_source: Any | None = None output: TaskOutput | LiteAgentOutput,
guardrail: Callable[[Any], tuple[bool, Any | str]],
retry_count: int,
event_source: Any | None = None,
) -> GuardrailResult: ) -> GuardrailResult:
"""Process the guardrail for the agent output. """Process the guardrail for the agent output.
@@ -68,7 +90,19 @@ def process_guardrail(
Returns: Returns:
GuardrailResult: The result of the guardrail validation GuardrailResult: The result of the guardrail validation
Raises:
TypeError: If output is not a TaskOutput or LiteAgentOutput
ValueError: If guardrail is None
""" """
from crewai.lite_agent import LiteAgentOutput
from crewai.tasks.task_output import TaskOutput
if not isinstance(output, (TaskOutput, LiteAgentOutput)):
raise TypeError("Output must be a TaskOutput or LiteAgentOutput")
if guardrail is None:
raise ValueError("Guardrail must not be None")
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_guardrail_events import ( from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent, LLMGuardrailCompletedEvent,

View File

@@ -1,36 +1,51 @@
import json
import os
from typing import Dict, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, model_validator
"""Internationalization support for CrewAI prompts and messages.""" """Internationalization support for CrewAI prompts and messages."""
import json
import os
from typing import Literal
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
class I18N(BaseModel): class I18N(BaseModel):
"""Handles loading and retrieving internationalized prompts.""" """Handles loading and retrieving internationalized prompts.
_prompts: Dict[str, Dict[str, str]] = PrivateAttr()
prompt_file: Optional[str] = Field( Attributes:
_prompts: Internal dictionary storing loaded prompts.
prompt_file: Optional path to a custom JSON file containing prompts.
"""
_prompts: dict[str, dict[str, str]] = PrivateAttr()
prompt_file: str | None = Field(
default=None, default=None,
description="Path to the prompt_file file to load", description="Path to the prompt_file file to load",
) )
@model_validator(mode="after") @model_validator(mode="after")
def load_prompts(self) -> "I18N": def load_prompts(self) -> Self:
"""Load prompts from a JSON file.""" """Load prompts from a JSON file.
Returns:
The I18N instance with loaded prompts.
Raises:
Exception: If the prompt file is not found or cannot be decoded.
"""
try: try:
if self.prompt_file: if self.prompt_file:
with open(self.prompt_file, "r", encoding="utf-8") as f: with open(self.prompt_file, encoding="utf-8") as f:
self._prompts = json.load(f) self._prompts = json.load(f)
else: else:
dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = os.path.dirname(os.path.realpath(__file__))
prompts_path = os.path.join(dir_path, "../translations/en.json") prompts_path = os.path.join(dir_path, "../translations/en.json")
with open(prompts_path, "r", encoding="utf-8") as f: with open(prompts_path, encoding="utf-8") as f:
self._prompts = json.load(f) self._prompts = json.load(f)
except FileNotFoundError: except FileNotFoundError as e:
raise Exception(f"Prompt file '{self.prompt_file}' not found.") raise Exception(f"Prompt file '{self.prompt_file}' not found.") from e
except json.JSONDecodeError: except json.JSONDecodeError as e:
raise Exception("Error decoding JSON from the prompts file.") raise Exception("Error decoding JSON from the prompts file.") from e
if not self._prompts: if not self._prompts:
self._prompts = {} self._prompts = {}
@@ -38,16 +53,58 @@ class I18N(BaseModel):
return self return self
def slice(self, slice: str) -> str: def slice(self, slice: str) -> str:
"""Retrieve a prompt slice by key.
Args:
slice: The key of the prompt slice to retrieve.
Returns:
The prompt slice as a string.
"""
return self.retrieve("slices", slice) return self.retrieve("slices", slice)
def errors(self, error: str) -> str: def errors(self, error: str) -> str:
"""Retrieve an error message by key.
Args:
error: The key of the error message to retrieve.
Returns:
The error message as a string.
"""
return self.retrieve("errors", error) return self.retrieve("errors", error)
def tools(self, tool: str) -> Union[str, Dict[str, str]]: def tools(self, tool: str) -> str | dict[str, str]:
"""Retrieve a tool prompt by key.
Args:
tool: The key of the tool prompt to retrieve.
Returns:
The tool prompt as a string or dictionary.
"""
return self.retrieve("tools", tool) return self.retrieve("tools", tool)
def retrieve(self, kind, key) -> str: def retrieve(
self,
kind: Literal[
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent"
],
key: str,
) -> str:
"""Retrieve a prompt by kind and key.
Args:
kind: The kind of prompt.
key: The key of the specific prompt to retrieve.
Returns:
The prompt as a string.
Raises:
Exception: If the prompt for the given kind and key is not found.
"""
try: try:
return self._prompts[kind][key] return self._prompts[kind][key]
except Exception as _: except Exception as e:
raise Exception(f"Prompt for '{kind}':'{key}' not found.") raise Exception(f"Prompt for '{kind}':'{key}' not found.") from e

View File

@@ -7,8 +7,6 @@ from types import ModuleType
class OptionalDependencyError(ImportError): class OptionalDependencyError(ImportError):
"""Exception raised when an optional dependency is not installed.""" """Exception raised when an optional dependency is not installed."""
pass
def require(name: str, *, purpose: str) -> ModuleType: def require(name: str, *, purpose: str) -> ModuleType:
"""Import a module, raising a helpful error if it's not installed. """Import a module, raising a helpful error if it's not installed.

View File

@@ -1,43 +1,98 @@
import warnings from __future__ import annotations
from typing import Any, Optional, Type
from typing import TYPE_CHECKING, Any, Generic, TypeGuard, TypeVar
from pydantic import BaseModel
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.logger_utils import suppress_warnings
from crewai.utilities.types import LLMMessage
T = TypeVar("T", bound=BaseModel)
class InternalInstructor: def _is_valid_llm(llm: Any) -> TypeGuard[str | LLM | BaseLLM]:
"""Class that wraps an agent llm with instructor.""" """Type guard to ensure LLM is valid and not None.
Args:
llm: The LLM to validate
Returns:
True if LLM is valid (string or has model attribute), False otherwise
"""
return llm is not None and (isinstance(llm, str) or hasattr(llm, "model"))
class InternalInstructor(Generic[T]):
"""Class that wraps an agent LLM with instructor for structured output generation.
Attributes:
content: The content to be processed
model: The Pydantic model class for the response
agent: The agent with LLM
llm: The LLM instance or model name
"""
def __init__( def __init__(
self, self,
content: str, content: str,
model: Type, model: type[T],
agent: Optional[Any] = None, agent: Agent | None = None,
llm: Optional[str] = None, llm: LLM | BaseLLM | str | None = None,
): ) -> None:
"""Initialize InternalInstructor.
Args:
content: The content to be processed
model: The Pydantic model class for the response
agent: The agent with LLM
llm: The LLM instance or model name
"""
self.content = content self.content = content
self.agent = agent self.agent = agent
self.llm = llm
self.model = model self.model = model
self._client = None self.llm = llm or (agent.function_calling_llm or agent.llm if agent else None)
self.set_instructor()
def set_instructor(self): with suppress_warnings():
"""Set instructor."""
if self.agent and not self.llm:
self.llm = self.agent.function_calling_llm or self.agent.llm
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import instructor import instructor
from litellm import completion from litellm import completion
self._client = instructor.from_litellm(completion) self._client = instructor.from_litellm(completion)
def to_json(self): def to_json(self) -> str:
model = self.to_pydantic() """Convert the structured output to JSON format.
return model.model_dump_json(indent=2)
def to_pydantic(self): Returns:
messages = [{"role": "user", "content": self.content}] JSON string representation of the structured output
model = self._client.chat.completions.create( """
model=self.llm.model, response_model=self.model, messages=messages pydantic_model = self.to_pydantic()
return pydantic_model.model_dump_json(indent=2)
def to_pydantic(self) -> T:
"""Generate structured output using the specified Pydantic model.
Returns:
Instance of the specified Pydantic model with structured data
Raises:
ValueError: If LLM is not provided or invalid
"""
messages: list[LLMMessage] = [{"role": "user", "content": self.content}]
if not _is_valid_llm(self.llm):
raise ValueError(
"LLM must be provided and have a model attribute or be a string"
)
if isinstance(self.llm, str):
model_name = self.llm
else:
model_name = self.llm.model
return self._client.chat.completions.create(
model=model_name, response_model=self.model, messages=messages
) )
return model

View File

@@ -1,62 +1,55 @@
import logging
import os import os
from typing import Any, Dict, List, Optional, Union from typing import Any, Final
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM, BaseLLM from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
logger = logging.getLogger(__name__)
def create_llm( def create_llm(
llm_value: Union[str, LLM, Any, None] = None, llm_value: str | LLM | Any | None = None,
) -> Optional[LLM | BaseLLM]: ) -> LLM | BaseLLM | None:
""" """Creates or returns an LLM instance based on the given llm_value.
Creates or returns an LLM instance based on the given llm_value.
Args: Args:
llm_value (str | BaseLLM | Any | None): llm_value: LLM instance, model name string, None, or an object with LLM attributes.
- str: The model name (e.g., "gpt-4").
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
Returns: Returns:
A BaseLLM instance if successful, or None if something fails. A BaseLLM instance if successful, or None if something fails.
""" """
# 1) If llm_value is already a BaseLLM or LLM object, return it directly if isinstance(llm_value, (LLM, BaseLLM)):
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
return llm_value return llm_value
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str): if isinstance(llm_value, str):
try: try:
created_llm = LLM(model=llm_value) return LLM(model=llm_value)
return created_llm
except Exception as e: except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}") logger.debug(f"Failed to instantiate LLM with model='{llm_value}': {e}")
return None return None
# 3) If llm_value is None, parse environment variables or use default
if llm_value is None: if llm_value is None:
return _llm_via_environment_or_fallback() return _llm_via_environment_or_fallback()
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
try: try:
# Extract attributes with explicit types
model = ( model = (
getattr(llm_value, "model", None) getattr(llm_value, "model", None)
or getattr(llm_value, "model_name", None) or getattr(llm_value, "model_name", None)
or getattr(llm_value, "deployment_name", None) or getattr(llm_value, "deployment_name", None)
or str(llm_value) or str(llm_value)
) )
temperature: Optional[float] = getattr(llm_value, "temperature", None) temperature: float | None = getattr(llm_value, "temperature", None)
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None) max_tokens: int | None = getattr(llm_value, "max_tokens", None)
logprobs: Optional[int] = getattr(llm_value, "logprobs", None) logprobs: int | None = getattr(llm_value, "logprobs", None)
timeout: Optional[float] = getattr(llm_value, "timeout", None) timeout: float | None = getattr(llm_value, "timeout", None)
api_key: Optional[str] = getattr(llm_value, "api_key", None) api_key: str | None = getattr(llm_value, "api_key", None)
base_url: Optional[str] = getattr(llm_value, "base_url", None) base_url: str | None = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None) api_base: str | None = getattr(llm_value, "api_base", None)
created_llm = LLM( return LLM(
model=model, model=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
@@ -66,15 +59,23 @@ def create_llm(
base_url=base_url, base_url=base_url,
api_base=api_base, api_base=api_base,
) )
return created_llm
except Exception as e: except Exception as e:
print(f"Error instantiating LLM from unknown object type: {e}") logger.debug(f"Error instantiating LLM from unknown object type: {e}")
return None return None
def _llm_via_environment_or_fallback() -> Optional[LLM]: UNACCEPTED_ATTRIBUTES: Final[list[str]] = [
""" "AWS_ACCESS_KEY_ID",
Helper function: if llm_value is None, we load environment variables or fallback default model. "AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]
def _llm_via_environment_or_fallback() -> LLM | None:
"""Creates an LLM instance based on environment variables or defaults.
Returns:
A BaseLLM instance if successful, or None if something fails.
""" """
model_name = ( model_name = (
os.environ.get("MODEL") os.environ.get("MODEL")
@@ -83,28 +84,25 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
or DEFAULT_LLM_MODEL or DEFAULT_LLM_MODEL
) )
# Initialize parameters with correct types
model: str = model_name model: str = model_name
temperature: Optional[float] = None temperature: float | None = None
max_tokens: Optional[int] = None max_tokens: int | None = None
max_completion_tokens: Optional[int] = None max_completion_tokens: int | None = None
logprobs: Optional[int] = None logprobs: int | None = None
timeout: Optional[float] = None timeout: float | None = None
api_key: Optional[str] = None api_key: str | None = None
base_url: Optional[str] = None api_version: str | None = None
api_version: Optional[str] = None presence_penalty: float | None = None
presence_penalty: Optional[float] = None frequency_penalty: float | None = None
frequency_penalty: Optional[float] = None top_p: float | None = None
top_p: Optional[float] = None n: int | None = None
n: Optional[int] = None stop: str | list[str] | None = None
stop: Optional[Union[str, List[str]]] = None logit_bias: dict[int, float] | None = None
logit_bias: Optional[Dict[int, float]] = None response_format: dict[str, Any] | None = None
response_format: Optional[Dict[str, Any]] = None seed: int | None = None
seed: Optional[int] = None top_logprobs: int | None = None
top_logprobs: Optional[int] = None callbacks: list[Any] = []
callbacks: List[Any] = []
# Optional base URL from env
base_url = ( base_url = (
os.environ.get("BASE_URL") os.environ.get("BASE_URL")
or os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_API_BASE")
@@ -119,8 +117,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
elif api_base and not base_url: elif api_base and not base_url:
base_url = api_base base_url = api_base
# Initialize llm_params dictionary llm_params: dict[str, Any] = {
llm_params: Dict[str, Any] = {
"model": model, "model": model,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
@@ -143,11 +140,6 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
"callbacks": callbacks, "callbacks": callbacks,
} }
UNACCEPTED_ATTRIBUTES = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]
set_provider = model_name.partition("/")[0] if "/" in model_name else "openai" set_provider = model_name.partition("/")[0] if "/" in model_name else "openai"
if set_provider in ENV_VARS: if set_provider in ENV_VARS:
@@ -167,28 +159,26 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
if key not in ["prompt", "key_name", "default"]: if key not in ["prompt", "key_name", "default"]:
llm_params[key.lower()] = value llm_params[key.lower()] = value
else: else:
print( logger.debug(
f"Expected env_var to be a dictionary, but got {type(env_var)}" f"Expected env_var to be a dictionary, but got {type(env_var)}"
) )
# Remove None values
llm_params = {k: v for k, v in llm_params.items() if v is not None} llm_params = {k: v for k, v in llm_params.items() if v is not None}
# Try creating the LLM
try: try:
new_llm = LLM(**llm_params) return LLM(**llm_params)
return new_llm
except Exception as e: except Exception as e:
print( logger.debug(
f"Error instantiating LLM from environment/fallback: {type(e).__name__}: {e}" f"Error instantiating LLM from environment/fallback: {type(e).__name__}: {e}"
) )
return None return None
def _normalize_key_name(key_name: str) -> str: def _normalize_key_name(key_name: str) -> str:
""" """Maps environment variable names to recognized litellm parameter keys.
Maps environment variable names to recognized litellm parameter keys,
using patterns from LITELLM_PARAMS. Args:
key_name: The environment variable name to normalize.
""" """
for pattern in LITELLM_PARAMS: for pattern in LITELLM_PARAMS:
if pattern in key_name: if pattern in key_name:

View File

@@ -2,19 +2,34 @@ from datetime import datetime
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from crewai.utilities.printer import Printer from crewai.utilities.printer import ColoredText, Printer, PrinterColor
class Logger(BaseModel): class Logger(BaseModel):
verbose: bool = Field(default=False) verbose: bool = Field(
default=False,
description="Enables verbose logging with timestamps",
)
default_color: PrinterColor = Field(
default="bold_yellow",
description="Default color for log messages",
)
_printer: Printer = PrivateAttr(default_factory=Printer) _printer: Printer = PrivateAttr(default_factory=Printer)
default_color: str = Field(default="bold_yellow")
def log(self, level, message, color=None): def log(self, level: str, message: str, color: PrinterColor | None = None) -> None:
if color is None: """Log a message with timestamp if verbose mode is enabled.
color = self.default_color
Args:
level: The log level (e.g., 'info', 'warning', 'error').
message: The message to log.
color: Optional color for the message. Defaults to default_color.
"""
if self.verbose: if self.verbose:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self._printer.print( self._printer.print(
f"\n[{timestamp}][{level.upper()}]: {message}", color=color [
ColoredText(f"\n[{timestamp}]", "cyan"),
ColoredText(f"[{level.upper()}]: ", "yellow"),
ColoredText(message, color or self.default_color),
]
) )

View File

@@ -47,10 +47,12 @@ def suppress_warnings() -> Generator[None, None, None]:
None during the context execution. None during the context execution.
Note: Note:
There is a similar implementation in src/crewai/llm.py that also This implementation consolidates warning suppression used throughout
suppresses a specific deprecation warning. That version may be the codebase, including specific deprecation warnings from dependencies.
consolidated here in the future.
""" """
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
warnings.filterwarnings(
"ignore", message="open_text is deprecated*", category=DeprecationWarning
)
yield yield

View File

@@ -1,31 +0,0 @@
import re
class YamlParser:
@staticmethod
def parse(file):
"""
Parses a YAML file, modifies specific patterns, and checks for unsupported 'context' usage.
Args:
file (file object): The YAML file to parse.
Returns:
str: The modified content of the YAML file.
Raises:
ValueError: If 'context:' is used incorrectly.
"""
content = file.read()
# Replace single { and } with doubled ones, while leaving already doubled ones intact and the other special characters {# and {%
modified_content = re.sub(r"(?<!\{){(?!\{)(?!\#)(?!\%)", "{{", content)
modified_content = re.sub(
r"(?<!\})(?<!\%)(?<!\#)\}(?!})", "}}", modified_content
)
# Check for 'context:' not followed by '[' and raise an error
if re.search(r"context:(?!\s*\[)", modified_content):
raise ValueError(
"Context is currently only supported in code when creating a task. "
"Please use the 'context' key in the task configuration."
)
return modified_content

View File

@@ -1,9 +1,10 @@
"""Path management utilities for CrewAI storage and configuration."""
import os import os
from pathlib import Path from pathlib import Path
import appdirs import appdirs
"""Path management utilities for CrewAI storage and configuration."""
def db_storage_path() -> str: def db_storage_path() -> str:
"""Returns the path for SQLite database storage. """Returns the path for SQLite database storage.
@@ -19,13 +20,6 @@ def db_storage_path() -> str:
return str(data_dir) return str(data_dir)
def get_project_directory_name(): def get_project_directory_name() -> str:
"""Returns the current project directory name.""" """Returns the current project directory name."""
project_directory_name = os.environ.get("CREWAI_STORAGE_DIR") return os.environ.get("CREWAI_STORAGE_DIR", Path.cwd().name)
if project_directory_name:
return project_directory_name
else:
cwd = Path.cwd()
project_directory_name = cwd.name
return project_directory_name

View File

@@ -1,16 +1,19 @@
"""Handles planning and coordination of crew tasks."""
import logging import logging
from typing import Any, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.agent import Agent from crewai.agent import Agent
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task from crewai.task import Task
"""Handles planning and coordination of crew tasks."""
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PlanPerTask(BaseModel): class PlanPerTask(BaseModel):
"""Represents a plan for a specific task.""" """Represents a plan for a specific task."""
task: str = Field(..., description="The task for which the plan is created") task: str = Field(..., description="The task for which the plan is created")
plan: str = Field( plan: str = Field(
..., ...,
@@ -20,28 +23,48 @@ class PlanPerTask(BaseModel):
class PlannerTaskPydanticOutput(BaseModel): class PlannerTaskPydanticOutput(BaseModel):
"""Output format for task planning results.""" """Output format for task planning results."""
list_of_plans_per_task: List[PlanPerTask] = Field(
list_of_plans_per_task: list[PlanPerTask] = Field(
..., ...,
description="Step by step plan on how the agents can execute their tasks using the available tools with mastery", description="Step by step plan on how the agents can execute their tasks using the available tools with mastery",
) )
class CrewPlanner: class CrewPlanner:
"""Plans and coordinates the execution of crew tasks.""" """Plans and coordinates the execution of crew tasks.
def __init__(self, tasks: List[Task], planning_agent_llm: Optional[Any] = None):
self.tasks = tasks
if planning_agent_llm is None: Attributes:
self.planning_agent_llm = "gpt-4o-mini" tasks: List of tasks to be planned.
else: planning_agent_llm: Optional LLM model for the planning agent.
self.planning_agent_llm = planning_agent_llm """
def __init__(
self, tasks: list[Task], planning_agent_llm: str | BaseLLM | None = None
) -> None:
"""Initialize CrewPlanner with tasks and optional planning agent LLM.
Args:
tasks: List of tasks to be planned.
planning_agent_llm: Optional LLM model for the planning agent. Defaults to None.
"""
self.tasks = tasks
self.planning_agent_llm = planning_agent_llm or "gpt-4o-mini"
def _handle_crew_planning(self) -> PlannerTaskPydanticOutput: def _handle_crew_planning(self) -> PlannerTaskPydanticOutput:
"""Handles the Crew planning by creating detailed step-by-step plans for each task.""" """Handles the Crew planning by creating detailed step-by-step plans for each task.
Returns:
A PlannerTaskPydanticOutput containing the detailed plans for each task.
Raises:
ValueError: If the planning output cannot be obtained.
"""
planning_agent = self._create_planning_agent() planning_agent = self._create_planning_agent()
tasks_summary = self._create_tasks_summary() tasks_summary = self._create_tasks_summary()
planner_task = self._create_planner_task(planning_agent, tasks_summary) planner_task = self._create_planner_task(
planning_agent=planning_agent, tasks_summary=tasks_summary
)
result = planner_task.execute_sync() result = planner_task.execute_sync()
@@ -51,7 +74,11 @@ class CrewPlanner:
raise ValueError("Failed to get the Planning output") raise ValueError("Failed to get the Planning output")
def _create_planning_agent(self) -> Agent: def _create_planning_agent(self) -> Agent:
"""Creates the planning agent for the crew planning.""" """Creates the planning agent for the crew planning.
Returns:
An Agent instance configured for planning tasks.
"""
return Agent( return Agent(
role="Task Execution Planner", role="Task Execution Planner",
goal=( goal=(
@@ -62,8 +89,17 @@ class CrewPlanner:
llm=self.planning_agent_llm, llm=self.planning_agent_llm,
) )
def _create_planner_task(self, planning_agent: Agent, tasks_summary: str) -> Task: @staticmethod
"""Creates the planner task using the given agent and tasks summary.""" def _create_planner_task(planning_agent: Agent, tasks_summary: str) -> Task:
"""Creates the planner task using the given agent and tasks summary.
Args:
planning_agent: The agent responsible for planning.
tasks_summary: A summary of all tasks to be included in the planning.
Returns:
A Task instance configured for planning.
"""
return Task( return Task(
description=( description=(
f"Based on these tasks summary: {tasks_summary} \n Create the most descriptive plan based on the tasks " f"Based on these tasks summary: {tasks_summary} \n Create the most descriptive plan based on the tasks "
@@ -74,31 +110,42 @@ class CrewPlanner:
output_pydantic=PlannerTaskPydanticOutput, output_pydantic=PlannerTaskPydanticOutput,
) )
def _get_agent_knowledge(self, task: Task) -> List[str]: @staticmethod
""" def _get_agent_knowledge(task: Task) -> list[str]:
Safely retrieve knowledge source content from the task's agent. """Safely retrieve knowledge source content from the task's agent.
Args: Args:
task: The task containing an agent with potential knowledge sources task: The task containing an agent with potential knowledge sources
Returns: Returns:
List[str]: A list of knowledge source strings A list of knowledge source strings
""" """
try: try:
if task.agent and task.agent.knowledge_sources: if task.agent and task.agent.knowledge_sources:
return [source.content for source in task.agent.knowledge_sources] return [
getattr(source, "content", str(source))
for source in task.agent.knowledge_sources
]
except AttributeError: except AttributeError:
logger.warning("Error accessing agent knowledge sources") logger.warning("Error accessing agent knowledge sources")
return [] return []
def _create_tasks_summary(self) -> str: def _create_tasks_summary(self) -> str:
"""Creates a summary of all tasks.""" """Creates a summary of all tasks.
Returns:
A string summarizing all tasks with their details.
"""
tasks_summary = [] tasks_summary = []
for idx, task in enumerate(self.tasks): for idx, task in enumerate(self.tasks):
knowledge_list = self._get_agent_knowledge(task) knowledge_list = self._get_agent_knowledge(task)
agent_tools = ( agent_tools = (
f"[{', '.join(str(tool) for tool in task.agent.tools)}]" if task.agent and task.agent.tools else '"agent has no tools"', f"[{', '.join(str(tool) for tool in task.agent.tools)}]"
f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' if knowledge_list and str(knowledge_list) != "None" else "" if task.agent and task.agent.tools
else '"agent has no tools"',
f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"'
if knowledge_list and str(knowledge_list) != "None"
else "",
) )
task_summary = f""" task_summary = f"""
Task Number {idx + 1} - {task.description} Task Number {idx + 1} - {task.description}

View File

@@ -1,71 +1,72 @@
"""Utility for colored console output.""" """Utility for colored console output."""
from typing import Optional from typing import Final, Literal, NamedTuple
PrinterColor = Literal[
"purple",
"bold_purple",
"green",
"bold_green",
"cyan",
"bold_cyan",
"magenta",
"bold_magenta",
"yellow",
"bold_yellow",
"red",
"blue",
"bold_blue",
]
_COLOR_CODES: Final[dict[PrinterColor, str]] = {
"purple": "\033[95m",
"bold_purple": "\033[1m\033[95m",
"red": "\033[91m",
"bold_green": "\033[1m\033[92m",
"green": "\033[32m",
"blue": "\033[94m",
"bold_blue": "\033[1m\033[94m",
"yellow": "\033[93m",
"bold_yellow": "\033[1m\033[93m",
"cyan": "\033[96m",
"bold_cyan": "\033[1m\033[96m",
"magenta": "\033[35m",
"bold_magenta": "\033[1m\033[35m",
}
RESET: Final[str] = "\033[0m"
class ColoredText(NamedTuple):
"""Represents text with an optional color for console output.
Attributes:
text: The text content to be printed.
color: Optional color for the text, specified as a PrinterColor.
"""
text: str
color: PrinterColor | None
class Printer: class Printer:
"""Handles colored console output formatting.""" """Handles colored console output formatting."""
def print(self, content: str, color: Optional[str] = None): @staticmethod
if color == "purple": def print(
self._print_purple(content) content: str | list[ColoredText], color: PrinterColor | None = None
elif color == "red": ) -> None:
self._print_red(content) """Prints content to the console with optional color formatting.
elif color == "bold_green":
self._print_bold_green(content)
elif color == "bold_purple":
self._print_bold_purple(content)
elif color == "bold_blue":
self._print_bold_blue(content)
elif color == "yellow":
self._print_yellow(content)
elif color == "bold_yellow":
self._print_bold_yellow(content)
elif color == "cyan":
self._print_cyan(content)
elif color == "bold_cyan":
self._print_bold_cyan(content)
elif color == "magenta":
self._print_magenta(content)
elif color == "bold_magenta":
self._print_bold_magenta(content)
elif color == "green":
self._print_green(content)
else:
print(content)
def _print_bold_purple(self, content): Args:
print("\033[1m\033[95m {}\033[00m".format(content)) content: Either a string or a list of ColoredText objects for multicolor output.
color: Optional color for the text when content is a string. Ignored when content is a list.
def _print_bold_green(self, content): """
print("\033[1m\033[92m {}\033[00m".format(content)) if isinstance(content, str):
content = [ColoredText(content, color)]
def _print_purple(self, content): print(
print("\033[95m {}\033[00m".format(content)) "".join(
f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}"
def _print_red(self, content): for c in content
print("\033[91m {}\033[00m".format(content)) )
)
def _print_bold_blue(self, content):
print("\033[1m\033[94m {}\033[00m".format(content))
def _print_yellow(self, content):
print("\033[93m {}\033[00m".format(content))
def _print_bold_yellow(self, content):
print("\033[1m\033[93m {}\033[00m".format(content))
def _print_cyan(self, content):
print("\033[96m {}\033[00m".format(content))
def _print_bold_cyan(self, content):
print("\033[1m\033[96m {}\033[00m".format(content))
def _print_magenta(self, content):
print("\033[35m {}\033[00m".format(content))
def _print_bold_magenta(self, content):
print("\033[1m\033[35m {}\033[00m".format(content))
def _print_green(self, content):
print("\033[32m {}\033[00m".format(content))

View File

@@ -1,29 +1,59 @@
from typing import Any, Optional from __future__ import annotations
from typing import Any, TypedDict
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.utilities import I18N from crewai.utilities.i18n import I18N
class StandardPromptResult(TypedDict):
"""Result with only prompt field for standard mode."""
prompt: str
class SystemPromptResult(StandardPromptResult):
"""Result with system, user, and prompt fields for system prompt mode."""
system: str
user: str
class Prompts(BaseModel): class Prompts(BaseModel):
"""Manages and generates prompts for a generic agent.""" """Manages and generates prompts for a generic agent."""
i18n: I18N = Field(default=I18N()) i18n: I18N = Field(default_factory=I18N)
has_tools: bool = False has_tools: bool = Field(
system_template: Optional[str] = None default=False, description="Indicates if the agent has access to tools"
prompt_template: Optional[str] = None )
response_template: Optional[str] = None system_template: str | None = Field(
use_system_prompt: Optional[bool] = False default=None, description="Custom system prompt template"
agent: Any )
prompt_template: str | None = Field(
default=None, description="Custom user prompt template"
)
response_template: str | None = Field(
default=None, description="Custom response prompt template"
)
use_system_prompt: bool | None = Field(
default=False,
description="Whether to use the system prompt when no custom templates are provided",
)
agent: Any = Field(description="Reference to the agent using these prompts")
def task_execution(self) -> dict[str, str]: def task_execution(self) -> SystemPromptResult | StandardPromptResult:
"""Generate a standard prompt for task execution.""" """Generate a standard prompt for task execution.
slices = ["role_playing"]
Returns:
A dictionary containing the constructed prompt(s).
"""
slices: list[str] = ["role_playing"]
if self.has_tools: if self.has_tools:
slices.append("tools") slices.append("tools")
else: else:
slices.append("no_tools") slices.append("no_tools")
system = self._build_prompt(slices) system: str = self._build_prompt(slices)
slices.append("task") slices.append("task")
if ( if (
@@ -31,54 +61,67 @@ class Prompts(BaseModel):
and not self.prompt_template and not self.prompt_template
and self.use_system_prompt and self.use_system_prompt
): ):
return { return SystemPromptResult(
"system": system, system=system,
"user": self._build_prompt(["task"]), user=self._build_prompt(["task"]),
"prompt": self._build_prompt(slices), prompt=self._build_prompt(slices),
} )
else: return StandardPromptResult(
return { prompt=self._build_prompt(
"prompt": self._build_prompt( slices,
slices, self.system_template,
self.system_template, self.prompt_template,
self.prompt_template, self.response_template,
self.response_template, )
) )
}
def _build_prompt( def _build_prompt(
self, self,
components: list[str], components: list[str],
system_template=None, system_template: str | None = None,
prompt_template=None, prompt_template: str | None = None,
response_template=None, response_template: str | None = None,
) -> str: ) -> str:
"""Constructs a prompt string from specified components.""" """Constructs a prompt string from specified components.
Args:
components: List of component names to include in the prompt.
system_template: Optional custom template for the system prompt.
prompt_template: Optional custom template for the user prompt.
response_template: Optional custom template for the response prompt.
Returns:
The constructed prompt string.
"""
prompt: str
if not system_template or not prompt_template: if not system_template or not prompt_template:
# If any of the required templates are missing, fall back to the default format # If any of the required templates are missing, fall back to the default format
prompt_parts = [self.i18n.slice(component) for component in components] prompt_parts: list[str] = [
self.i18n.slice(component) for component in components
]
prompt = "".join(prompt_parts) prompt = "".join(prompt_parts)
else: else:
# All templates are provided, use them # All templates are provided, use them
prompt_parts = [ template_parts: list[str] = [
self.i18n.slice(component) self.i18n.slice(component)
for component in components for component in components
if component != "task" if component != "task"
] ]
system = system_template.replace("{{ .System }}", "".join(prompt_parts)) system: str = system_template.replace(
"{{ .System }}", "".join(template_parts)
)
prompt = prompt_template.replace( prompt = prompt_template.replace(
"{{ .Prompt }}", "".join(self.i18n.slice("task")) "{{ .Prompt }}", "".join(self.i18n.slice("task"))
) )
# Handle missing response_template # Handle missing response_template
if response_template: if response_template:
response = response_template.split("{{ .Response }}")[0] response: str = response_template.split("{{ .Response }}")[0]
prompt = f"{system}\n{prompt}\n{response}" prompt = f"{system}\n{prompt}\n{response}"
else: else:
prompt = f"{system}\n{prompt}" prompt = f"{system}\n{prompt}"
prompt = ( return (
prompt.replace("{goal}", self.agent.goal) prompt.replace("{goal}", self.agent.goal)
.replace("{role}", self.agent.role) .replace("{role}", self.agent.role)
.replace("{backstory}", self.agent.backstory) .replace("{backstory}", self.agent.backstory)
) )
return prompt

View File

@@ -1,57 +1,62 @@
from typing import Dict, List, Type, Union, get_args, get_origin from typing import Any, Union, get_args, get_origin
from pydantic import BaseModel from pydantic import BaseModel, Field
class PydanticSchemaParser(BaseModel): class PydanticSchemaParser(BaseModel):
model: Type[BaseModel] model: type[BaseModel] = Field(..., description="The Pydantic model to parse.")
def get_schema(self) -> str: def get_schema(self) -> str:
""" """Public method to get the schema of a Pydantic model.
Public method to get the schema of a Pydantic model.
:return: String representation of the model schema. Returns:
String representation of the model schema.
""" """
return "{\n" + self._get_model_schema(self.model) + "\n}" return "{\n" + self._get_model_schema(self.model) + "\n}"
def _get_model_schema(self, model: Type[BaseModel], depth: int = 0) -> str: def _get_model_schema(self, model: type[BaseModel], depth: int = 0) -> str:
indent = " " * 4 * depth """Recursively get the schema of a Pydantic model, handling nested models and lists.
lines = [
f"{indent} {field_name}: {self._get_field_type(field, depth + 1)}" Args:
model: The Pydantic model to process.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the model schema.
"""
indent: str = " " * 4 * depth
lines: list[str] = [
f"{indent} {field_name}: {self._get_field_type_for_annotation(field.annotation, depth + 1)}"
for field_name, field in model.model_fields.items() for field_name, field in model.model_fields.items()
] ]
return ",\n".join(lines) return ",\n".join(lines)
def _get_field_type(self, field, depth: int) -> str: def _format_list_type(self, list_item_type: Any, depth: int) -> str:
field_type = field.annotation """Format a List type, handling nested models if necessary.
origin = get_origin(field_type)
if origin in {list, List}: Args:
list_item_type = get_args(field_type)[0] list_item_type: The type of items in the list.
return self._format_list_type(list_item_type, depth) depth: The current depth of recursion for indentation purposes.
if origin in {dict, Dict}: Returns:
key_type, value_type = get_args(field_type) A string representation of the List type.
return f"Dict[{key_type.__name__}, {value_type.__name__}]" """
if origin is Union:
return self._format_union_type(field_type, depth)
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
nested_schema = self._get_model_schema(field_type, depth)
nested_indent = " " * 4 * depth
return f"{field_type.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
return field_type.__name__
def _format_list_type(self, list_item_type, depth: int) -> str:
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel): if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
nested_schema = self._get_model_schema(list_item_type, depth + 1) nested_schema = self._get_model_schema(list_item_type, depth + 1)
nested_indent = " " * 4 * (depth) nested_indent = " " * 4 * depth
return f"List[\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}\n{nested_indent}]" return f"List[\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}\n{nested_indent}]"
return f"List[{list_item_type.__name__}]" return f"List[{list_item_type.__name__}]"
def _format_union_type(self, field_type, depth: int) -> str: def _format_union_type(self, field_type: Any, depth: int) -> str:
"""Format a Union type, handling Optional and nested types.
Args:
field_type: The Union type to format.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the Union type.
"""
args = get_args(field_type) args = get_args(field_type)
if type(None) in args: if type(None) in args:
# It's an Optional type # It's an Optional type
@@ -61,26 +66,32 @@ class PydanticSchemaParser(BaseModel):
non_none_args[0], depth non_none_args[0], depth
) )
return f"Optional[{inner_type}]" return f"Optional[{inner_type}]"
else: # Union with None and multiple other types
# Union with None and multiple other types
inner_types = ", ".join(
self._get_field_type_for_annotation(arg, depth)
for arg in non_none_args
)
return f"Optional[Union[{inner_types}]]"
else:
# General Union type
inner_types = ", ".join( inner_types = ", ".join(
self._get_field_type_for_annotation(arg, depth) for arg in args self._get_field_type_for_annotation(arg, depth) for arg in non_none_args
) )
return f"Union[{inner_types}]" return f"Optional[Union[{inner_types}]]"
# General Union type
inner_types = ", ".join(
self._get_field_type_for_annotation(arg, depth) for arg in args
)
return f"Union[{inner_types}]"
def _get_field_type_for_annotation(self, annotation, depth: int) -> str: def _get_field_type_for_annotation(self, annotation: Any, depth: int) -> str:
origin = get_origin(annotation) """Recursively get the string representation of a field's type annotation.
if origin in {list, List}:
Args:
annotation: The type annotation to process.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the type annotation.
"""
origin: Any = get_origin(annotation)
if origin is list:
list_item_type = get_args(annotation)[0] list_item_type = get_args(annotation)[0]
return self._format_list_type(list_item_type, depth) return self._format_list_type(list_item_type, depth)
if origin in {dict, Dict}: if origin is dict:
key_type, value_type = get_args(annotation) key_type, value_type = get_args(annotation)
return f"Dict[{key_type.__name__}, {value_type.__name__}]" return f"Dict[{key_type.__name__}, {value_type.__name__}]"
if origin is Union: if origin is Union:

View File

@@ -1,19 +1,19 @@
import logging
import json import json
from typing import Tuple, cast import logging
from typing import Any, Final, Literal, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.agent import Agent from crewai.agent import Agent
from crewai.task import Task
from crewai.utilities import I18N
from crewai.llm import LLM
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.reasoning_events import ( from crewai.events.types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent, AgentReasoningCompletedEvent,
AgentReasoningFailedEvent, AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
) )
from crewai.llm import LLM
from crewai.task import Task
from crewai.utilities.i18n import I18N
class ReasoningPlan(BaseModel): class ReasoningPlan(BaseModel):
@@ -29,22 +29,49 @@ class AgentReasoningOutput(BaseModel):
plan: ReasoningPlan = Field(description="The reasoning plan for the task.") plan: ReasoningPlan = Field(description="The reasoning plan for the task.")
class ReasoningFunction(BaseModel): FUNCTION_SCHEMA: Final[dict[str, Any]] = {
"""Model for function calling with reasoning.""" "type": "function",
"function": {
plan: str = Field(description="The detailed reasoning plan for the task.") "name": "create_reasoning_plan",
ready: bool = Field(description="Whether the agent is ready to execute the task.") "description": "Create or refine a reasoning plan for a task",
"parameters": {
"type": "object",
"properties": {
"plan": {
"type": "string",
"description": "The detailed reasoning plan for the task.",
},
"ready": {
"type": "boolean",
"description": "Whether the agent is ready to execute the task.",
},
},
"required": ["plan", "ready"],
},
},
}
class AgentReasoning: class AgentReasoning:
""" """
Handles the agent reasoning process, enabling an agent to reflect and create a plan Handles the agent reasoning process, enabling an agent to reflect and create a plan
before executing a task. before executing a task.
Attributes:
task: The task for which the agent is reasoning.
agent: The agent performing the reasoning.
llm: The language model used for reasoning.
logger: Logger for logging events and errors.
i18n: Internationalization utility for retrieving prompts.
""" """
def __init__(self, task: Task, agent: Agent): def __init__(self, task: Task, agent: Agent) -> None:
if not task or not agent: """Initialize the AgentReasoning with a task and an agent.
raise ValueError("Both task and agent must be provided.")
Args:
task: The task for which the agent is reasoning.
agent: The agent performing the reasoning.
"""
self.task = task self.task = task
self.agent = agent self.agent = agent
self.llm = cast(LLM, agent.llm) self.llm = cast(LLM, agent.llm)
@@ -52,9 +79,7 @@ class AgentReasoning:
self.i18n = I18N() self.i18n = I18N()
def handle_agent_reasoning(self) -> AgentReasoningOutput: def handle_agent_reasoning(self) -> AgentReasoningOutput:
""" """Public method for the reasoning process that creates and refines a plan for the task until the agent is ready to execute it.
Public method for the reasoning process that creates and refines a plan
for the task until the agent is ready to execute it.
Returns: Returns:
AgentReasoningOutput: The output of the agent reasoning process. AgentReasoningOutput: The output of the agent reasoning process.
@@ -70,7 +95,7 @@ class AgentReasoning:
from_task=self.task, from_task=self.task,
), ),
) )
except Exception: except Exception: # noqa: S110
# Ignore event bus errors to avoid breaking execution # Ignore event bus errors to avoid breaking execution
pass pass
@@ -90,7 +115,7 @@ class AgentReasoning:
from_task=self.task, from_task=self.task,
), ),
) )
except Exception: except Exception: # noqa: S110
pass pass
return output return output
@@ -107,17 +132,16 @@ class AgentReasoning:
from_task=self.task, from_task=self.task,
), ),
) )
except Exception: except Exception: # noqa: S110
pass pass
raise raise
def __handle_agent_reasoning(self) -> AgentReasoningOutput: def __handle_agent_reasoning(self) -> AgentReasoningOutput:
""" """Private method that handles the agent reasoning process.
Private method that handles the agent reasoning process.
Returns: Returns:
AgentReasoningOutput: The output of the agent reasoning process. The output of the agent reasoning process.
""" """
plan, ready = self.__create_initial_plan() plan, ready = self.__create_initial_plan()
@@ -126,46 +150,38 @@ class AgentReasoning:
reasoning_plan = ReasoningPlan(plan=plan, ready=ready) reasoning_plan = ReasoningPlan(plan=plan, ready=ready)
return AgentReasoningOutput(plan=reasoning_plan) return AgentReasoningOutput(plan=reasoning_plan)
def __create_initial_plan(self) -> Tuple[str, bool]: def __create_initial_plan(self) -> tuple[str, bool]:
""" """Creates the initial reasoning plan for the task.
Creates the initial reasoning plan for the task.
Returns: Returns:
Tuple[str, bool]: The initial plan and whether the agent is ready to execute the task. The initial plan and whether the agent is ready to execute the task.
""" """
reasoning_prompt = self.__create_reasoning_prompt() reasoning_prompt = self.__create_reasoning_prompt()
if self.llm.supports_function_calling(): if self.llm.supports_function_calling():
plan, ready = self.__call_with_function(reasoning_prompt, "initial_plan") plan, ready = self.__call_with_function(reasoning_prompt, "initial_plan")
return plan, ready return plan, ready
else: response = _call_llm_with_reasoning_prompt(
system_prompt = self.i18n.retrieve("reasoning", "initial_plan").format( llm=self.llm,
role=self.agent.role, prompt=reasoning_prompt,
goal=self.agent.goal, task=self.task,
backstory=self.__get_agent_backstory(), agent=self.agent,
) i18n=self.i18n,
backstory=self.__get_agent_backstory(),
plan_type="initial_plan",
)
response = self.llm.call( return self.__parse_reasoning_response(str(response))
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": reasoning_prompt},
],
from_task=self.task,
from_agent=self.agent,
)
return self.__parse_reasoning_response(str(response)) def __refine_plan_if_needed(self, plan: str, ready: bool) -> tuple[str, bool]:
"""Refines the reasoning plan if the agent is not ready to execute the task.
def __refine_plan_if_needed(self, plan: str, ready: bool) -> Tuple[str, bool]:
"""
Refines the reasoning plan if the agent is not ready to execute the task.
Args: Args:
plan: The current reasoning plan. plan: The current reasoning plan.
ready: Whether the agent is ready to execute the task. ready: Whether the agent is ready to execute the task.
Returns: Returns:
Tuple[str, bool]: The refined plan and whether the agent is ready to execute the task. The refined plan and whether the agent is ready to execute the task.
""" """
attempt = 1 attempt = 1
max_attempts = self.agent.max_reasoning_attempts max_attempts = self.agent.max_reasoning_attempts
@@ -182,7 +198,7 @@ class AgentReasoning:
from_task=self.task, from_task=self.task,
), ),
) )
except Exception: except Exception: # noqa: S110
pass pass
refine_prompt = self.__create_refine_prompt(plan) refine_prompt = self.__create_refine_prompt(plan)
@@ -190,19 +206,14 @@ class AgentReasoning:
if self.llm.supports_function_calling(): if self.llm.supports_function_calling():
plan, ready = self.__call_with_function(refine_prompt, "refine_plan") plan, ready = self.__call_with_function(refine_prompt, "refine_plan")
else: else:
system_prompt = self.i18n.retrieve("reasoning", "refine_plan").format( response = _call_llm_with_reasoning_prompt(
role=self.agent.role, llm=self.llm,
goal=self.agent.goal, prompt=refine_prompt,
task=self.task,
agent=self.agent,
i18n=self.i18n,
backstory=self.__get_agent_backstory(), backstory=self.__get_agent_backstory(),
) plan_type="refine_plan",
response = self.llm.call(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": refine_prompt},
],
from_task=self.task,
from_agent=self.agent,
) )
plan, ready = self.__parse_reasoning_response(str(response)) plan, ready = self.__parse_reasoning_response(str(response))
@@ -216,41 +227,18 @@ class AgentReasoning:
return plan, ready return plan, ready
def __call_with_function(self, prompt: str, prompt_type: str) -> Tuple[str, bool]: def __call_with_function(self, prompt: str, prompt_type: str) -> tuple[str, bool]:
""" """Calls the LLM with function calling to get a reasoning plan.
Calls the LLM with function calling to get a reasoning plan.
Args: Args:
prompt: The prompt to send to the LLM. prompt: The prompt to send to the LLM.
prompt_type: The type of prompt (initial_plan or refine_plan). prompt_type: The type of prompt (initial_plan or refine_plan).
Returns: Returns:
Tuple[str, bool]: A tuple containing the plan and whether the agent is ready. A tuple containing the plan and whether the agent is ready.
""" """
self.logger.debug(f"Using function calling for {prompt_type} reasoning") self.logger.debug(f"Using function calling for {prompt_type} reasoning")
function_schema = {
"type": "function",
"function": {
"name": "create_reasoning_plan",
"description": "Create or refine a reasoning plan for a task",
"parameters": {
"type": "object",
"properties": {
"plan": {
"type": "string",
"description": "The detailed reasoning plan for the task.",
},
"ready": {
"type": "boolean",
"description": "Whether the agent is ready to execute the task.",
},
},
"required": ["plan", "ready"],
},
},
}
try: try:
system_prompt = self.i18n.retrieve("reasoning", prompt_type).format( system_prompt = self.i18n.retrieve("reasoning", prompt_type).format(
role=self.agent.role, role=self.agent.role,
@@ -259,7 +247,7 @@ class AgentReasoning:
) )
# Prepare a simple callable that just returns the tool arguments as JSON # Prepare a simple callable that just returns the tool arguments as JSON
def _create_reasoning_plan(plan: str, ready: bool = True): # noqa: N802 def _create_reasoning_plan(plan: str, ready: bool = True):
"""Return the reasoning plan result in JSON string form.""" """Return the reasoning plan result in JSON string form."""
return json.dumps({"plan": plan, "ready": ready}) return json.dumps({"plan": plan, "ready": ready})
@@ -268,7 +256,7 @@ class AgentReasoning:
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
], ],
tools=[function_schema], tools=[FUNCTION_SCHEMA],
available_functions={"create_reasoning_plan": _create_reasoning_plan}, available_functions={"create_reasoning_plan": _create_reasoning_plan},
from_task=self.task, from_task=self.task,
from_agent=self.agent, from_agent=self.agent,
@@ -291,7 +279,7 @@ class AgentReasoning:
except Exception as e: except Exception as e:
self.logger.warning( self.logger.warning(
f"Error during function calling: {str(e)}. Falling back to text parsing." f"Error during function calling: {e!s}. Falling back to text parsing."
) )
try: try:
@@ -316,7 +304,7 @@ class AgentReasoning:
"READY: I am ready to execute the task." in fallback_str, "READY: I am ready to execute the task." in fallback_str,
) )
except Exception as inner_e: except Exception as inner_e:
self.logger.error(f"Error during fallback text parsing: {str(inner_e)}") self.logger.error(f"Error during fallback text parsing: {inner_e!s}")
return ( return (
"Failed to generate a plan due to an error.", "Failed to generate a plan due to an error.",
True, True,
@@ -378,7 +366,8 @@ class AgentReasoning:
current_plan=current_plan, current_plan=current_plan,
) )
def __parse_reasoning_response(self, response: str) -> Tuple[str, bool]: @staticmethod
def __parse_reasoning_response(response: str) -> tuple[str, bool]:
""" """
Parses the reasoning response to extract the plan and whether Parses the reasoning response to extract the plan and whether
the agent is ready to execute the task. the agent is ready to execute the task.
@@ -387,7 +376,7 @@ class AgentReasoning:
response: The LLM response. response: The LLM response.
Returns: Returns:
Tuple[str, bool]: The plan and whether the agent is ready to execute the task. The plan and whether the agent is ready to execute the task.
""" """
if not response: if not response:
return "No plan was generated.", False return "No plan was generated.", False
@@ -412,3 +401,43 @@ class AgentReasoning:
"The _handle_agent_reasoning method is deprecated. Use handle_agent_reasoning instead." "The _handle_agent_reasoning method is deprecated. Use handle_agent_reasoning instead."
) )
return self.handle_agent_reasoning() return self.handle_agent_reasoning()
def _call_llm_with_reasoning_prompt(
llm: LLM,
prompt: str,
task: Task,
agent: Agent,
i18n: I18N,
backstory: str,
plan_type: Literal["initial_plan", "refine_plan"],
) -> str:
"""Calls the LLM with the reasoning prompt.
Args:
llm: The language model to use.
prompt: The prompt to send to the LLM.
task: The task for which the agent is reasoning.
agent: The agent performing the reasoning.
i18n: Internationalization utility for retrieving prompts.
backstory: The agent's backstory.
plan_type: The type of plan being created ("initial_plan" or "refine_plan").
Returns:
The LLM response.
"""
system_prompt = i18n.retrieve("reasoning", plan_type).format(
role=agent.role,
goal=agent.goal,
backstory=backstory,
)
response = llm.call(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
from_task=task,
from_agent=agent,
)
return str(response)

View File

@@ -1,41 +1,54 @@
"""Controls request rate limiting for API calls."""
import threading import threading
import time import time
from typing import Optional
from pydantic import BaseModel, Field, PrivateAttr, model_validator from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
"""Controls request rate limiting for API calls."""
class RPMController(BaseModel): class RPMController(BaseModel):
"""Manages requests per minute limiting.""" """Manages requests per minute limiting."""
max_rpm: Optional[int] = Field(default=None) max_rpm: int | None = Field(
default=None,
description="Maximum requests per minute. If None, no limit is applied.",
)
logger: Logger = Field(default_factory=lambda: Logger(verbose=False)) logger: Logger = Field(default_factory=lambda: Logger(verbose=False))
_current_rpm: int = PrivateAttr(default=0) _current_rpm: int = PrivateAttr(default=0)
_timer: Optional[threading.Timer] = PrivateAttr(default=None) _timer: "threading.Timer | None" = PrivateAttr(default=None)
_lock: Optional[threading.Lock] = PrivateAttr(default=None) _lock: "threading.Lock | None" = PrivateAttr(default=None)
_shutdown_flag: bool = PrivateAttr(default=False) _shutdown_flag: bool = PrivateAttr(default=False)
@model_validator(mode="after") @model_validator(mode="after")
def reset_counter(self): def reset_counter(self) -> Self:
"""Resets the RPM counter and starts the timer if max_rpm is set.
Returns:
The instance of the RPMController.
"""
if self.max_rpm is not None: if self.max_rpm is not None:
if not self._shutdown_flag: if not self._shutdown_flag:
self._lock = threading.Lock() self._lock = threading.Lock()
self._reset_request_count() self._reset_request_count()
return self return self
def check_or_wait(self): def check_or_wait(self) -> bool:
"""Checks if a new request can be made based on the RPM limit.
Returns:
True if a new request can be made, False otherwise.
"""
if self.max_rpm is None: if self.max_rpm is None:
return True return True
def _check_and_increment(): def _check_and_increment() -> bool:
if self.max_rpm is not None and self._current_rpm < self.max_rpm: if self.max_rpm is not None and self._current_rpm < self.max_rpm:
self._current_rpm += 1 self._current_rpm += 1
return True return True
elif self.max_rpm is not None: if self.max_rpm is not None:
self.logger.log( self.logger.log(
"info", "Max RPM reached, waiting for next minute to start." "info", "Max RPM reached, waiting for next minute to start."
) )
@@ -50,16 +63,18 @@ class RPMController(BaseModel):
else: else:
return _check_and_increment() return _check_and_increment()
def stop_rpm_counter(self): def stop_rpm_counter(self) -> None:
"""Stops the RPM counter and cancels any active timers."""
self._shutdown_flag = True
if self._timer: if self._timer:
self._timer.cancel() self._timer.cancel()
self._timer = None self._timer = None
def _wait_for_next_minute(self): def _wait_for_next_minute(self) -> None:
time.sleep(60) time.sleep(60)
self._current_rpm = 0 self._current_rpm = 0
def _reset_request_count(self): def _reset_request_count(self) -> None:
def _reset(): def _reset():
self._current_rpm = 0 self._current_rpm = 0
if not self._shutdown_flag: if not self._shutdown_flag:
@@ -71,7 +86,3 @@ class RPMController(BaseModel):
_reset() _reset()
else: else:
_reset() _reset()
if self._timer:
self._shutdown_flag = True
self._timer.cancel()

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
import json import json
import uuid import uuid
from datetime import date, datetime from datetime import date, datetime
from typing import Any, Dict, List, Union from typing import Any, TypeAlias
from pydantic import BaseModel from pydantic import BaseModel
SerializablePrimitive = Union[str, int, float, bool, None] SerializablePrimitive: TypeAlias = str | int | float | bool | None
Serializable = Union[ Serializable: TypeAlias = (
SerializablePrimitive, List["Serializable"], Dict[str, "Serializable"] SerializablePrimitive | list["Serializable"] | dict[str, "Serializable"]
] )
def to_serializable( def to_serializable(
@@ -24,9 +26,10 @@ def to_serializable(
Non-convertible objects default to their string representations. Non-convertible objects default to their string representations.
Args: Args:
obj (Any): Object to transform. obj: Object to transform.
exclude (set[str], optional): Set of keys to exclude from the result. exclude: Set of keys to exclude from the result.
max_depth (int, optional): Maximum recursion depth. Defaults to 5. max_depth: Maximum recursion depth. Defaults to 5.
_current_depth: Current recursion depth (for internal use).
Returns: Returns:
Serializable: A JSON-compatible structure. Serializable: A JSON-compatible structure.
@@ -39,18 +42,18 @@ def to_serializable(
if isinstance(obj, (str, int, float, bool, type(None))): if isinstance(obj, (str, int, float, bool, type(None))):
return obj return obj
elif isinstance(obj, uuid.UUID): if isinstance(obj, uuid.UUID):
return str(obj) return str(obj)
elif isinstance(obj, (date, datetime)): if isinstance(obj, (date, datetime)):
return obj.isoformat() return obj.isoformat()
elif isinstance(obj, (list, tuple, set)): if isinstance(obj, (list, tuple, set)):
return [ return [
to_serializable( to_serializable(
item, max_depth=max_depth, _current_depth=_current_depth + 1 item, max_depth=max_depth, _current_depth=_current_depth + 1
) )
for item in obj for item in obj
] ]
elif isinstance(obj, dict): if isinstance(obj, dict):
return { return {
_to_serializable_key(key): to_serializable( _to_serializable_key(key): to_serializable(
obj=value, obj=value,
@@ -61,33 +64,31 @@ def to_serializable(
for key, value in obj.items() for key, value in obj.items()
if key not in exclude if key not in exclude
} }
elif isinstance(obj, BaseModel): if isinstance(obj, BaseModel):
return to_serializable( return to_serializable(
obj=obj.model_dump(exclude=exclude), obj=obj.model_dump(exclude=exclude),
max_depth=max_depth, max_depth=max_depth,
_current_depth=_current_depth + 1, _current_depth=_current_depth + 1,
) )
else: return repr(obj)
return repr(obj)
def _to_serializable_key(key: Any) -> str: def _to_serializable_key(key: Any) -> str:
if isinstance(key, (str, int)): if isinstance(key, (str, int)):
return str(key) return str(key)
return f"key_{id(key)}_{repr(key)}" return f"key_{id(key)}_{key!r}"
def to_string(obj: Any) -> str | None: def to_string(obj: Any) -> str | None:
"""Serializes an object into a JSON string. """Serializes an object into a JSON string.
Args: Args:
obj (Any): Object to serialize. obj: Object to serialize.
Returns: Returns:
str | None: A JSON-formatted string or `None` if empty. A JSON-formatted string or `None` if empty.
""" """
serializable = to_serializable(obj) serializable = to_serializable(obj)
if serializable is None: if serializable is None:
return None return None
else: return json.dumps(serializable)
return json.dumps(serializable)

View File

@@ -1,10 +1,12 @@
import re import re
from typing import Any, Dict, List, Optional, Union from typing import Any, Final
_VARIABLE_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
def interpolate_only( def interpolate_only(
input_string: Optional[str], input_string: str | None,
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]], inputs: dict[str, str | int | float | dict[str, Any] | list[Any]],
) -> str: ) -> str:
"""Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched. """Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched.
Only interpolates placeholders that follow the pattern {variable_name} where Only interpolates placeholders that follow the pattern {variable_name} where
@@ -26,26 +28,30 @@ def interpolate_only(
""" """
# Validation function for recursive type checking # Validation function for recursive type checking
def validate_type(value: Any) -> None: def _validate_type(validate_value: Any) -> None:
if value is None: if validate_value is None:
return return
if isinstance(value, (str, int, float, bool)): if isinstance(validate_value, (str, int, float, bool)):
return return
if isinstance(value, (dict, list)): if isinstance(validate_value, (dict, list)):
for item in value.values() if isinstance(value, dict) else value: for item in (
validate_type(item) validate_value.values()
if isinstance(validate_value, dict)
else validate_value
):
_validate_type(item)
return return
raise ValueError( raise ValueError(
f"Unsupported type {type(value).__name__} in inputs. " f"Unsupported type {type(validate_value).__name__} in inputs. "
"Only str, int, float, bool, dict, and list are allowed." "Only str, int, float, bool, dict, and list are allowed."
) )
# Validate all input values # Validate all input values
for key, value in inputs.items(): for key, value in inputs.items():
try: try:
validate_type(value) _validate_type(value)
except ValueError as e: except ValueError as e: # noqa: PERF203
raise ValueError(f"Invalid value for key '{key}': {str(e)}") from e raise ValueError(f"Invalid value for key '{key}': {e!s}") from e
if input_string is None or not input_string: if input_string is None or not input_string:
return "" return ""
@@ -56,13 +62,7 @@ def interpolate_only(
"Inputs dictionary cannot be empty when interpolating variables" "Inputs dictionary cannot be empty when interpolating variables"
) )
# The regex pattern to find valid variable placeholders variables = _VARIABLE_PATTERN.findall(input_string)
# Matches {variable_name} where variable_name starts with a letter/underscore
# and contains only letters, numbers, and underscores
pattern = r"\{([A-Za-z_][A-Za-z0-9_\-]*)\}"
# Find all matching variables in the input string
variables = re.findall(pattern, input_string)
result = input_string result = input_string
# Check if all variables exist in inputs # Check if all variables exist in inputs

View File

@@ -4,50 +4,14 @@ This module provides functionality for storing and retrieving task outputs
from persistent storage, supporting replay and audit capabilities. from persistent storage, supporting replay and audit capabilities.
""" """
from datetime import datetime
from typing import Any from typing import Any
from pydantic import BaseModel, Field
from crewai.memory.storage.kickoff_task_outputs_storage import ( from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage, KickoffTaskOutputsSQLiteStorage,
) )
from crewai.task import Task from crewai.task import Task
class ExecutionLog(BaseModel):
"""Represents a log entry for task execution.
Attributes:
task_id: Unique identifier for the task.
expected_output: The expected output description for the task.
output: The actual output produced by the task.
timestamp: When the task was executed.
task_index: The position of the task in the execution sequence.
inputs: Input parameters provided to the task.
was_replayed: Whether this output was replayed from a previous run.
"""
task_id: str
expected_output: str | None = None
output: dict[str, Any]
timestamp: datetime = Field(default_factory=datetime.now)
task_index: int
inputs: dict[str, Any] = Field(default_factory=dict)
was_replayed: bool = False
def __getitem__(self, key: str) -> Any:
"""Enable dictionary-style access to execution log attributes.
Args:
key: The attribute name to access.
Returns:
The value of the requested attribute.
"""
return getattr(self, key)
class TaskOutputStorageHandler: class TaskOutputStorageHandler:
"""Manages storage and retrieval of task outputs. """Manages storage and retrieval of task outputs.

View File

@@ -4,13 +4,13 @@ This module provides a callback handler that tracks token usage
for LLM API calls through the litellm library. for LLM API calls through the litellm library.
""" """
import warnings
from typing import Any from typing import Any
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import Usage from litellm.types.utils import Usage
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.utilities.logger_utils import suppress_warnings
class TokenCalcHandler(CustomLogger): class TokenCalcHandler(CustomLogger):
@@ -23,12 +23,13 @@ class TokenCalcHandler(CustomLogger):
token_cost_process: The token process tracker to accumulate usage metrics. token_cost_process: The token process tracker to accumulate usage metrics.
""" """
def __init__(self, token_cost_process: TokenProcess | None) -> None: def __init__(self, token_cost_process: TokenProcess | None, **kwargs: Any) -> None:
"""Initialize the token calculation handler. """Initialize the token calculation handler.
Args: Args:
token_cost_process: Optional token process tracker for accumulating metrics. token_cost_process: Optional token process tracker for accumulating metrics.
""" """
super().__init__(**kwargs)
self.token_cost_process = token_cost_process self.token_cost_process = token_cost_process
def log_success_event( def log_success_event(
@@ -49,8 +50,7 @@ class TokenCalcHandler(CustomLogger):
if self.token_cost_process is None: if self.token_cost_process is None:
return return
with warnings.catch_warnings(): with suppress_warnings():
warnings.simplefilter("ignore", UserWarning)
if isinstance(response_obj, dict) and "usage" in response_obj: if isinstance(response_obj, dict) and "usage" in response_obj:
usage: Usage = response_obj["usage"] usage: Usage = response_obj["usage"]
if usage: if usage:

View File

@@ -1,12 +1,21 @@
from typing import Any from __future__ import annotations
from typing import TYPE_CHECKING
from crewai.agents.parser import AgentAction from crewai.agents.parser import AgentAction
from crewai.security import Fingerprint from crewai.agents.tools_handler import ToolsHandler
from crewai.security.fingerprint import Fingerprint
from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult from crewai.tools.tool_types import ToolResult
from crewai.tools.tool_usage import ToolUsage, ToolUsageError from crewai.tools.tool_usage import ToolUsage, ToolUsageError
from crewai.utilities.i18n import I18N from crewai.utilities.i18n import I18N
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
def execute_tool_and_check_finality( def execute_tool_and_check_finality(
agent_action: AgentAction, agent_action: AgentAction,
@@ -14,10 +23,10 @@ def execute_tool_and_check_finality(
i18n: I18N, i18n: I18N,
agent_key: str | None = None, agent_key: str | None = None,
agent_role: str | None = None, agent_role: str | None = None,
tools_handler: Any | None = None, tools_handler: ToolsHandler | None = None,
task: Any | None = None, task: Task | None = None,
agent: Any | None = None, agent: Agent | None = None,
function_calling_llm: Any | None = None, function_calling_llm: BaseLLM | LLM | None = None,
fingerprint_context: dict[str, str] | None = None, fingerprint_context: dict[str, str] | None = None,
) -> ToolResult: ) -> ToolResult:
"""Execute a tool and check if the result should be treated as a final answer. """Execute a tool and check if the result should be treated as a final answer.
@@ -32,59 +41,54 @@ def execute_tool_and_check_finality(
task: Optional task for tool execution task: Optional task for tool execution
agent: Optional agent instance for tool execution agent: Optional agent instance for tool execution
function_calling_llm: Optional LLM for function calling function_calling_llm: Optional LLM for function calling
fingerprint_context: Optional context for fingerprinting
Returns: Returns:
ToolResult containing the execution result and whether it should be treated as a final answer ToolResult containing the execution result and whether it should be treated as a final answer
""" """
try: tool_name_to_tool_map = {tool.name: tool for tool in tools}
tool_name_to_tool_map = {tool.name: tool for tool in tools}
if agent_key and agent_role and agent: if agent_key and agent_role and agent:
fingerprint_context = fingerprint_context or {} fingerprint_context = fingerprint_context or {}
if agent: if agent:
if hasattr(agent, "set_fingerprint") and callable( if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint):
agent.set_fingerprint if isinstance(fingerprint_context, dict):
): try:
if isinstance(fingerprint_context, dict): fingerprint_obj = Fingerprint.from_dict(fingerprint_context)
try: agent.set_fingerprint(fingerprint=fingerprint_obj)
fingerprint_obj = Fingerprint.from_dict(fingerprint_context) except Exception as e:
agent.set_fingerprint(fingerprint_obj) raise ValueError(f"Failed to set fingerprint: {e}") from e
except Exception as e:
raise ValueError(f"Failed to set fingerprint: {e}") from e
# Create tool usage instance # Create tool usage instance
tool_usage = ToolUsage( tool_usage = ToolUsage(
tools_handler=tools_handler, tools_handler=tools_handler,
tools=tools, tools=tools,
function_calling_llm=function_calling_llm, function_calling_llm=function_calling_llm,
task=task, task=task,
agent=agent, agent=agent,
action=agent_action, action=agent_action,
) )
# Parse tool calling # Parse tool calling
tool_calling = tool_usage.parse_tool_calling(agent_action.text) tool_calling = tool_usage.parse_tool_calling(agent_action.text)
if isinstance(tool_calling, ToolUsageError): if isinstance(tool_calling, ToolUsageError):
return ToolResult(tool_calling.message, False) return ToolResult(tool_calling.message, False)
# Check if tool name matches # Check if tool name matches
if tool_calling.tool_name.casefold().strip() in [ if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in tool_name_to_tool_map name.casefold().strip() for name in tool_name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [ ] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in tool_name_to_tool_map name.casefold().strip() for name in tool_name_to_tool_map
]: ]:
tool_result = tool_usage.use(tool_calling, agent_action.text) tool_result = tool_usage.use(tool_calling, agent_action.text)
tool = tool_name_to_tool_map.get(tool_calling.tool_name) tool = tool_name_to_tool_map.get(tool_calling.tool_name)
if tool: if tool:
return ToolResult(tool_result, tool.result_as_answer) return ToolResult(tool_result, tool.result_as_answer)
# Handle invalid tool name # Handle invalid tool name
tool_result = i18n.errors("wrong_tool_name").format( tool_result = i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name, tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in tools]), tools=", ".join([tool.name.casefold() for tool in tools]),
) )
return ToolResult(tool_result, False) return ToolResult(result=tool_result, result_as_answer=False)
except Exception as e:
raise e

View File

@@ -1,42 +1,72 @@
import json import json
import re import re
from typing import Any, get_origin from typing import Any, Final, get_origin
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from crewai.utilities.converter import Converter, ConverterError from crewai.utilities.converter import Converter, ConverterError
_FLOAT_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\d+(?:\.\d+)?)")
class TrainingConverter(Converter): class TrainingConverter(Converter):
""" """A specialized converter for smaller LLMs (up to 7B parameters) that handles validation errors
A specialized converter for smaller LLMs (up to 7B parameters) that handles validation errors
by breaking down the model into individual fields and querying the LLM for each field separately. by breaking down the model into individual fields and querying the LLM for each field separately.
""" """
def to_pydantic(self, current_attempt=1) -> BaseModel: def to_pydantic(self, current_attempt: int = 1) -> BaseModel:
"""Convert the text to a Pydantic model, with fallback to field-by-field extraction on failure.
Args:
current_attempt: The current attempt number for conversion.
Returns:
An instance of the Pydantic model.
Raises:
ConverterError: If conversion fails after field-by-field extraction.
"""
try: try:
return super().to_pydantic(current_attempt) return super().to_pydantic(current_attempt)
except ConverterError: except ConverterError:
return self._convert_field_by_field() return self._convert_field_by_field()
def _convert_field_by_field(self) -> BaseModel: def _convert_field_by_field(self) -> BaseModel:
field_values = {} field_values: dict[str, Any] = {}
for field_name, field_info in self.model.model_fields.items(): for field_name, field_info in self.model.model_fields.items():
field_description = field_info.description field_description: str | None = field_info.description
field_type = field_info.annotation field_type: type | None = field_info.annotation
response = self._ask_llm_for_field(field_name, field_description) if field_description is None:
value = self._process_field_value(response, field_type) raise ValueError(f"Field '{field_name}' has no description")
response: str = self._ask_llm_for_field(
field_name=field_name, field_description=field_description
)
value: Any = self._process_field_value(
response=response, field_type=field_type
)
field_values[field_name] = value field_values[field_name] = value
try: try:
return self.model(**field_values) return self.model(**field_values)
except ValidationError as e: except ValidationError as e:
raise ConverterError(f"Failed to create model from individually collected fields: {e}") raise ConverterError(
f"Failed to create model from individually collected fields: {e}"
) from e
def _ask_llm_for_field(self, field_name: str, field_description: str) -> str: def _ask_llm_for_field(self, field_name: str, field_description: str) -> str:
prompt = f""" """Query the LLM for a specific field value based on its description.
Args:
field_name: The name of the field to extract.
field_description: The description of the field to guide extraction.
Returns:
The LLM's response containing the field value.
"""
prompt: str = f"""
Based on the following information: Based on the following information:
{self.text} {self.text}
@@ -45,14 +75,19 @@ Please provide ONLY the {field_name} field value as described:
Respond with ONLY the requested information, nothing else. Respond with ONLY the requested information, nothing else.
""" """
return self.llm.call([ return self.llm.call(
{"role": "system", "content": f"Extract the {field_name} from the previous information."}, [
{"role": "user", "content": prompt} {
]) "role": "system",
"content": f"Extract the {field_name} from the previous information.",
},
{"role": "user", "content": prompt},
]
)
def _process_field_value(self, response: str, field_type: Any) -> Any: def _process_field_value(self, response: str, field_type: type | None) -> Any:
response = response.strip() response = response.strip()
origin = get_origin(field_type) origin: type[Any] | None = get_origin(field_type)
if origin is list: if origin is list:
return self._parse_list(response) return self._parse_list(response)
@@ -65,25 +100,45 @@ Respond with ONLY the requested information, nothing else.
return response return response
def _parse_list(self, response: str) -> list: def _parse_list(self, response: str) -> list[Any]:
try: try:
if response.startswith('['): if response.startswith("["):
return json.loads(response) return json.loads(response)
items = [item.strip() for item in response.split('\n') if item.strip()] items: list[str] = [
item.strip() for item in response.split("\n") if item.strip()
]
return [self._strip_bullet(item) for item in items] return [self._strip_bullet(item) for item in items]
except json.JSONDecodeError: except json.JSONDecodeError:
return [response] return [response]
def _parse_float(self, response: str) -> float: @staticmethod
def _parse_float(response: str) -> float:
"""Parse a float from the response, extracting the first numeric value found.
Args:
response: The response string from which to extract the float.
Returns:
The extracted float value, or 0.0 if no valid float is found.
"""
try: try:
match = re.search(r'(\d+(\.\d+)?)', response) match = _FLOAT_PATTERN.search(response)
return float(match.group(1)) if match else 0.0 return float(match.group(1)) if match else 0.0
except Exception: except (ValueError, AttributeError):
return 0.0 return 0.0
def _strip_bullet(self, item: str) -> str: @staticmethod
if item.startswith(('- ', '* ')): def _strip_bullet(item: str) -> str:
"""Remove common bullet point characters from the start of a string.
Args:
item: The string item to process.
Returns:
The string without leading bullet characters.
"""
if item.startswith(("- ", "* ")):
return item[2:].strip() return item[2:].strip()
return item.strip() return item.strip()

View File

@@ -1,35 +1,33 @@
import os import os
from typing import Any
from crewai.utilities.file_handler import PickleHandler from crewai.utilities.file_handler import PickleHandler
class CrewTrainingHandler(PickleHandler): class CrewTrainingHandler(PickleHandler):
def save_trained_data(self, agent_id: str, trained_data: dict) -> None: def save_trained_data(self, agent_id: str, trained_data: dict[int, Any]) -> None:
""" """Save the trained data for a specific agent.
Save the trained data for a specific agent.
Parameters: Args:
- agent_id (str): The ID of the agent. agent_id: The ID of the agent.
- trained_data (dict): The trained data to be saved. trained_data: The trained data to be saved.
""" """
data = self.load() data = self.load()
data[agent_id] = trained_data data[agent_id] = trained_data
self.save(data) self.save(data)
def append(self, train_iteration: int, agent_id: str, new_data) -> None: def append(self, train_iteration: int, agent_id: str, new_data: Any) -> None:
""" """Append new training data for a specific agent and iteration.
Append new data to the existing pickle file.
Parameters: Args:
- new_data (object): The new data to be appended. train_iteration: The training iteration number.
agent_id: The ID of the agent.
new_data: The new training data to append.
""" """
data = self.load() data = self.load()
if agent_id not in data:
if agent_id in data: data[agent_id] = {}
data[agent_id][train_iteration] = new_data data[agent_id][train_iteration] = new_data
else:
data[agent_id] = {train_iteration: new_data}
self.save(data) self.save(data)
def clear(self) -> None: def clear(self) -> None:

View File

@@ -0,0 +1,15 @@
"""Types for CrewAI utilities."""
from typing import Literal, TypedDict
class LLMMessage(TypedDict):
"""Type for formatted LLM messages.
Notes:
- TODO: Update the LLM.call & BaseLLM.call signatures to use this type
instead of str | list[dict[str, str]]
"""
role: Literal["user", "assistant", "system"]
content: str

View File

@@ -1,11 +1,11 @@
"""Tests for reasoning in agents.""" """Tests for reasoning in agents."""
import json import json
import pytest import pytest
from crewai import Agent, Task from crewai import Agent, Task
from crewai.llm import LLM from crewai.llm import LLM
from crewai.utilities.reasoning_handler import AgentReasoning
@pytest.fixture @pytest.fixture
@@ -79,10 +79,8 @@ def test_agent_with_reasoning_not_ready_initially(mock_llm_responses):
call_count[0] += 1 call_count[0] += 1
if call_count[0] == 1: if call_count[0] == 1:
return mock_llm_responses["not_ready"] return mock_llm_responses["not_ready"]
else: return mock_llm_responses["ready_after_refine"]
return mock_llm_responses["ready_after_refine"] return "2x"
else:
return "2x"
agent.llm.call = mock_llm_call agent.llm.call = mock_llm_call
@@ -121,8 +119,7 @@ def test_agent_with_reasoning_max_attempts_reached():
) or any("refine your plan" in msg.get("content", "") for msg in messages): ) or any("refine your plan" in msg.get("content", "") for msg in messages):
call_count[0] += 1 call_count[0] += 1
return f"Attempt {call_count[0]}: I need more time to think.\n\nNOT READY: I need to refine my plan further." return f"Attempt {call_count[0]}: I need more time to think.\n\nNOT READY: I need to refine my plan further."
else: return "This is an unsolved problem in mathematics."
return "This is an unsolved problem in mathematics."
agent.llm.call = mock_llm_call agent.llm.call = mock_llm_call
@@ -135,26 +132,6 @@ def test_agent_with_reasoning_max_attempts_reached():
assert "Reasoning Plan:" in task.description assert "Reasoning Plan:" in task.description
def test_agent_reasoning_input_validation():
"""Test input validation in AgentReasoning."""
llm = LLM("gpt-3.5-turbo")
agent = Agent(
role="Test Agent",
goal="To test the reasoning feature",
backstory="I am a test agent created to verify the reasoning feature works correctly.",
llm=llm,
reasoning=True,
)
with pytest.raises(ValueError, match="Both task and agent must be provided"):
AgentReasoning(task=None, agent=agent)
task = Task(description="Simple task", expected_output="Simple output")
with pytest.raises(ValueError, match="Both task and agent must be provided"):
AgentReasoning(task=task, agent=None)
def test_agent_reasoning_error_handling(): def test_agent_reasoning_error_handling():
"""Test error handling during the reasoning process.""" """Test error handling during the reasoning process."""
llm = LLM("gpt-3.5-turbo") llm = LLM("gpt-3.5-turbo")
@@ -215,8 +192,7 @@ def test_agent_with_function_calling():
return json.dumps( return json.dumps(
{"plan": "I'll solve this simple math problem: 2+2=4.", "ready": True} {"plan": "I'll solve this simple math problem: 2+2=4.", "ready": True}
) )
else: return "4"
return "4"
agent.llm.call = mock_function_call agent.llm.call = mock_function_call
@@ -251,8 +227,7 @@ def test_agent_with_function_calling_fallback():
def mock_function_call(messages, *args, **kwargs): def mock_function_call(messages, *args, **kwargs):
if "tools" in kwargs: if "tools" in kwargs:
return "Invalid JSON that will trigger fallback. READY: I am ready to execute the task." return "Invalid JSON that will trigger fallback. READY: I am ready to execute the task."
else: return "4"
return "4"
agent.llm.call = mock_function_call agent.llm.call = mock_function_call

View File

@@ -39,7 +39,7 @@ def short_term_memory():
def test_short_term_memory_search_events(short_term_memory): def test_short_term_memory_search_events(short_term_memory):
events = defaultdict(list) events = defaultdict(list)
with patch("crewai.rag.chromadb.client.ChromaDBClient.search", return_value=[]): with patch.object(short_term_memory.storage, "search", return_value=[]):
with crewai_event_bus.scoped_handlers(): with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MemoryQueryStartedEvent) @crewai_event_bus.on(MemoryQueryStartedEvent)

View File

@@ -7,15 +7,14 @@ import pytest
from pydantic import BaseModel from pydantic import BaseModel
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM
from crewai.events.event_types import ( from crewai.events.event_types import (
LLMCallCompletedEvent, LLMCallCompletedEvent,
LLMStreamChunkEvent, LLMStreamChunkEvent,
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent, ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
) )
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM
from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.token_counter_callback import TokenCalcHandler
@@ -376,11 +375,11 @@ def get_weather_tool_schema():
def test_context_window_exceeded_error_handling(): def test_context_window_exceeded_error_handling():
"""Test that litellm.ContextWindowExceededError is converted to LLMContextLengthExceededException.""" """Test that litellm.ContextWindowExceededError is converted to LLMContextLengthExceededError."""
from litellm.exceptions import ContextWindowExceededError from litellm.exceptions import ContextWindowExceededError
from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException, LLMContextLengthExceededError,
) )
llm = LLM(model="gpt-4") llm = LLM(model="gpt-4")
@@ -393,7 +392,7 @@ def test_context_window_exceeded_error_handling():
llm_provider="openai", llm_provider="openai",
) )
with pytest.raises(LLMContextLengthExceededException) as excinfo: with pytest.raises(LLMContextLengthExceededError) as excinfo:
llm.call("This is a test message") llm.call("This is a test message")
assert "context length exceeded" in str(excinfo.value).lower() assert "context length exceeded" in str(excinfo.value).lower()
@@ -408,7 +407,7 @@ def test_context_window_exceeded_error_handling():
llm_provider="openai", llm_provider="openai",
) )
with pytest.raises(LLMContextLengthExceededException) as excinfo: with pytest.raises(LLMContextLengthExceededError) as excinfo:
llm.call("This is a test message") llm.call("This is a test message")
assert "context length exceeded" in str(excinfo.value).lower() assert "context length exceeded" in str(excinfo.value).lower()