chore: improve typing and consolidate utilities

- add type annotations across utility modules  
- refactor printer system, agent utils, and imports for consistency  
- remove unused modules, constants, and redundant patterns  
- improve runtime type checks, exception handling, and guardrail validation  
- standardize warning suppression and logging utilities  
- fix llm typing, threading/typing edge cases, and test behavior
This commit is contained in:
Greyson LaLonde
2025-09-23 11:33:46 -04:00
committed by GitHub
parent 34bed359a6
commit 3e97393f58
47 changed files with 1939 additions and 1233 deletions

View File

@@ -18,7 +18,7 @@ from crewai.agents.constants import (
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
UNABLE_TO_REPAIR_JSON_RESULTS,
)
from crewai.utilities import I18N
from crewai.utilities.i18n import I18N
_I18N = I18N()

View File

@@ -1,28 +1,26 @@
import io
import json
import logging
import os
import sys
import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from collections.abc import Callable
from datetime import datetime
from typing import (
Any,
DefaultDict,
Dict,
List,
Final,
Literal,
Optional,
Type,
TextIO,
TypedDict,
Union,
cast,
)
from datetime import datetime
from dotenv import load_dotenv
from litellm.types.utils import ChatCompletionDeltaToolCall
from pydantic import BaseModel, Field
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
@@ -31,15 +29,19 @@ from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
)
from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.logger_utils import suppress_warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
with suppress_warnings():
import litellm
from litellm import Choices
from litellm import Choices, CustomLogger
from litellm.exceptions import ContextWindowExceededError
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
@@ -47,16 +49,6 @@ with warnings.catch_warnings():
from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema
import io
from typing import TextIO
from crewai.llms.base_llm import BaseLLM
from crewai.events.event_bus import crewai_event_bus
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
load_dotenv()
litellm.suppress_debug_info = True
@@ -126,7 +118,11 @@ if not isinstance(sys.stderr, FilteredStream):
sys.stderr = FilteredStream(sys.stderr)
LLM_CONTEXT_WINDOW_SIZES = {
MIN_CONTEXT: Final[int] = 1024
MAX_CONTEXT: Final[int] = 2097152 # Current max from gemini-1.5-pro
ANTHROPIC_PREFIXES: Final[tuple[str, str, str]] = ("anthropic/", "claude-", "claude/")
LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
# openai
"gpt-4": 8192,
"gpt-4o": 128000,
@@ -252,30 +248,19 @@ LLM_CONTEXT_WINDOW_SIZES = {
"mistral/mistral-large-2402": 32768,
}
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
CONTEXT_WINDOW_USAGE_RATIO = 0.85
@contextmanager
def suppress_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
warnings.filterwarnings(
"ignore", message="open_text is deprecated*", category=DeprecationWarning
)
yield
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 8192
CONTEXT_WINDOW_USAGE_RATIO: Final[float] = 0.85
class Delta(TypedDict):
content: Optional[str]
role: Optional[str]
content: str | None
role: str | None
class StreamingChoices(TypedDict):
delta: Delta
index: int
finish_reason: Optional[str]
finish_reason: str | None
class FunctionArgs(BaseModel):
@@ -288,31 +273,31 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: Optional[float] = None
completion_cost: float | None = None
def __init__(
self,
model: str,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] | None = None,
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
timeout: float | int | None = None,
temperature: float | None = None,
top_p: float | None = None,
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
logprobs: int | None = None,
top_logprobs: int | None = None,
base_url: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
**kwargs,
):
@@ -345,7 +330,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str]
if stop is None:
self.stop: List[str] = []
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
@@ -354,7 +339,8 @@ class LLM(BaseLLM):
self.set_callbacks(callbacks or [])
self.set_env_callbacks()
def _is_anthropic_model(self, model: str) -> bool:
@staticmethod
def _is_anthropic_model(model: str) -> bool:
"""Determine if the model is from Anthropic provider.
Args:
@@ -363,21 +349,18 @@ class LLM(BaseLLM):
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for the completion call.
Args:
messages: Input messages for the LLM
tools: Optional list of tool schemas
callbacks: Optional list of callback functions
available_functions: Optional dict of available functions
Returns:
Dict[str, Any]: Parameters for the completion call
@@ -419,11 +402,11 @@ class LLM(BaseLLM):
def _handle_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle a streaming response from the LLM.
@@ -445,9 +428,8 @@ class LLM(BaseLLM):
last_chunk = None
chunk_count = 0
usage_info = None
tool_calls = None
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs
)
@@ -472,16 +454,16 @@ class LLM(BaseLLM):
choices = chunk["choices"]
elif hasattr(chunk, "choices"):
# Check if choices is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "choices"), type):
choices = getattr(chunk, "choices")
if not isinstance(chunk.choices, type):
choices = chunk.choices
# Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"]
elif hasattr(chunk, "usage"):
# Check if usage is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "usage"), type):
usage_info = getattr(chunk, "usage")
if not isinstance(chunk.usage, type):
usage_info = chunk.usage
if choices and len(choices) > 0:
choice = choices[0]
@@ -491,7 +473,7 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "delta" in choice:
delta = choice["delta"]
elif hasattr(choice, "delta"):
delta = getattr(choice, "delta")
delta = choice.delta
# Extract content from delta
if delta:
@@ -501,7 +483,7 @@ class LLM(BaseLLM):
chunk_content = delta["content"]
# Handle object format
elif hasattr(delta, "content"):
chunk_content = getattr(delta, "content")
chunk_content = delta.content
# Handle case where content might be None or empty
if chunk_content is None and isinstance(delta, dict):
@@ -533,7 +515,9 @@ class LLM(BaseLLM):
full_response += chunk_content
# Emit the chunk event
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise Exception("crewai_event_bus must have an `emit` method")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
@@ -572,8 +556,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -583,14 +567,14 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
content = None
if isinstance(message, dict) and "content" in message:
content = message["content"]
elif hasattr(message, "content"):
content = getattr(message, "content")
content = message.content
if content:
full_response = content
@@ -617,8 +601,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -627,13 +611,13 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"):
tool_calls = getattr(message, "tool_calls")
tool_calls = message.tool_calls
except Exception as e:
logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly
@@ -673,11 +657,11 @@ class LLM(BaseLLM):
# Catch context window errors from litellm and convert them to our own exception type.
# This exception is handled by CrewAgentExecutor._invoke_loop() which can then
# decide whether to summarize the content or abort based on the respect_context_window flag.
raise LLMContextLengthExceededException(str(e))
raise LLMContextLengthExceededError(str(e)) from e
except Exception as e:
logging.error(f"Error in streaming response: {str(e)}")
logging.error(f"Error in streaming response: {e!s}")
if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}")
logging.warning(f"Returning partial response despite error: {e!s}")
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
@@ -688,22 +672,25 @@ class LLM(BaseLLM):
return full_response
# Emit failed event and re-raise the exception
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise Exception(f"Failed to get streaming response: {str(e)}")
raise Exception(f"Failed to get streaming response: {e!s}") from e
def _handle_streaming_tool_calls(
self,
tool_calls: List[ChatCompletionDeltaToolCall],
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
tool_calls: list[ChatCompletionDeltaToolCall],
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> None | str:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -715,7 +702,8 @@ class LLM(BaseLLM):
current_tool_accumulator.function.arguments += (
tool_call.function.arguments
)
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
@@ -742,11 +730,11 @@ class LLM(BaseLLM):
continue
return None
@staticmethod
def _handle_streaming_callbacks(
self,
callbacks: Optional[List[Any]],
usage_info: Optional[Dict[str, Any]],
last_chunk: Optional[Any],
callbacks: list[Any] | None,
usage_info: dict[str, Any] | None,
last_chunk: Any | None,
) -> None:
"""Handle callbacks with usage info for streaming responses.
@@ -769,10 +757,8 @@ class LLM(BaseLLM):
):
usage_info = last_chunk["usage"]
elif hasattr(last_chunk, "usage"):
if not isinstance(
getattr(last_chunk, "usage"), type
):
usage_info = getattr(last_chunk, "usage")
if not isinstance(last_chunk.usage, type):
usage_info = last_chunk.usage
except Exception as e:
logging.debug(f"Error extracting usage info: {e}")
@@ -786,11 +772,11 @@ class LLM(BaseLLM):
def _handle_non_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Handle a non-streaming response from the LLM.
@@ -815,7 +801,7 @@ class LLM(BaseLLM):
except ContextWindowExceededError as e:
# Convert litellm's context window error to our own exception type
# for consistent handling in the rest of the codebase
raise LLMContextLengthExceededException(str(e))
raise LLMContextLengthExceededError(str(e)) from e
# --- 2) Extract response message and content
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
@@ -847,7 +833,7 @@ class LLM(BaseLLM):
)
return text_response
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
elif tool_calls and not available_functions and not text_response:
if tool_calls and not available_functions and not text_response:
return tool_calls
# --- 7) Handle tool calls if present
@@ -868,19 +854,21 @@ class LLM(BaseLLM):
def _handle_tool_call(
self,
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Optional[str]:
tool_calls: list[Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | None:
"""Handle a tool call from the LLM.
Args:
tool_calls: List of tool calls from the LLM
available_functions: Dict of available functions
from_task: Optional Task that invoked the LLM
from_agent: Optional Agent that invoked the LLM
Returns:
Optional[str]: The result of the tool call, or None if no tool call was made
The result of the tool call, or None if no tool call was made
"""
# --- 1) Validate tool calls and available functions
if not tool_calls or not available_functions:
@@ -899,7 +887,8 @@ class LLM(BaseLLM):
fn = available_functions[function_name]
# --- 3.2) Execute function
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
started_at = datetime.now()
crewai_event_bus.emit(
self,
@@ -939,17 +928,20 @@ class LLM(BaseLLM):
function_name, lambda: None
) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}")
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
)
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=function_name,
tool_args=function_args,
error=f"Tool execution error: {str(e)}",
error=f"Tool execution error: {e!s}",
from_task=from_task,
from_agent=from_agent,
),
@@ -958,13 +950,13 @@ class LLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""High-level LLM call method.
Args:
@@ -988,10 +980,11 @@ class LLM(BaseLLM):
Raises:
TypeError: If messages format is invalid
ValueError: If response format is not supported
LLMContextLengthExceededException: If input exceeds model's context limit
LLMContextLengthExceededError: If input exceeds model's context limit
"""
# --- 1) Emit call started event
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
@@ -1028,13 +1021,12 @@ class LLM(BaseLLM):
return self._handle_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
else:
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
except LLMContextLengthExceededException:
# Re-raise LLMContextLengthExceededException as it should be handled
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
@@ -1065,7 +1057,10 @@ class LLM(BaseLLM):
from_agent=from_agent,
)
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
@@ -1078,8 +1073,8 @@ class LLM(BaseLLM):
self,
response: Any,
call_type: LLMCallType,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
from_task: Any | None = None,
from_agent: Any | None = None,
messages: str | list[dict[str, Any]] | None = None,
):
"""Handle the events for the LLM call.
@@ -1091,7 +1086,8 @@ class LLM(BaseLLM):
from_agent: Optional agent object
messages: Optional messages object
"""
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(
@@ -1105,8 +1101,8 @@ class LLM(BaseLLM):
)
def _format_messages_for_provider(
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
self, messages: list[dict[str, str]]
) -> list[dict[str, str]]:
"""Format messages according to provider requirements.
Args:
@@ -1147,7 +1143,7 @@ class LLM(BaseLLM):
if "mistral" in self.model.lower():
# Check if the last message has a role of 'assistant'
if messages and messages[-1]["role"] == "assistant":
return messages + [{"role": "user", "content": "Please continue."}]
return [*messages, {"role": "user", "content": "Please continue."}]
return messages
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
@@ -1157,7 +1153,7 @@ class LLM(BaseLLM):
and messages
and messages[-1]["role"] == "assistant"
):
return messages + [{"role": "user", "content": ""}]
return [*messages, {"role": "user", "content": ""}]
# Handle Anthropic models
if not self.is_anthropic:
@@ -1170,7 +1166,7 @@ class LLM(BaseLLM):
return messages
def _get_custom_llm_provider(self) -> Optional[str]:
def _get_custom_llm_provider(self) -> str | None:
"""
Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
@@ -1207,7 +1203,7 @@ class LLM(BaseLLM):
self.model, custom_llm_provider=provider
)
except Exception as e:
logging.error(f"Failed to check function calling support: {str(e)}")
logging.error(f"Failed to check function calling support: {e!s}")
return False
def supports_stop_words(self) -> bool:
@@ -1215,7 +1211,7 @@ class LLM(BaseLLM):
params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
logging.error(f"Failed to get supported params: {e!s}")
return False
def get_context_window_size(self) -> int:
@@ -1229,9 +1225,6 @@ class LLM(BaseLLM):
if self.context_window_size != 0:
return self.context_window_size
MIN_CONTEXT = 1024
MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro
# Validate all context window sizes
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < MIN_CONTEXT or value > MAX_CONTEXT:
@@ -1247,7 +1240,8 @@ class LLM(BaseLLM):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
@staticmethod
def set_callbacks(callbacks: list[Any]):
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
@@ -1264,9 +1258,9 @@ class LLM(BaseLLM):
litellm.callbacks = callbacks
def set_env_callbacks(self):
"""
Sets the success and failure callbacks for the LiteLLM library from environment variables.
@staticmethod
def set_env_callbacks() -> None:
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
environment variables, which should contain comma-separated lists of callback names.
@@ -1276,7 +1270,7 @@ class LLM(BaseLLM):
If the environment variables are not set or are empty, the corresponding callback lists
will be set to empty lists.
Example:
Examples:
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
LITELLM_FAILURE_CALLBACKS="langfuse"
@@ -1285,16 +1279,15 @@ class LLM(BaseLLM):
"""
with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
success_callbacks = []
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
if success_callbacks_str:
success_callbacks = [
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
]
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
failure_callbacks = []
if failure_callbacks_str:
failure_callbacks = [
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
]

View File

@@ -1,26 +1,24 @@
from .converter import Converter, ConverterError
from .file_handler import FileHandler
from .i18n import I18N
from .internal_instructor import InternalInstructor
from .logger import Logger
from .parser import YamlParser
from .printer import Printer
from .prompts import Prompts
from .rpm_controller import RPMController
from .exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
from crewai.utilities.converter import Converter, ConverterError
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.file_handler import FileHandler
from crewai.utilities.i18n import I18N
from crewai.utilities.internal_instructor import InternalInstructor
from crewai.utilities.logger import Logger
from crewai.utilities.printer import Printer
from crewai.utilities.prompts import Prompts
from crewai.utilities.rpm_controller import RPMController
__all__ = [
"I18N",
"Converter",
"ConverterError",
"FileHandler",
"I18N",
"InternalInstructor",
"LLMContextLengthExceededError",
"Logger",
"Printer",
"Prompts",
"RPMController",
"YamlParser",
"LLMContextLengthExceededException",
]

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import json
import re
from collections.abc import Callable, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
from rich.console import Console
@@ -19,18 +21,47 @@ from crewai.tools import BaseTool as CrewAITool
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer
from crewai.utilities.errors import AgentRepositoryError
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
LLMContextLengthExceededError,
)
from crewai.utilities.i18n import I18N
from crewai.utilities.printer import ColoredText, Printer
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
class SummaryContent(TypedDict):
"""Structure for summary content entries.
Attributes:
content: The summarized content.
"""
content: str
console = Console()
_MULTIPLE_NEWLINES: Final[re.Pattern[str]] = re.compile(r"\n+")
def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
"""Parse tools to be used for the task."""
tools_list = []
"""Parse tools to be used for the task.
Args:
tools: List of tools to parse.
Returns:
List of structured tools.
Raises:
ValueError: If a tool is not a CrewStructuredTool or BaseTool.
"""
tools_list: list[CrewStructuredTool] = []
for tool in tools:
if isinstance(tool, CrewAITool):
@@ -42,7 +73,14 @@ def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str:
"""Get the names of the tools."""
"""Get the names of the tools.
Args:
tools: List of tools to get names from.
Returns:
Comma-separated string of tool names.
"""
return ", ".join([t.name for t in tools])
@@ -51,16 +89,30 @@ def render_text_description_and_args(
) -> str:
"""Render the tool name, description, and args in plain text.
search: This tool is used for search, args: {"query": {"type": "string"}}
calculator: This tool is used for math, \
args: {"expression": {"type": "string"}}
search: This tool is used for search, args: {"query": {"type": "string"}}
calculator: This tool is used for math, \
args: {"expression": {"type": "string"}}
Args:
tools: List of tools to render.
Returns:
Plain text description of tools.
"""
tool_strings = [tool.description for tool in tools]
return "\n".join(tool_strings)
def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool:
"""Check if the maximum number of iterations has been reached."""
"""Check if the maximum number of iterations has been reached.
Args:
iterations: Current number of iterations.
max_iterations: Maximum allowed iterations.
Returns:
True if maximum iterations reached, False otherwise.
"""
return iterations >= max_iterations
@@ -68,16 +120,19 @@ def handle_max_iterations_exceeded(
formatted_answer: AgentAction | AgentFinish | None,
printer: Printer,
i18n: I18N,
messages: list[dict[str, str]],
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[Any],
callbacks: list[Callable[..., Any]],
) -> AgentAction | AgentFinish:
"""
Handles the case when the maximum number of iterations is exceeded.
Performs one more LLM call to get the final answer.
"""Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer.
Parameters:
Args:
formatted_answer: The last formatted answer from the agent.
printer: Printer instance for output.
i18n: I18N instance for internationalization.
messages: List of messages to send to the LLM.
llm: The LLM instance to call.
callbacks: List of callbacks for the LLM call.
Returns:
The final formatted answer after exceeding max iterations.
@@ -98,7 +153,7 @@ def handle_max_iterations_exceeded(
# Perform one more LLM call to get the final answer
answer = llm.call(
messages,
messages, # type: ignore[arg-type]
callbacks=callbacks,
)
@@ -110,20 +165,38 @@ def handle_max_iterations_exceeded(
raise ValueError("Invalid response from LLM call - None or empty.")
# Return the formatted answer, regardless of its type
return format_answer(answer)
return format_answer(answer=answer)
def format_message_for_llm(prompt: str, role: str = "user") -> dict[str, str]:
def format_message_for_llm(
prompt: str, role: Literal["user", "assistant", "system"] = "user"
) -> LLMMessage:
"""Format a message for the LLM.
Args:
prompt: The message content.
role: The role of the message sender, either 'user' or 'assistant'.
Returns:
A dictionary with 'role' and 'content' keys.
"""
prompt = prompt.rstrip()
return {"role": role, "content": prompt}
def format_answer(answer: str) -> AgentAction | AgentFinish:
"""Format a response from the LLM into an AgentAction or AgentFinish."""
"""Format a response from the LLM into an AgentAction or AgentFinish.
Args:
answer: The raw response from the LLM
Returns:
Either an AgentAction or AgentFinish
"""
try:
return parse(answer)
except Exception:
# If parsing fails, return a default AgentFinish
return AgentFinish(
thought="Failed to parse LLM response",
output=answer,
@@ -134,23 +207,43 @@ def format_answer(answer: str) -> AgentAction | AgentFinish:
def enforce_rpm_limit(
request_within_rpm_limit: Callable[[], bool] | None = None,
) -> None:
"""Enforce the requests per minute (RPM) limit if applicable."""
"""Enforce the requests per minute (RPM) limit if applicable.
Args:
request_within_rpm_limit: Function to enforce RPM limit.
"""
if request_within_rpm_limit:
request_within_rpm_limit()
def get_llm_response(
llm: LLM | BaseLLM,
messages: list[dict[str, str]],
callbacks: list[Any],
messages: list[LLMMessage],
callbacks: list[Callable[..., Any]],
printer: Printer,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> str:
"""Call the LLM and return the response, handling any invalid responses."""
"""Call the LLM and return the response, handling any invalid responses.
Args:
llm: The LLM instance to call
messages: The messages to send to the LLM
callbacks: List of callbacks for the LLM call
printer: Printer instance for output
from_task: Optional task context for the LLM call
from_agent: Optional agent context for the LLM call
Returns:
The response from the LLM as a string
Raises:
Exception: If an error occurs.
ValueError: If the response is None or empty.
"""
try:
answer = llm.call(
messages,
messages, # type: ignore[arg-type]
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent,
@@ -170,7 +263,15 @@ def get_llm_response(
def process_llm_response(
answer: str, use_stop_words: bool
) -> AgentAction | AgentFinish:
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
"""Process the LLM response and format it into an AgentAction or AgentFinish.
Args:
answer: The raw response from the LLM
use_stop_words: Whether to use stop words in the LLM call
Returns:
Either an AgentAction or AgentFinish
"""
if not use_stop_words:
try:
# Preliminary parsing to check for errors.
@@ -200,6 +301,9 @@ def handle_agent_action_core(
Returns:
Either an AgentAction or AgentFinish
Notes:
- TODO: Remove messages parameter and its usage.
"""
if step_callback:
step_callback(tool_result)
@@ -220,7 +324,7 @@ def handle_agent_action_core(
return formatted_answer
def handle_unknown_error(printer: Any, exception: Exception) -> None:
def handle_unknown_error(printer: Printer, exception: Exception) -> None:
"""Handle unknown errors by informing the user.
Args:
@@ -244,10 +348,10 @@ def handle_unknown_error(printer: Any, exception: Exception) -> None:
def handle_output_parser_exception(
e: OutputParserError,
messages: list[dict[str, str]],
messages: list[LLMMessage],
iterations: int,
log_error_after: int = 3,
printer: Any | None = None,
printer: Printer | None = None,
) -> AgentAction:
"""Handle OutputParserError by updating messages and formatted_answer.
@@ -288,18 +392,18 @@ def is_context_length_exceeded(exception: Exception) -> bool:
Returns:
bool: True if the exception is due to context length exceeding
"""
return LLMContextLengthExceededException(str(exception))._is_context_limit_error(
return LLMContextLengthExceededError(str(exception))._is_context_limit_error(
str(exception)
)
def handle_context_length(
respect_context_window: bool,
printer: Any,
messages: list[dict[str, str]],
llm: Any,
callbacks: list[Any],
i18n: Any,
printer: Printer,
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[Callable[..., Any]],
i18n: I18N,
) -> None:
"""Handle context length exceeded by either summarizing or raising an error.
@@ -310,13 +414,16 @@ def handle_context_length(
llm: LLM instance for summarization
callbacks: List of callbacks for LLM
i18n: I18N instance for messages
Raises:
SystemExit: If context length is exceeded and user opts not to summarize
"""
if respect_context_window:
printer.print(
content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...",
color="yellow",
)
summarize_messages(messages, llm, callbacks, i18n)
summarize_messages(messages=messages, llm=llm, callbacks=callbacks, i18n=i18n)
else:
printer.print(
content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
@@ -328,10 +435,10 @@ def handle_context_length(
def summarize_messages(
messages: list[dict[str, str]],
llm: Any,
callbacks: list[Any],
i18n: Any,
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[Callable[..., Any]],
i18n: I18N,
) -> None:
"""Summarize messages to fit within context window.
@@ -349,7 +456,7 @@ def summarize_messages(
for i in range(0, len(messages_string), cut_size)
]
summarized_contents = []
summarized_contents: list[SummaryContent] = []
total_groups = len(messages_groups)
for idx, group in enumerate(messages_groups, 1):
@@ -357,15 +464,17 @@ def summarize_messages(
content=f"Summarizing {idx}/{total_groups}...",
color="yellow",
)
messages = [
format_message_for_llm(
i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
i18n.slice("summarize_instruction").format(group=group["content"]),
),
]
summary = llm.call(
[
format_message_for_llm(
i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
i18n.slice("summarize_instruction").format(group=group["content"]),
),
],
messages, # type: ignore[arg-type]
callbacks=callbacks,
)
summarized_contents.append({"content": str(summary)})
@@ -404,20 +513,29 @@ def show_agent_logs(
if formatted_answer is None:
# Start logs
printer.print(
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
content=[
ColoredText("# Agent: ", "bold_purple"),
ColoredText(agent_role, "bold_green"),
]
)
if task_description:
printer.print(
content=f"\033[95m## Task:\033[00m \033[92m{task_description}\033[00m"
content=[
ColoredText("## Task: ", "purple"),
ColoredText(task_description, "green"),
]
)
else:
# Execution logs
printer.print(
content=f"\n\n\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
content=[
ColoredText("\n\n# Agent: ", "bold_purple"),
ColoredText(agent_role, "bold_green"),
]
)
if isinstance(formatted_answer, AgentAction):
thought = re.sub(r"\n+", "\n", formatted_answer.thought)
thought = _MULTIPLE_NEWLINES.sub("\n", formatted_answer.thought)
formatted_json = json.dumps(
formatted_answer.tool_input,
indent=2,
@@ -425,24 +543,39 @@ def show_agent_logs(
)
if thought and thought != "":
printer.print(
content=f"\033[95m## Thought:\033[00m \033[92m{thought}\033[00m"
content=[
ColoredText("## Thought: ", "purple"),
ColoredText(thought, "green"),
]
)
printer.print(
content=f"\033[95m## Using tool:\033[00m \033[92m{formatted_answer.tool}\033[00m"
content=[
ColoredText("## Using tool: ", "purple"),
ColoredText(formatted_answer.tool, "green"),
]
)
printer.print(
content=f"\033[95m## Tool Input:\033[00m \033[92m\n{formatted_json}\033[00m"
content=[
ColoredText("## Tool Input: ", "purple"),
ColoredText(f"\n{formatted_json}", "green"),
]
)
printer.print(
content=f"\033[95m## Tool Output:\033[00m \033[92m\n{formatted_answer.result}\033[00m"
content=[
ColoredText("## Tool Output: ", "purple"),
ColoredText(f"\n{formatted_answer.result}", "green"),
]
)
elif isinstance(formatted_answer, AgentFinish):
printer.print(
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
content=[
ColoredText("## Final Answer: ", "purple"),
ColoredText(f"\n{formatted_answer.output}\n\n", "green"),
]
)
def _print_current_organization():
def _print_current_organization() -> None:
settings = Settings()
if settings.org_uuid:
console.print(
@@ -457,6 +590,17 @@ def _print_current_organization():
def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
"""Load an agent from the repository.
Args:
from_repository: The name of the agent to load.
Returns:
A dictionary of attributes to use for the agent.
Raises:
AgentRepositoryError: If the agent cannot be loaded.
"""
attributes: dict[str, Any] = {}
if from_repository:
import importlib

View File

@@ -1,20 +1,19 @@
from typing import Any, Dict, Type
from typing import Any
from pydantic import BaseModel
def process_config(
values: Dict[str, Any], model_class: Type[BaseModel]
) -> Dict[str, Any]:
"""
Process the config dictionary and update the values accordingly.
values: dict[str, Any], model_class: type[BaseModel]
) -> dict[str, Any]:
"""Process the config dictionary and update the values accordingly.
Args:
values (Dict[str, Any]): The dictionary of values to update.
model_class (Type[BaseModel]): The Pydantic model class to reference for field validation.
values: The dictionary of values to update.
model_class: The Pydantic model class to reference for field validation.
Returns:
Dict[str, Any]: The updated values dictionary.
The updated values dictionary.
"""
config = values.get("config", {})
if not config:

View File

@@ -1,19 +1,32 @@
TRAINING_DATA_FILE = "training_data.pkl"
TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl"
DEFAULT_SCORE_THRESHOLD = 0.35
KNOWLEDGE_DIRECTORY = "knowledge"
MAX_LLM_RETRY = 3
MAX_FILE_NAME_LENGTH = 255
EMITTER_COLOR = "bold_blue"
from typing import Annotated, Final
from crewai.utilities.printer import PrinterColor
TRAINING_DATA_FILE: Final[str] = "training_data.pkl"
TRAINED_AGENTS_DATA_FILE: Final[str] = "trained_agents_data.pkl"
KNOWLEDGE_DIRECTORY: Final[str] = "knowledge"
MAX_FILE_NAME_LENGTH: Final[int] = 255
EMITTER_COLOR: Final[PrinterColor] = "bold_blue"
class _NotSpecified:
def __repr__(self):
"""Sentinel class to detect when no value has been explicitly provided.
Notes:
- TODO: Consider moving this class and NOT_SPECIFIED to types.py
as they are more type-related constructs than business constants.
"""
def __repr__(self) -> str:
return "NOT_SPECIFIED"
# Sentinel value used to detect when no value has been explicitly provided.
# Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` allows
# us to distinguish between "not passed at all" and "explicitly passed None" or "[]".
NOT_SPECIFIED = _NotSpecified()
CREWAI_BASE_URL = "https://app.crewai.com"
NOT_SPECIFIED: Final[
Annotated[
_NotSpecified,
"Sentinel value used to detect when no value has been explicitly provided. "
"Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` "
"allows us to distinguish between 'not passed at all' and 'explicitly passed None' or '[]'.",
]
] = _NotSpecified()
CREWAI_BASE_URL: Final[str] = "https://app.crewai.com"

View File

@@ -1,18 +1,35 @@
from __future__ import annotations
import json
import re
from typing import Any, Optional, Type, Union, get_args, get_origin
from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin
from pydantic import BaseModel, ValidationError
from typing_extensions import Unpack
from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter
from crewai.utilities.internal_instructor import InternalInstructor
from crewai.utilities.printer import Printer
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
_JSON_PATTERN: Final[re.Pattern[str]] = re.compile(r"({.*})", re.DOTALL)
class ConverterError(Exception):
"""Error raised when Converter fails to parse the input."""
def __init__(self, message: str, *args: object) -> None:
"""Initialize the ConverterError with a message.
Args:
message: The error message.
*args: Additional arguments for the base Exception class.
"""
super().__init__(message, *args)
self.message = message
@@ -20,8 +37,18 @@ class ConverterError(Exception):
class Converter(OutputConverter):
"""Class that converts text into either pydantic or json."""
def to_pydantic(self, current_attempt=1) -> BaseModel:
"""Convert text to pydantic."""
def to_pydantic(self, current_attempt: int = 1) -> BaseModel:
"""Convert text to pydantic.
Args:
current_attempt: The current attempt number for conversion retries.
Returns:
A Pydantic BaseModel instance.
Raises:
ConverterError: If conversion fails after maximum attempts.
"""
try:
if self.llm.supports_function_calling():
result = self._create_instructor().to_pydantic()
@@ -37,104 +64,124 @@ class Converter(OutputConverter):
result = self.model.model_validate_json(response)
except ValidationError:
# If direct validation fails, attempt to extract valid JSON
result = handle_partial_json(response, self.model, False, None)
result = handle_partial_json(
result=response,
model=self.model,
is_json_output=False,
agent=None,
)
# Ensure result is a BaseModel instance
if not isinstance(result, BaseModel):
if isinstance(result, dict):
result = self.model.parse_obj(result)
result = self.model.model_validate(result)
elif isinstance(result, str):
try:
parsed = json.loads(result)
result = self.model.parse_obj(parsed)
result = self.model.model_validate(parsed)
except Exception as parse_err:
raise ConverterError(
f"Failed to convert partial JSON result into Pydantic: {parse_err}"
)
) from parse_err
else:
raise ConverterError(
"handle_partial_json returned an unexpected type."
)
) from None
return result
except ValidationError as e:
if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1)
raise ConverterError(
f"Failed to convert text into a Pydantic model due to validation error: {e}"
)
) from e
except Exception as e:
if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1)
raise ConverterError(
f"Failed to convert text into a Pydantic model due to error: {e}"
)
) from e
def to_json(self, current_attempt=1):
"""Convert text to json."""
def to_json(self, current_attempt: int = 1) -> str | ConverterError | Any: # type: ignore[override]
"""Convert text to json.
Args:
current_attempt: The current attempt number for conversion retries.
Returns:
A JSON string or ConverterError if conversion fails.
Raises:
ConverterError: If conversion fails after maximum attempts.
"""
try:
if self.llm.supports_function_calling():
return self._create_instructor().to_json()
else:
return json.dumps(
self.llm.call(
[
{"role": "system", "content": self.instructions},
{"role": "user", "content": self.text},
]
)
return json.dumps(
self.llm.call(
[
{"role": "system", "content": self.instructions},
{"role": "user", "content": self.text},
]
)
)
except Exception as e:
if current_attempt < self.max_attempts:
return self.to_json(current_attempt + 1)
return ConverterError(f"Failed to convert text into JSON, error: {e}.")
def _create_instructor(self):
def _create_instructor(self) -> InternalInstructor:
"""Create an instructor."""
from crewai.utilities import InternalInstructor
inst = InternalInstructor(
return InternalInstructor(
llm=self.llm,
model=self.model,
content=self.text,
)
return inst
def _convert_with_instructions(self):
"""Create a chain."""
from crewai.utilities.crew_pydantic_output_parser import (
CrewPydanticOutputParser,
)
parser = CrewPydanticOutputParser(pydantic_object=self.model)
result = self.llm.call(
[
{"role": "system", "content": self.instructions},
{"role": "user", "content": self.text},
]
)
return parser.parse_result(result)
def convert_to_model(
result: str,
output_pydantic: Optional[Type[BaseModel]],
output_json: Optional[Type[BaseModel]],
agent: Any,
converter_cls: Optional[Type[Converter]] = None,
) -> Union[dict, BaseModel, str]:
output_pydantic: type[BaseModel] | None,
output_json: type[BaseModel] | None,
agent: Agent | None = None,
converter_cls: type[Converter] | None = None,
) -> dict[str, Any] | BaseModel | str:
"""Convert a result string to a Pydantic model or JSON.
Args:
result: The result string to convert.
output_pydantic: The Pydantic model class to convert to.
output_json: The Pydantic model class to convert to JSON.
agent: The agent instance.
converter_cls: The converter class to use.
Returns:
The converted result as a dict, BaseModel, or original string.
"""
model = output_pydantic or output_json
if model is None:
return result
try:
escaped_result = json.dumps(json.loads(result, strict=False))
return validate_model(escaped_result, model, bool(output_json))
return validate_model(
result=escaped_result, model=model, is_json_output=bool(output_json)
)
except json.JSONDecodeError:
return handle_partial_json(
result, model, bool(output_json), agent, converter_cls
result=result,
model=model,
is_json_output=bool(output_json),
agent=agent,
converter_cls=converter_cls,
)
except ValidationError:
return handle_partial_json(
result, model, bool(output_json), agent, converter_cls
result=result,
model=model,
is_json_output=bool(output_json),
agent=agent,
converter_cls=converter_cls,
)
except Exception as e:
@@ -146,8 +193,18 @@ def convert_to_model(
def validate_model(
result: str, model: Type[BaseModel], is_json_output: bool
) -> Union[dict, BaseModel]:
result: str, model: type[BaseModel], is_json_output: bool
) -> dict[str, Any] | BaseModel:
"""Validate and convert a JSON string to a Pydantic model or dict.
Args:
result: The JSON string to validate and convert.
model: The Pydantic model class to convert to.
is_json_output: Whether to return a dict (True) or Pydantic model (False).
Returns:
The converted result as a dict or BaseModel.
"""
exported_result = model.model_validate_json(result)
if is_json_output:
return exported_result.model_dump()
@@ -156,15 +213,27 @@ def validate_model(
def handle_partial_json(
result: str,
model: Type[BaseModel],
model: type[BaseModel],
is_json_output: bool,
agent: Any,
converter_cls: Optional[Type[Converter]] = None,
) -> Union[dict, BaseModel, str]:
match = re.search(r"({.*})", result, re.DOTALL)
agent: Agent | None,
converter_cls: type[Converter] | None = None,
) -> dict[str, Any] | BaseModel | str:
"""Handle partial JSON in a result string and convert to Pydantic model or dict.
Args:
result: The result string to process.
model: The Pydantic model class to convert to.
is_json_output: Whether to return a dict (True) or Pydantic model (False).
agent: The agent instance.
converter_cls: The converter class to use.
Returns:
The converted result as a dict, BaseModel, or original string.
"""
match = _JSON_PATTERN.search(result)
if match:
try:
exported_result = model.model_validate_json(match.group(0))
exported_result = model.model_validate_json(match.group())
if is_json_output:
return exported_result.model_dump()
return exported_result
@@ -179,19 +248,43 @@ def handle_partial_json(
)
return convert_with_instructions(
result, model, is_json_output, agent, converter_cls
result=result,
model=model,
is_json_output=is_json_output,
agent=agent,
converter_cls=converter_cls,
)
def convert_with_instructions(
result: str,
model: Type[BaseModel],
model: type[BaseModel],
is_json_output: bool,
agent: Any,
converter_cls: Optional[Type[Converter]] = None,
) -> Union[dict, BaseModel, str]:
agent: Agent | None,
converter_cls: type[Converter] | None = None,
) -> dict | BaseModel | str:
"""Convert a result string to a Pydantic model or JSON using instructions.
Args:
result: The result string to convert.
model: The Pydantic model class to convert to.
is_json_output: Whether to return a dict (True) or Pydantic model (False).
agent: The agent instance.
converter_cls: The converter class to use.
Returns:
The converted result as a dict, BaseModel, or original string.
Raises:
TypeError: If neither agent nor converter_cls is provided.
Notes:
- TODO: Fix llm typing issues, return llm should not be able to be str or None.
"""
if agent is None:
raise TypeError("Agent must be provided if converter_cls is not specified.")
llm = agent.function_calling_llm or agent.llm
instructions = get_conversion_instructions(model, llm)
instructions = get_conversion_instructions(model=model, llm=llm)
converter = create_converter(
agent=agent,
converter_cls=converter_cls,
@@ -214,9 +307,25 @@ def convert_with_instructions(
return exported_result
def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str:
def get_conversion_instructions(
model: type[BaseModel], llm: BaseLLM | LLM | str
) -> str:
"""Generate conversion instructions based on the model and LLM capabilities.
Args:
model: A Pydantic model class.
llm: The language model instance.
Returns:
"""
instructions = "Please convert the following text into valid JSON."
if llm and not isinstance(llm, str) and llm.supports_function_calling():
if (
llm
and not isinstance(llm, str)
and hasattr(llm, "supports_function_calling")
and llm.supports_function_calling()
):
model_schema = PydanticSchemaParser(model=model).get_schema()
instructions += (
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
@@ -231,12 +340,45 @@ def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str:
return instructions
class CreateConverterKwargs(TypedDict, total=False):
"""Keyword arguments for creating a converter.
Attributes:
llm: The language model instance.
text: The text to convert.
model: The Pydantic model class.
instructions: The conversion instructions.
"""
llm: BaseLLM | LLM | str
text: str
model: type[BaseModel]
instructions: str
def create_converter(
agent: Optional[Any] = None,
converter_cls: Optional[Type[Converter]] = None,
*args,
**kwargs,
agent: Agent | None = None,
converter_cls: type[Converter] | None = None,
*args: Any,
**kwargs: Unpack[CreateConverterKwargs],
) -> Converter:
"""Create a converter instance based on the agent or provided class.
Args:
agent: The agent instance.
converter_cls: The converter class to instantiate.
*args: The positional arguments to pass to the converter.
**kwargs: The keyword arguments to pass to the converter.
Returns:
An instance of the specified converter class.
Raises:
ValueError: If neither agent nor converter_cls is provided.
AttributeError: If the agent does not have a 'get_output_converter' method.
Exception: If no converter instance is created.
"""
if agent and not converter_cls:
if hasattr(agent, "get_output_converter"):
converter = agent.get_output_converter(*args, **kwargs)
@@ -253,17 +395,30 @@ def create_converter(
return converter
def generate_model_description(model: Type[BaseModel]) -> str:
"""
Generate a string description of a Pydantic model's fields and their types.
def generate_model_description(model: type[BaseModel]) -> str:
"""Generate a string description of a Pydantic model's fields and their types.
This function takes a Pydantic model class and returns a string that describes
the model's fields and their respective types. The description includes handling
of complex types such as `Optional`, `List`, and `Dict`, as well as nested Pydantic
models.
Args:
model: A Pydantic model class.
Returns:
A string representation of the model's fields and types.
"""
def describe_field(field_type):
def describe_field(field_type: Any) -> str:
"""Recursively describe a field's type.
Args:
field_type: The type of the field to describe.
Returns:
A string representation of the field's type.
"""
origin = get_origin(field_type)
args = get_args(field_type)
@@ -272,20 +427,18 @@ def generate_model_description(model: Type[BaseModel]) -> str:
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
return f"Optional[{describe_field(non_none_args[0])}]"
else:
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
elif origin is list:
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
if origin is list:
return f"List[{describe_field(args[0])}]"
elif origin is dict:
if origin is dict:
key_type = describe_field(args[0])
value_type = describe_field(args[1])
return f"Dict[{key_type}, {value_type}]"
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
return generate_model_description(field_type)
elif hasattr(field_type, "__name__"):
if hasattr(field_type, "__name__"):
return field_type.__name__
else:
return str(field_type)
return str(field_type)
fields = model.model_fields
field_descriptions = [

View File

@@ -1 +1 @@
"""Crew-specific utilities."""
"""Crew-specific utilities."""

View File

@@ -1,16 +1,16 @@
"""Context management utilities for tracking crew and task execution context using OpenTelemetry baggage."""
from typing import Optional
from typing import cast
from opentelemetry import baggage
from crewai.utilities.crew.models import CrewContext
def get_crew_context() -> Optional[CrewContext]:
def get_crew_context() -> CrewContext | None:
"""Get the current crew context from OpenTelemetry baggage.
Returns:
CrewContext instance containing crew context information, or None if no context is set
"""
return baggage.get_baggage("crew_context")
return cast(CrewContext | None, baggage.get_baggage("crew_context"))

View File

@@ -1,16 +1,17 @@
"""Models for crew-related data structures."""
from typing import Optional
from pydantic import BaseModel, Field
class CrewContext(BaseModel):
"""Model representing crew context information."""
"""Model representing crew context information.
id: Optional[str] = Field(
default=None, description="Unique identifier for the crew"
)
key: Optional[str] = Field(
Attributes:
id: Unique identifier for the crew.
key: Optional crew key/name for identification.
"""
id: str | None = Field(default=None, description="Unique identifier for the crew")
key: str | None = Field(
default=None, description="Optional crew key/name for identification"
)

View File

@@ -4,6 +4,7 @@ import json
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from typing import Any
from uuid import UUID
from pydantic import BaseModel
@@ -11,18 +12,28 @@ from pydantic import BaseModel
class CrewJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder for CrewAI objects and special types."""
def default(self, obj):
def default(self, obj: Any) -> Any:
"""Custom serialization for CrewAI specific types.
Args:
obj: The object to serialize.
Returns:
A JSON-serializable representation of the object.
"""
if isinstance(obj, BaseModel):
return self._handle_pydantic_model(obj)
elif isinstance(obj, UUID) or isinstance(obj, Decimal) or isinstance(obj, Enum):
if isinstance(obj, (UUID, Decimal, Enum)):
return str(obj)
elif isinstance(obj, datetime) or isinstance(obj, date):
if isinstance(obj, (datetime, date)):
return obj.isoformat()
return super().default(obj)
def _handle_pydantic_model(self, obj):
@staticmethod
def _handle_pydantic_model(obj: BaseModel) -> str | Any:
try:
data = obj.model_dump()
# Remove circular references

View File

@@ -1,48 +0,0 @@
import json
from typing import Any
import regex
from pydantic import BaseModel, ValidationError
from crewai.agents.parser import OutputParserError
"""Parser for converting text outputs into Pydantic models."""
class CrewPydanticOutputParser:
"""Parses text outputs into specified Pydantic models."""
pydantic_object: type[BaseModel]
def parse_result(self, result: str) -> Any:
result = self._transform_in_valid_json(result)
# Treating edge case of function calling llm returning the name instead of tool_name
json_object = json.loads(result)
if "tool_name" not in json_object:
json_object["tool_name"] = json_object.get("name", "")
result = json.dumps(json_object)
try:
return self.pydantic_object.model_validate(json_object)
except ValidationError as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
raise OutputParserError(error=msg) from e
def _transform_in_valid_json(self, text) -> str:
text = text.replace("```", "").replace("json", "")
json_pattern = r"\{(?:[^{}]|(?R))*\}"
matches = regex.finditer(json_pattern, text)
for match in matches:
try:
# Attempt to parse the matched string as JSON
json_obj = json.loads(match.group())
# Return the first successfully parsed JSON object
json_obj = json.dumps(json_obj)
return str(json_obj)
except json.JSONDecodeError: # noqa: PERF203
# If parsing fails, skip to the next match
continue
return text

View File

@@ -8,7 +8,11 @@ from typing import Final
class DatabaseOperationError(Exception):
"""Base exception class for database operation errors."""
"""Base exception class for database operation errors.
Attributes:
original_error: The original exception that caused this error, if any.
"""
def __init__(self, message: str, original_error: Exception | None = None) -> None:
"""Initialize the database operation error.

View File

@@ -1,4 +1,7 @@
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field, InstanceOf
from rich.box import HEAVY_EDGE
@@ -6,11 +9,14 @@ from rich.console import Console
from rich.table import Table
from crewai.agent import Agent
from crewai.llm import BaseLLM
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.crew_events import CrewTestResultEvent
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
if TYPE_CHECKING:
from crewai.crew import Crew
class TaskEvaluationPydanticOutput(BaseModel):
@@ -20,23 +26,21 @@ class TaskEvaluationPydanticOutput(BaseModel):
class CrewEvaluator:
"""
A class to evaluate the performance of the agents in the crew based on the tasks they have performed.
"""A class to evaluate the performance of the agents in the crew based on the tasks they have performed.
Attributes:
crew (Crew): The crew of agents to evaluate.
eval_llm (BaseLLM): Language model instance to use for evaluations
tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task.
iteration (int): The current iteration of the evaluation.
crew: The crew of agents to evaluate.
tasks_scores: A dictionary to store the scores of the agents for each task.
run_execution_times: A dictionary to store execution times for each run.
iteration: The current iteration of the evaluation.
"""
tasks_scores: defaultdict = defaultdict(list)
run_execution_times: defaultdict = defaultdict(list)
iteration: int = 0
def __init__(self, crew, eval_llm: InstanceOf[BaseLLM]):
def __init__(self, crew: Crew, eval_llm: InstanceOf[BaseLLM]) -> None:
self.crew = crew
self.llm = eval_llm
self.tasks_scores: defaultdict[int, list[float]] = defaultdict(list)
self.run_execution_times: defaultdict[int, list[float]] = defaultdict(list)
self.iteration: int = 0
self._setup_for_evaluating()
def _setup_for_evaluating(self) -> None:
@@ -44,7 +48,7 @@ class CrewEvaluator:
for task in self.crew.tasks:
task.callback = self.evaluate
def _evaluator_agent(self):
def _evaluator_agent(self) -> Agent:
return Agent(
role="Task Execution Evaluator",
goal=(
@@ -55,8 +59,9 @@ class CrewEvaluator:
llm=self.llm,
)
@staticmethod
def _evaluation_task(
self, evaluator_agent: Agent, task_to_evaluate: Task, task_output: str
evaluator_agent: Agent, task_to_evaluate: Task, task_output: str
) -> Task:
return Task(
description=(
@@ -73,6 +78,11 @@ class CrewEvaluator:
)
def set_iteration(self, iteration: int) -> None:
"""Sets the current iteration of the evaluation.
Args:
iteration: The current iteration number.
"""
self.iteration = iteration
def print_crew_evaluation_result(self) -> None:
@@ -97,7 +107,8 @@ class CrewEvaluator:
└────────────────────┴───────┴───────┴───────┴────────────┴──────────────────────────────┘
"""
task_averages = [
sum(scores) / len(scores) for scores in zip(*self.tasks_scores.values())
sum(scores) / len(scores)
for scores in zip(*self.tasks_scores.values(), strict=False)
]
crew_average = sum(task_averages) / len(task_averages)
@@ -158,8 +169,12 @@ class CrewEvaluator:
console.print("\n")
console.print(table)
def evaluate(self, task_output: TaskOutput):
"""Evaluates the performance of the agents in the crew based on the tasks they have performed."""
def evaluate(self, task_output: TaskOutput) -> None:
"""Evaluates the performance of the agents in the crew based on the tasks they have performed.
Args:
task_output: The output of the task to evaluate.
"""
current_task = None
for task in self.crew.tasks:
if task.description == task_output.description:
@@ -179,19 +194,24 @@ class CrewEvaluator:
evaluation_result = evaluation_task.execute_sync()
if isinstance(evaluation_result.pydantic, TaskEvaluationPydanticOutput):
quality_score = evaluation_result.pydantic.quality
if quality_score is None:
raise ValueError("Evaluation quality score cannot be None")
crewai_event_bus.emit(
self.crew,
CrewTestResultEvent(
quality=evaluation_result.pydantic.quality,
quality=quality_score,
execution_duration=current_task.execution_duration,
model=self.llm.model,
crew_name=self.crew.name,
crew=self.crew,
),
)
self.tasks_scores[self.iteration].append(evaluation_result.pydantic.quality)
self.run_execution_times[self.iteration].append(
current_task.execution_duration
)
self.tasks_scores[self.iteration].append(quality_score)
if current_task.execution_duration is not None:
self.run_execution_times[self.iteration].append(
current_task.execution_duration
)
else:
raise ValueError("Evaluation result is not in the expected format")

View File

@@ -1,35 +1,42 @@
from typing import List
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from pydantic import BaseModel, Field
from crewai.utilities import Converter
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.llm import LLM
from crewai.utilities.converter import Converter
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
from crewai.utilities.training_converter import TrainingConverter
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
class Entity(BaseModel):
name: str = Field(description="The name of the entity.")
type: str = Field(description="The type of the entity.")
description: str = Field(description="Description of the entity.")
relationships: List[str] = Field(description="Relationships of the entity.")
relationships: list[str] = Field(description="Relationships of the entity.")
class TaskEvaluation(BaseModel):
suggestions: List[str] = Field(
suggestions: list[str] = Field(
description="Suggestions to improve future similar tasks."
)
quality: float = Field(
description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task."
)
entities: List[Entity] = Field(
entities: list[Entity] = Field(
description="Entities extracted from the task output."
)
class TrainingTaskEvaluation(BaseModel):
suggestions: List[str] = Field(
suggestions: list[str] = Field(
description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions."
)
quality: float = Field(
@@ -41,11 +48,35 @@ class TrainingTaskEvaluation(BaseModel):
class TaskEvaluator:
def __init__(self, original_agent):
self.llm = original_agent.llm
"""A class to evaluate the performance of an agent based on the tasks they have performed.
Attributes:
llm: The LLM to use for evaluation.
original_agent: The agent to evaluate.
"""
def __init__(self, original_agent: Agent) -> None:
"""Initializes the TaskEvaluator with the given LLM and agent.
Args:
original_agent: The agent to evaluate.
"""
self.llm = cast(LLM, original_agent.llm)
self.original_agent = original_agent
def evaluate(self, task, output) -> TaskEvaluation:
def evaluate(self, task: Task, output: str) -> TaskEvaluation:
"""
Args:
task: The task to be evaluated.
output: The output of the task.
Returns:
TaskEvaluation: The evaluation of the task.
Notes:
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
"""
crewai_event_bus.emit(
self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task)
)
@@ -73,7 +104,7 @@ class TaskEvaluator:
instructions=instructions,
)
return converter.to_pydantic()
return cast(TaskEvaluation, converter.to_pydantic())
def evaluate_training_data(
self, training_data: dict, agent_id: str
@@ -81,9 +112,12 @@ class TaskEvaluator:
"""
Evaluate the training data based on the llm output, human feedback, and improved output.
Parameters:
- training_data (dict): The training data to be evaluated.
- agent_id (str): The ID of the agent.
Args:
- training_data: The training data to be evaluated.
- agent_id: The ID of the agent.
Notes:
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
"""
crewai_event_bus.emit(
self, TaskEvaluationEvent(evaluation_type="training_data_evaluation")
@@ -142,5 +176,4 @@ class TaskEvaluator:
instructions=instructions,
)
pydantic_result = converter.to_pydantic()
return pydantic_result
return cast(TrainingTaskEvaluation, converter.to_pydantic())

View File

@@ -3,9 +3,10 @@
import warnings
from abc import ABC
from collections.abc import Callable
from typing import Any, Type, TypeVar
from typing import Any, TypeVar
from typing_extensions import deprecated
import crewai.events as new_events
from crewai.events.base_events import BaseEvent
from crewai.events.event_types import EventTypes
@@ -17,14 +18,14 @@ warnings.warn(
"Importing from 'crewai.utilities.events' is deprecated and will be removed in v1.0.0. "
"Please use 'crewai.events' instead.",
DeprecationWarning,
stacklevel=2
stacklevel=2,
)
@deprecated("Use 'from crewai.events import BaseEventListener' instead")
class BaseEventListener(new_events.BaseEventListener, ABC):
"""Deprecated: Use crewai.events.BaseEventListener instead."""
pass
@deprecated("Use 'from crewai.events import crewai_event_bus' instead")
class crewai_event_bus: # noqa: N801
@@ -32,7 +33,7 @@ class crewai_event_bus: # noqa: N801
@classmethod
def on(
cls, event_type: Type[EventT]
cls, event_type: type[EventT]
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
"""Delegate to the actual event bus instance."""
return new_events.crewai_event_bus.on(event_type)
@@ -44,7 +45,7 @@ class crewai_event_bus: # noqa: N801
@classmethod
def register_handler(
cls, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
cls, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None]
) -> None:
"""Delegate to the actual event bus instance."""
return new_events.crewai_event_bus.register_handler(event_type, handler)
@@ -54,87 +55,88 @@ class crewai_event_bus: # noqa: N801
"""Delegate to the actual event bus instance."""
return new_events.crewai_event_bus.scoped_handlers()
@deprecated("Use 'from crewai.events import CrewKickoffStartedEvent' instead")
class CrewKickoffStartedEvent(new_events.CrewKickoffStartedEvent):
"""Deprecated: Use crewai.events.CrewKickoffStartedEvent instead."""
pass
@deprecated("Use 'from crewai.events import CrewKickoffCompletedEvent' instead")
class CrewKickoffCompletedEvent(new_events.CrewKickoffCompletedEvent):
"""Deprecated: Use crewai.events.CrewKickoffCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import AgentExecutionCompletedEvent' instead")
class AgentExecutionCompletedEvent(new_events.AgentExecutionCompletedEvent):
"""Deprecated: Use crewai.events.AgentExecutionCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemoryQueryCompletedEvent' instead")
class MemoryQueryCompletedEvent(new_events.MemoryQueryCompletedEvent):
"""Deprecated: Use crewai.events.MemoryQueryCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemorySaveCompletedEvent' instead")
class MemorySaveCompletedEvent(new_events.MemorySaveCompletedEvent):
"""Deprecated: Use crewai.events.MemorySaveCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemorySaveStartedEvent' instead")
class MemorySaveStartedEvent(new_events.MemorySaveStartedEvent):
"""Deprecated: Use crewai.events.MemorySaveStartedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemoryQueryStartedEvent' instead")
class MemoryQueryStartedEvent(new_events.MemoryQueryStartedEvent):
"""Deprecated: Use crewai.events.MemoryQueryStartedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemoryRetrievalCompletedEvent' instead")
class MemoryRetrievalCompletedEvent(new_events.MemoryRetrievalCompletedEvent):
"""Deprecated: Use crewai.events.MemoryRetrievalCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemorySaveFailedEvent' instead")
class MemorySaveFailedEvent(new_events.MemorySaveFailedEvent):
"""Deprecated: Use crewai.events.MemorySaveFailedEvent instead."""
pass
@deprecated("Use 'from crewai.events import MemoryQueryFailedEvent' instead")
class MemoryQueryFailedEvent(new_events.MemoryQueryFailedEvent):
"""Deprecated: Use crewai.events.MemoryQueryFailedEvent instead."""
pass
@deprecated("Use 'from crewai.events import KnowledgeRetrievalStartedEvent' instead")
class KnowledgeRetrievalStartedEvent(new_events.KnowledgeRetrievalStartedEvent):
"""Deprecated: Use crewai.events.KnowledgeRetrievalStartedEvent instead."""
pass
@deprecated("Use 'from crewai.events import KnowledgeRetrievalCompletedEvent' instead")
class KnowledgeRetrievalCompletedEvent(new_events.KnowledgeRetrievalCompletedEvent):
"""Deprecated: Use crewai.events.KnowledgeRetrievalCompletedEvent instead."""
pass
@deprecated("Use 'from crewai.events import LLMStreamChunkEvent' instead")
class LLMStreamChunkEvent(new_events.LLMStreamChunkEvent):
"""Deprecated: Use crewai.events.LLMStreamChunkEvent instead."""
pass
__all__ = [
'BaseEventListener',
'crewai_event_bus',
'CrewKickoffStartedEvent',
'CrewKickoffCompletedEvent',
'AgentExecutionCompletedEvent',
'MemoryQueryCompletedEvent',
'MemorySaveCompletedEvent',
'MemorySaveStartedEvent',
'MemoryQueryStartedEvent',
'MemoryRetrievalCompletedEvent',
'MemorySaveFailedEvent',
'MemoryQueryFailedEvent',
'KnowledgeRetrievalStartedEvent',
'KnowledgeRetrievalCompletedEvent',
'LLMStreamChunkEvent',
"AgentExecutionCompletedEvent",
"BaseEventListener",
"CrewKickoffCompletedEvent",
"CrewKickoffStartedEvent",
"KnowledgeRetrievalCompletedEvent",
"KnowledgeRetrievalStartedEvent",
"LLMStreamChunkEvent",
"MemoryQueryCompletedEvent",
"MemoryQueryFailedEvent",
"MemoryQueryStartedEvent",
"MemoryRetrievalCompletedEvent",
"MemorySaveCompletedEvent",
"MemorySaveFailedEvent",
"MemorySaveStartedEvent",
"crewai_event_bus",
]
__deprecated__ = "Use 'crewai.events' instead of 'crewai.utilities.events'"

View File

@@ -1,6 +1,7 @@
"""Backwards compatibility stub for crewai.utilities.events.base_event_listener."""
import warnings
from crewai.events import BaseEventListener
warnings.warn(

View File

@@ -1,6 +1,7 @@
"""Backwards compatibility stub for crewai.utilities.events.crewai_event_bus."""
import warnings
from crewai.events import crewai_event_bus
warnings.warn(

View File

@@ -1,26 +1,57 @@
class LLMContextLengthExceededException(Exception):
CONTEXT_LIMIT_ERRORS = [
"expected a string with maximum length",
"maximum context length",
"context length exceeded",
"context_length_exceeded",
"context window full",
"too many tokens",
"input is too long",
"exceeds token limit",
]
from typing import Final
def __init__(self, error_message: str):
CONTEXT_LIMIT_ERRORS: Final[list[str]] = [
"expected a string with maximum length",
"maximum context length",
"context length exceeded",
"context_length_exceeded",
"context window full",
"too many tokens",
"input is too long",
"exceeds token limit",
]
class LLMContextLengthExceededError(Exception):
"""Exception raised when the context length of a language model is exceeded.
Attributes:
original_error_message: The original error message from the LLM.
"""
def __init__(self, error_message: str) -> None:
"""Initialize the exception with the original error message.
Args:
error_message: The original error message from the LLM.
"""
self.original_error_message = error_message
super().__init__(self._get_error_message(error_message))
def _is_context_limit_error(self, error_message: str) -> bool:
@staticmethod
def _is_context_limit_error(error_message: str) -> bool:
"""Check if the error message indicates a context length limit error.
Args:
error_message: The error message to check.
Returns:
True if the error message indicates a context length limit error, False otherwise.
"""
return any(
phrase.lower() in error_message.lower()
for phrase in self.CONTEXT_LIMIT_ERRORS
phrase.lower() in error_message.lower() for phrase in CONTEXT_LIMIT_ERRORS
)
def _get_error_message(self, error_message: str):
@staticmethod
def _get_error_message(error_message: str) -> str:
"""Generate a user-friendly error message based on the original error message.
Args:
error_message: The original error message from the LLM.
Returns:
A user-friendly error message.
"""
return (
f"LLM context length exceeded. Original error: {error_message}\n"
"Consider using a smaller input or implementing a text splitting strategy."

View File

@@ -2,71 +2,140 @@ import json
import os
import pickle
from datetime import datetime
from typing import Union
from typing import Any, TypedDict
from typing_extensions import Unpack
class LogEntry(TypedDict, total=False):
"""TypedDict for log entry kwargs with optional fields for flexibility."""
task_name: str
task: str
agent: str
status: str
output: str
input: str
message: str
level: str
crew: str
flow: str
tool: str
error: str
duration: float
metadata: dict[str, Any]
class FileHandler:
"""Handler for file operations supporting both JSON and text-based logging.
Args:
file_path (Union[bool, str]): Path to the log file or boolean flag
Attributes:
_path: The path to the log file.
"""
def __init__(self, file_path: Union[bool, str]):
def __init__(self, file_path: bool | str) -> None:
"""Initialize the FileHandler with the specified file path.
Args:
file_path: Path to the log file or boolean flag.
"""
self._initialize_path(file_path)
def _initialize_path(self, file_path: Union[bool, str]):
def _initialize_path(self, file_path: bool | str) -> None:
"""Initialize the file path based on the input type.
Args:
file_path: Path to the log file or boolean flag.
Raises:
ValueError: If file_path is neither a string nor a boolean.
"""
if file_path is True: # File path is boolean True
self._path = os.path.join(os.curdir, "logs.txt")
elif isinstance(file_path, str): # File path is a string
if file_path.endswith((".json", ".txt")):
self._path = file_path # No modification if the file ends with .json or .txt
self._path = (
file_path # No modification if the file ends with .json or .txt
)
else:
self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt
self._path = (
file_path + ".txt"
) # Append .txt if the file doesn't end with .json or .txt
else:
raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid
def log(self, **kwargs):
raise ValueError(
"file_path must be a string or boolean."
) # Handle the case where file_path isn't valid
def log(self, **kwargs: Unpack[LogEntry]) -> None:
"""Log data with structured fields.
Keyword Args:
task_name: Name of the task.
task: Description of the task.
agent: Name of the agent.
status: Status of the operation.
output: Output data.
input: Input data.
message: Log message.
level: Log level (e.g., INFO, ERROR).
crew: Name of the crew.
flow: Name of the flow.
tool: Name of the tool used.
error: Error message if any.
duration: Duration of the operation in seconds.
metadata: Additional metadata as a dictionary.
Raises:
ValueError: If logging fails.
"""
try:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = {"timestamp": now, **kwargs}
if self._path.endswith(".json"):
# Append log in JSON format
with open(self._path, "a", encoding="utf-8") as file:
# If the file is empty, start with a list; else, append to it
try:
# Try reading existing content to avoid overwriting
with open(self._path, "r", encoding="utf-8") as read_file:
existing_data = json.load(read_file)
existing_data.append(log_entry)
except (json.JSONDecodeError, FileNotFoundError):
# If no valid JSON or file doesn't exist, start with an empty list
existing_data = [log_entry]
with open(self._path, "w", encoding="utf-8") as write_file:
json.dump(existing_data, write_file, indent=4)
write_file.write("\n")
try:
# Try reading existing content to avoid overwriting
with open(self._path, encoding="utf-8") as read_file:
existing_data = json.load(read_file)
existing_data.append(log_entry)
except (json.JSONDecodeError, FileNotFoundError):
# If no valid JSON or file doesn't exist, start with an empty list
existing_data = [log_entry]
with open(self._path, "w", encoding="utf-8") as write_file:
json.dump(existing_data, write_file, indent=4)
write_file.write("\n")
else:
# Append log in plain text format
message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n"
message = (
f"{now}: "
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
+ "\n"
)
with open(self._path, "a", encoding="utf-8") as file:
file.write(message)
except Exception as e:
raise ValueError(f"Failed to log message: {str(e)}")
raise ValueError(f"Failed to log message: {e!s}") from e
class PickleHandler:
"""Handler for saving and loading data using pickle.
Attributes:
file_path: The path to the pickle file.
"""
def __init__(self, file_name: str) -> None:
"""
Initialize the PickleHandler with the name of the file where data will be stored.
"""Initialize the PickleHandler with the name of the file where data will be stored.
The file will be saved in the current directory.
Parameters:
- file_name (str): The name of the file for saving and loading data.
Args:
file_name: The name of the file for saving and loading data.
"""
if not file_name.endswith(".pkl"):
file_name += ".pkl"
@@ -74,34 +143,31 @@ class PickleHandler:
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
"""
Initialize the file with an empty dictionary and overwrite any existing data.
"""
"""Initialize the file with an empty dictionary and overwrite any existing data."""
self.save({})
def save(self, data) -> None:
def save(self, data: Any) -> None:
"""
Save the data to the specified file using pickle.
Parameters:
- data (object): The data to be saved.
Args:
data: The data to be saved to the file.
"""
with open(self.file_path, "wb") as file:
pickle.dump(data, file)
with open(self.file_path, "wb") as f:
pickle.dump(obj=data, file=f)
def load(self) -> dict:
"""
Load the data from the specified file using pickle.
def load(self) -> Any:
"""Load the data from the specified file using pickle.
Returns:
- dict: The data loaded from the file.
The data loaded from the file.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return {} # Return an empty dictionary if the file does not exist or is empty
with open(self.file_path, "rb") as file:
try:
return pickle.load(file) # nosec
return pickle.load(file) # noqa: S301
except EOFError:
return {} # Return an empty dictionary if the file is empty or corrupted
except Exception:

View File

@@ -1,4 +1,7 @@
from typing import TYPE_CHECKING, List, Union
from __future__ import annotations
from typing import TYPE_CHECKING, Final
from crewai.utilities.constants import _NotSpecified
if TYPE_CHECKING:
@@ -6,17 +9,31 @@ if TYPE_CHECKING:
from crewai.tasks.task_output import TaskOutput
def aggregate_raw_outputs_from_task_outputs(task_outputs: List["TaskOutput"]) -> str:
"""Generate string context from the task outputs."""
dividers = "\n\n----------\n\n"
# Join task outputs with dividers
context = dividers.join(output.raw for output in task_outputs)
return context
DIVIDERS: Final[str] = "\n\n----------\n\n"
def aggregate_raw_outputs_from_tasks(tasks: Union[List["Task"],_NotSpecified]) -> str:
"""Generate string context from the tasks."""
def aggregate_raw_outputs_from_task_outputs(task_outputs: list[TaskOutput]) -> str:
"""Generate string context from the task outputs.
Args:
task_outputs: List of TaskOutput objects.
Returns:
A string containing the aggregated raw outputs from the task outputs.
"""
return DIVIDERS.join(output.raw for output in task_outputs)
def aggregate_raw_outputs_from_tasks(tasks: list[Task] | _NotSpecified) -> str:
"""Generate string context from the tasks.
Args:
tasks: List of Task objects or _NotSpecified.
Returns:
A string containing the aggregated raw outputs from the tasks.
"""
task_outputs = (
[task.output for task in tasks if task.output is not None]

View File

@@ -1,7 +1,14 @@
from collections.abc import Callable
from typing import Any
from __future__ import annotations
from pydantic import BaseModel, field_validator
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Self
if TYPE_CHECKING:
from crewai.lite_agent import LiteAgentOutput
from crewai.tasks.task_output import TaskOutput
class GuardrailResult(BaseModel):
@@ -12,18 +19,31 @@ class GuardrailResult(BaseModel):
be easily handled by the task execution system.
Attributes:
success (bool): Whether the guardrail validation passed
result (Any, optional): The validated/transformed result if successful
error (str, optional): Error message if validation failed
success: Whether the guardrail validation passed
result: The validated/transformed result if successful
error: Error message if validation failed
"""
success: bool
result: Any | None = None
error: str | None = None
success: bool = Field(description="Whether the guardrail validation passed")
result: Any | None = Field(
default=None, description="The validated/transformed result if successful"
)
error: str | None = Field(
default=None, description="Error message if validation failed"
)
@field_validator("result", "error")
@classmethod
def validate_result_error_exclusivity(cls, v: Any, info) -> Any:
"""Ensure that result and error are mutually exclusive based on success.
Args:
v: The value being validated (either result or error)
info: Validation info containing the entire model data
Returns:
The original value if validation passes
"""
values = info.data
if "success" in values:
if values["success"] and v and "error" in values and values["error"]:
@@ -37,15 +57,14 @@ class GuardrailResult(BaseModel):
return v
@classmethod
def from_tuple(cls, result: tuple[bool, Any | str]) -> "GuardrailResult":
def from_tuple(cls, result: tuple[bool, Any | str]) -> Self:
"""Create a GuardrailResult from a validation tuple.
Args:
result: A tuple of (success, data) where data is either
the validated result or error message.
result: A tuple of (success, data) where data is either the validated result or error message.
Returns:
GuardrailResult: A new instance with the tuple data.
A new instance with the tuple data.
"""
success, data = result
return cls(
@@ -56,7 +75,10 @@ class GuardrailResult(BaseModel):
def process_guardrail(
output: Any, guardrail: Callable, retry_count: int, event_source: Any | None = None
output: TaskOutput | LiteAgentOutput,
guardrail: Callable[[Any], tuple[bool, Any | str]],
retry_count: int,
event_source: Any | None = None,
) -> GuardrailResult:
"""Process the guardrail for the agent output.
@@ -68,7 +90,19 @@ def process_guardrail(
Returns:
GuardrailResult: The result of the guardrail validation
Raises:
TypeError: If output is not a TaskOutput or LiteAgentOutput
ValueError: If guardrail is None
"""
from crewai.lite_agent import LiteAgentOutput
from crewai.tasks.task_output import TaskOutput
if not isinstance(output, (TaskOutput, LiteAgentOutput)):
raise TypeError("Output must be a TaskOutput or LiteAgentOutput")
if guardrail is None:
raise ValueError("Guardrail must not be None")
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent,

View File

@@ -1,36 +1,51 @@
import json
import os
from typing import Dict, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, model_validator
"""Internationalization support for CrewAI prompts and messages."""
import json
import os
from typing import Literal
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
class I18N(BaseModel):
"""Handles loading and retrieving internationalized prompts."""
_prompts: Dict[str, Dict[str, str]] = PrivateAttr()
prompt_file: Optional[str] = Field(
"""Handles loading and retrieving internationalized prompts.
Attributes:
_prompts: Internal dictionary storing loaded prompts.
prompt_file: Optional path to a custom JSON file containing prompts.
"""
_prompts: dict[str, dict[str, str]] = PrivateAttr()
prompt_file: str | None = Field(
default=None,
description="Path to the prompt_file file to load",
)
@model_validator(mode="after")
def load_prompts(self) -> "I18N":
"""Load prompts from a JSON file."""
def load_prompts(self) -> Self:
"""Load prompts from a JSON file.
Returns:
The I18N instance with loaded prompts.
Raises:
Exception: If the prompt file is not found or cannot be decoded.
"""
try:
if self.prompt_file:
with open(self.prompt_file, "r", encoding="utf-8") as f:
with open(self.prompt_file, encoding="utf-8") as f:
self._prompts = json.load(f)
else:
dir_path = os.path.dirname(os.path.realpath(__file__))
prompts_path = os.path.join(dir_path, "../translations/en.json")
with open(prompts_path, "r", encoding="utf-8") as f:
with open(prompts_path, encoding="utf-8") as f:
self._prompts = json.load(f)
except FileNotFoundError:
raise Exception(f"Prompt file '{self.prompt_file}' not found.")
except json.JSONDecodeError:
raise Exception("Error decoding JSON from the prompts file.")
except FileNotFoundError as e:
raise Exception(f"Prompt file '{self.prompt_file}' not found.") from e
except json.JSONDecodeError as e:
raise Exception("Error decoding JSON from the prompts file.") from e
if not self._prompts:
self._prompts = {}
@@ -38,16 +53,58 @@ class I18N(BaseModel):
return self
def slice(self, slice: str) -> str:
"""Retrieve a prompt slice by key.
Args:
slice: The key of the prompt slice to retrieve.
Returns:
The prompt slice as a string.
"""
return self.retrieve("slices", slice)
def errors(self, error: str) -> str:
"""Retrieve an error message by key.
Args:
error: The key of the error message to retrieve.
Returns:
The error message as a string.
"""
return self.retrieve("errors", error)
def tools(self, tool: str) -> Union[str, Dict[str, str]]:
def tools(self, tool: str) -> str | dict[str, str]:
"""Retrieve a tool prompt by key.
Args:
tool: The key of the tool prompt to retrieve.
Returns:
The tool prompt as a string or dictionary.
"""
return self.retrieve("tools", tool)
def retrieve(self, kind, key) -> str:
def retrieve(
self,
kind: Literal[
"slices", "errors", "tools", "reasoning", "hierarchical_manager_agent"
],
key: str,
) -> str:
"""Retrieve a prompt by kind and key.
Args:
kind: The kind of prompt.
key: The key of the specific prompt to retrieve.
Returns:
The prompt as a string.
Raises:
Exception: If the prompt for the given kind and key is not found.
"""
try:
return self._prompts[kind][key]
except Exception as _:
raise Exception(f"Prompt for '{kind}':'{key}' not found.")
except Exception as e:
raise Exception(f"Prompt for '{kind}':'{key}' not found.") from e

View File

@@ -7,8 +7,6 @@ from types import ModuleType
class OptionalDependencyError(ImportError):
"""Exception raised when an optional dependency is not installed."""
pass
def require(name: str, *, purpose: str) -> ModuleType:
"""Import a module, raising a helpful error if it's not installed.

View File

@@ -1,43 +1,98 @@
import warnings
from typing import Any, Optional, Type
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, TypeGuard, TypeVar
from pydantic import BaseModel
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.logger_utils import suppress_warnings
from crewai.utilities.types import LLMMessage
T = TypeVar("T", bound=BaseModel)
class InternalInstructor:
"""Class that wraps an agent llm with instructor."""
def _is_valid_llm(llm: Any) -> TypeGuard[str | LLM | BaseLLM]:
"""Type guard to ensure LLM is valid and not None.
Args:
llm: The LLM to validate
Returns:
True if LLM is valid (string or has model attribute), False otherwise
"""
return llm is not None and (isinstance(llm, str) or hasattr(llm, "model"))
class InternalInstructor(Generic[T]):
"""Class that wraps an agent LLM with instructor for structured output generation.
Attributes:
content: The content to be processed
model: The Pydantic model class for the response
agent: The agent with LLM
llm: The LLM instance or model name
"""
def __init__(
self,
content: str,
model: Type,
agent: Optional[Any] = None,
llm: Optional[str] = None,
):
model: type[T],
agent: Agent | None = None,
llm: LLM | BaseLLM | str | None = None,
) -> None:
"""Initialize InternalInstructor.
Args:
content: The content to be processed
model: The Pydantic model class for the response
agent: The agent with LLM
llm: The LLM instance or model name
"""
self.content = content
self.agent = agent
self.llm = llm
self.model = model
self._client = None
self.set_instructor()
self.llm = llm or (agent.function_calling_llm or agent.llm if agent else None)
def set_instructor(self):
"""Set instructor."""
if self.agent and not self.llm:
self.llm = self.agent.function_calling_llm or self.agent.llm
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
with suppress_warnings():
import instructor
from litellm import completion
self._client = instructor.from_litellm(completion)
def to_json(self):
model = self.to_pydantic()
return model.model_dump_json(indent=2)
def to_json(self) -> str:
"""Convert the structured output to JSON format.
def to_pydantic(self):
messages = [{"role": "user", "content": self.content}]
model = self._client.chat.completions.create(
model=self.llm.model, response_model=self.model, messages=messages
Returns:
JSON string representation of the structured output
"""
pydantic_model = self.to_pydantic()
return pydantic_model.model_dump_json(indent=2)
def to_pydantic(self) -> T:
"""Generate structured output using the specified Pydantic model.
Returns:
Instance of the specified Pydantic model with structured data
Raises:
ValueError: If LLM is not provided or invalid
"""
messages: list[LLMMessage] = [{"role": "user", "content": self.content}]
if not _is_valid_llm(self.llm):
raise ValueError(
"LLM must be provided and have a model attribute or be a string"
)
if isinstance(self.llm, str):
model_name = self.llm
else:
model_name = self.llm.model
return self._client.chat.completions.create(
model=model_name, response_model=self.model, messages=messages
)
return model

View File

@@ -1,62 +1,55 @@
import logging
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Final
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
logger = logging.getLogger(__name__)
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[LLM | BaseLLM]:
"""
Creates or returns an LLM instance based on the given llm_value.
llm_value: str | LLM | Any | None = None,
) -> LLM | BaseLLM | None:
"""Creates or returns an LLM instance based on the given llm_value.
Args:
llm_value (str | BaseLLM | Any | None):
- str: The model name (e.g., "gpt-4").
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
llm_value: LLM instance, model name string, None, or an object with LLM attributes.
Returns:
A BaseLLM instance if successful, or None if something fails.
"""
# 1) If llm_value is already a BaseLLM or LLM object, return it directly
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
if isinstance(llm_value, (LLM, BaseLLM)):
return llm_value
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str):
try:
created_llm = LLM(model=llm_value)
return created_llm
return LLM(model=llm_value)
except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
logger.debug(f"Failed to instantiate LLM with model='{llm_value}': {e}")
return None
# 3) If llm_value is None, parse environment variables or use default
if llm_value is None:
return _llm_via_environment_or_fallback()
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
try:
# Extract attributes with explicit types
model = (
getattr(llm_value, "model", None)
or getattr(llm_value, "model_name", None)
or getattr(llm_value, "deployment_name", None)
or str(llm_value)
)
temperature: Optional[float] = getattr(llm_value, "temperature", None)
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None)
logprobs: Optional[int] = getattr(llm_value, "logprobs", None)
timeout: Optional[float] = getattr(llm_value, "timeout", None)
api_key: Optional[str] = getattr(llm_value, "api_key", None)
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
temperature: float | None = getattr(llm_value, "temperature", None)
max_tokens: int | None = getattr(llm_value, "max_tokens", None)
logprobs: int | None = getattr(llm_value, "logprobs", None)
timeout: float | None = getattr(llm_value, "timeout", None)
api_key: str | None = getattr(llm_value, "api_key", None)
base_url: str | None = getattr(llm_value, "base_url", None)
api_base: str | None = getattr(llm_value, "api_base", None)
created_llm = LLM(
return LLM(
model=model,
temperature=temperature,
max_tokens=max_tokens,
@@ -66,15 +59,23 @@ def create_llm(
base_url=base_url,
api_base=api_base,
)
return created_llm
except Exception as e:
print(f"Error instantiating LLM from unknown object type: {e}")
logger.debug(f"Error instantiating LLM from unknown object type: {e}")
return None
def _llm_via_environment_or_fallback() -> Optional[LLM]:
"""
Helper function: if llm_value is None, we load environment variables or fallback default model.
UNACCEPTED_ATTRIBUTES: Final[list[str]] = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]
def _llm_via_environment_or_fallback() -> LLM | None:
"""Creates an LLM instance based on environment variables or defaults.
Returns:
A BaseLLM instance if successful, or None if something fails.
"""
model_name = (
os.environ.get("MODEL")
@@ -83,28 +84,25 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
or DEFAULT_LLM_MODEL
)
# Initialize parameters with correct types
model: str = model_name
temperature: Optional[float] = None
max_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
logprobs: Optional[int] = None
timeout: Optional[float] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
logit_bias: Optional[Dict[int, float]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
top_logprobs: Optional[int] = None
callbacks: List[Any] = []
temperature: float | None = None
max_tokens: int | None = None
max_completion_tokens: int | None = None
logprobs: int | None = None
timeout: float | None = None
api_key: str | None = None
api_version: str | None = None
presence_penalty: float | None = None
frequency_penalty: float | None = None
top_p: float | None = None
n: int | None = None
stop: str | list[str] | None = None
logit_bias: dict[int, float] | None = None
response_format: dict[str, Any] | None = None
seed: int | None = None
top_logprobs: int | None = None
callbacks: list[Any] = []
# Optional base URL from env
base_url = (
os.environ.get("BASE_URL")
or os.environ.get("OPENAI_API_BASE")
@@ -119,8 +117,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
elif api_base and not base_url:
base_url = api_base
# Initialize llm_params dictionary
llm_params: Dict[str, Any] = {
llm_params: dict[str, Any] = {
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
@@ -143,11 +140,6 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
"callbacks": callbacks,
}
UNACCEPTED_ATTRIBUTES = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]
set_provider = model_name.partition("/")[0] if "/" in model_name else "openai"
if set_provider in ENV_VARS:
@@ -167,28 +159,26 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
if key not in ["prompt", "key_name", "default"]:
llm_params[key.lower()] = value
else:
print(
logger.debug(
f"Expected env_var to be a dictionary, but got {type(env_var)}"
)
# Remove None values
llm_params = {k: v for k, v in llm_params.items() if v is not None}
# Try creating the LLM
try:
new_llm = LLM(**llm_params)
return new_llm
return LLM(**llm_params)
except Exception as e:
print(
logger.debug(
f"Error instantiating LLM from environment/fallback: {type(e).__name__}: {e}"
)
return None
def _normalize_key_name(key_name: str) -> str:
"""
Maps environment variable names to recognized litellm parameter keys,
using patterns from LITELLM_PARAMS.
"""Maps environment variable names to recognized litellm parameter keys.
Args:
key_name: The environment variable name to normalize.
"""
for pattern in LITELLM_PARAMS:
if pattern in key_name:

View File

@@ -2,19 +2,34 @@ from datetime import datetime
from pydantic import BaseModel, Field, PrivateAttr
from crewai.utilities.printer import Printer
from crewai.utilities.printer import ColoredText, Printer, PrinterColor
class Logger(BaseModel):
verbose: bool = Field(default=False)
verbose: bool = Field(
default=False,
description="Enables verbose logging with timestamps",
)
default_color: PrinterColor = Field(
default="bold_yellow",
description="Default color for log messages",
)
_printer: Printer = PrivateAttr(default_factory=Printer)
default_color: str = Field(default="bold_yellow")
def log(self, level, message, color=None):
if color is None:
color = self.default_color
def log(self, level: str, message: str, color: PrinterColor | None = None) -> None:
"""Log a message with timestamp if verbose mode is enabled.
Args:
level: The log level (e.g., 'info', 'warning', 'error').
message: The message to log.
color: Optional color for the message. Defaults to default_color.
"""
if self.verbose:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self._printer.print(
f"\n[{timestamp}][{level.upper()}]: {message}", color=color
[
ColoredText(f"\n[{timestamp}]", "cyan"),
ColoredText(f"[{level.upper()}]: ", "yellow"),
ColoredText(message, color or self.default_color),
]
)

View File

@@ -47,10 +47,12 @@ def suppress_warnings() -> Generator[None, None, None]:
None during the context execution.
Note:
There is a similar implementation in src/crewai/llm.py that also
suppresses a specific deprecation warning. That version may be
consolidated here in the future.
This implementation consolidates warning suppression used throughout
the codebase, including specific deprecation warnings from dependencies.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
warnings.filterwarnings(
"ignore", message="open_text is deprecated*", category=DeprecationWarning
)
yield

View File

@@ -1,31 +0,0 @@
import re
class YamlParser:
@staticmethod
def parse(file):
"""
Parses a YAML file, modifies specific patterns, and checks for unsupported 'context' usage.
Args:
file (file object): The YAML file to parse.
Returns:
str: The modified content of the YAML file.
Raises:
ValueError: If 'context:' is used incorrectly.
"""
content = file.read()
# Replace single { and } with doubled ones, while leaving already doubled ones intact and the other special characters {# and {%
modified_content = re.sub(r"(?<!\{){(?!\{)(?!\#)(?!\%)", "{{", content)
modified_content = re.sub(
r"(?<!\})(?<!\%)(?<!\#)\}(?!})", "}}", modified_content
)
# Check for 'context:' not followed by '[' and raise an error
if re.search(r"context:(?!\s*\[)", modified_content):
raise ValueError(
"Context is currently only supported in code when creating a task. "
"Please use the 'context' key in the task configuration."
)
return modified_content

View File

@@ -1,9 +1,10 @@
"""Path management utilities for CrewAI storage and configuration."""
import os
from pathlib import Path
import appdirs
"""Path management utilities for CrewAI storage and configuration."""
def db_storage_path() -> str:
"""Returns the path for SQLite database storage.
@@ -19,13 +20,6 @@ def db_storage_path() -> str:
return str(data_dir)
def get_project_directory_name():
def get_project_directory_name() -> str:
"""Returns the current project directory name."""
project_directory_name = os.environ.get("CREWAI_STORAGE_DIR")
if project_directory_name:
return project_directory_name
else:
cwd = Path.cwd()
project_directory_name = cwd.name
return project_directory_name
return os.environ.get("CREWAI_STORAGE_DIR", Path.cwd().name)

View File

@@ -1,16 +1,19 @@
"""Handles planning and coordination of crew tasks."""
import logging
from typing import Any, List, Optional
from pydantic import BaseModel, Field
from crewai.agent import Agent
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
"""Handles planning and coordination of crew tasks."""
logger = logging.getLogger(__name__)
class PlanPerTask(BaseModel):
"""Represents a plan for a specific task."""
task: str = Field(..., description="The task for which the plan is created")
plan: str = Field(
...,
@@ -20,28 +23,48 @@ class PlanPerTask(BaseModel):
class PlannerTaskPydanticOutput(BaseModel):
"""Output format for task planning results."""
list_of_plans_per_task: List[PlanPerTask] = Field(
list_of_plans_per_task: list[PlanPerTask] = Field(
...,
description="Step by step plan on how the agents can execute their tasks using the available tools with mastery",
)
class CrewPlanner:
"""Plans and coordinates the execution of crew tasks."""
def __init__(self, tasks: List[Task], planning_agent_llm: Optional[Any] = None):
self.tasks = tasks
"""Plans and coordinates the execution of crew tasks.
if planning_agent_llm is None:
self.planning_agent_llm = "gpt-4o-mini"
else:
self.planning_agent_llm = planning_agent_llm
Attributes:
tasks: List of tasks to be planned.
planning_agent_llm: Optional LLM model for the planning agent.
"""
def __init__(
self, tasks: list[Task], planning_agent_llm: str | BaseLLM | None = None
) -> None:
"""Initialize CrewPlanner with tasks and optional planning agent LLM.
Args:
tasks: List of tasks to be planned.
planning_agent_llm: Optional LLM model for the planning agent. Defaults to None.
"""
self.tasks = tasks
self.planning_agent_llm = planning_agent_llm or "gpt-4o-mini"
def _handle_crew_planning(self) -> PlannerTaskPydanticOutput:
"""Handles the Crew planning by creating detailed step-by-step plans for each task."""
"""Handles the Crew planning by creating detailed step-by-step plans for each task.
Returns:
A PlannerTaskPydanticOutput containing the detailed plans for each task.
Raises:
ValueError: If the planning output cannot be obtained.
"""
planning_agent = self._create_planning_agent()
tasks_summary = self._create_tasks_summary()
planner_task = self._create_planner_task(planning_agent, tasks_summary)
planner_task = self._create_planner_task(
planning_agent=planning_agent, tasks_summary=tasks_summary
)
result = planner_task.execute_sync()
@@ -51,7 +74,11 @@ class CrewPlanner:
raise ValueError("Failed to get the Planning output")
def _create_planning_agent(self) -> Agent:
"""Creates the planning agent for the crew planning."""
"""Creates the planning agent for the crew planning.
Returns:
An Agent instance configured for planning tasks.
"""
return Agent(
role="Task Execution Planner",
goal=(
@@ -62,8 +89,17 @@ class CrewPlanner:
llm=self.planning_agent_llm,
)
def _create_planner_task(self, planning_agent: Agent, tasks_summary: str) -> Task:
"""Creates the planner task using the given agent and tasks summary."""
@staticmethod
def _create_planner_task(planning_agent: Agent, tasks_summary: str) -> Task:
"""Creates the planner task using the given agent and tasks summary.
Args:
planning_agent: The agent responsible for planning.
tasks_summary: A summary of all tasks to be included in the planning.
Returns:
A Task instance configured for planning.
"""
return Task(
description=(
f"Based on these tasks summary: {tasks_summary} \n Create the most descriptive plan based on the tasks "
@@ -74,31 +110,42 @@ class CrewPlanner:
output_pydantic=PlannerTaskPydanticOutput,
)
def _get_agent_knowledge(self, task: Task) -> List[str]:
"""
Safely retrieve knowledge source content from the task's agent.
@staticmethod
def _get_agent_knowledge(task: Task) -> list[str]:
"""Safely retrieve knowledge source content from the task's agent.
Args:
task: The task containing an agent with potential knowledge sources
Returns:
List[str]: A list of knowledge source strings
A list of knowledge source strings
"""
try:
if task.agent and task.agent.knowledge_sources:
return [source.content for source in task.agent.knowledge_sources]
return [
getattr(source, "content", str(source))
for source in task.agent.knowledge_sources
]
except AttributeError:
logger.warning("Error accessing agent knowledge sources")
return []
def _create_tasks_summary(self) -> str:
"""Creates a summary of all tasks."""
"""Creates a summary of all tasks.
Returns:
A string summarizing all tasks with their details.
"""
tasks_summary = []
for idx, task in enumerate(self.tasks):
knowledge_list = self._get_agent_knowledge(task)
agent_tools = (
f"[{', '.join(str(tool) for tool in task.agent.tools)}]" if task.agent and task.agent.tools else '"agent has no tools"',
f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' if knowledge_list and str(knowledge_list) != "None" else ""
f"[{', '.join(str(tool) for tool in task.agent.tools)}]"
if task.agent and task.agent.tools
else '"agent has no tools"',
f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"'
if knowledge_list and str(knowledge_list) != "None"
else "",
)
task_summary = f"""
Task Number {idx + 1} - {task.description}

View File

@@ -1,71 +1,72 @@
"""Utility for colored console output."""
from typing import Optional
from typing import Final, Literal, NamedTuple
PrinterColor = Literal[
"purple",
"bold_purple",
"green",
"bold_green",
"cyan",
"bold_cyan",
"magenta",
"bold_magenta",
"yellow",
"bold_yellow",
"red",
"blue",
"bold_blue",
]
_COLOR_CODES: Final[dict[PrinterColor, str]] = {
"purple": "\033[95m",
"bold_purple": "\033[1m\033[95m",
"red": "\033[91m",
"bold_green": "\033[1m\033[92m",
"green": "\033[32m",
"blue": "\033[94m",
"bold_blue": "\033[1m\033[94m",
"yellow": "\033[93m",
"bold_yellow": "\033[1m\033[93m",
"cyan": "\033[96m",
"bold_cyan": "\033[1m\033[96m",
"magenta": "\033[35m",
"bold_magenta": "\033[1m\033[35m",
}
RESET: Final[str] = "\033[0m"
class ColoredText(NamedTuple):
"""Represents text with an optional color for console output.
Attributes:
text: The text content to be printed.
color: Optional color for the text, specified as a PrinterColor.
"""
text: str
color: PrinterColor | None
class Printer:
"""Handles colored console output formatting."""
def print(self, content: str, color: Optional[str] = None):
if color == "purple":
self._print_purple(content)
elif color == "red":
self._print_red(content)
elif color == "bold_green":
self._print_bold_green(content)
elif color == "bold_purple":
self._print_bold_purple(content)
elif color == "bold_blue":
self._print_bold_blue(content)
elif color == "yellow":
self._print_yellow(content)
elif color == "bold_yellow":
self._print_bold_yellow(content)
elif color == "cyan":
self._print_cyan(content)
elif color == "bold_cyan":
self._print_bold_cyan(content)
elif color == "magenta":
self._print_magenta(content)
elif color == "bold_magenta":
self._print_bold_magenta(content)
elif color == "green":
self._print_green(content)
else:
print(content)
@staticmethod
def print(
content: str | list[ColoredText], color: PrinterColor | None = None
) -> None:
"""Prints content to the console with optional color formatting.
def _print_bold_purple(self, content):
print("\033[1m\033[95m {}\033[00m".format(content))
def _print_bold_green(self, content):
print("\033[1m\033[92m {}\033[00m".format(content))
def _print_purple(self, content):
print("\033[95m {}\033[00m".format(content))
def _print_red(self, content):
print("\033[91m {}\033[00m".format(content))
def _print_bold_blue(self, content):
print("\033[1m\033[94m {}\033[00m".format(content))
def _print_yellow(self, content):
print("\033[93m {}\033[00m".format(content))
def _print_bold_yellow(self, content):
print("\033[1m\033[93m {}\033[00m".format(content))
def _print_cyan(self, content):
print("\033[96m {}\033[00m".format(content))
def _print_bold_cyan(self, content):
print("\033[1m\033[96m {}\033[00m".format(content))
def _print_magenta(self, content):
print("\033[35m {}\033[00m".format(content))
def _print_bold_magenta(self, content):
print("\033[1m\033[35m {}\033[00m".format(content))
def _print_green(self, content):
print("\033[32m {}\033[00m".format(content))
Args:
content: Either a string or a list of ColoredText objects for multicolor output.
color: Optional color for the text when content is a string. Ignored when content is a list.
"""
if isinstance(content, str):
content = [ColoredText(content, color)]
print(
"".join(
f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}"
for c in content
)
)

View File

@@ -1,29 +1,59 @@
from typing import Any, Optional
from __future__ import annotations
from typing import Any, TypedDict
from pydantic import BaseModel, Field
from crewai.utilities import I18N
from crewai.utilities.i18n import I18N
class StandardPromptResult(TypedDict):
"""Result with only prompt field for standard mode."""
prompt: str
class SystemPromptResult(StandardPromptResult):
"""Result with system, user, and prompt fields for system prompt mode."""
system: str
user: str
class Prompts(BaseModel):
"""Manages and generates prompts for a generic agent."""
i18n: I18N = Field(default=I18N())
has_tools: bool = False
system_template: Optional[str] = None
prompt_template: Optional[str] = None
response_template: Optional[str] = None
use_system_prompt: Optional[bool] = False
agent: Any
i18n: I18N = Field(default_factory=I18N)
has_tools: bool = Field(
default=False, description="Indicates if the agent has access to tools"
)
system_template: str | None = Field(
default=None, description="Custom system prompt template"
)
prompt_template: str | None = Field(
default=None, description="Custom user prompt template"
)
response_template: str | None = Field(
default=None, description="Custom response prompt template"
)
use_system_prompt: bool | None = Field(
default=False,
description="Whether to use the system prompt when no custom templates are provided",
)
agent: Any = Field(description="Reference to the agent using these prompts")
def task_execution(self) -> dict[str, str]:
"""Generate a standard prompt for task execution."""
slices = ["role_playing"]
def task_execution(self) -> SystemPromptResult | StandardPromptResult:
"""Generate a standard prompt for task execution.
Returns:
A dictionary containing the constructed prompt(s).
"""
slices: list[str] = ["role_playing"]
if self.has_tools:
slices.append("tools")
else:
slices.append("no_tools")
system = self._build_prompt(slices)
system: str = self._build_prompt(slices)
slices.append("task")
if (
@@ -31,54 +61,67 @@ class Prompts(BaseModel):
and not self.prompt_template
and self.use_system_prompt
):
return {
"system": system,
"user": self._build_prompt(["task"]),
"prompt": self._build_prompt(slices),
}
else:
return {
"prompt": self._build_prompt(
slices,
self.system_template,
self.prompt_template,
self.response_template,
)
}
return SystemPromptResult(
system=system,
user=self._build_prompt(["task"]),
prompt=self._build_prompt(slices),
)
return StandardPromptResult(
prompt=self._build_prompt(
slices,
self.system_template,
self.prompt_template,
self.response_template,
)
)
def _build_prompt(
self,
components: list[str],
system_template=None,
prompt_template=None,
response_template=None,
system_template: str | None = None,
prompt_template: str | None = None,
response_template: str | None = None,
) -> str:
"""Constructs a prompt string from specified components."""
"""Constructs a prompt string from specified components.
Args:
components: List of component names to include in the prompt.
system_template: Optional custom template for the system prompt.
prompt_template: Optional custom template for the user prompt.
response_template: Optional custom template for the response prompt.
Returns:
The constructed prompt string.
"""
prompt: str
if not system_template or not prompt_template:
# If any of the required templates are missing, fall back to the default format
prompt_parts = [self.i18n.slice(component) for component in components]
prompt_parts: list[str] = [
self.i18n.slice(component) for component in components
]
prompt = "".join(prompt_parts)
else:
# All templates are provided, use them
prompt_parts = [
template_parts: list[str] = [
self.i18n.slice(component)
for component in components
if component != "task"
]
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
system: str = system_template.replace(
"{{ .System }}", "".join(template_parts)
)
prompt = prompt_template.replace(
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
)
# Handle missing response_template
if response_template:
response = response_template.split("{{ .Response }}")[0]
response: str = response_template.split("{{ .Response }}")[0]
prompt = f"{system}\n{prompt}\n{response}"
else:
prompt = f"{system}\n{prompt}"
prompt = (
return (
prompt.replace("{goal}", self.agent.goal)
.replace("{role}", self.agent.role)
.replace("{backstory}", self.agent.backstory)
)
return prompt

View File

@@ -1,57 +1,62 @@
from typing import Dict, List, Type, Union, get_args, get_origin
from typing import Any, Union, get_args, get_origin
from pydantic import BaseModel
from pydantic import BaseModel, Field
class PydanticSchemaParser(BaseModel):
model: Type[BaseModel]
model: type[BaseModel] = Field(..., description="The Pydantic model to parse.")
def get_schema(self) -> str:
"""
Public method to get the schema of a Pydantic model.
"""Public method to get the schema of a Pydantic model.
:return: String representation of the model schema.
Returns:
String representation of the model schema.
"""
return "{\n" + self._get_model_schema(self.model) + "\n}"
def _get_model_schema(self, model: Type[BaseModel], depth: int = 0) -> str:
indent = " " * 4 * depth
lines = [
f"{indent} {field_name}: {self._get_field_type(field, depth + 1)}"
def _get_model_schema(self, model: type[BaseModel], depth: int = 0) -> str:
"""Recursively get the schema of a Pydantic model, handling nested models and lists.
Args:
model: The Pydantic model to process.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the model schema.
"""
indent: str = " " * 4 * depth
lines: list[str] = [
f"{indent} {field_name}: {self._get_field_type_for_annotation(field.annotation, depth + 1)}"
for field_name, field in model.model_fields.items()
]
return ",\n".join(lines)
def _get_field_type(self, field, depth: int) -> str:
field_type = field.annotation
origin = get_origin(field_type)
def _format_list_type(self, list_item_type: Any, depth: int) -> str:
"""Format a List type, handling nested models if necessary.
if origin in {list, List}:
list_item_type = get_args(field_type)[0]
return self._format_list_type(list_item_type, depth)
Args:
list_item_type: The type of items in the list.
depth: The current depth of recursion for indentation purposes.
if origin in {dict, Dict}:
key_type, value_type = get_args(field_type)
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
if origin is Union:
return self._format_union_type(field_type, depth)
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
nested_schema = self._get_model_schema(field_type, depth)
nested_indent = " " * 4 * depth
return f"{field_type.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
return field_type.__name__
def _format_list_type(self, list_item_type, depth: int) -> str:
Returns:
A string representation of the List type.
"""
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
nested_schema = self._get_model_schema(list_item_type, depth + 1)
nested_indent = " " * 4 * (depth)
nested_indent = " " * 4 * depth
return f"List[\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}\n{nested_indent}]"
return f"List[{list_item_type.__name__}]"
def _format_union_type(self, field_type, depth: int) -> str:
def _format_union_type(self, field_type: Any, depth: int) -> str:
"""Format a Union type, handling Optional and nested types.
Args:
field_type: The Union type to format.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the Union type.
"""
args = get_args(field_type)
if type(None) in args:
# It's an Optional type
@@ -61,26 +66,32 @@ class PydanticSchemaParser(BaseModel):
non_none_args[0], depth
)
return f"Optional[{inner_type}]"
else:
# Union with None and multiple other types
inner_types = ", ".join(
self._get_field_type_for_annotation(arg, depth)
for arg in non_none_args
)
return f"Optional[Union[{inner_types}]]"
else:
# General Union type
# Union with None and multiple other types
inner_types = ", ".join(
self._get_field_type_for_annotation(arg, depth) for arg in args
self._get_field_type_for_annotation(arg, depth) for arg in non_none_args
)
return f"Union[{inner_types}]"
return f"Optional[Union[{inner_types}]]"
# General Union type
inner_types = ", ".join(
self._get_field_type_for_annotation(arg, depth) for arg in args
)
return f"Union[{inner_types}]"
def _get_field_type_for_annotation(self, annotation, depth: int) -> str:
origin = get_origin(annotation)
if origin in {list, List}:
def _get_field_type_for_annotation(self, annotation: Any, depth: int) -> str:
"""Recursively get the string representation of a field's type annotation.
Args:
annotation: The type annotation to process.
depth: The current depth of recursion for indentation purposes.
Returns:
A string representation of the type annotation.
"""
origin: Any = get_origin(annotation)
if origin is list:
list_item_type = get_args(annotation)[0]
return self._format_list_type(list_item_type, depth)
if origin in {dict, Dict}:
if origin is dict:
key_type, value_type = get_args(annotation)
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
if origin is Union:

View File

@@ -1,19 +1,19 @@
import logging
import json
from typing import Tuple, cast
import logging
from typing import Any, Final, Literal, cast
from pydantic import BaseModel, Field
from crewai.agent import Agent
from crewai.task import Task
from crewai.utilities import I18N
from crewai.llm import LLM
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from crewai.llm import LLM
from crewai.task import Task
from crewai.utilities.i18n import I18N
class ReasoningPlan(BaseModel):
@@ -29,22 +29,49 @@ class AgentReasoningOutput(BaseModel):
plan: ReasoningPlan = Field(description="The reasoning plan for the task.")
class ReasoningFunction(BaseModel):
"""Model for function calling with reasoning."""
plan: str = Field(description="The detailed reasoning plan for the task.")
ready: bool = Field(description="Whether the agent is ready to execute the task.")
FUNCTION_SCHEMA: Final[dict[str, Any]] = {
"type": "function",
"function": {
"name": "create_reasoning_plan",
"description": "Create or refine a reasoning plan for a task",
"parameters": {
"type": "object",
"properties": {
"plan": {
"type": "string",
"description": "The detailed reasoning plan for the task.",
},
"ready": {
"type": "boolean",
"description": "Whether the agent is ready to execute the task.",
},
},
"required": ["plan", "ready"],
},
},
}
class AgentReasoning:
"""
Handles the agent reasoning process, enabling an agent to reflect and create a plan
before executing a task.
Attributes:
task: The task for which the agent is reasoning.
agent: The agent performing the reasoning.
llm: The language model used for reasoning.
logger: Logger for logging events and errors.
i18n: Internationalization utility for retrieving prompts.
"""
def __init__(self, task: Task, agent: Agent):
if not task or not agent:
raise ValueError("Both task and agent must be provided.")
def __init__(self, task: Task, agent: Agent) -> None:
"""Initialize the AgentReasoning with a task and an agent.
Args:
task: The task for which the agent is reasoning.
agent: The agent performing the reasoning.
"""
self.task = task
self.agent = agent
self.llm = cast(LLM, agent.llm)
@@ -52,9 +79,7 @@ class AgentReasoning:
self.i18n = I18N()
def handle_agent_reasoning(self) -> AgentReasoningOutput:
"""
Public method for the reasoning process that creates and refines a plan
for the task until the agent is ready to execute it.
"""Public method for the reasoning process that creates and refines a plan for the task until the agent is ready to execute it.
Returns:
AgentReasoningOutput: The output of the agent reasoning process.
@@ -70,7 +95,7 @@ class AgentReasoning:
from_task=self.task,
),
)
except Exception:
except Exception: # noqa: S110
# Ignore event bus errors to avoid breaking execution
pass
@@ -90,7 +115,7 @@ class AgentReasoning:
from_task=self.task,
),
)
except Exception:
except Exception: # noqa: S110
pass
return output
@@ -107,17 +132,16 @@ class AgentReasoning:
from_task=self.task,
),
)
except Exception:
except Exception: # noqa: S110
pass
raise
def __handle_agent_reasoning(self) -> AgentReasoningOutput:
"""
Private method that handles the agent reasoning process.
"""Private method that handles the agent reasoning process.
Returns:
AgentReasoningOutput: The output of the agent reasoning process.
The output of the agent reasoning process.
"""
plan, ready = self.__create_initial_plan()
@@ -126,46 +150,38 @@ class AgentReasoning:
reasoning_plan = ReasoningPlan(plan=plan, ready=ready)
return AgentReasoningOutput(plan=reasoning_plan)
def __create_initial_plan(self) -> Tuple[str, bool]:
"""
Creates the initial reasoning plan for the task.
def __create_initial_plan(self) -> tuple[str, bool]:
"""Creates the initial reasoning plan for the task.
Returns:
Tuple[str, bool]: The initial plan and whether the agent is ready to execute the task.
The initial plan and whether the agent is ready to execute the task.
"""
reasoning_prompt = self.__create_reasoning_prompt()
if self.llm.supports_function_calling():
plan, ready = self.__call_with_function(reasoning_prompt, "initial_plan")
return plan, ready
else:
system_prompt = self.i18n.retrieve("reasoning", "initial_plan").format(
role=self.agent.role,
goal=self.agent.goal,
backstory=self.__get_agent_backstory(),
)
response = _call_llm_with_reasoning_prompt(
llm=self.llm,
prompt=reasoning_prompt,
task=self.task,
agent=self.agent,
i18n=self.i18n,
backstory=self.__get_agent_backstory(),
plan_type="initial_plan",
)
response = self.llm.call(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": reasoning_prompt},
],
from_task=self.task,
from_agent=self.agent,
)
return self.__parse_reasoning_response(str(response))
return self.__parse_reasoning_response(str(response))
def __refine_plan_if_needed(self, plan: str, ready: bool) -> Tuple[str, bool]:
"""
Refines the reasoning plan if the agent is not ready to execute the task.
def __refine_plan_if_needed(self, plan: str, ready: bool) -> tuple[str, bool]:
"""Refines the reasoning plan if the agent is not ready to execute the task.
Args:
plan: The current reasoning plan.
ready: Whether the agent is ready to execute the task.
Returns:
Tuple[str, bool]: The refined plan and whether the agent is ready to execute the task.
The refined plan and whether the agent is ready to execute the task.
"""
attempt = 1
max_attempts = self.agent.max_reasoning_attempts
@@ -182,7 +198,7 @@ class AgentReasoning:
from_task=self.task,
),
)
except Exception:
except Exception: # noqa: S110
pass
refine_prompt = self.__create_refine_prompt(plan)
@@ -190,19 +206,14 @@ class AgentReasoning:
if self.llm.supports_function_calling():
plan, ready = self.__call_with_function(refine_prompt, "refine_plan")
else:
system_prompt = self.i18n.retrieve("reasoning", "refine_plan").format(
role=self.agent.role,
goal=self.agent.goal,
response = _call_llm_with_reasoning_prompt(
llm=self.llm,
prompt=refine_prompt,
task=self.task,
agent=self.agent,
i18n=self.i18n,
backstory=self.__get_agent_backstory(),
)
response = self.llm.call(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": refine_prompt},
],
from_task=self.task,
from_agent=self.agent,
plan_type="refine_plan",
)
plan, ready = self.__parse_reasoning_response(str(response))
@@ -216,41 +227,18 @@ class AgentReasoning:
return plan, ready
def __call_with_function(self, prompt: str, prompt_type: str) -> Tuple[str, bool]:
"""
Calls the LLM with function calling to get a reasoning plan.
def __call_with_function(self, prompt: str, prompt_type: str) -> tuple[str, bool]:
"""Calls the LLM with function calling to get a reasoning plan.
Args:
prompt: The prompt to send to the LLM.
prompt_type: The type of prompt (initial_plan or refine_plan).
Returns:
Tuple[str, bool]: A tuple containing the plan and whether the agent is ready.
A tuple containing the plan and whether the agent is ready.
"""
self.logger.debug(f"Using function calling for {prompt_type} reasoning")
function_schema = {
"type": "function",
"function": {
"name": "create_reasoning_plan",
"description": "Create or refine a reasoning plan for a task",
"parameters": {
"type": "object",
"properties": {
"plan": {
"type": "string",
"description": "The detailed reasoning plan for the task.",
},
"ready": {
"type": "boolean",
"description": "Whether the agent is ready to execute the task.",
},
},
"required": ["plan", "ready"],
},
},
}
try:
system_prompt = self.i18n.retrieve("reasoning", prompt_type).format(
role=self.agent.role,
@@ -259,7 +247,7 @@ class AgentReasoning:
)
# Prepare a simple callable that just returns the tool arguments as JSON
def _create_reasoning_plan(plan: str, ready: bool = True): # noqa: N802
def _create_reasoning_plan(plan: str, ready: bool = True):
"""Return the reasoning plan result in JSON string form."""
return json.dumps({"plan": plan, "ready": ready})
@@ -268,7 +256,7 @@ class AgentReasoning:
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
tools=[function_schema],
tools=[FUNCTION_SCHEMA],
available_functions={"create_reasoning_plan": _create_reasoning_plan},
from_task=self.task,
from_agent=self.agent,
@@ -291,7 +279,7 @@ class AgentReasoning:
except Exception as e:
self.logger.warning(
f"Error during function calling: {str(e)}. Falling back to text parsing."
f"Error during function calling: {e!s}. Falling back to text parsing."
)
try:
@@ -316,7 +304,7 @@ class AgentReasoning:
"READY: I am ready to execute the task." in fallback_str,
)
except Exception as inner_e:
self.logger.error(f"Error during fallback text parsing: {str(inner_e)}")
self.logger.error(f"Error during fallback text parsing: {inner_e!s}")
return (
"Failed to generate a plan due to an error.",
True,
@@ -378,7 +366,8 @@ class AgentReasoning:
current_plan=current_plan,
)
def __parse_reasoning_response(self, response: str) -> Tuple[str, bool]:
@staticmethod
def __parse_reasoning_response(response: str) -> tuple[str, bool]:
"""
Parses the reasoning response to extract the plan and whether
the agent is ready to execute the task.
@@ -387,7 +376,7 @@ class AgentReasoning:
response: The LLM response.
Returns:
Tuple[str, bool]: The plan and whether the agent is ready to execute the task.
The plan and whether the agent is ready to execute the task.
"""
if not response:
return "No plan was generated.", False
@@ -412,3 +401,43 @@ class AgentReasoning:
"The _handle_agent_reasoning method is deprecated. Use handle_agent_reasoning instead."
)
return self.handle_agent_reasoning()
def _call_llm_with_reasoning_prompt(
llm: LLM,
prompt: str,
task: Task,
agent: Agent,
i18n: I18N,
backstory: str,
plan_type: Literal["initial_plan", "refine_plan"],
) -> str:
"""Calls the LLM with the reasoning prompt.
Args:
llm: The language model to use.
prompt: The prompt to send to the LLM.
task: The task for which the agent is reasoning.
agent: The agent performing the reasoning.
i18n: Internationalization utility for retrieving prompts.
backstory: The agent's backstory.
plan_type: The type of plan being created ("initial_plan" or "refine_plan").
Returns:
The LLM response.
"""
system_prompt = i18n.retrieve("reasoning", plan_type).format(
role=agent.role,
goal=agent.goal,
backstory=backstory,
)
response = llm.call(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
from_task=task,
from_agent=agent,
)
return str(response)

View File

@@ -1,41 +1,54 @@
"""Controls request rate limiting for API calls."""
import threading
import time
from typing import Optional
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.utilities.logger import Logger
"""Controls request rate limiting for API calls."""
class RPMController(BaseModel):
"""Manages requests per minute limiting."""
max_rpm: Optional[int] = Field(default=None)
max_rpm: int | None = Field(
default=None,
description="Maximum requests per minute. If None, no limit is applied.",
)
logger: Logger = Field(default_factory=lambda: Logger(verbose=False))
_current_rpm: int = PrivateAttr(default=0)
_timer: Optional[threading.Timer] = PrivateAttr(default=None)
_lock: Optional[threading.Lock] = PrivateAttr(default=None)
_timer: "threading.Timer | None" = PrivateAttr(default=None)
_lock: "threading.Lock | None" = PrivateAttr(default=None)
_shutdown_flag: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def reset_counter(self):
def reset_counter(self) -> Self:
"""Resets the RPM counter and starts the timer if max_rpm is set.
Returns:
The instance of the RPMController.
"""
if self.max_rpm is not None:
if not self._shutdown_flag:
self._lock = threading.Lock()
self._reset_request_count()
return self
def check_or_wait(self):
def check_or_wait(self) -> bool:
"""Checks if a new request can be made based on the RPM limit.
Returns:
True if a new request can be made, False otherwise.
"""
if self.max_rpm is None:
return True
def _check_and_increment():
def _check_and_increment() -> bool:
if self.max_rpm is not None and self._current_rpm < self.max_rpm:
self._current_rpm += 1
return True
elif self.max_rpm is not None:
if self.max_rpm is not None:
self.logger.log(
"info", "Max RPM reached, waiting for next minute to start."
)
@@ -50,16 +63,18 @@ class RPMController(BaseModel):
else:
return _check_and_increment()
def stop_rpm_counter(self):
def stop_rpm_counter(self) -> None:
"""Stops the RPM counter and cancels any active timers."""
self._shutdown_flag = True
if self._timer:
self._timer.cancel()
self._timer = None
def _wait_for_next_minute(self):
def _wait_for_next_minute(self) -> None:
time.sleep(60)
self._current_rpm = 0
def _reset_request_count(self):
def _reset_request_count(self) -> None:
def _reset():
self._current_rpm = 0
if not self._shutdown_flag:
@@ -71,7 +86,3 @@ class RPMController(BaseModel):
_reset()
else:
_reset()
if self._timer:
self._shutdown_flag = True
self._timer.cancel()

View File

@@ -1,14 +1,16 @@
from __future__ import annotations
import json
import uuid
from datetime import date, datetime
from typing import Any, Dict, List, Union
from typing import Any, TypeAlias
from pydantic import BaseModel
SerializablePrimitive = Union[str, int, float, bool, None]
Serializable = Union[
SerializablePrimitive, List["Serializable"], Dict[str, "Serializable"]
]
SerializablePrimitive: TypeAlias = str | int | float | bool | None
Serializable: TypeAlias = (
SerializablePrimitive | list["Serializable"] | dict[str, "Serializable"]
)
def to_serializable(
@@ -24,9 +26,10 @@ def to_serializable(
Non-convertible objects default to their string representations.
Args:
obj (Any): Object to transform.
exclude (set[str], optional): Set of keys to exclude from the result.
max_depth (int, optional): Maximum recursion depth. Defaults to 5.
obj: Object to transform.
exclude: Set of keys to exclude from the result.
max_depth: Maximum recursion depth. Defaults to 5.
_current_depth: Current recursion depth (for internal use).
Returns:
Serializable: A JSON-compatible structure.
@@ -39,18 +42,18 @@ def to_serializable(
if isinstance(obj, (str, int, float, bool, type(None))):
return obj
elif isinstance(obj, uuid.UUID):
if isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, (date, datetime)):
if isinstance(obj, (date, datetime)):
return obj.isoformat()
elif isinstance(obj, (list, tuple, set)):
if isinstance(obj, (list, tuple, set)):
return [
to_serializable(
item, max_depth=max_depth, _current_depth=_current_depth + 1
)
for item in obj
]
elif isinstance(obj, dict):
if isinstance(obj, dict):
return {
_to_serializable_key(key): to_serializable(
obj=value,
@@ -61,33 +64,31 @@ def to_serializable(
for key, value in obj.items()
if key not in exclude
}
elif isinstance(obj, BaseModel):
if isinstance(obj, BaseModel):
return to_serializable(
obj=obj.model_dump(exclude=exclude),
max_depth=max_depth,
_current_depth=_current_depth + 1,
)
else:
return repr(obj)
return repr(obj)
def _to_serializable_key(key: Any) -> str:
if isinstance(key, (str, int)):
return str(key)
return f"key_{id(key)}_{repr(key)}"
return f"key_{id(key)}_{key!r}"
def to_string(obj: Any) -> str | None:
"""Serializes an object into a JSON string.
Args:
obj (Any): Object to serialize.
obj: Object to serialize.
Returns:
str | None: A JSON-formatted string or `None` if empty.
A JSON-formatted string or `None` if empty.
"""
serializable = to_serializable(obj)
if serializable is None:
return None
else:
return json.dumps(serializable)
return json.dumps(serializable)

View File

@@ -1,10 +1,12 @@
import re
from typing import Any, Dict, List, Optional, Union
from typing import Any, Final
_VARIABLE_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
def interpolate_only(
input_string: Optional[str],
inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]],
input_string: str | None,
inputs: dict[str, str | int | float | dict[str, Any] | list[Any]],
) -> str:
"""Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched.
Only interpolates placeholders that follow the pattern {variable_name} where
@@ -26,26 +28,30 @@ def interpolate_only(
"""
# Validation function for recursive type checking
def validate_type(value: Any) -> None:
if value is None:
def _validate_type(validate_value: Any) -> None:
if validate_value is None:
return
if isinstance(value, (str, int, float, bool)):
if isinstance(validate_value, (str, int, float, bool)):
return
if isinstance(value, (dict, list)):
for item in value.values() if isinstance(value, dict) else value:
validate_type(item)
if isinstance(validate_value, (dict, list)):
for item in (
validate_value.values()
if isinstance(validate_value, dict)
else validate_value
):
_validate_type(item)
return
raise ValueError(
f"Unsupported type {type(value).__name__} in inputs. "
f"Unsupported type {type(validate_value).__name__} in inputs. "
"Only str, int, float, bool, dict, and list are allowed."
)
# Validate all input values
for key, value in inputs.items():
try:
validate_type(value)
except ValueError as e:
raise ValueError(f"Invalid value for key '{key}': {str(e)}") from e
_validate_type(value)
except ValueError as e: # noqa: PERF203
raise ValueError(f"Invalid value for key '{key}': {e!s}") from e
if input_string is None or not input_string:
return ""
@@ -56,13 +62,7 @@ def interpolate_only(
"Inputs dictionary cannot be empty when interpolating variables"
)
# The regex pattern to find valid variable placeholders
# Matches {variable_name} where variable_name starts with a letter/underscore
# and contains only letters, numbers, and underscores
pattern = r"\{([A-Za-z_][A-Za-z0-9_\-]*)\}"
# Find all matching variables in the input string
variables = re.findall(pattern, input_string)
variables = _VARIABLE_PATTERN.findall(input_string)
result = input_string
# Check if all variables exist in inputs

View File

@@ -4,50 +4,14 @@ This module provides functionality for storing and retrieving task outputs
from persistent storage, supporting replay and audit capabilities.
"""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
from crewai.task import Task
class ExecutionLog(BaseModel):
"""Represents a log entry for task execution.
Attributes:
task_id: Unique identifier for the task.
expected_output: The expected output description for the task.
output: The actual output produced by the task.
timestamp: When the task was executed.
task_index: The position of the task in the execution sequence.
inputs: Input parameters provided to the task.
was_replayed: Whether this output was replayed from a previous run.
"""
task_id: str
expected_output: str | None = None
output: dict[str, Any]
timestamp: datetime = Field(default_factory=datetime.now)
task_index: int
inputs: dict[str, Any] = Field(default_factory=dict)
was_replayed: bool = False
def __getitem__(self, key: str) -> Any:
"""Enable dictionary-style access to execution log attributes.
Args:
key: The attribute name to access.
Returns:
The value of the requested attribute.
"""
return getattr(self, key)
class TaskOutputStorageHandler:
"""Manages storage and retrieval of task outputs.

View File

@@ -4,13 +4,13 @@ This module provides a callback handler that tracks token usage
for LLM API calls through the litellm library.
"""
import warnings
from typing import Any
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import Usage
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.utilities.logger_utils import suppress_warnings
class TokenCalcHandler(CustomLogger):
@@ -23,12 +23,13 @@ class TokenCalcHandler(CustomLogger):
token_cost_process: The token process tracker to accumulate usage metrics.
"""
def __init__(self, token_cost_process: TokenProcess | None) -> None:
def __init__(self, token_cost_process: TokenProcess | None, **kwargs: Any) -> None:
"""Initialize the token calculation handler.
Args:
token_cost_process: Optional token process tracker for accumulating metrics.
"""
super().__init__(**kwargs)
self.token_cost_process = token_cost_process
def log_success_event(
@@ -49,8 +50,7 @@ class TokenCalcHandler(CustomLogger):
if self.token_cost_process is None:
return
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
with suppress_warnings():
if isinstance(response_obj, dict) and "usage" in response_obj:
usage: Usage = response_obj["usage"]
if usage:

View File

@@ -1,12 +1,21 @@
from typing import Any
from __future__ import annotations
from typing import TYPE_CHECKING
from crewai.agents.parser import AgentAction
from crewai.security import Fingerprint
from crewai.agents.tools_handler import ToolsHandler
from crewai.security.fingerprint import Fingerprint
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
from crewai.tools.tool_usage import ToolUsage, ToolUsageError
from crewai.utilities.i18n import I18N
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
def execute_tool_and_check_finality(
agent_action: AgentAction,
@@ -14,10 +23,10 @@ def execute_tool_and_check_finality(
i18n: I18N,
agent_key: str | None = None,
agent_role: str | None = None,
tools_handler: Any | None = None,
task: Any | None = None,
agent: Any | None = None,
function_calling_llm: Any | None = None,
tools_handler: ToolsHandler | None = None,
task: Task | None = None,
agent: Agent | None = None,
function_calling_llm: BaseLLM | LLM | None = None,
fingerprint_context: dict[str, str] | None = None,
) -> ToolResult:
"""Execute a tool and check if the result should be treated as a final answer.
@@ -32,59 +41,54 @@ def execute_tool_and_check_finality(
task: Optional task for tool execution
agent: Optional agent instance for tool execution
function_calling_llm: Optional LLM for function calling
fingerprint_context: Optional context for fingerprinting
Returns:
ToolResult containing the execution result and whether it should be treated as a final answer
"""
try:
tool_name_to_tool_map = {tool.name: tool for tool in tools}
tool_name_to_tool_map = {tool.name: tool for tool in tools}
if agent_key and agent_role and agent:
fingerprint_context = fingerprint_context or {}
if agent:
if hasattr(agent, "set_fingerprint") and callable(
agent.set_fingerprint
):
if isinstance(fingerprint_context, dict):
try:
fingerprint_obj = Fingerprint.from_dict(fingerprint_context)
agent.set_fingerprint(fingerprint_obj)
except Exception as e:
raise ValueError(f"Failed to set fingerprint: {e}") from e
if agent_key and agent_role and agent:
fingerprint_context = fingerprint_context or {}
if agent:
if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint):
if isinstance(fingerprint_context, dict):
try:
fingerprint_obj = Fingerprint.from_dict(fingerprint_context)
agent.set_fingerprint(fingerprint=fingerprint_obj)
except Exception as e:
raise ValueError(f"Failed to set fingerprint: {e}") from e
# Create tool usage instance
tool_usage = ToolUsage(
tools_handler=tools_handler,
tools=tools,
function_calling_llm=function_calling_llm,
task=task,
agent=agent,
action=agent_action,
)
# Create tool usage instance
tool_usage = ToolUsage(
tools_handler=tools_handler,
tools=tools,
function_calling_llm=function_calling_llm,
task=task,
agent=agent,
action=agent_action,
)
# Parse tool calling
tool_calling = tool_usage.parse_tool_calling(agent_action.text)
# Parse tool calling
tool_calling = tool_usage.parse_tool_calling(agent_action.text)
if isinstance(tool_calling, ToolUsageError):
return ToolResult(tool_calling.message, False)
if isinstance(tool_calling, ToolUsageError):
return ToolResult(tool_calling.message, False)
# Check if tool name matches
if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in tool_name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in tool_name_to_tool_map
]:
tool_result = tool_usage.use(tool_calling, agent_action.text)
tool = tool_name_to_tool_map.get(tool_calling.tool_name)
if tool:
return ToolResult(tool_result, tool.result_as_answer)
# Check if tool name matches
if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in tool_name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in tool_name_to_tool_map
]:
tool_result = tool_usage.use(tool_calling, agent_action.text)
tool = tool_name_to_tool_map.get(tool_calling.tool_name)
if tool:
return ToolResult(tool_result, tool.result_as_answer)
# Handle invalid tool name
tool_result = i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in tools]),
)
return ToolResult(tool_result, False)
except Exception as e:
raise e
# Handle invalid tool name
tool_result = i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in tools]),
)
return ToolResult(result=tool_result, result_as_answer=False)

View File

@@ -1,42 +1,72 @@
import json
import re
from typing import Any, get_origin
from typing import Any, Final, get_origin
from pydantic import BaseModel, ValidationError
from crewai.utilities.converter import Converter, ConverterError
_FLOAT_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\d+(?:\.\d+)?)")
class TrainingConverter(Converter):
"""
A specialized converter for smaller LLMs (up to 7B parameters) that handles validation errors
"""A specialized converter for smaller LLMs (up to 7B parameters) that handles validation errors
by breaking down the model into individual fields and querying the LLM for each field separately.
"""
def to_pydantic(self, current_attempt=1) -> BaseModel:
def to_pydantic(self, current_attempt: int = 1) -> BaseModel:
"""Convert the text to a Pydantic model, with fallback to field-by-field extraction on failure.
Args:
current_attempt: The current attempt number for conversion.
Returns:
An instance of the Pydantic model.
Raises:
ConverterError: If conversion fails after field-by-field extraction.
"""
try:
return super().to_pydantic(current_attempt)
except ConverterError:
return self._convert_field_by_field()
def _convert_field_by_field(self) -> BaseModel:
field_values = {}
field_values: dict[str, Any] = {}
for field_name, field_info in self.model.model_fields.items():
field_description = field_info.description
field_type = field_info.annotation
field_description: str | None = field_info.description
field_type: type | None = field_info.annotation
response = self._ask_llm_for_field(field_name, field_description)
value = self._process_field_value(response, field_type)
if field_description is None:
raise ValueError(f"Field '{field_name}' has no description")
response: str = self._ask_llm_for_field(
field_name=field_name, field_description=field_description
)
value: Any = self._process_field_value(
response=response, field_type=field_type
)
field_values[field_name] = value
try:
return self.model(**field_values)
except ValidationError as e:
raise ConverterError(f"Failed to create model from individually collected fields: {e}")
raise ConverterError(
f"Failed to create model from individually collected fields: {e}"
) from e
def _ask_llm_for_field(self, field_name: str, field_description: str) -> str:
prompt = f"""
"""Query the LLM for a specific field value based on its description.
Args:
field_name: The name of the field to extract.
field_description: The description of the field to guide extraction.
Returns:
The LLM's response containing the field value.
"""
prompt: str = f"""
Based on the following information:
{self.text}
@@ -45,14 +75,19 @@ Please provide ONLY the {field_name} field value as described:
Respond with ONLY the requested information, nothing else.
"""
return self.llm.call([
{"role": "system", "content": f"Extract the {field_name} from the previous information."},
{"role": "user", "content": prompt}
])
return self.llm.call(
[
{
"role": "system",
"content": f"Extract the {field_name} from the previous information.",
},
{"role": "user", "content": prompt},
]
)
def _process_field_value(self, response: str, field_type: Any) -> Any:
def _process_field_value(self, response: str, field_type: type | None) -> Any:
response = response.strip()
origin = get_origin(field_type)
origin: type[Any] | None = get_origin(field_type)
if origin is list:
return self._parse_list(response)
@@ -65,25 +100,45 @@ Respond with ONLY the requested information, nothing else.
return response
def _parse_list(self, response: str) -> list:
def _parse_list(self, response: str) -> list[Any]:
try:
if response.startswith('['):
if response.startswith("["):
return json.loads(response)
items = [item.strip() for item in response.split('\n') if item.strip()]
items: list[str] = [
item.strip() for item in response.split("\n") if item.strip()
]
return [self._strip_bullet(item) for item in items]
except json.JSONDecodeError:
return [response]
def _parse_float(self, response: str) -> float:
@staticmethod
def _parse_float(response: str) -> float:
"""Parse a float from the response, extracting the first numeric value found.
Args:
response: The response string from which to extract the float.
Returns:
The extracted float value, or 0.0 if no valid float is found.
"""
try:
match = re.search(r'(\d+(\.\d+)?)', response)
match = _FLOAT_PATTERN.search(response)
return float(match.group(1)) if match else 0.0
except Exception:
except (ValueError, AttributeError):
return 0.0
def _strip_bullet(self, item: str) -> str:
if item.startswith(('- ', '* ')):
@staticmethod
def _strip_bullet(item: str) -> str:
"""Remove common bullet point characters from the start of a string.
Args:
item: The string item to process.
Returns:
The string without leading bullet characters.
"""
if item.startswith(("- ", "* ")):
return item[2:].strip()
return item.strip()
return item.strip()

View File

@@ -1,35 +1,33 @@
import os
from typing import Any
from crewai.utilities.file_handler import PickleHandler
class CrewTrainingHandler(PickleHandler):
def save_trained_data(self, agent_id: str, trained_data: dict) -> None:
"""
Save the trained data for a specific agent.
def save_trained_data(self, agent_id: str, trained_data: dict[int, Any]) -> None:
"""Save the trained data for a specific agent.
Parameters:
- agent_id (str): The ID of the agent.
- trained_data (dict): The trained data to be saved.
Args:
agent_id: The ID of the agent.
trained_data: The trained data to be saved.
"""
data = self.load()
data[agent_id] = trained_data
self.save(data)
def append(self, train_iteration: int, agent_id: str, new_data) -> None:
"""
Append new data to the existing pickle file.
def append(self, train_iteration: int, agent_id: str, new_data: Any) -> None:
"""Append new training data for a specific agent and iteration.
Parameters:
- new_data (object): The new data to be appended.
Args:
train_iteration: The training iteration number.
agent_id: The ID of the agent.
new_data: The new training data to append.
"""
data = self.load()
if agent_id in data:
data[agent_id][train_iteration] = new_data
else:
data[agent_id] = {train_iteration: new_data}
if agent_id not in data:
data[agent_id] = {}
data[agent_id][train_iteration] = new_data
self.save(data)
def clear(self) -> None:

View File

@@ -0,0 +1,15 @@
"""Types for CrewAI utilities."""
from typing import Literal, TypedDict
class LLMMessage(TypedDict):
"""Type for formatted LLM messages.
Notes:
- TODO: Update the LLM.call & BaseLLM.call signatures to use this type
instead of str | list[dict[str, str]]
"""
role: Literal["user", "assistant", "system"]
content: str