Compare commits

...

36 Commits

Author SHA1 Message Date
Greyson LaLonde
af84ba2272 Merge branch 'main' into gl/fix/cache-handler-types-and-imports 2025-09-05 10:03:41 -04:00
Greyson LaLonde
e93d597721 fix: add type annotations to contextual_memory.py 2025-09-05 09:57:02 -04:00
Greyson LaLonde
a414e7f2a7 fix: update test files to use 'limit' instead of 'latest_n' and fix metadata in MemorySaveCompletedEvent 2025-09-05 09:22:21 -04:00
Greyson LaLonde
fbcd8bcd83 fix: update contextual_memory to use 'limit' instead of 'latest_n' 2025-09-05 09:09:33 -04:00
Greyson LaLonde
5f776bbb0a fix: update import to use crewai.llms.base_llm 2025-09-05 08:58:43 -04:00
Greyson LaLonde
909b2fd0ef fix: use create_default_llm when llm is None in BaseEvaluator 2025-09-05 08:57:58 -04:00
Greyson LaLonde
929f9dadb4 fix: remove unnecessary wraps parameter in test patch 2025-09-04 23:15:20 -04:00
Greyson LaLonde
4ef4632a8c fix: update return type annotations in OpenAIAgentAdapter 2025-09-04 23:08:56 -04:00
Greyson LaLonde
c246df3cb2 fix: add type annotations to converter instance variables 2025-09-04 23:02:13 -04:00
Greyson LaLonde
4fd40e7857 fix: add missing super call in LangGraphConverterAdapter 2025-09-04 22:58:18 -04:00
Greyson LaLonde
25204c6cb8 fix: add type annotations to structured output converter 2025-09-04 22:53:32 -04:00
Greyson LaLonde
b44776c367 fix: resolve mypy type errors across agent adapters and core modules 2025-09-04 22:47:18 -04:00
Greyson LaLonde
843801f554 fix: make task required in CrewAgentExecutor and fix all type annotations
- Make task parameter required in CrewAgentExecutor.__init__
- Update Agent.create_agent_executor to require task parameter
- Handle cases where crew can be None (standalone agent usage)
- Update base class signatures to match
- Remove unnecessary create_agent_executor calls during setup
- Add missing type annotations in base_agent_executor_mixin
- Fix all type errors in base_agent.py using Self return type
- Add assert for agent_executor before use
- Fix crew access checks to handle None case
2025-09-04 22:13:46 -04:00
Greyson LaLonde
2faa13ddcb refactor: improve type annotations and simplify code in CrewAgentExecutor 2025-09-04 17:07:02 -04:00
Greyson LaLonde
e385b45667 fix: update cache test assertions for JSON serialization 2025-09-04 16:10:46 -04:00
Greyson LaLonde
f03567d463 Merge branch 'main' into gl/fix/cache-handler-types-and-imports 2025-09-04 16:01:26 -04:00
Greyson LaLonde
e9f4ac070b chore: Relax mypy to not run on tests dir for now 2025-09-04 15:57:17 -04:00
Greyson LaLonde
bcee792390 fix: resolve mypy errors in storage and tracing modules 2025-09-04 15:39:01 -04:00
Greyson LaLonde
221bfcccce refactor: consolidate ChromaDB response extraction logic 2025-09-04 15:21:48 -04:00
Greyson LaLonde
4812986f58 fix: resolve mypy type annotation issues in storage and telemetry modules
- Add proper type parameters for EmbeddingFunction generics
- Fix ChromaDB query response handling with proper type checking
- Add missing return type annotations to telemetry methods
- Fix trace listener type annotations and imports
- Handle potential None values in nested list indexing
- Improve type safety in RAG and knowledge storage modules
2025-09-04 14:58:28 -04:00
Greyson LaLonde
23c60befd8 fix: resolve additional mypy type annotation issues
- Fixed rag_storage.py embedder type compatibility and query response handling
- Fixed knowledge_storage.py dict type parameters and return types
- Added comprehensive type annotations to telemetry.py methods
- Added type annotations to trace_listener.py event handlers and methods
- Fixed ChromaDB response indexing safety checks
2025-09-04 13:23:57 -04:00
Greyson LaLonde
8dd3493e9c fix: resolve additional mypy type annotation issues
- Fixed file_handler.py PickleHandler type annotations
- Fixed task_events.py None checks before accessing task.fingerprint
- Added type annotations to memory_listener.py event handlers
2025-09-04 13:07:37 -04:00
Greyson LaLonde
9306d889a7 fix: resolve remaining mypy type annotation issues
- Applied proper decorator typing with ParamSpec and typing_extensions.Self
- Fixed event bus decorator to preserve type information
- Added type annotations to BaseEventListener and TraceCollectionListener
- Fixed LongTermMemory.search to handle None return from storage.load
- Resolved all type errors tracked in strict mode
2025-09-04 13:00:11 -04:00
Greyson LaLonde
8354cdf061 fix: add missing type annotations to fix mypy strict mode errors
Added type annotations to 10 files to resolve mypy type checking errors:
- Added return type annotations to methods missing them
- Added parameter type annotations where missing
- Fixed Optional type hints to be explicit
- Removed redundant type cast in crew.py
- Changed _execute_with_timeout return type from str to Any in agent.py

Additional type errors remain in other files throughout the codebase.
2025-09-04 11:41:57 -04:00
Greyson LaLonde
2ba48dd82a fix: add type annotations and exclude tests from mypy
- Add type: ignore for mem0 import
- Fix tool_usage.py cache_function None check
- Change _execute_without_timeout return type to Any
- Add type annotations to multiple functions:
  - add_sources() -> None
  - log() with proper parameter types
  - stop_rpm_counter() -> None
  - EventListener.__new__() -> Self
  - setup_listeners() -> None
  - Memory class __init__ methods -> None
  - TaskEvaluator.__init__() -> None
  - get_skipped_task_output() -> TaskOutput
- Exclude tests directory from mypy checks in pyproject.toml
- Update deprecated typing imports to use built-in types
2025-09-04 11:11:59 -04:00
Greyson LaLonde
0bab041531 fix: resolve remaining mypy type errors
- Fix tool_usage.py: rename result variable to avoid redefinition
- Fix lite_agent.py: import TaskOutput from correct module and add type casts
- Add explicit type annotation for data dict in tool_usage.py
2025-09-04 10:40:33 -04:00
Greyson LaLonde
eed2ffde5f fix: resolve additional mypy type errors
- Fix tool_usage.py: proper type annotations for result and fingerprint metadata
- Fix lite_agent.py: proper Union type for guardrail callable accepting both LiteAgentOutput and TaskOutput
- Add missing return type annotations to task_output_storage_handler.py methods
- Fix crew.py: replace Json generic check with str, remove unused type:ignore and redundant cast
2025-09-03 23:23:36 -04:00
Greyson LaLonde
b6e7311d2d fix: update cache tests to use input_data parameter name
The CacheHandler methods use 'input_data' not 'input' as the parameter name
2025-09-03 23:09:51 -04:00
Greyson LaLonde
90ca02b9dc fix: address mypy type errors in multiple files
- Fix return type and argument handling in cache_tools.py
- Add missing return statements in agent.py
- Fix _inject_date_to_task signature to accept Task object
- Remove unused type:ignore comments in tool_usage.py
- Add type annotations to internal methods in mem0_storage.py
2025-09-03 23:05:07 -04:00
Greyson LaLonde
06d5c3f170 fix: update remaining deprecated type annotations in tests 2025-09-03 22:40:05 -04:00
Greyson LaLonde
b94fbd3d3a fix: improve type annotations across codebase 2025-09-03 22:29:41 -04:00
Greyson LaLonde
43880b49a6 Merge branch 'main' into gl/fix/cache-handler-types-and-imports 2025-09-03 21:15:07 -04:00
Greyson LaLonde
bdfc38ba32 refactor: update CacheHandler imports to use direct path
- Update imports from crewai.agents.cache to crewai.agents.cache.cache_handler
- Remove CacheHandler from agents module __all__ export
2025-09-03 18:18:05 -04:00
Greyson LaLonde
94029017c3 refactor: remove __all__ from internal cache module
- Remove __all__ export as this is an internal module
- Add module docstring describing package purpose
2025-09-03 18:17:19 -04:00
Greyson LaLonde
89df777887 refactor: use absolute imports in parser module
- Import I18N directly from utilities.i18n
2025-09-03 18:16:03 -04:00
Greyson LaLonde
d1fbf24d9e fix: add type annotations to CacheHandler methods
- Replace Optional with union syntax
- Rename input parameter to input_data to avoid shadowing
- Add JSON serialization for dict cache keys
- Add thread-safety TODO note
2025-09-03 18:15:46 -04:00
45 changed files with 1692 additions and 1263 deletions

View File

@@ -11,4 +11,6 @@ repos:
rev: v1.17.1
hooks:
- id: mypy
args: ["--config-file", "pyproject.toml"]
args: ["--strict", "--exclude", "src/crewai/cli/templates"]
files: ^src/
exclude: ^tests/

View File

@@ -1,29 +1,50 @@
import shutil
import subprocess
import time
from collections.abc import Callable, Sequence
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from pydantic import (
BeforeValidator,
Field,
InstanceOf,
PrivateAttr,
computed_field,
field_validator,
model_validator,
)
from typing_extensions import Self
from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
)
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.lite_agent import LiteAgent, LiteAgentOutput
from crewai.llm import BaseLLM
from crewai.llms.base_llm import BaseLLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.security import Fingerprint
from crewai.task import Task
@@ -38,25 +59,7 @@ from crewai.utilities.agent_utils import (
)
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryRetrievalStartedEvent,
MemoryRetrievalCompletedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.llm_utils import create_default_llm, create_llm
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -87,6 +90,8 @@ class Agent(BaseAgent):
"""
_times_executed: int = PrivateAttr(default=0)
_llm: BaseLLM = PrivateAttr()
_function_calling_llm: BaseLLM | None = PrivateAttr(default=None)
max_execution_time: Optional[int] = Field(
default=None,
description="Maximum execution time for an agent to execute a task",
@@ -101,10 +106,11 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
description="Language model that will run the agent.", default=None
llm: str | InstanceOf[BaseLLM] | None = Field(
description="Language model that will run the agent.",
default_factory=create_default_llm,
)
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
description="Language model that will run the agent.", default=None
)
system_template: Optional[str] = Field(
@@ -151,7 +157,7 @@ class Agent(BaseAgent):
default=None,
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
)
embedder: Optional[Dict[str, Any]] = Field(
embedder: Optional[dict[str, Any]] = Field(
default=None,
description="Embedder configuration for the agent.",
)
@@ -171,7 +177,7 @@ class Agent(BaseAgent):
default=None,
description="The Agent's role to be used from your repository.",
)
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
guardrail: Optional[Callable[[Any], tuple[bool, Any]] | str] = Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
@@ -180,20 +186,36 @@ class Agent(BaseAgent):
)
@model_validator(mode="before")
def validate_from_repository(cls, v):
@classmethod
def validate_from_repository(cls, v: Any) -> Any:
if v is not None and (from_repository := v.get("from_repository")):
return load_agent_from_repository(from_repository) | v
return v
@field_validator("function_calling_llm", mode="after")
@classmethod
def validate_function_calling_llm(cls, v: Any) -> BaseLLM | None:
if not v or isinstance(v, BaseLLM):
return v
return create_llm(v)
@model_validator(mode="after")
def post_init_setup(self):
def post_init_setup(self) -> Self:
self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm)
if self.function_calling_llm and not isinstance(
self.function_calling_llm, BaseLLM
):
self.function_calling_llm = create_llm(self.function_calling_llm)
# Validate and set the private LLM attributes
if isinstance(self.llm, BaseLLM):
self._llm = self.llm
elif self.llm is None:
self._llm = create_default_llm()
else:
self._llm = create_llm(self.llm)
if self.function_calling_llm:
if isinstance(self.function_calling_llm, BaseLLM):
self._function_calling_llm = self.function_calling_llm
else:
self._function_calling_llm = create_llm(self.function_calling_llm)
if not self.agent_executor:
self._setup_agent_executor()
@@ -203,12 +225,12 @@ class Agent(BaseAgent):
return self
def _setup_agent_executor(self):
def _setup_agent_executor(self) -> None:
if not self.cache_handler:
self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler)
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
def set_knowledge(self, crew_embedder: Optional[dict[str, Any]] = None) -> None:
try:
if self.embedder is None and crew_embedder:
self.embedder = crew_embedder
@@ -245,8 +267,8 @@ class Agent(BaseAgent):
self,
task: Task,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
) -> str:
tools: Optional[list[BaseTool]] = None,
) -> Any:
"""Execute a task with the agent.
Args:
@@ -417,7 +439,7 @@ class Agent(BaseAgent):
)
tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task)
self.create_agent_executor(task=task, tools=tools)
if self.crew and self.crew._train:
task_prompt = self._training_handler(task_prompt=task_prompt)
@@ -492,7 +514,7 @@ class Agent(BaseAgent):
# If there was any tool in self.tools_results that had result_as_answer
# set to True, return the results of the last tool that had
# result_as_answer set to True
for tool_result in self.tools_results: # type: ignore # Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable)
for tool_result in self.tools_results:
if tool_result.get("result_as_answer", False):
result = tool_result["result"]
crewai_event_bus.emit(
@@ -501,7 +523,7 @@ class Agent(BaseAgent):
)
return result
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> str:
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any:
"""Execute a task with a timeout.
Args:
@@ -534,7 +556,7 @@ class Agent(BaseAgent):
future.cancel()
raise RuntimeError(f"Task execution failed: {str(e)}")
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
def _execute_without_timeout(self, task_prompt: str, task: Task) -> Any:
"""Execute a task without a timeout.
Args:
@@ -544,6 +566,9 @@ class Agent(BaseAgent):
Returns:
The output of the agent.
"""
assert self.agent_executor is not None, (
"Agent executor must be created before execution"
)
return self.agent_executor.invoke(
{
"input": task_prompt,
@@ -554,14 +579,15 @@ class Agent(BaseAgent):
)["output"]
def create_agent_executor(
self, tools: Optional[List[BaseTool]] = None, task=None
self, task: Task, tools: Optional[list[BaseTool]] = None
) -> None:
"""Create an agent executor for the agent.
Returns:
An instance of the CrewAgentExecutor class.
Args:
task: Task to execute.
tools: Optional list of tools to use.
"""
raw_tools: List[BaseTool] = tools or self.tools or []
raw_tools: list[BaseTool] = tools or self.tools or []
parsed_tools = parse_tools(raw_tools)
prompt = Prompts(
@@ -582,7 +608,7 @@ class Agent(BaseAgent):
)
self.agent_executor = CrewAgentExecutor(
llm=self.llm,
llm=self._llm,
task=task,
agent=self,
crew=self.crew,
@@ -595,15 +621,15 @@ class Agent(BaseAgent):
tools_names=get_tool_names(parsed_tools),
tools_description=render_text_description_and_args(parsed_tools),
step_callback=self.step_callback,
function_calling_llm=self.function_calling_llm,
function_calling_llm=self._function_calling_llm,
respect_context_window=self.respect_context_window,
request_within_rpm_limit=(
self._rpm_controller.check_or_wait if self._rpm_controller else None
),
callbacks=[TokenCalcHandler(self._token_process)],
litellm_callbacks=[TokenCalcHandler(self._token_process)],
)
def get_delegation_tools(self, agents: List[BaseAgent]):
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools()
return tools
@@ -613,7 +639,7 @@ class Agent(BaseAgent):
return [AddImageTool()]
def get_code_execution_tools(self):
def get_code_execution_tools(self) -> list[BaseTool]:
try:
from crewai_tools import CodeInterpreterTool # type: ignore
@@ -624,8 +650,11 @@ class Agent(BaseAgent):
self._logger.log(
"info", "Coding tools not available. Install crewai_tools. "
)
return []
def get_output_converter(self, llm, text, model, instructions):
def get_output_converter(
self, llm: BaseLLM, text: str, model: str, instructions: str
) -> Converter:
return Converter(llm=llm, text=text, model=model, instructions=instructions)
def _training_handler(self, task_prompt: str) -> str:
@@ -654,7 +683,7 @@ class Agent(BaseAgent):
)
return task_prompt
def _render_text_description(self, tools: List[Any]) -> str:
def _render_text_description(self, tools: list[Any]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:
@@ -673,7 +702,7 @@ class Agent(BaseAgent):
return description
def _inject_date_to_task(self, task):
def _inject_date_to_task(self, task: Task) -> None:
"""Inject the current date into the task description if inject_date is enabled."""
if self.inject_date:
from datetime import datetime
@@ -723,7 +752,7 @@ class Agent(BaseAgent):
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
)
def __repr__(self):
def __repr__(self) -> str:
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
@property
@@ -736,7 +765,7 @@ class Agent(BaseAgent):
"""
return self.security_config.fingerprint
def set_fingerprint(self, fingerprint: Fingerprint):
def set_fingerprint(self, fingerprint: Fingerprint) -> None:
self.security_config.fingerprint = fingerprint
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
@@ -752,22 +781,8 @@ class Agent(BaseAgent):
task_prompt=task_prompt
)
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
if not isinstance(self.llm, BaseLLM):
self._logger.log(
"warning",
f"Knowledge search query failed: LLM for agent '{self.role}' is not an instance of BaseLLM",
)
crewai_event_bus.emit(
self,
event=KnowledgeQueryFailedEvent(
agent=self,
error="LLM is not compatible with knowledge search queries",
),
)
return None
try:
rewritten_query = self.llm.call(
rewritten_query = self._llm.call(
[
{
"role": "system",
@@ -796,8 +811,8 @@ class Agent(BaseAgent):
def kickoff(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: Optional[type[Any]] = None,
) -> LiteAgentOutput:
"""
Execute the agent with the given messages using a LiteAgent instance.
@@ -819,7 +834,7 @@ class Agent(BaseAgent):
role=self.role,
goal=self.goal,
backstory=self.backstory,
llm=self.llm,
llm=self._llm,
tools=self.tools or [],
max_iterations=self.max_iter,
max_execution_time=self.max_execution_time,
@@ -836,8 +851,8 @@ class Agent(BaseAgent):
async def kickoff_async(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: Optional[type[Any]] = None,
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages using a LiteAgent instance.
@@ -857,7 +872,7 @@ class Agent(BaseAgent):
role=self.role,
goal=self.goal,
backstory=self.backstory,
llm=self.llm,
llm=self._llm,
tools=self.tools or [],
max_iterations=self.max_iter,
max_execution_time=self.max_execution_time,

View File

@@ -1,5 +1,10 @@
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.parser import parse, AgentAction, AgentFinish, OutputParserException
from crewai.agents.tools_handler import ToolsHandler
__all__ = ["CacheHandler", "parse", "AgentAction", "AgentFinish", "OutputParserException", "ToolsHandler"]
__all__ = [
"parse",
"AgentAction",
"AgentFinish",
"OutputParserException",
"ToolsHandler",
]

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from pydantic import Field, PrivateAttr
@@ -10,20 +10,22 @@ from crewai.agents.agent_adapters.langgraph.structured_output_converter import (
LangGraphConverterAdapter,
)
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool
from crewai.utilities import Logger
from crewai.utilities.converter import Converter
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool
from crewai.utilities import Logger
from crewai.utilities.converter import Converter
try:
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import ( # type: ignore
MemorySaver,
)
from langgraph.prebuilt import create_react_agent # type: ignore
LANGGRAPH_AVAILABLE = True
except ImportError:
@@ -51,11 +53,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
role: str,
goal: str,
backstory: str,
tools: Optional[List[BaseTool]] = None,
tools: Optional[list[BaseTool]] = None,
llm: Any = None,
max_iterations: int = 10,
agent_config: Optional[Dict[str, Any]] = None,
**kwargs,
agent_config: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
"""Initialize the LangGraph agent adapter."""
if not LANGGRAPH_AVAILABLE:
@@ -81,7 +83,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
try:
self._memory = MemorySaver()
converted_tools: List[Any] = self._tool_adapter.tools()
converted_tools: list[Any] = self._tool_adapter.tools()
if self._agent_config:
self._graph = create_react_agent(
model=self.llm,
@@ -111,7 +113,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
"""Build a system prompt for the LangGraph agent."""
base_prompt = f"""
You are {self.role}.
Your goal is: {self.goal}
Your backstory: {self.backstory}
@@ -124,10 +126,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
tools: Optional[list[BaseTool]] = None,
) -> str:
"""Execute a task using the LangGraph workflow."""
self.create_agent_executor(tools)
self.create_agent_executor(task, tools)
self.configure_structured_output(task)
@@ -197,11 +199,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
)
raise
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
def create_agent_executor(
self, task: Any = None, tools: Optional[list[BaseTool]] = None
) -> None:
"""Configure the LangGraph agent for execution."""
self.configure_tools(tools)
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
def configure_tools(self, tools: Optional[list[BaseTool]] = None) -> None:
"""Configure tools for the LangGraph agent."""
if tools:
all_tools = list(self.tools or []) + list(tools or [])
@@ -209,7 +213,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
available_tools = self._tool_adapter.tools()
self._graph.tools = available_tools
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
"""Implement delegation tools support for LangGraph."""
agent_tools = AgentTools(agents=agents)
return agent_tools.tools()
@@ -220,6 +224,6 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
"""Convert output format if needed."""
return Converter(llm=llm, text=text, model=model, instructions=instructions)
def configure_structured_output(self, task) -> None:
def configure_structured_output(self, task: Any) -> None:
"""Configure the structured output for LangGraph."""
self._converter_adapter.configure_structured_output(task)

View File

@@ -1,4 +1,5 @@
import json
from typing import Any
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
from crewai.utilities.converter import generate_model_description
@@ -7,14 +8,15 @@ from crewai.utilities.converter import generate_model_description
class LangGraphConverterAdapter(BaseConverterAdapter):
"""Adapter for handling structured output conversion in LangGraph agents"""
def __init__(self, agent_adapter):
def __init__(self, agent_adapter: Any) -> None:
"""Initialize the converter adapter with a reference to the agent adapter"""
super().__init__(agent_adapter) # type: ignore
self.agent_adapter = agent_adapter
self._output_format = None
self._schema = None
self._system_prompt_appendix = None
self._output_format: str | None = None
self._schema: str | None = None
self._system_prompt_appendix: str | None = None
def configure_structured_output(self, task) -> None:
def configure_structured_output(self, task: Any) -> None:
"""Configure the structured output for LangGraph."""
if not (task.output_json or task.output_pydantic):
self._output_format = None
@@ -41,7 +43,7 @@ Important: Your final answer MUST be provided in the following structured format
{self._schema}
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
The output should be raw JSON that exactly matches the specified schema.
"""

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, Optional
from pydantic import Field, PrivateAttr
@@ -7,19 +7,19 @@ from crewai.agents.agent_adapters.openai_agents.structured_output_converter impo
OpenAIConverterAdapter,
)
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tools import BaseTool
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.utilities import Logger
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.tools import BaseTool
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.utilities import Logger
try:
from agents import Agent as OpenAIAgent # type: ignore
from agents import Runner, enable_verbose_stdout_logging # type: ignore
from agents import Agent as OpenAIAgent # type: ignore[import-not-found]
from agents import Runner, enable_verbose_stdout_logging
from .openai_agent_tool_adapter import OpenAIAgentToolAdapter
@@ -40,13 +40,14 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
step_callback: Any = Field(default=None)
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
_converter_adapter: OpenAIConverterAdapter = PrivateAttr()
agent_executor: Any = Field(default=None)
def __init__(
self,
model: str = "gpt-4o-mini",
tools: Optional[List[BaseTool]] = None,
agent_config: Optional[dict] = None,
**kwargs,
tools: Optional[list[BaseTool]] = None,
agent_config: Optional[dict[str, Any]] = None,
**kwargs: Any,
):
if not OPENAI_AVAILABLE:
raise ImportError(
@@ -72,7 +73,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
"""Build a system prompt for the OpenAI agent."""
base_prompt = f"""
You are {self.role}.
Your goal is: {self.goal}
Your backstory: {self.backstory}
@@ -85,11 +86,11 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
) -> str:
tools: Optional[list[BaseTool]] = None,
) -> Any:
"""Execute a task using the OpenAI Assistant"""
self._converter_adapter.configure_structured_output(task)
self.create_agent_executor(tools)
self.create_agent_executor(task, tools)
if self.verbose:
enable_verbose_stdout_logging()
@@ -109,6 +110,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
task=task,
),
)
assert hasattr(self, "agent_executor"), "agent_executor not initialized"
result = self.agent_executor.run_sync(self._openai_agent, task_prompt)
final_answer = self.handle_execution_result(result)
crewai_event_bus.emit(
@@ -131,7 +133,9 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
)
raise
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
def create_agent_executor(
self, task: Any = None, tools: Optional[list[BaseTool]] = None
) -> None:
"""
Configure the OpenAI agent for execution.
While OpenAI handles execution differently through Runner,
@@ -152,24 +156,24 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
self.agent_executor = Runner
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
def configure_tools(self, tools: Optional[list[BaseTool]] = None) -> None:
"""Configure tools for the OpenAI Assistant"""
if tools:
self._tool_adapter.configure_tools(tools)
if self._tool_adapter.converted_tools:
self._openai_agent.tools = self._tool_adapter.converted_tools
def handle_execution_result(self, result: Any) -> str:
def handle_execution_result(self, result: Any) -> Any:
"""Process OpenAI Assistant execution result converting any structured output to a string"""
return self._converter_adapter.post_process_result(result.final_output)
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
"""Implement delegation tools support"""
agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools()
return tools
def configure_structured_output(self, task) -> None:
def configure_structured_output(self, task: Any) -> None:
"""Configure the structured output for the specific agent implementation.
Args:

View File

@@ -1,5 +1,6 @@
import json
import re
from typing import Any
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
from crewai.utilities.converter import generate_model_description
@@ -19,14 +20,15 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
_output_model: The Pydantic model for the output
"""
def __init__(self, agent_adapter):
def __init__(self, agent_adapter: Any) -> None:
"""Initialize the converter adapter with a reference to the agent adapter"""
super().__init__(agent_adapter) # type: ignore
self.agent_adapter = agent_adapter
self._output_format = None
self._schema = None
self._output_model = None
self._output_format: str | None = None
self._schema: str | None = None
self._output_model: Any = None
def configure_structured_output(self, task) -> None:
def configure_structured_output(self, task: Any) -> None:
"""
Configure the structured output for OpenAI agent based on task requirements.
@@ -75,7 +77,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
return f"{base_prompt}\n\n{output_schema}"
def post_process_result(self, result: str) -> str:
def post_process_result(self, result: str) -> Any:
"""
Post-process the result to ensure it matches the expected format.

View File

@@ -1,8 +1,9 @@
import uuid
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Optional, TypeVar
from pydantic import (
UUID4,
@@ -14,6 +15,7 @@ from pydantic import (
model_validator,
)
from pydantic_core import PydanticCustomError
from typing_extensions import Self
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
@@ -61,7 +63,7 @@ class BaseAgent(ABC, BaseModel):
Methods:
execute_task(task: Any, context: Optional[str] = None, tools: Optional[List[BaseTool]] = None) -> str:
Abstract method to execute a task.
create_agent_executor(tools=None) -> None:
create_agent_executor(task, tools=None) -> None:
Abstract method to create an agent executor.
get_delegation_tools(agents: List["BaseAgent"]):
Abstract method to set the agents task tools for handling delegation and question asking to other agents in crew.
@@ -79,7 +81,7 @@ class BaseAgent(ABC, BaseModel):
Set private attributes.
"""
__hash__ = object.__hash__ # type: ignore
__hash__ = object.__hash__
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
@@ -91,7 +93,7 @@ class BaseAgent(ABC, BaseModel):
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
config: Optional[Dict[str, Any]] = Field(
config: Optional[dict[str, Any]] = Field(
description="Configuration for the agent", default=None, exclude=True
)
cache: bool = Field(
@@ -108,14 +110,14 @@ class BaseAgent(ABC, BaseModel):
default=False,
description="Enable agent to delegate and ask questions among each other.",
)
tools: Optional[List[BaseTool]] = Field(
tools: Optional[list[BaseTool]] = Field(
default_factory=list, description="Tools at agents' disposal"
)
max_iter: int = Field(
default=25, description="Maximum iterations for an agent to execute a task"
)
agent_executor: InstanceOf = Field(
default=None, description="An instance of the CrewAgentExecutor class."
agent_executor: Optional[Any] = Field(
default=None, description="An instance of the agent executor class."
)
llm: Any = Field(
default=None, description="Language model that will run the agent."
@@ -129,7 +131,7 @@ class BaseAgent(ABC, BaseModel):
default_factory=ToolsHandler,
description="An instance of the ToolsHandler class.",
)
tools_results: List[Dict[str, Any]] = Field(
tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
)
max_tokens: Optional[int] = Field(
@@ -138,7 +140,7 @@ class BaseAgent(ABC, BaseModel):
knowledge: Optional[Knowledge] = Field(
default=None, description="Knowledge for the agent."
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
knowledge_sources: Optional[list[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the agent.",
)
@@ -150,7 +152,7 @@ class BaseAgent(ABC, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the agent, including fingerprinting.",
)
callbacks: List[Callable] = Field(
callbacks: list[Callable[..., Any]] = Field(
default=[], description="Callbacks to be used for the agent"
)
adapted_agent: bool = Field(
@@ -163,12 +165,12 @@ class BaseAgent(ABC, BaseModel):
@model_validator(mode="before")
@classmethod
def process_model_config(cls, values):
def process_model_config(cls, values: Any) -> Any:
return process_config(values, cls)
@field_validator("tools")
@classmethod
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
@@ -196,7 +198,7 @@ class BaseAgent(ABC, BaseModel):
return processed_tools
@model_validator(mode="after")
def validate_and_set_attributes(self):
def validate_and_set_attributes(self) -> Self:
# Validate required fields
for field in ["role", "goal", "backstory"]:
if getattr(self, field) is None:
@@ -228,7 +230,7 @@ class BaseAgent(ABC, BaseModel):
)
@model_validator(mode="after")
def set_private_attrs(self):
def set_private_attrs(self) -> Self:
"""Set private attributes."""
self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller:
@@ -240,7 +242,7 @@ class BaseAgent(ABC, BaseModel):
return self
@property
def key(self):
def key(self) -> str:
source = [
self._original_role or self.role,
self._original_goal or self.goal,
@@ -253,16 +255,18 @@ class BaseAgent(ABC, BaseModel):
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
tools: Optional[list[BaseTool]] = None,
) -> str:
pass
@abstractmethod
def create_agent_executor(self, tools=None) -> None:
def create_agent_executor(
self, task: Any, tools: Optional[list[BaseTool]] = None
) -> None:
pass
@abstractmethod
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
"""Set the task tools that init BaseAgenTools class."""
pass
@@ -320,7 +324,7 @@ class BaseAgent(ABC, BaseModel):
return copied_agent
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolate inputs into the agent description and backstory."""
if self._original_role is None:
self._original_role = self.role
@@ -350,7 +354,7 @@ class BaseAgent(ABC, BaseModel):
if self.cache:
self.cache_handler = cache_handler
self.tools_handler.cache = cache_handler
self.create_agent_executor()
# Executor will be created when a task is executed
def set_rpm_controller(self, rpm_controller: RPMController) -> None:
"""Set the rpm controller for the agent.
@@ -360,7 +364,7 @@ class BaseAgent(ABC, BaseModel):
"""
if not self._rpm_controller:
self._rpm_controller = rpm_controller
self.create_agent_executor()
# Executor will be created when a task is executed
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
def set_knowledge(self, crew_embedder: Optional[dict[str, Any]] = None) -> None:
pass

View File

@@ -1,31 +1,32 @@
import time
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING
from crewai.events.event_listener import event_listener
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.utilities import I18N
from crewai.utilities.converter import ConverterError
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.printer import Printer
from crewai.events.event_listener import event_listener
if TYPE_CHECKING:
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.parser import AgentFinish
from crewai.crew import Crew
from crewai.task import Task
class CrewAgentExecutorMixin:
crew: "Crew"
crew: "Crew | None"
agent: "BaseAgent"
task: "Task"
iterations: int
max_iter: int
messages: List[Dict[str, str]]
messages: list[dict[str, str]]
_i18n: I18N
_printer: Printer = Printer()
def _create_short_term_memory(self, output) -> None:
def _create_short_term_memory(self, output: "AgentFinish") -> None:
"""Create and save a short-term memory item if conditions are met."""
if (
self.crew
@@ -35,7 +36,8 @@ class CrewAgentExecutorMixin:
):
try:
if (
hasattr(self.crew, "_short_term_memory")
self.crew
and hasattr(self.crew, "_short_term_memory")
and self.crew._short_term_memory
):
self.crew._short_term_memory.save(
@@ -48,7 +50,7 @@ class CrewAgentExecutorMixin:
print(f"Failed to add to short term memory: {e}")
pass
def _create_external_memory(self, output) -> None:
def _create_external_memory(self, output: "AgentFinish") -> None:
"""Create and save a external-term memory item if conditions are met."""
if (
self.crew
@@ -69,7 +71,7 @@ class CrewAgentExecutorMixin:
print(f"Failed to add to external memory: {e}")
pass
def _create_long_term_memory(self, output) -> None:
def _create_long_term_memory(self, output: "AgentFinish") -> None:
"""Create and save long-term and entity memory items based on evaluation."""
if (
self.crew

View File

@@ -1,3 +1,5 @@
from .cache_handler import CacheHandler
"""Internal caching utilities for agent tool execution.
__all__ = ["CacheHandler"]
This package provides caching mechanisms for storing and retrieving
tool execution results to avoid redundant operations.
"""

View File

@@ -1,15 +1,50 @@
from typing import Any, Dict, Optional
"""Cache handler for storing and retrieving tool execution results.
This module provides a caching mechanism for tool outputs in the CrewAI framework,
allowing agents to reuse previous tool execution results when the same tool is
called with identical arguments.
Classes:
CacheHandler: Manages the caching of tool execution results using an in-memory
dictionary with serialized tool arguments as keys.
"""
import json
from typing import Any
from pydantic import BaseModel, PrivateAttr
class CacheHandler(BaseModel):
"""Callback handler for tool usage."""
"""Callback handler for tool usage.
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
def add(self, tool, input, output):
self._cache[f"{tool}-{input}"] = output
Notes:
TODO: Make thread-safe, currently not thread-safe.
"""
def read(self, tool, input) -> Optional[str]:
return self._cache.get(f"{tool}-{input}")
_cache: dict[str, Any] = PrivateAttr(default_factory=dict)
def add(self, tool: str, input_data: dict[str, Any] | None, output: str) -> None:
"""Add a tool execution result to the cache.
Args:
tool: The name of the tool.
input_data: The input arguments for the tool.
output: The output from the tool execution.
"""
cache_key = json.dumps(input_data, sort_keys=True) if input_data else ""
self._cache[f"{tool}-{cache_key}"] = output
def read(self, tool: str, input_data: dict[str, Any] | None) -> str | None:
"""Read a tool execution result from the cache.
Args:
tool: The name of the tool.
input_data: The input arguments for the tool.
Returns:
The cached output if found, None otherwise.
"""
cache_key = json.dumps(input_data, sort_keys=True) if input_data else ""
return self._cache.get(f"{tool}-{cache_key}")

View File

@@ -4,8 +4,14 @@ Handles agent execution flow including LLM interactions, tool execution,
and memory management.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from crewai.crew import Crew
from crewai.task import Task
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
@@ -21,6 +27,7 @@ from crewai.events.types.logging_events import (
AgentLogsStartedEvent,
)
from crewai.llms.base_llm import BaseLLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer
@@ -51,9 +58,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def __init__(
self,
llm: Any,
task: Any,
crew: Any,
llm: BaseLLM,
task: Task,
crew: Crew | None,
agent: BaseAgent,
prompt: dict[str, str],
max_iter: int,
@@ -62,19 +69,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
stop_words: list[str],
tools_description: str,
tools_handler: ToolsHandler,
step_callback: Any = None,
original_tools: list[Any] | None = None,
function_calling_llm: Any = None,
step_callback: Callable[[AgentAction | AgentFinish], None] | None = None,
original_tools: list[BaseTool] | None = None,
function_calling_llm: BaseLLM | None = None,
respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
litellm_callbacks: list[Any] | None = None,
) -> None:
"""Initialize executor.
Args:
llm: Language model instance.
task: Task to execute.
crew: Crew instance.
crew: Optional Crew instance.
agent: Agent to execute.
prompt: Prompt templates.
max_iter: Maximum iterations.
@@ -88,19 +95,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
function_calling_llm: Optional function calling LLM.
respect_context_window: Respect context limits.
request_within_rpm_limit: RPM limit check function.
callbacks: Optional callbacks list.
litellm_callbacks: Optional litellm callbacks list.
"""
self._i18n: I18N = I18N()
self.llm: BaseLLM = llm
self.llm = llm
self.task = task
self.agent = agent
self.crew = crew
self.crew: Crew | None = crew
self.prompt = prompt
self.tools = tools
self.tools_names = tools_names
self.stop = stop_words
self.max_iter = max_iter
self.callbacks = callbacks or []
self.litellm_callbacks = litellm_callbacks or []
self._printer: Printer = Printer()
self.tools_handler = tools_handler
self.original_tools = original_tools or []
@@ -123,7 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
)
def invoke(self, inputs: dict[str, str]) -> dict[str, Any]:
def invoke(self, inputs: dict[str, str]) -> dict[str, str]:
"""Execute the agent with given inputs.
Args:
@@ -131,6 +138,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
Dictionary with agent output.
Raises:
AssertionError: If agent fails to reach final answer.
Exception: If unknown error occurs during execution.
"""
if "system" in self.prompt:
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
@@ -170,6 +181,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
Final answer from the agent.
Raises:
Exception: If litellm error or unknown error occurs.
"""
formatted_answer = None
while not isinstance(formatted_answer, AgentFinish):
@@ -181,7 +195,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
callbacks=self.litellm_callbacks,
)
enforce_rpm_limit(self.request_within_rpm_limit)
@@ -189,7 +203,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
answer = get_llm_response(
llm=self.llm,
messages=self.messages,
callbacks=self.callbacks,
callbacks=self.litellm_callbacks,
printer=self._printer,
from_task=self.task,
)
@@ -198,10 +212,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if isinstance(formatted_answer, AgentAction):
# Extract agent fingerprint if available
fingerprint_context = {}
if (
self.agent
and hasattr(self.agent, "security_config")
and hasattr(self.agent.security_config, "fingerprint")
if hasattr(self.agent, "security_config") and hasattr(
self.agent.security_config, "fingerprint"
):
fingerprint_context = {
"agent_fingerprint": str(
@@ -214,8 +226,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
fingerprint_context=fingerprint_context,
tools=self.tools,
i18n=self._i18n,
agent_key=self.agent.key if self.agent else None,
agent_role=self.agent.role if self.agent else None,
agent_key=self.agent.key,
agent_role=self.agent.role,
tools_handler=self.tools_handler,
task=self.task,
agent=self.agent,
@@ -247,7 +259,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
printer=self._printer,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
callbacks=self.litellm_callbacks,
i18n=self._i18n,
)
continue
@@ -317,18 +329,13 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def _show_start_logs(self) -> None:
"""Emit agent start event."""
if self.agent is None:
raise ValueError("Agent cannot be None")
crewai_event_bus.emit(
self.agent,
AgentLogsStartedEvent(
agent_role=self.agent.role,
task_description=(
getattr(self.task, "description") if self.task else "Not Found"
),
task_description=self.task.description,
verbose=self.agent.verbose
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
or (self.crew.verbose if self.crew else False),
),
)
@@ -338,16 +345,13 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Args:
formatted_answer: Agent's response to log.
"""
if self.agent is None:
raise ValueError("Agent cannot be None")
crewai_event_bus.emit(
self.agent,
AgentLogsExecutionEvent(
agent_role=self.agent.role,
formatted_answer=formatted_answer,
verbose=self.agent.verbose
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
or (self.crew.verbose if self.crew else False),
),
)
@@ -361,9 +365,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
human_feedback: Optional feedback from human.
"""
agent_id = str(self.agent.id)
train_iteration = (
getattr(self.crew, "_train_iteration", None) if self.crew else None
)
train_iteration = getattr(self.crew, "_train_iteration", None)
if train_iteration is None or not isinstance(train_iteration, int):
self._printer.print(

View File

@@ -7,18 +7,18 @@ AgentAction or AgentFinish objects.
from dataclasses import dataclass
from json_repair import repair_json
from json_repair import repair_json # type: ignore[import-untyped]
from crewai.agents.constants import (
ACTION_INPUT_ONLY_REGEX,
ACTION_INPUT_REGEX,
ACTION_REGEX,
ACTION_INPUT_ONLY_REGEX,
FINAL_ANSWER_ACTION,
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
UNABLE_TO_REPAIR_JSON_RESULTS,
)
from crewai.utilities import I18N
from crewai.utilities.i18n import I18N
_I18N = I18N()

View File

@@ -1,8 +1,8 @@
"""Tools handler for managing tool execution and caching."""
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.tools.cache_tools.cache_tools import CacheTools
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
from crewai.agents.cache.cache_handler import CacheHandler
class ToolsHandler:
@@ -39,6 +39,6 @@ class ToolsHandler:
if self.cache and should_cache and calling.tool_name != CacheTools().name:
self.cache.add(
tool=calling.tool_name,
input=calling.arguments,
input_data=calling.arguments,
output=output,
)

View File

@@ -3,26 +3,18 @@ import json
import re
import uuid
import warnings
from collections.abc import Callable, Mapping, Set
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from crewai.utilities.crew.models import CrewContext
from pydantic import (
UUID4,
BaseModel,
@@ -34,15 +26,36 @@ from pydantic import (
model_validator,
)
from pydantic_core import PydanticCustomError
from typing_extensions import Self
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache import CacheHandler
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
@@ -57,29 +70,9 @@ from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.crew.models import CrewContext
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks,
@@ -116,9 +109,12 @@ class Crew(FlowTrackable, BaseModel):
planning: Plan the crew execution and add the plan to the crew.
chat_llm: The language model used for orchestrating chat interactions with the crew.
security_config: Security configuration for the crew, including fingerprinting.
Notes:
TODO: Improve the embedder type from dict[str, Any] to a more specific TypedDict or dataclass.
"""
__hash__ = object.__hash__ # type: ignore
__hash__ = object.__hash__
_execution_span: Any = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr()
@@ -130,7 +126,7 @@ class Crew(FlowTrackable, BaseModel):
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
_train: Optional[bool] = PrivateAttr(default=False)
_train_iteration: Optional[int] = PrivateAttr()
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_inputs: Optional[dict[str, Any]] = PrivateAttr(default=None)
_logging_color: str = PrivateAttr(
default="bold_purple",
)
@@ -140,8 +136,8 @@ class Crew(FlowTrackable, BaseModel):
name: Optional[str] = Field(default="crew")
cache: bool = Field(default=True)
tasks: List[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list)
tasks: list[Task] = Field(default_factory=list)
agents: list[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False)
memory: bool = Field(
@@ -164,7 +160,7 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="An Instance of the ExternalMemory to be used by the Crew",
)
embedder: Optional[dict] = Field(
embedder: Optional[dict[str, Any]] = Field(
default=None,
description="Configuration for the embedder to be used for the crew.",
)
@@ -172,16 +168,16 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
manager_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: Optional[BaseAgent] = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: Optional[str | InstanceOf[LLM] | Any] = Field(
description="Language model that will run the agent.", default=None
)
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
config: Optional[Json[dict[str, Any]] | dict[str, Any]] = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
share_crew: Optional[bool] = Field(default=False)
step_callback: Optional[Any] = Field(
@@ -192,13 +188,13 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Callback to be executed after each task for all agents execution.",
)
before_kickoff_callbacks: List[
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
before_kickoff_callbacks: list[
Callable[[Optional[dict[str, Any]]], Optional[dict[str, Any]]]
] = Field(
default_factory=list,
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
)
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
default_factory=list,
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
)
@@ -210,7 +206,7 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Path to the prompt json file to be used for the crew.",
)
output_log_file: Optional[Union[bool, str]] = Field(
output_log_file: Optional[bool | str] = Field(
default=None,
description="Path to the log file to be saved",
)
@@ -218,23 +214,23 @@ class Crew(FlowTrackable, BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
planning_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
default=None,
description="Language model that will run the AgentPlanner if planning is True.",
)
task_execution_output_json_files: Optional[List[str]] = Field(
task_execution_output_json_files: Optional[list[str]] = Field(
default=None,
description="List of file paths for task execution JSON files.",
)
execution_logs: List[Dict[str, Any]] = Field(
execution_logs: list[dict[str, Any]] = Field(
default=[],
description="List of execution logs for tasks",
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
knowledge_sources: Optional[list[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
)
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
chat_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
@@ -267,8 +263,8 @@ class Crew(FlowTrackable, BaseModel):
@field_validator("config", mode="before")
@classmethod
def check_config_type(
cls, v: Union[Json, Dict[str, Any]]
) -> Union[Json, Dict[str, Any]]:
cls, v: Json[dict[str, Any]] | dict[str, Any]
) -> Json[dict[str, Any]] | dict[str, Any]:
"""Validates that the config is a valid type.
Args:
v: The config to be validated.
@@ -277,10 +273,10 @@ class Crew(FlowTrackable, BaseModel):
"""
# TODO: Improve typing
return json.loads(v) if isinstance(v, Json) else v # type: ignore
return json.loads(v) if isinstance(v, str) else v
@model_validator(mode="after")
def set_private_attrs(self) -> "Crew":
def set_private_attrs(self) -> Self:
"""Set private attributes."""
self._cache_handler = CacheHandler()
@@ -300,7 +296,7 @@ class Crew(FlowTrackable, BaseModel):
return self
def _initialize_default_memories(self):
def _initialize_default_memories(self) -> None:
self._long_term_memory = self._long_term_memory or LongTermMemory()
self._short_term_memory = self._short_term_memory or ShortTermMemory(
crew=self,
@@ -311,7 +307,7 @@ class Crew(FlowTrackable, BaseModel):
)
@model_validator(mode="after")
def create_crew_memory(self) -> "Crew":
def create_crew_memory(self) -> Self:
"""Initialize private memory attributes."""
self._external_memory = (
# External memory doesnt support a default value since it was designed to be managed entirely externally
@@ -328,7 +324,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def create_crew_knowledge(self) -> "Crew":
def create_crew_knowledge(self) -> Self:
"""Create the knowledge for the crew."""
if self.knowledge_sources:
try:
@@ -349,7 +345,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def check_manager_llm(self):
def check_manager_llm(self) -> Self:
"""Validates that the language model is set when using hierarchical process."""
if self.process == Process.hierarchical:
if not self.manager_llm and not self.manager_agent:
@@ -371,7 +367,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def check_config(self):
def check_config(self) -> Self:
"""Validates that the crew is properly configured with agents and tasks."""
if not self.config and not self.tasks and not self.agents:
raise PydanticCustomError(
@@ -392,20 +388,20 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_tasks(self):
def validate_tasks(self) -> Self:
if self.process == Process.sequential:
for task in self.tasks:
if task.agent is None:
raise PydanticCustomError(
"missing_agent_in_task",
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
f"Sequential process error: Agent is missing in the task with the following description: {task.description}",
{},
)
return self
@model_validator(mode="after")
def validate_end_with_at_most_one_async_task(self):
def validate_end_with_at_most_one_async_task(self) -> Self:
"""Validates that the crew ends with at most one asynchronous task."""
final_async_task_count = 0
@@ -426,7 +422,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_must_have_non_conditional_task(self) -> "Crew":
def validate_must_have_non_conditional_task(self) -> Self:
"""Ensure that a crew has at least one non-conditional task."""
if not self.tasks:
return self
@@ -442,7 +438,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_first_task(self) -> "Crew":
def validate_first_task(self) -> Self:
"""Ensure the first task is not a ConditionalTask."""
if self.tasks and isinstance(self.tasks[0], ConditionalTask):
raise PydanticCustomError(
@@ -453,19 +449,21 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_async_tasks_not_async(self) -> "Crew":
def validate_async_tasks_not_async(self) -> Self:
"""Ensure that ConditionalTask is not async."""
for task in self.tasks:
if task.async_execution and isinstance(task, ConditionalTask):
raise PydanticCustomError(
"invalid_async_conditional_task",
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
f"Conditional Task: {task.description} , cannot be executed asynchronously.",
{},
)
return self
@model_validator(mode="after")
def validate_async_task_cannot_include_sequential_async_tasks_in_context(self):
def validate_async_task_cannot_include_sequential_async_tasks_in_context(
self,
) -> Self:
"""
Validates that if a task is set to be executed asynchronously,
it cannot include other asynchronous tasks in its context unless
@@ -485,7 +483,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_context_no_future_tasks(self):
def validate_context_no_future_tasks(self) -> Self:
"""Validates that a task's context does not include future tasks."""
task_indices = {id(task): i for i, task in enumerate(self.tasks)}
@@ -502,7 +500,7 @@ class Crew(FlowTrackable, BaseModel):
@property
def key(self) -> str:
source: List[str] = [agent.key for agent in self.agents] + [
source: list[str] = [agent.key for agent in self.agents] + [
task.key for task in self.tasks
]
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
@@ -517,7 +515,7 @@ class Crew(FlowTrackable, BaseModel):
"""
return self.security_config.fingerprint
def _setup_from_config(self):
def _setup_from_config(self) -> None:
assert self.config is not None, "Config should not be None."
"""Initializes agents and tasks from the provided config."""
@@ -530,7 +528,7 @@ class Crew(FlowTrackable, BaseModel):
self.agents = [Agent(**agent) for agent in self.config["agents"]]
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
def _create_task(self, task_config: Dict[str, Any]) -> Task:
def _create_task(self, task_config: dict[str, Any]) -> Task:
"""Creates a task instance from its configuration.
Args:
@@ -559,7 +557,7 @@ class Crew(FlowTrackable, BaseModel):
CrewTrainingHandler(filename).initialize_file()
def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None
self, n_iterations: int, filename: str, inputs: Optional[dict[str, Any]] = None
) -> None:
"""Trains the crew for a given number of iterations."""
inputs = inputs or {}
@@ -611,7 +609,7 @@ class Crew(FlowTrackable, BaseModel):
def kickoff(
self,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
) -> CrewOutput:
ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
@@ -643,8 +641,7 @@ class Crew(FlowTrackable, BaseModel):
for agent in self.agents:
agent.i18n = i18n
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
agent.crew = self # type: ignore[attr-defined]
agent.crew = self
agent.set_knowledge(crew_embedder=self.embedder)
# TODO: Create an AgentFunctionCalling protocol for future refactoring
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
@@ -653,7 +650,7 @@ class Crew(FlowTrackable, BaseModel):
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
agent.create_agent_executor()
# Agent executor will be created when tasks are executed
if self.planning:
self._handle_crew_planning()
@@ -682,9 +679,9 @@ class Crew(FlowTrackable, BaseModel):
finally:
detach(token)
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
"""Executes the Crew's workflow for each input in the list and aggregates results."""
results: List[CrewOutput] = []
results: list[CrewOutput] = []
# Initialize the parent crew's usage metrics
total_usage_metrics = UsageMetrics()
@@ -704,16 +701,18 @@ class Crew(FlowTrackable, BaseModel):
return results
async def kickoff_async(
self, inputs: Optional[Dict[str, Any]] = None
self, inputs: Optional[dict[str, Any]] = None
) -> CrewOutput:
"""Asynchronous kickoff method to start the crew execution."""
inputs = inputs or {}
return await asyncio.to_thread(self.kickoff, inputs)
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
async def kickoff_for_each_async(
self, inputs: list[dict[str, Any]]
) -> list[CrewOutput]:
crew_copies = [self.copy() for _ in inputs]
async def run_crew(crew, input_data):
async def run_crew(crew: Self, input_data: dict[str, Any]) -> CrewOutput:
return await crew.kickoff_async(inputs=input_data)
tasks = [
@@ -732,7 +731,7 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
return results
def _handle_crew_planning(self):
def _handle_crew_planning(self) -> None:
"""Handles the Crew planning."""
self._logger.log("info", "Planning the crew execution")
result = CrewPlanner(
@@ -748,7 +747,7 @@ class Crew(FlowTrackable, BaseModel):
output: TaskOutput,
task_index: int,
was_replayed: bool = False,
):
) -> None:
if self._inputs:
inputs = self._inputs
else:
@@ -780,7 +779,7 @@ class Crew(FlowTrackable, BaseModel):
self._create_manager_agent()
return self._execute_tasks(self.tasks)
def _create_manager_agent(self):
def _create_manager_agent(self) -> None:
i18n = I18N(prompt_file=self.prompt_file)
if self.manager_agent is not None:
self.manager_agent.allow_delegation = True
@@ -792,7 +791,12 @@ class Crew(FlowTrackable, BaseModel):
manager.tools = []
raise Exception("Manager agent should not have tools")
else:
self.manager_llm = create_llm(self.manager_llm)
if self.manager_llm is None:
from crewai.utilities.llm_utils import create_default_llm
self.manager_llm = create_default_llm()
else:
self.manager_llm = create_llm(self.manager_llm)
manager = Agent(
role=i18n.retrieve("hierarchical_manager_agent", "role"),
goal=i18n.retrieve("hierarchical_manager_agent", "goal"),
@@ -807,7 +811,7 @@ class Crew(FlowTrackable, BaseModel):
def _execute_tasks(
self,
tasks: List[Task],
tasks: list[Task],
start_index: Optional[int] = 0,
was_replayed: bool = False,
) -> CrewOutput:
@@ -821,8 +825,8 @@ class Crew(FlowTrackable, BaseModel):
CrewOutput: Final output of the crew
"""
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
task_outputs: list[TaskOutput] = []
futures: list[tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: Optional[TaskOutput] = None
for task_index, task in enumerate(tasks):
@@ -847,7 +851,7 @@ class Crew(FlowTrackable, BaseModel):
tools_for_task = self._prepare_tools(
agent_to_use,
task,
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
cast(list[Tool] | list[BaseTool], tools_for_task),
)
self._log_task_start(task, agent_to_use.role)
@@ -867,7 +871,7 @@ class Crew(FlowTrackable, BaseModel):
future = task.execute_async(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=tools_for_task,
)
futures.append((task, future, task_index))
else:
@@ -879,7 +883,7 @@ class Crew(FlowTrackable, BaseModel):
task_output = task.execute_sync(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=tools_for_task,
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
@@ -893,8 +897,8 @@ class Crew(FlowTrackable, BaseModel):
def _handle_conditional_task(
self,
task: ConditionalTask,
task_outputs: List[TaskOutput],
futures: List[Tuple[Task, Future[TaskOutput], int]],
task_outputs: list[TaskOutput],
futures: list[tuple[Task, Future[TaskOutput], int]],
task_index: int,
was_replayed: bool,
) -> Optional[TaskOutput]:
@@ -917,8 +921,8 @@ class Crew(FlowTrackable, BaseModel):
return None
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
# Add delegation tools if agent allows delegation
if hasattr(agent, "allow_delegation") and getattr(
agent, "allow_delegation", False
@@ -948,7 +952,7 @@ class Crew(FlowTrackable, BaseModel):
tools = self._add_multimodal_tools(agent, tools)
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
if self.process == Process.hierarchical:
@@ -957,12 +961,12 @@ class Crew(FlowTrackable, BaseModel):
def _merge_tools(
self,
existing_tools: Union[List[Tool], List[BaseTool]],
new_tools: Union[List[Tool], List[BaseTool]],
) -> List[BaseTool]:
existing_tools: list[Tool] | list[BaseTool],
new_tools: list[Tool] | list[BaseTool],
) -> list[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
if not new_tools:
return cast(List[BaseTool], existing_tools)
return cast(list[BaseTool], existing_tools)
# Create mapping of tool names to new tools
new_tool_map = {tool.name: tool for tool in new_tools}
@@ -973,41 +977,41 @@ class Crew(FlowTrackable, BaseModel):
# Add all new tools
tools.extend(new_tools)
return cast(List[BaseTool], tools)
return tools
def _inject_delegation_tools(
self,
tools: Union[List[Tool], List[BaseTool]],
tools: list[Tool] | list[BaseTool],
task_agent: BaseAgent,
agents: List[BaseAgent],
) -> List[BaseTool]:
agents: list[BaseAgent],
) -> list[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents)
# Cast delegation_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, delegation_tools)
return cast(list[BaseTool], tools)
def _add_multimodal_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
return cast(list[BaseTool], tools)
def _add_code_execution_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
return cast(list[BaseTool], tools)
def _add_delegation_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
@@ -1015,17 +1019,17 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _log_task_start(self, task: Task, role: str = "None"):
def _log_task_start(self, task: Task, role: str = "None") -> None:
if self.output_log_file:
self._file_handler.log(
task_name=task.name, task=task.description, agent=role, status="started"
)
def _update_manager_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if self.manager_agent:
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
@@ -1033,9 +1037,9 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
if not task.context:
return ""
@@ -1057,7 +1061,7 @@ class Crew(FlowTrackable, BaseModel):
output=output.raw,
)
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
if not task_outputs:
raise ValueError("No task outputs available to create crew output.")
@@ -1088,10 +1092,10 @@ class Crew(FlowTrackable, BaseModel):
def _process_async_tasks(
self,
futures: List[Tuple[Task, Future[TaskOutput], int]],
futures: list[tuple[Task, Future[TaskOutput], int]],
was_replayed: bool = False,
) -> List[TaskOutput]:
task_outputs: List[TaskOutput] = []
) -> list[TaskOutput]:
task_outputs: list[TaskOutput] = []
for future_task, future, task_index in futures:
task_output = future.result()
task_outputs.append(task_output)
@@ -1102,7 +1106,7 @@ class Crew(FlowTrackable, BaseModel):
return task_outputs
def _find_task_index(
self, task_id: str, stored_outputs: List[Any]
self, task_id: str, stored_outputs: list[Any]
) -> Optional[int]:
return next(
(
@@ -1114,7 +1118,7 @@ class Crew(FlowTrackable, BaseModel):
)
def replay(
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
self, task_id: str, inputs: Optional[dict[str, Any]] = None
) -> CrewOutput:
stored_outputs = self._task_output_handler.load()
if not stored_outputs:
@@ -1155,15 +1159,15 @@ class Crew(FlowTrackable, BaseModel):
return result
def query_knowledge(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> Union[List[Dict[str, Any]], None]:
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[dict[str, Any]] | None:
if self.knowledge:
return self.knowledge.query(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> Set[str]:
def fetch_inputs(self) -> set[str]:
"""
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
Scans each task's 'description' + 'expected_output', and each agent's
@@ -1172,7 +1176,7 @@ class Crew(FlowTrackable, BaseModel):
Returns a set of all discovered placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: Set[str] = set()
required_inputs: set[str] = set()
# Scan tasks for inputs
for task in self.tasks:
@@ -1188,7 +1192,18 @@ class Crew(FlowTrackable, BaseModel):
return required_inputs
def copy(self):
def copy(
self,
*,
include: Optional[
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
] = None,
exclude: Optional[
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
] = None,
update: Optional[dict[str, Any]] = None,
deep: bool = True,
) -> "Crew":
"""
Creates a deep copy of the Crew instance.
@@ -1219,7 +1234,7 @@ class Crew(FlowTrackable, BaseModel):
manager_agent = self.manager_agent.copy() if self.manager_agent else None
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None
task_mapping = {}
task_mapping: dict[str, Task] = {}
cloned_tasks = []
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
@@ -1274,16 +1289,10 @@ class Crew(FlowTrackable, BaseModel):
if not task.callback:
task.callback = self.task_callback
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolates the inputs in the tasks and agents."""
[
task.interpolate_inputs_and_add_conversation_history(
# type: ignore # "interpolate_inputs" of "Task" does not return a value (it only ever returns None)
inputs
)
for task in self.tasks
]
# type: ignore # "interpolate_inputs" of "Agent" does not return a value (it only ever returns None)
for task in self.tasks:
task.interpolate_inputs_and_add_conversation_history(inputs)
for agent in self.agents:
agent.interpolate_inputs(inputs)
@@ -1307,8 +1316,8 @@ class Crew(FlowTrackable, BaseModel):
def test(
self,
n_iterations: int,
eval_llm: Union[str, InstanceOf[BaseLLM]],
inputs: Optional[Dict[str, Any]] = None,
eval_llm: str | InstanceOf[BaseLLM],
inputs: Optional[dict[str, Any]] = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
try:
@@ -1349,7 +1358,7 @@ class Crew(FlowTrackable, BaseModel):
)
raise
def __repr__(self):
def __repr__(self) -> str:
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
def reset_memories(self, command_type: str) -> None:
@@ -1401,7 +1410,9 @@ class Crew(FlowTrackable, BaseModel):
if (system := config.get("system")) is not None:
name = config.get("name")
try:
reset_fn: Callable = cast(Callable, config.get("reset"))
reset_fn: Callable[..., None] = cast(
Callable[..., None], config.get("reset")
)
reset_fn(system)
self._logger.log(
"info",
@@ -1430,7 +1441,9 @@ class Crew(FlowTrackable, BaseModel):
raise RuntimeError(f"{name} memory system is not initialized")
try:
reset_fn: Callable = cast(Callable, config.get("reset"))
reset_fn: Callable[..., None] = cast(
Callable[..., None], config.get("reset")
)
reset_fn(system)
self._logger.log(
"info",
@@ -1441,18 +1454,18 @@ class Crew(FlowTrackable, BaseModel):
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
) from e
def _get_memory_systems(self):
def _get_memory_systems(self) -> dict[str, dict[str, Any]]:
"""Get all available memory systems with their configuration.
Returns:
Dict containing all memory systems with their reset functions and display names.
"""
def default_reset(memory):
return memory.reset()
def default_reset(memory: Any) -> None:
memory.reset()
def knowledge_reset(memory):
return self.reset_knowledge(memory)
def knowledge_reset(memory: Any) -> None:
self.reset_knowledge(memory)
# Get knowledge for agents
agent_knowledges = [
@@ -1506,12 +1519,12 @@ class Crew(FlowTrackable, BaseModel):
},
}
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
"""Reset crew and agent knowledge storage."""
for ks in knowledges:
ks.reset()
def _set_allow_crewai_trigger_context_for_first_task(self):
def _set_allow_crewai_trigger_context_for_first_task(self) -> None:
crewai_trigger_payload = self._inputs and self._inputs.get(
"crewai_trigger_payload"
)

View File

@@ -6,10 +6,10 @@ from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
class BaseEventListener(ABC):
verbose: bool = False
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.setup_listeners(crewai_event_bus)
@abstractmethod
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
pass

View File

@@ -1,15 +1,17 @@
from __future__ import annotations
import threading
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
from typing import Any, ParamSpec, TypeVar, cast
from blinker import Signal
from typing_extensions import Self
from crewai.events.base_events import BaseEvent
from crewai.events.event_types import EventTypes
EventT = TypeVar("EventT", bound=BaseEvent)
P = ParamSpec("P")
class CrewAIEventsBus:
@@ -21,21 +23,21 @@ class CrewAIEventsBus:
_instance = None
_lock = threading.Lock()
def __new__(cls):
def __new__(cls) -> Self:
if cls._instance is None:
with cls._lock:
if cls._instance is None: # prevent race condition
cls._instance = super(CrewAIEventsBus, cls).__new__(cls)
cls._instance = super().__new__(cls)
cls._instance._initialize()
return cls._instance
def _initialize(self) -> None:
"""Initialize the event bus internal state"""
self._signal = Signal("crewai_event_bus")
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
self._handlers: dict[type[BaseEvent], list[Callable[[Any, Any], None]]] = {}
def on(
self, event_type: Type[EventT]
self, event_type: type[EventT]
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
"""
Decorator to register an event handler for a specific event type.
@@ -54,9 +56,7 @@ class CrewAIEventsBus:
) -> Callable[[Any, EventT], None]:
if event_type not in self._handlers:
self._handlers[event_type] = []
self._handlers[event_type].append(
cast(Callable[[Any, EventT], None], handler)
)
self._handlers[event_type].append(cast(Callable[[Any, Any], None], handler))
return handler
return decorator
@@ -82,17 +82,15 @@ class CrewAIEventsBus:
self._signal.send(source, event=event)
def register_handler(
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
self, event_type: type[BaseEvent], handler: Callable[[Any, Any], None]
) -> None:
"""Register an event handler for a specific event type"""
if event_type not in self._handlers:
self._handlers[event_type] = []
self._handlers[event_type].append(
cast(Callable[[Any, EventTypes], None], handler)
)
self._handlers[event_type].append(handler)
@contextmanager
def scoped_handlers(self):
def scoped_handlers(self) -> Iterator[None]:
"""
Context manager for temporary event handling scope.
Useful for testing or temporary event handling.

View File

@@ -1,15 +1,32 @@
from __future__ import annotations
from io import StringIO
from typing import Any, Dict
from typing import Any
from pydantic import Field, PrivateAttr
from crewai.llm import LLM
from crewai.task import Task
from crewai.telemetry.telemetry import Telemetry
from crewai.utilities import Logger
from crewai.utilities.constants import EMITTER_COLOR
from typing_extensions import Self
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestResultEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
@@ -25,34 +42,21 @@ from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
)
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.types.logging_events import (
AgentLogsStartedEvent,
AgentLogsExecutionEvent,
AgentLogsStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestResultEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.llm import LLM
from crewai.task import Task
from crewai.telemetry.telemetry import Telemetry
from crewai.utilities import Logger
from crewai.utilities.constants import EMITTER_COLOR
from .listeners.memory_listener import MemoryListener
from .types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
@@ -61,38 +65,37 @@ from .types.flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from .types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
from .types.tool_usage_events import (
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from .types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
)
from .listeners.memory_listener import MemoryListener
class EventListener(BaseEventListener):
_instance = None
_initialized: bool = False
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
execution_spans: dict[Task, Any] = Field(default_factory=dict)
next_chunk = 0
text_stream = StringIO()
knowledge_retrieval_in_progress = False
knowledge_query_in_progress = False
def __new__(cls):
def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
def __init__(self) -> None:
if not hasattr(self, "_initialized") or not self._initialized:
super().__init__()
self._telemetry = Telemetry()
@@ -105,14 +108,14 @@ class EventListener(BaseEventListener):
# ----------- CREW EVENTS -----------
def setup_listeners(self, crewai_event_bus):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
@crewai_event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source, event: CrewKickoffStartedEvent):
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
self._telemetry.crew_execution_span(source, event.inputs)
@crewai_event_bus.on(CrewKickoffCompletedEvent)
def on_crew_completed(source, event: CrewKickoffCompletedEvent):
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
# Handle telemetry
final_string_output = event.output.raw
self._telemetry.end_crew(source, final_string_output)
@@ -126,7 +129,7 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source, event: CrewKickoffFailedEvent):
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
self.formatter.update_crew_tree(
self.formatter.current_crew_tree,
event.crew_name or "Crew",
@@ -135,23 +138,25 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(CrewTrainStartedEvent)
def on_crew_train_started(source, event: CrewTrainStartedEvent):
def on_crew_train_started(source: Any, event: CrewTrainStartedEvent) -> None:
self.formatter.handle_crew_train_started(
event.crew_name or "Crew", str(event.timestamp)
)
@crewai_event_bus.on(CrewTrainCompletedEvent)
def on_crew_train_completed(source, event: CrewTrainCompletedEvent):
def on_crew_train_completed(
source: Any, event: CrewTrainCompletedEvent
) -> None:
self.formatter.handle_crew_train_completed(
event.crew_name or "Crew", str(event.timestamp)
)
@crewai_event_bus.on(CrewTrainFailedEvent)
def on_crew_train_failed(source, event: CrewTrainFailedEvent):
def on_crew_train_failed(source: Any, event: CrewTrainFailedEvent) -> None:
self.formatter.handle_crew_train_failed(event.crew_name or "Crew")
@crewai_event_bus.on(CrewTestResultEvent)
def on_crew_test_result(source, event: CrewTestResultEvent):
def on_crew_test_result(source: Any, event: CrewTestResultEvent) -> None:
self._telemetry.individual_test_result_span(
source.crew,
event.quality,
@@ -162,7 +167,7 @@ class EventListener(BaseEventListener):
# ----------- TASK EVENTS -----------
@crewai_event_bus.on(TaskStartedEvent)
def on_task_started(source, event: TaskStartedEvent):
def on_task_started(source: Any, event: TaskStartedEvent) -> None:
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
self.execution_spans[source] = span
# Pass both task ID and task name (if set)
@@ -172,7 +177,7 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(TaskCompletedEvent)
def on_task_completed(source, event: TaskCompletedEvent):
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
# Handle telemetry
span = self.execution_spans.get(source)
if span:
@@ -190,7 +195,7 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(TaskFailedEvent)
def on_task_failed(source, event: TaskFailedEvent):
def on_task_failed(source: Any, event: TaskFailedEvent) -> None:
span = self.execution_spans.get(source)
if span:
if source.agent and source.agent.crew:
@@ -210,7 +215,9 @@ class EventListener(BaseEventListener):
# ----------- AGENT EVENTS -----------
@crewai_event_bus.on(AgentExecutionStartedEvent)
def on_agent_execution_started(source, event: AgentExecutionStartedEvent):
def on_agent_execution_started(
source: Any, event: AgentExecutionStartedEvent
) -> None:
self.formatter.create_agent_branch(
self.formatter.current_task_branch,
event.agent.role,
@@ -218,7 +225,9 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(AgentExecutionCompletedEvent)
def on_agent_execution_completed(source, event: AgentExecutionCompletedEvent):
def on_agent_execution_completed(
source: Any, event: AgentExecutionCompletedEvent
) -> None:
self.formatter.update_agent_status(
self.formatter.current_agent_branch,
event.agent.role,
@@ -229,8 +238,8 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
def on_lite_agent_execution_started(
source, event: LiteAgentExecutionStartedEvent
):
source: Any, event: LiteAgentExecutionStartedEvent
) -> None:
"""Handle LiteAgent execution started event."""
self.formatter.handle_lite_agent_execution(
event.agent_info["role"], status="started", **event.agent_info
@@ -238,15 +247,17 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(LiteAgentExecutionCompletedEvent)
def on_lite_agent_execution_completed(
source, event: LiteAgentExecutionCompletedEvent
):
source: Any, event: LiteAgentExecutionCompletedEvent
) -> None:
"""Handle LiteAgent execution completed event."""
self.formatter.handle_lite_agent_execution(
event.agent_info["role"], status="completed", **event.agent_info
)
@crewai_event_bus.on(LiteAgentExecutionErrorEvent)
def on_lite_agent_execution_error(source, event: LiteAgentExecutionErrorEvent):
def on_lite_agent_execution_error(
source: Any, event: LiteAgentExecutionErrorEvent
) -> None:
"""Handle LiteAgent execution error event."""
self.formatter.handle_lite_agent_execution(
event.agent_info["role"],
@@ -258,25 +269,27 @@ class EventListener(BaseEventListener):
# ----------- FLOW EVENTS -----------
@crewai_event_bus.on(FlowCreatedEvent)
def on_flow_created(source, event: FlowCreatedEvent):
def on_flow_created(source: Any, event: FlowCreatedEvent) -> None:
self._telemetry.flow_creation_span(event.flow_name)
self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
@crewai_event_bus.on(FlowStartedEvent)
def on_flow_started(source, event: FlowStartedEvent):
def on_flow_started(source: Any, event: FlowStartedEvent) -> None:
self._telemetry.flow_execution_span(
event.flow_name, list(source._methods.keys())
)
self.formatter.start_flow(event.flow_name, str(source.flow_id))
@crewai_event_bus.on(FlowFinishedEvent)
def on_flow_finished(source, event: FlowFinishedEvent):
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
self.formatter.update_flow_status(
self.formatter.current_flow_tree, event.flow_name, source.flow_id
)
@crewai_event_bus.on(MethodExecutionStartedEvent)
def on_method_execution_started(source, event: MethodExecutionStartedEvent):
def on_method_execution_started(
source: Any, event: MethodExecutionStartedEvent
) -> None:
self.formatter.update_method_status(
self.formatter.current_method_branch,
self.formatter.current_flow_tree,
@@ -285,7 +298,9 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def on_method_execution_finished(source, event: MethodExecutionFinishedEvent):
def on_method_execution_finished(
source: Any, event: MethodExecutionFinishedEvent
) -> None:
self.formatter.update_method_status(
self.formatter.current_method_branch,
self.formatter.current_flow_tree,
@@ -294,7 +309,9 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(MethodExecutionFailedEvent)
def on_method_execution_failed(source, event: MethodExecutionFailedEvent):
def on_method_execution_failed(
source: Any, event: MethodExecutionFailedEvent
) -> None:
self.formatter.update_method_status(
self.formatter.current_method_branch,
self.formatter.current_flow_tree,
@@ -305,7 +322,7 @@ class EventListener(BaseEventListener):
# ----------- TOOL USAGE EVENTS -----------
@crewai_event_bus.on(ToolUsageStartedEvent)
def on_tool_usage_started(source, event: ToolUsageStartedEvent):
def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None:
if isinstance(source, LLM):
self.formatter.handle_llm_tool_usage_started(
event.tool_name,
@@ -319,7 +336,7 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(ToolUsageFinishedEvent)
def on_tool_usage_finished(source, event: ToolUsageFinishedEvent):
def on_tool_usage_finished(source: Any, event: ToolUsageFinishedEvent) -> None:
if isinstance(source, LLM):
self.formatter.handle_llm_tool_usage_finished(
event.tool_name,
@@ -332,7 +349,7 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(ToolUsageErrorEvent)
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
if isinstance(source, LLM):
self.formatter.handle_llm_tool_usage_error(
event.tool_name,
@@ -349,7 +366,7 @@ class EventListener(BaseEventListener):
# ----------- LLM EVENTS -----------
@crewai_event_bus.on(LLMCallStartedEvent)
def on_llm_call_started(source, event: LLMCallStartedEvent):
def on_llm_call_started(source: Any, event: LLMCallStartedEvent) -> None:
# Capture the returned tool branch and update the current_tool_branch reference
thinking_branch = self.formatter.handle_llm_call_started(
self.formatter.current_agent_branch,
@@ -360,7 +377,7 @@ class EventListener(BaseEventListener):
self.formatter.current_tool_branch = thinking_branch
@crewai_event_bus.on(LLMCallCompletedEvent)
def on_llm_call_completed(source, event: LLMCallCompletedEvent):
def on_llm_call_completed(source: Any, event: LLMCallCompletedEvent) -> None:
self.formatter.handle_llm_call_completed(
self.formatter.current_tool_branch,
self.formatter.current_agent_branch,
@@ -368,7 +385,7 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(source, event: LLMCallFailedEvent):
def on_llm_call_failed(source: Any, event: LLMCallFailedEvent) -> None:
self.formatter.handle_llm_call_failed(
self.formatter.current_tool_branch,
event.error,
@@ -376,7 +393,7 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(LLMStreamChunkEvent)
def on_llm_stream_chunk(source, event: LLMStreamChunkEvent):
def on_llm_stream_chunk(source: Any, event: LLMStreamChunkEvent) -> None:
self.text_stream.write(event.chunk)
self.text_stream.seek(self.next_chunk)
@@ -389,7 +406,9 @@ class EventListener(BaseEventListener):
# ----------- LLM GUARDRAIL EVENTS -----------
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def on_llm_guardrail_started(source, event: LLMGuardrailStartedEvent):
def on_llm_guardrail_started(
source: Any, event: LLMGuardrailStartedEvent
) -> None:
guardrail_str = str(event.guardrail)
guardrail_name = (
guardrail_str[:50] + "..." if len(guardrail_str) > 50 else guardrail_str
@@ -398,13 +417,15 @@ class EventListener(BaseEventListener):
self.formatter.handle_guardrail_started(guardrail_name, event.retry_count)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def on_llm_guardrail_completed(source, event: LLMGuardrailCompletedEvent):
def on_llm_guardrail_completed(
source: Any, event: LLMGuardrailCompletedEvent
) -> None:
self.formatter.handle_guardrail_completed(
event.success, event.error, event.retry_count
)
@crewai_event_bus.on(CrewTestStartedEvent)
def on_crew_test_started(source, event: CrewTestStartedEvent):
def on_crew_test_started(source: Any, event: CrewTestStartedEvent) -> None:
cloned_crew = source.copy()
self._telemetry.test_execution_span(
cloned_crew,
@@ -418,20 +439,20 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(CrewTestCompletedEvent)
def on_crew_test_completed(source, event: CrewTestCompletedEvent):
def on_crew_test_completed(source: Any, event: CrewTestCompletedEvent) -> None:
self.formatter.handle_crew_test_completed(
self.formatter.current_flow_tree,
event.crew_name or "Crew",
)
@crewai_event_bus.on(CrewTestFailedEvent)
def on_crew_test_failed(source, event: CrewTestFailedEvent):
def on_crew_test_failed(source: Any, event: CrewTestFailedEvent) -> None:
self.formatter.handle_crew_test_failed(event.crew_name or "Crew")
@crewai_event_bus.on(KnowledgeRetrievalStartedEvent)
def on_knowledge_retrieval_started(
source, event: KnowledgeRetrievalStartedEvent
):
source: Any, event: KnowledgeRetrievalStartedEvent
) -> None:
if self.knowledge_retrieval_in_progress:
return
@@ -444,8 +465,8 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(KnowledgeRetrievalCompletedEvent)
def on_knowledge_retrieval_completed(
source, event: KnowledgeRetrievalCompletedEvent
):
source: Any, event: KnowledgeRetrievalCompletedEvent
) -> None:
if not self.knowledge_retrieval_in_progress:
return
@@ -457,11 +478,15 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(KnowledgeQueryStartedEvent)
def on_knowledge_query_started(source, event: KnowledgeQueryStartedEvent):
def on_knowledge_query_started(
source: Any, event: KnowledgeQueryStartedEvent
) -> None:
pass
@crewai_event_bus.on(KnowledgeQueryFailedEvent)
def on_knowledge_query_failed(source, event: KnowledgeQueryFailedEvent):
def on_knowledge_query_failed(
source: Any, event: KnowledgeQueryFailedEvent
) -> None:
self.formatter.handle_knowledge_query_failed(
self.formatter.current_agent_branch,
event.error,
@@ -469,13 +494,15 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(KnowledgeQueryCompletedEvent)
def on_knowledge_query_completed(source, event: KnowledgeQueryCompletedEvent):
def on_knowledge_query_completed(
source: Any, event: KnowledgeQueryCompletedEvent
) -> None:
pass
@crewai_event_bus.on(KnowledgeSearchQueryFailedEvent)
def on_knowledge_search_query_failed(
source, event: KnowledgeSearchQueryFailedEvent
):
source: Any, event: KnowledgeSearchQueryFailedEvent
) -> None:
self.formatter.handle_knowledge_search_query_failed(
self.formatter.current_agent_branch,
event.error,
@@ -485,7 +512,9 @@ class EventListener(BaseEventListener):
# ----------- REASONING EVENTS -----------
@crewai_event_bus.on(AgentReasoningStartedEvent)
def on_agent_reasoning_started(source, event: AgentReasoningStartedEvent):
def on_agent_reasoning_started(
source: Any, event: AgentReasoningStartedEvent
) -> None:
self.formatter.handle_reasoning_started(
self.formatter.current_agent_branch,
event.attempt,
@@ -493,7 +522,9 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(AgentReasoningCompletedEvent)
def on_agent_reasoning_completed(source, event: AgentReasoningCompletedEvent):
def on_agent_reasoning_completed(
source: Any, event: AgentReasoningCompletedEvent
) -> None:
self.formatter.handle_reasoning_completed(
event.plan,
event.ready,
@@ -501,7 +532,9 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(AgentReasoningFailedEvent)
def on_agent_reasoning_failed(source, event: AgentReasoningFailedEvent):
def on_agent_reasoning_failed(
source: Any, event: AgentReasoningFailedEvent
) -> None:
self.formatter.handle_reasoning_failed(
event.error,
self.formatter.current_crew_tree,
@@ -510,7 +543,7 @@ class EventListener(BaseEventListener):
# ----------- AGENT LOGGING EVENTS -----------
@crewai_event_bus.on(AgentLogsStartedEvent)
def on_agent_logs_started(source, event: AgentLogsStartedEvent):
def on_agent_logs_started(source: Any, event: AgentLogsStartedEvent) -> None:
self.formatter.handle_agent_logs_started(
event.agent_role,
event.task_description,
@@ -518,7 +551,9 @@ class EventListener(BaseEventListener):
)
@crewai_event_bus.on(AgentLogsExecutionEvent)
def on_agent_logs_execution(source, event: AgentLogsExecutionEvent):
def on_agent_logs_execution(
source: Any, event: AgentLogsExecutionEvent
) -> None:
self.formatter.handle_agent_logs_execution(
event.agent_role,
event.formatted_answer,

View File

@@ -1,25 +1,30 @@
from typing import Any
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
MemoryQueryFailedEvent,
MemoryQueryCompletedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
class MemoryListener(BaseEventListener):
def __init__(self, formatter):
def __init__(self, formatter: Any) -> None:
super().__init__()
self.formatter = formatter
self.memory_retrieval_in_progress = False
self.memory_save_in_progress = False
def setup_listeners(self, crewai_event_bus):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
def on_memory_retrieval_started(source, event: MemoryRetrievalStartedEvent):
def on_memory_retrieval_started(
source: Any, event: MemoryRetrievalStartedEvent
) -> None:
if self.memory_retrieval_in_progress:
return
@@ -31,7 +36,9 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
def on_memory_retrieval_completed(source, event: MemoryRetrievalCompletedEvent):
def on_memory_retrieval_completed(
source: Any, event: MemoryRetrievalCompletedEvent
) -> None:
if not self.memory_retrieval_in_progress:
return
@@ -44,7 +51,9 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_memory_query_completed(source, event: MemoryQueryCompletedEvent):
def on_memory_query_completed(
source: Any, event: MemoryQueryCompletedEvent
) -> None:
if not self.memory_retrieval_in_progress:
return
@@ -56,7 +65,7 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemoryQueryFailedEvent)
def on_memory_query_failed(source, event: MemoryQueryFailedEvent):
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
if not self.memory_retrieval_in_progress:
return
@@ -68,7 +77,7 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_memory_save_started(source, event: MemorySaveStartedEvent):
def on_memory_save_started(source: Any, event: MemorySaveStartedEvent) -> None:
if self.memory_save_in_progress:
return
@@ -80,7 +89,9 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_memory_save_completed(source, event: MemorySaveCompletedEvent):
def on_memory_save_completed(
source: Any, event: MemorySaveCompletedEvent
) -> None:
if not self.memory_save_in_progress:
return
@@ -94,7 +105,7 @@ class MemoryListener(BaseEventListener):
)
@crewai_event_bus.on(MemorySaveFailedEvent)
def on_memory_save_failed(source, event: MemorySaveFailedEvent):
def on_memory_save_failed(source: Any, event: MemorySaveFailedEvent) -> None:
if not self.memory_save_in_progress:
return

View File

@@ -1,28 +1,59 @@
import os
import uuid
from typing import Any, Optional
from typing import Dict, Any, Optional
from typing_extensions import Self
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
AgentExecutionErrorEvent,
)
from crewai.events.listeners.tracing.types import TraceEvent
from crewai.events.types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFailedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.events.types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from crewai.events.types.task_events import (
TaskCompletedEvent,
TaskFailedEvent,
@@ -33,43 +64,9 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
)
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowStartedEvent,
FlowFinishedEvent,
MethodExecutionStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionFailedEvent,
FlowPlotEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
)
from crewai.utilities.serialization import to_serializable
from .trace_batch_manager import TraceBatchManager
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
from crewai.cli.authentication.token import AuthError, get_auth_token
from crewai.cli.version import get_crewai_version
class TraceCollectionListener(BaseEventListener):
"""
Trace collection listener that orchestrates trace collection
@@ -88,7 +85,7 @@ class TraceCollectionListener(BaseEventListener):
_initialized = False
_listeners_setup = False
def __new__(cls, batch_manager=None):
def __new__(cls, batch_manager: Optional[Any] = None) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@@ -101,10 +98,11 @@ class TraceCollectionListener(BaseEventListener):
return
super().__init__()
self.batch_manager = batch_manager or TraceBatchManager()
self.batch_manager = batch_manager or TraceBatchManager() # type: ignore
self._initialized = True
def _check_authenticated(self) -> bool:
@staticmethod
def _check_authenticated() -> bool:
"""Check if tracing should be enabled"""
try:
res = bool(get_auth_token())
@@ -112,7 +110,8 @@ class TraceCollectionListener(BaseEventListener):
except AuthError:
return False
def _get_user_context(self) -> Dict[str, str]:
@staticmethod
def _get_user_context() -> dict[str, str]:
"""Extract user context for tracing"""
return {
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
@@ -121,7 +120,7 @@ class TraceCollectionListener(BaseEventListener):
"trace_id": str(uuid.uuid4()),
}
def setup_listeners(self, crewai_event_bus):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
"""Setup event listeners - delegates to specific handlers"""
if self._listeners_setup:
@@ -133,169 +132,169 @@ class TraceCollectionListener(BaseEventListener):
self._listeners_setup = True
def _register_flow_event_handlers(self, event_bus):
def _register_flow_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for flow events"""
@event_bus.on(FlowCreatedEvent)
def on_flow_created(source, event):
def on_flow_created(source: Any, event: Any) -> None:
pass
@event_bus.on(FlowStartedEvent)
def on_flow_started(source, event):
def on_flow_started(source: Any, event: Any) -> None:
if not self.batch_manager.is_batch_initialized():
self._initialize_flow_batch(source, event)
self._handle_trace_event("flow_started", source, event)
@event_bus.on(MethodExecutionStartedEvent)
def on_method_started(source, event):
def on_method_started(source: Any, event: Any) -> None:
self._handle_trace_event("method_execution_started", source, event)
@event_bus.on(MethodExecutionFinishedEvent)
def on_method_finished(source, event):
def on_method_finished(source: Any, event: Any) -> None:
self._handle_trace_event("method_execution_finished", source, event)
@event_bus.on(MethodExecutionFailedEvent)
def on_method_failed(source, event):
def on_method_failed(source: Any, event: Any) -> None:
self._handle_trace_event("method_execution_failed", source, event)
@event_bus.on(FlowFinishedEvent)
def on_flow_finished(source, event):
def on_flow_finished(source: Any, event: Any) -> None:
self._handle_trace_event("flow_finished", source, event)
if self.batch_manager.batch_owner_type == "flow":
self.batch_manager.finalize_batch()
@event_bus.on(FlowPlotEvent)
def on_flow_plot(source, event):
def on_flow_plot(source: Any, event: Any) -> None:
self._handle_action_event("flow_plot", source, event)
def _register_context_event_handlers(self, event_bus):
def _register_context_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for context events (start/end)"""
@event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source, event):
def on_crew_started(source: Any, event: Any) -> None:
if not self.batch_manager.is_batch_initialized():
self._initialize_crew_batch(source, event)
self._handle_trace_event("crew_kickoff_started", source, event)
@event_bus.on(CrewKickoffCompletedEvent)
def on_crew_completed(source, event):
def on_crew_completed(source: Any, event: Any) -> None:
self._handle_trace_event("crew_kickoff_completed", source, event)
if self.batch_manager.batch_owner_type == "crew":
self.batch_manager.finalize_batch()
@event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source, event):
def on_crew_failed(source: Any, event: Any) -> None:
self._handle_trace_event("crew_kickoff_failed", source, event)
self.batch_manager.finalize_batch()
@event_bus.on(TaskStartedEvent)
def on_task_started(source, event):
def on_task_started(source: Any, event: Any) -> None:
self._handle_trace_event("task_started", source, event)
@event_bus.on(TaskCompletedEvent)
def on_task_completed(source, event):
def on_task_completed(source: Any, event: Any) -> None:
self._handle_trace_event("task_completed", source, event)
@event_bus.on(TaskFailedEvent)
def on_task_failed(source, event):
def on_task_failed(source: Any, event: Any) -> None:
self._handle_trace_event("task_failed", source, event)
@event_bus.on(AgentExecutionStartedEvent)
def on_agent_started(source, event):
def on_agent_started(source: Any, event: Any) -> None:
self._handle_trace_event("agent_execution_started", source, event)
@event_bus.on(AgentExecutionCompletedEvent)
def on_agent_completed(source, event):
def on_agent_completed(source: Any, event: Any) -> None:
self._handle_trace_event("agent_execution_completed", source, event)
@event_bus.on(LiteAgentExecutionStartedEvent)
def on_lite_agent_started(source, event):
def on_lite_agent_started(source: Any, event: Any) -> None:
self._handle_trace_event("lite_agent_execution_started", source, event)
@event_bus.on(LiteAgentExecutionCompletedEvent)
def on_lite_agent_completed(source, event):
def on_lite_agent_completed(source: Any, event: Any) -> None:
self._handle_trace_event("lite_agent_execution_completed", source, event)
@event_bus.on(LiteAgentExecutionErrorEvent)
def on_lite_agent_error(source, event):
def on_lite_agent_error(source: Any, event: Any) -> None:
self._handle_trace_event("lite_agent_execution_error", source, event)
@event_bus.on(AgentExecutionErrorEvent)
def on_agent_error(source, event):
def on_agent_error(source: Any, event: Any) -> None:
self._handle_trace_event("agent_execution_error", source, event)
@event_bus.on(LLMGuardrailStartedEvent)
def on_guardrail_started(source, event):
def on_guardrail_started(source: Any, event: Any) -> None:
self._handle_trace_event("llm_guardrail_started", source, event)
@event_bus.on(LLMGuardrailCompletedEvent)
def on_guardrail_completed(source, event):
def on_guardrail_completed(source: Any, event: Any) -> None:
self._handle_trace_event("llm_guardrail_completed", source, event)
def _register_action_event_handlers(self, event_bus):
def _register_action_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
"""Register handlers for action events (LLM calls, tool usage)"""
@event_bus.on(LLMCallStartedEvent)
def on_llm_call_started(source, event):
def on_llm_call_started(source: Any, event: Any) -> None:
self._handle_action_event("llm_call_started", source, event)
@event_bus.on(LLMCallCompletedEvent)
def on_llm_call_completed(source, event):
def on_llm_call_completed(source: Any, event: Any) -> None:
self._handle_action_event("llm_call_completed", source, event)
@event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(source, event):
def on_llm_call_failed(source: Any, event: Any) -> None:
self._handle_action_event("llm_call_failed", source, event)
@event_bus.on(ToolUsageStartedEvent)
def on_tool_started(source, event):
def on_tool_started(source: Any, event: Any) -> None:
self._handle_action_event("tool_usage_started", source, event)
@event_bus.on(ToolUsageFinishedEvent)
def on_tool_finished(source, event):
def on_tool_finished(source: Any, event: Any) -> None:
self._handle_action_event("tool_usage_finished", source, event)
@event_bus.on(ToolUsageErrorEvent)
def on_tool_error(source, event):
def on_tool_error(source: Any, event: Any) -> None:
self._handle_action_event("tool_usage_error", source, event)
@event_bus.on(MemoryQueryStartedEvent)
def on_memory_query_started(source, event):
def on_memory_query_started(source: Any, event: Any) -> None:
self._handle_action_event("memory_query_started", source, event)
@event_bus.on(MemoryQueryCompletedEvent)
def on_memory_query_completed(source, event):
def on_memory_query_completed(source: Any, event: Any) -> None:
self._handle_action_event("memory_query_completed", source, event)
@event_bus.on(MemoryQueryFailedEvent)
def on_memory_query_failed(source, event):
def on_memory_query_failed(source: Any, event: Any) -> None:
self._handle_action_event("memory_query_failed", source, event)
@event_bus.on(MemorySaveStartedEvent)
def on_memory_save_started(source, event):
def on_memory_save_started(source: Any, event: Any) -> None:
self._handle_action_event("memory_save_started", source, event)
@event_bus.on(MemorySaveCompletedEvent)
def on_memory_save_completed(source, event):
def on_memory_save_completed(source: Any, event: Any) -> None:
self._handle_action_event("memory_save_completed", source, event)
@event_bus.on(MemorySaveFailedEvent)
def on_memory_save_failed(source, event):
def on_memory_save_failed(source: Any, event: Any) -> None:
self._handle_action_event("memory_save_failed", source, event)
@event_bus.on(AgentReasoningStartedEvent)
def on_agent_reasoning_started(source, event):
def on_agent_reasoning_started(source: Any, event: Any) -> None:
self._handle_action_event("agent_reasoning_started", source, event)
@event_bus.on(AgentReasoningCompletedEvent)
def on_agent_reasoning_completed(source, event):
def on_agent_reasoning_completed(source: Any, event: Any) -> None:
self._handle_action_event("agent_reasoning_completed", source, event)
@event_bus.on(AgentReasoningFailedEvent)
def on_agent_reasoning_failed(source, event):
def on_agent_reasoning_failed(source: Any, event: Any) -> None:
self._handle_action_event("agent_reasoning_failed", source, event)
def _initialize_crew_batch(self, source: Any, event: Any):
def _initialize_crew_batch(self, source: Any, event: Any) -> None:
"""Initialize trace batch"""
user_context = self._get_user_context()
execution_metadata = {
@@ -309,7 +308,7 @@ class TraceCollectionListener(BaseEventListener):
self._initialize_batch(user_context, execution_metadata)
def _initialize_flow_batch(self, source: Any, event: Any):
def _initialize_flow_batch(self, source: Any, event: Any) -> None:
"""Initialize trace batch for Flow execution"""
user_context = self._get_user_context()
execution_metadata = {
@@ -325,26 +324,24 @@ class TraceCollectionListener(BaseEventListener):
self._initialize_batch(user_context, execution_metadata)
def _initialize_batch(
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
):
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
) -> None:
"""Initialize trace batch if ephemeral"""
if not self._check_authenticated():
self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=True
)
else:
self.batch_manager.initialize_batch(
user_context, execution_metadata, use_ephemeral=False
)
self.batch_manager.initialize_batch(user_context, execution_metadata)
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None:
"""Generic handler for context end events"""
trace_event = self._create_trace_event(event_type, source, event)
self.batch_manager.add_event(trace_event)
def _handle_action_event(self, event_type: str, source: Any, event: Any):
def _handle_action_event(self, event_type: str, source: Any, event: Any) -> None:
"""Generic handler for action events (LLM calls, tool usage)"""
if not self.batch_manager.is_batch_initialized():
@@ -371,7 +368,7 @@ class TraceCollectionListener(BaseEventListener):
def _build_event_data(
self, event_type: str, event: Any, source: Any
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Build event data"""
if event_type not in self.complex_events:
return self._safe_serialize_to_dict(event)
@@ -426,11 +423,19 @@ class TraceCollectionListener(BaseEventListener):
"source": source,
}
# TODO: move to utils
@staticmethod
def _safe_serialize_to_dict(
self, obj, exclude: set[str] | None = None
) -> Dict[str, Any]:
"""Safely serialize an object to a dictionary for event data."""
obj: Any, exclude: set[str] | None = None
) -> dict[str, Any]:
"""Safely serialize an object to a dictionary for event data.
Args:
obj: The object to serialize.
exclude: Optional set of attribute names to exclude from serialization.
Notes:
- TODO: refactor to utilities function.
"""
try:
serialized = to_serializable(obj, exclude)
if isinstance(serialized, dict):
@@ -440,9 +445,20 @@ class TraceCollectionListener(BaseEventListener):
except Exception as e:
return {"serialization_error": str(e), "object_type": type(obj).__name__}
# TODO: move to utils
def _truncate_messages(self, messages, max_content_length=500, max_messages=5):
"""Truncate message content and limit number of messages"""
@staticmethod
def _truncate_messages(
messages: Any, max_content_length: int = 500, max_messages: int = 5
) -> Any:
"""Truncate message content and limit number of messages
Args:
messages: List of message dicts with 'content' keys.
max_content_length: Max length of each message content.
max_messages: Max number of messages to retain.
Notes:
- TODO: refactor to utilities function.
"""
if not messages or not isinstance(messages, list):
return messages

View File

@@ -1,7 +1,7 @@
from typing import Any, Optional
from crewai.tasks.task_output import TaskOutput
from crewai.events.base_events import BaseEvent
from crewai.tasks.task_output import TaskOutput
class TaskStartedEvent(BaseEvent):
@@ -11,14 +11,15 @@ class TaskStartedEvent(BaseEvent):
context: Optional[str]
task: Optional[Any] = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
self.task
and hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
@@ -31,14 +32,15 @@ class TaskCompletedEvent(BaseEvent):
type: str = "task_completed"
task: Optional[Any] = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
self.task
and hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
@@ -51,14 +53,15 @@ class TaskFailedEvent(BaseEvent):
type: str = "task_failed"
task: Optional[Any] = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
self.task
and hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
@@ -71,14 +74,15 @@ class TaskEvaluationEvent(BaseEvent):
evaluation_type: str
task: Optional[Any] = None
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if (
hasattr(self.task.fingerprint, "metadata")
self.task
and hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata

View File

@@ -1,14 +1,15 @@
import abc
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from pydantic import BaseModel, Field
from crewai.agent import Agent
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.llm import BaseLLM
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.llm_utils import create_default_llm, create_llm
class MetricCategory(enum.Enum):
GOAL_ALIGNMENT = "goal_alignment"
@@ -18,8 +19,8 @@ class MetricCategory(enum.Enum):
PARAMETER_EXTRACTION = "parameter_extraction"
TOOL_INVOCATION = "tool_invocation"
def title(self):
return self.value.replace('_', ' ').title()
def title(self) -> str:
return self.value.replace("_", " ").title()
class EvaluationScore(BaseModel):
@@ -27,15 +28,13 @@ class EvaluationScore(BaseModel):
default=5.0,
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
ge=0.0,
le=10.0
le=10.0,
)
feedback: str = Field(
default="",
description="Detailed feedback explaining the evaluation score"
default="", description="Detailed feedback explaining the evaluation score"
)
raw_response: str | None = Field(
default=None,
description="Raw response from the evaluator (e.g., LLM)"
default=None, description="Raw response from the evaluator (e.g., LLM)"
)
def __str__(self) -> str:
@@ -46,7 +45,9 @@ class EvaluationScore(BaseModel):
class BaseEvaluator(abc.ABC):
def __init__(self, llm: BaseLLM | None = None):
self.llm: BaseLLM | None = create_llm(llm)
self.llm: BaseLLM | None = (
create_llm(llm) if llm is not None else create_default_llm()
)
@property
@abc.abstractmethod
@@ -57,7 +58,7 @@ class BaseEvaluator(abc.ABC):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
execution_trace: dict[str, Any],
final_output: Any,
task: Task | None = None,
) -> EvaluationScore:
@@ -67,9 +68,8 @@ class BaseEvaluator(abc.ABC):
class AgentEvaluationResult(BaseModel):
agent_id: str = Field(description="ID of the evaluated agent")
task_id: str = Field(description="ID of the task that was executed")
metrics: Dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict,
description="Evaluation scores for each metric category"
metrics: dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict, description="Evaluation scores for each metric category"
)
@@ -81,33 +81,23 @@ class AggregationStrategy(Enum):
class AgentAggregatedEvaluationResult(BaseModel):
agent_id: str = Field(
default="",
description="ID of the agent"
)
agent_role: str = Field(
default="",
description="Role of the agent"
)
agent_id: str = Field(default="", description="ID of the agent")
agent_role: str = Field(default="", description="Role of the agent")
task_count: int = Field(
default=0,
description="Number of tasks included in this aggregation"
default=0, description="Number of tasks included in this aggregation"
)
aggregation_strategy: AggregationStrategy = Field(
default=AggregationStrategy.SIMPLE_AVERAGE,
description="Strategy used for aggregation"
description="Strategy used for aggregation",
)
metrics: Dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict,
description="Aggregated metrics across all tasks"
metrics: dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict, description="Aggregated metrics across all tasks"
)
task_results: List[str] = Field(
default_factory=list,
description="IDs of tasks included in this aggregation"
task_results: list[str] = Field(
default_factory=list, description="IDs of tasks included in this aggregation"
)
overall_score: Optional[float] = Field(
default=None,
description="Overall score for this agent"
default=None, description="Overall score for this agent"
)
def __str__(self) -> str:
@@ -119,7 +109,7 @@ class AgentAggregatedEvaluationResult(BaseModel):
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
if score.feedback:
detailed_feedback = "\n ".join(score.feedback.split('\n'))
detailed_feedback = "\n ".join(score.feedback.split("\n"))
result += f" {detailed_feedback}\n"
return result
return result

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field
@@ -18,20 +18,20 @@ class Knowledge(BaseModel):
embedder: Optional[Dict[str, Any]] = None
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
embedder: Optional[dict[str, Any]] = None
collection_name: Optional[str] = None
def __init__(
self,
collection_name: str,
sources: List[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None,
sources: list[BaseKnowledgeSource],
embedder: Optional[dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None,
**data,
):
**data: Any,
) -> None:
super().__init__(**data)
if storage:
self.storage = storage
@@ -43,8 +43,8 @@ class Knowledge(BaseModel):
self.storage.initialize_knowledge_storage()
def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]:
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[dict[str, Any]]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
@@ -62,7 +62,7 @@ class Knowledge(BaseModel):
)
return results
def add_sources(self):
def add_sources(self) -> None:
try:
for source in self.sources:
source.storage = self.storage

View File

@@ -2,23 +2,56 @@ import hashlib
import logging
import os
import shutil
from typing import Any, Dict, List, Optional, Union
import warnings
from collections.abc import Mapping
from typing import Any, Optional, Union
import chromadb
import chromadb.errors
from chromadb import EmbeddingFunction
from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany
from chromadb.config import Settings
import warnings
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.utilities.chromadb import create_persistent_client, sanitize_collection_name
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.logger_utils import suppress_logging
from crewai.utilities.paths import db_storage_path
def _extract_chromadb_response_item(
response_data: Any,
index: int,
expected_type: type[Any] | tuple[type[Any], ...],
) -> Any | None:
"""Extract an item from ChromaDB response data at the given index.
Args:
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
index: The index of the item to extract.
expected_type: The expected type(s) of the item.
Returns:
The extracted item if it exists and matches the expected type, otherwise None.
"""
if response_data is None or not response_data:
return None
# ChromaDB sometimes returns nested lists, handle both cases
data_list = (
response_data[0]
if response_data and isinstance(response_data[0], list)
else response_data
)
if index < len(data_list):
item = data_list[index]
if isinstance(item, expected_type):
return item
return None
class KnowledgeStorage(BaseKnowledgeStorage):
@@ -30,10 +63,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None
embedder: Optional[EmbeddingFunction[Any]] = None
def __init__(
self,
embedder: Optional[Dict[str, Any]] = None,
embedder: Optional[dict[str, Any]] = None,
collection_name: Optional[str] = None,
):
self.collection_name = collection_name
@@ -41,11 +75,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def search(
self,
query: List[str],
query: list[str],
limit: int = 3,
filter: Optional[dict] = None,
filter: Optional[dict[str, Any]] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
@@ -56,20 +90,51 @@ class KnowledgeStorage(BaseKnowledgeStorage):
where=filter,
)
results = []
for i in range(len(fetched["ids"][0])): # type: ignore
result = {
"id": fetched["ids"][0][i], # type: ignore
"metadata": fetched["metadatas"][0][i], # type: ignore
"context": fetched["documents"][0][i], # type: ignore
"score": fetched["distances"][0][i], # type: ignore
}
if result["score"] >= score_threshold:
results.append(result)
if (
fetched
and "ids" in fetched
and fetched["ids"]
and len(fetched["ids"]) > 0
):
ids_list = (
fetched["ids"][0]
if isinstance(fetched["ids"][0], list)
else fetched["ids"]
)
for i in range(len(ids_list)):
# Handle metadatas
meta_item = _extract_chromadb_response_item(
fetched.get("metadatas"), i, dict
)
metadata: dict[str, Any] = meta_item if meta_item else {}
# Handle documents
doc_item = _extract_chromadb_response_item(
fetched.get("documents"), i, str
)
context = doc_item if doc_item else ""
# Handle distances
dist_item = _extract_chromadb_response_item(
fetched.get("distances"), i, (int, float)
)
score = dist_item if dist_item is not None else 1.0
result = {
"id": ids_list[i],
"metadata": metadata,
"context": context,
"score": score,
}
# Check score threshold - distances are smaller when more similar
if isinstance(score, (int, float)) and score <= score_threshold:
results.append(result)
return results
else:
raise Exception("Collection not initialized")
def initialize_knowledge_storage(self):
def initialize_knowledge_storage(self) -> None:
# Suppress deprecation warnings from chromadb, which are not relevant to us
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
warnings.filterwarnings(
@@ -99,7 +164,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
except Exception:
raise Exception("Failed to create or get collection")
def reset(self):
def reset(self) -> None:
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
if not self.app:
self.app = create_persistent_client(
@@ -113,9 +178,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def save(
self,
documents: List[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
):
documents: list[str],
metadata: Optional[dict[str, Any] | list[dict[str, Any]]] = None,
) -> None:
if not self.collection:
raise Exception("Collection not initialized")
@@ -147,7 +212,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
# If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
None if all(m is None for m in filtered_metadata) else filtered_metadata
None if all(m is None for m in filtered_metadata) else filtered_metadata # type: ignore[assignment]
)
self.collection.upsert(
@@ -170,7 +235,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
def _create_default_embedding_function(self):
def _create_default_embedding_function(self) -> Any:
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
@@ -179,15 +244,18 @@ class KnowledgeStorage(BaseKnowledgeStorage):
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
def _set_embedder_config(self, embedder: Optional[dict[str, Any]] = None) -> None:
"""Set the embedding configuration for the knowledge storage.
Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
embedder (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function.
Notes:
- TODO: Improve typing for embedder configuration, remove type: ignore
"""
self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder)
EmbeddingConfigurator().configure_embedder(embedder) # type: ignore
if embedder
else self._create_default_embedding_function()
)

View File

@@ -1,50 +1,49 @@
import asyncio
import inspect
import uuid
from collections.abc import Callable
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
get_args,
get_origin,
)
try:
from typing import Self
except ImportError:
from typing_extensions import Self
from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
PrivateAttr,
model_validator,
field_validator,
model_validator,
)
from typing_extensions import Self
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache import CacheHandler
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.parser import (
AgentAction,
AgentFinish,
OutputParserException,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.flow.flow_trackable import FlowTrackable
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.tasks import TaskOutput
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities import I18N
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.agent_utils import (
enforce_rpm_limit,
format_message_for_llm,
@@ -62,15 +61,8 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
)
from crewai.utilities.converter import generate_model_description
from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.events.types.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.llm_utils import create_default_llm, create_llm
from crewai.utilities.printer import Printer
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.tool_utils import execute_tool_and_check_finality
@@ -86,11 +78,11 @@ class LiteAgentOutput(BaseModel):
description="Pydantic output of the agent", default=None
)
agent_role: str = Field(description="Role of the agent that produced this output")
usage_metrics: Optional[Dict[str, Any]] = Field(
usage_metrics: Optional[dict[str, Any]] = Field(
description="Token usage metrics for this execution", default=None
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert pydantic_output to a dictionary."""
if self.pydantic:
return self.pydantic.model_dump()
@@ -130,10 +122,10 @@ class LiteAgent(FlowTrackable, BaseModel):
role: str = Field(description="Role of the agent")
goal: str = Field(description="Goal of the agent")
backstory: str = Field(description="Backstory of the agent")
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
default=None, description="Language model that will run the agent"
)
tools: List[BaseTool] = Field(
tools: list[BaseTool] = Field(
default_factory=list, description="Tools at agent's disposal"
)
@@ -159,29 +151,27 @@ class LiteAgent(FlowTrackable, BaseModel):
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
# Output and Formatting Properties
response_format: Optional[Type[BaseModel]] = Field(
response_format: Optional[type[BaseModel]] = Field(
default=None, description="Pydantic model for structured output"
)
verbose: bool = Field(
default=False, description="Whether to print execution details"
)
callbacks: List[Callable] = Field(
callbacks: list[Callable[..., Any]] = Field(
default=[], description="Callbacks to be used for the agent"
)
# Guardrail Properties
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
guardrail: Optional[Callable[[LiteAgentOutput], tuple[bool, Any]] | str] = Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
# State and Results
tools_results: List[Dict[str, Any]] = Field(
tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
)
@@ -190,20 +180,25 @@ class LiteAgent(FlowTrackable, BaseModel):
default=None, description="Reference to the agent that created this LiteAgent"
)
# Private Attributes
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
_iterations: int = PrivateAttr(default=0)
_printer: Printer = PrivateAttr(default_factory=Printer)
_guardrail: Optional[Callable] = PrivateAttr(default=None)
_guardrail: Optional[Callable[[LiteAgentOutput | TaskOutput], tuple[bool, Any]]] = (
PrivateAttr(default=None)
)
_guardrail_retry_count: int = PrivateAttr(default=0)
@model_validator(mode="after")
def setup_llm(self):
def setup_llm(self) -> Self:
"""Set up the LLM and other components after initialization."""
self.llm = create_llm(self.llm)
if self.llm is None:
self.llm = create_default_llm()
else:
self.llm = create_llm(self.llm)
if not isinstance(self.llm, BaseLLM):
raise ValueError(
f"Expected LLM instance of type BaseLLM, got {type(self.llm).__name__}"
@@ -216,7 +211,7 @@ class LiteAgent(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def parse_tools(self):
def parse_tools(self) -> Self:
"""Parse the tools and convert them to CrewStructuredTool instances."""
self._parsed_tools = parse_tools(self.tools)
@@ -225,7 +220,10 @@ class LiteAgent(FlowTrackable, BaseModel):
@model_validator(mode="after")
def ensure_guardrail_is_callable(self) -> Self:
if callable(self.guardrail):
self._guardrail = self.guardrail
self._guardrail = cast(
Callable[[LiteAgentOutput | TaskOutput], tuple[bool, Any]],
self.guardrail,
)
elif isinstance(self.guardrail, str):
from crewai.tasks.llm_guardrail import LLMGuardrail
@@ -234,15 +232,18 @@ class LiteAgent(FlowTrackable, BaseModel):
f"Guardrail requires LLM instance of type BaseLLM, got {type(self.llm).__name__}"
)
self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm)
self._guardrail = cast(
Callable[[LiteAgentOutput | TaskOutput], tuple[bool, Any]],
LLMGuardrail(description=self.guardrail, llm=self.llm),
)
return self
@field_validator("guardrail", mode="before")
@classmethod
def validate_guardrail_function(
cls, v: Optional[Union[Callable, str]]
) -> Optional[Union[Callable, str]]:
cls, v: Optional[Callable[[Any], tuple[bool, Any]] | str]
) -> Optional[Callable[[Any], tuple[bool, Any]] | str]:
"""Validate that the guardrail function has the correct signature.
If v is a callable, validate that it has the correct signature.
@@ -267,7 +268,7 @@ class LiteAgent(FlowTrackable, BaseModel):
# Check return annotation if present
if sig.return_annotation is not sig.empty:
if sig.return_annotation == Tuple[bool, Any]:
if sig.return_annotation == tuple[bool, Any]:
return v
origin = get_origin(sig.return_annotation)
@@ -290,7 +291,7 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Return the original role for compatibility with tool interfaces."""
return self.role
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
"""
Execute the agent with the given messages.
@@ -338,7 +339,7 @@ class LiteAgent(FlowTrackable, BaseModel):
)
raise e
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
# Emit event for agent execution start
crewai_event_bus.emit(
self,
@@ -428,7 +429,7 @@ class LiteAgent(FlowTrackable, BaseModel):
return output
async def kickoff_async(
self, messages: Union[str, List[Dict[str, str]]]
self, messages: str | list[dict[str, str]]
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages.
@@ -475,8 +476,8 @@ class LiteAgent(FlowTrackable, BaseModel):
return base_prompt
def _format_messages(
self, messages: Union[str, List[Dict[str, str]]]
) -> List[Dict[str, str]]:
self, messages: str | list[dict[str, str]]
) -> list[dict[str, str]]:
"""Format messages for the LLM."""
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
@@ -582,7 +583,7 @@ class LiteAgent(FlowTrackable, BaseModel):
self._show_logs(formatted_answer)
return formatted_answer
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
"""Show logs for the agent's execution."""
crewai_event_bus.emit(
self,

View File

@@ -1,4 +1,6 @@
from typing import Optional, TYPE_CHECKING
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
from crewai.memory import (
EntityMemory,
@@ -19,8 +21,8 @@ class ContextualMemory:
ltm: LongTermMemory,
em: EntityMemory,
exm: ExternalMemory,
agent: Optional["Agent"] = None,
task: Optional["Task"] = None,
agent: Optional[Agent] = None,
task: Optional[Task] = None,
):
self.stm = stm
self.ltm = ltm
@@ -42,7 +44,7 @@ class ContextualMemory:
self.exm.agent = self.agent
self.exm.task = self.task
def build_context_for_task(self, task, context) -> str:
def build_context_for_task(self, task: Task, context: str) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
@@ -52,14 +54,14 @@ class ContextualMemory:
if query == "":
return ""
context = []
context.append(self._fetch_ltm_context(task.description))
context.append(self._fetch_stm_context(query))
context.append(self._fetch_entity_context(query))
context.append(self._fetch_external_context(query))
return "\n".join(filter(None, context))
context_parts = []
context_parts.append(self._fetch_ltm_context(task.description))
context_parts.append(self._fetch_stm_context(query))
context_parts.append(self._fetch_entity_context(query))
context_parts.append(self._fetch_external_context(query))
return "\n".join(filter(None, context_parts))
def _fetch_stm_context(self, query) -> str:
def _fetch_stm_context(self, query: str) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
@@ -74,7 +76,7 @@ class ContextualMemory:
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]:
def _fetch_ltm_context(self, task: str) -> Optional[str]:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
@@ -83,21 +85,23 @@ class ContextualMemory:
if self.ltm is None:
return ""
ltm_results = self.ltm.search(task, latest_n=2)
ltm_results = self.ltm.search(task, limit=2)
if not ltm_results:
return None
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
formatted_results_str = "\n".join(
[f"- {result}" for result in formatted_results]
)
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
return f"Historical Data:\n{formatted_results_str}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:
def _fetch_entity_context(self, query: str) -> str:
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
@@ -107,7 +111,7 @@ class ContextualMemory:
em_results = self.em.search(query)
formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
[f"- {result['context']}" for result in em_results]
)
return f"Entities:\n{formatted_results}" if em_results else ""

View File

@@ -1,20 +1,20 @@
from typing import Any
import time
from typing import Any
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class EntityMemory(Memory):
@@ -26,7 +26,13 @@ class EntityMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: Any = None,
) -> None:
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:
@@ -155,7 +161,7 @@ class EntityMemory(Memory):
query: str,
limit: int = 3,
score_threshold: float = 0.35,
):
) -> Any:
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(

View File

@@ -1,17 +1,17 @@
from typing import Any, Dict, List
import time
from typing import Any, Optional
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
@@ -24,72 +24,84 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""
def __init__(self, storage=None, path=None):
def __init__(self, storage: Any = None, path: Any = None) -> None:
if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
metadata = item.metadata
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
def save(
self,
value: Any,
metadata: Optional[dict[str, Any]] = None,
) -> None:
# Handle both LongTermMemoryItem and regular save calls
if isinstance(value, LongTermMemoryItem):
item = value
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
event=MemorySaveStartedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
error=str(e),
source_type="long_term_memory",
),
)
raise
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
start_time = time.time()
try:
metadata = item.metadata.copy()
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
self.storage.save(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=item.task,
metadata=metadata,
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
error=str(e),
source_type="long_term_memory",
),
)
raise
else:
# Regular save for compatibility with parent class
metadata = metadata or {}
self.storage.save(value, metadata)
def search(
self,
task: str,
latest_n: int = 3,
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> list[Any]:
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=task,
limit=latest_n,
query=query,
limit=limit,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
@@ -98,14 +110,16 @@ class LongTermMemory(Memory):
start_time = time.time()
try:
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
# The storage.load method uses different parameter names
# but we'll call it with the aligned names
results = self.storage.load(query, limit)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=task,
query=query,
results=results,
limit=latest_n,
limit=limit,
query_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
@@ -113,15 +127,17 @@ class LongTermMemory(Memory):
),
)
return results
return results if results is not None else []
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=task,
limit=latest_n,
query=query,
limit=limit,
error=str(e),
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise

View File

@@ -1,20 +1,20 @@
from typing import Any, Dict, Optional
import time
from typing import Any, Optional
from pydantic import PrivateAttr
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
)
class ShortTermMemory(Memory):
@@ -28,7 +28,13 @@ class ShortTermMemory(Memory):
_memory_provider: Optional[str] = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: Any = None,
) -> None:
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:
@@ -56,7 +62,7 @@ class ShortTermMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
) -> None:
crewai_event_bus.emit(
self,
@@ -114,7 +120,7 @@ class ShortTermMemory(Memory):
query: str,
limit: int = 3,
score_threshold: float = 0.35,
):
) -> Any:
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -131,7 +137,7 @@ class ShortTermMemory(Memory):
try:
results = self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
)
crewai_event_bus.emit(
self,

View File

@@ -1,10 +1,11 @@
import os
from typing import Any, Dict, List
from collections import defaultdict
from mem0 import Memory, MemoryClient
from crewai.utilities.chromadb import sanitize_collection_name
from typing import Any
from mem0 import Memory, MemoryClient # type: ignore[import-not-found]
from crewai.memory.storage.interface import Storage
from crewai.utilities.chromadb import sanitize_collection_name
MAX_AGENT_ID_LENGTH_MEM0 = 255
@@ -13,7 +14,10 @@ class Mem0Storage(Storage):
"""
Extends Storage to handle embedding and searching across entities using Mem0.
"""
def __init__(self, type, crew=None, config=None):
def __init__(
self, type: str, crew: Any = None, config: dict[str, Any] | None = None
) -> None:
super().__init__()
self._validate_type(type)
@@ -24,21 +28,21 @@ class Mem0Storage(Storage):
self._extract_config_values()
self._initialize_memory()
def _validate_type(self, type):
def _validate_type(self, type: str) -> None:
supported_types = {"short_term", "long_term", "entities", "external"}
if type not in supported_types:
raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}"
)
def _extract_config_values(self):
def _extract_config_values(self) -> None:
self.mem0_run_id = self.config.get("run_id")
self.includes = self.config.get("includes")
self.excludes = self.config.get("excludes")
self.custom_categories = self.config.get("custom_categories")
self.infer = self.config.get("infer", True)
def _initialize_memory(self):
def _initialize_memory(self) -> None:
api_key = self.config.get("api_key") or os.getenv("MEM0_API_KEY")
org_id = self.config.get("org_id")
project_id = self.config.get("project_id")
@@ -59,7 +63,7 @@ class Mem0Storage(Storage):
else Memory()
)
def _create_filter_for_search(self):
def _create_filter_for_search(self) -> dict[str, Any]:
"""
Returns:
dict: A filter dictionary containing AND conditions for querying data.
@@ -86,21 +90,21 @@ class Mem0Storage(Storage):
return filter
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
user_id = self.config.get("user_id", "")
assistant_message = [{"role" : "assistant","content" : value}]
assistant_message = [{"role": "assistant", "content": value}]
base_metadata = {
"short_term": "short_term",
"long_term": "long_term",
"entities": "entity",
"external": "external"
"external": "external",
}
# Shared base params
params: dict[str, Any] = {
"metadata": {"type": base_metadata[self.memory_type], **metadata},
"infer": self.infer
"infer": self.infer,
}
# MemoryClient-specific overrides
@@ -121,13 +125,15 @@ class Mem0Storage(Storage):
self.memory.add(assistant_message, **params)
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
) -> list[Any]:
params = {
"query": query,
"limit": limit,
"version": "v2",
"output_format": "v1.1"
}
"output_format": "v1.1",
}
if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id
@@ -148,10 +154,10 @@ class Mem0Storage(Storage):
# automatically when the crew is created.
params["filters"] = self._create_filter_for_search()
params['threshold'] = score_threshold
params["threshold"] = score_threshold
if isinstance(self.memory, Memory):
del params["metadata"], params["version"], params['output_format']
del params["metadata"], params["version"], params["output_format"]
if params.get("run_id"):
del params["run_id"]
@@ -160,10 +166,10 @@ class Mem0Storage(Storage):
# This makes it compatible for Contextual Memory to retrieve
for result in results["results"]:
result["context"] = result["memory"]
return [r for r in results["results"]]
def reset(self):
def reset(self) -> None:
if self.memory:
self.memory.reset()
@@ -180,4 +186,6 @@ class Mem0Storage(Storage):
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
return sanitize_collection_name(
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
)

View File

@@ -2,29 +2,72 @@ import logging
import os
import shutil
import uuid
import warnings
from typing import Any
from typing import Any, Dict, List, Optional
from chromadb import EmbeddingFunction
from chromadb.api import ClientAPI
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.chromadb import create_persistent_client
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
from crewai.utilities.logger_utils import suppress_logging
import warnings
from crewai.utilities.paths import db_storage_path
def _extract_chromadb_response_item(
response_data: Any,
index: int,
expected_type: type[Any] | tuple[type[Any], ...],
) -> Any | None:
"""Extract an item from ChromaDB response data at the given index.
Args:
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
index: The index of the item to extract.
expected_type: The expected type(s) of the item.
Returns:
The extracted item if it exists and matches the expected type, otherwise None.
"""
if response_data is None or not response_data:
return None
# ChromaDB sometimes returns nested lists, handle both cases
data_list = (
response_data[0]
if response_data and isinstance(response_data[0], list)
else response_data
)
if index < len(data_list):
item = data_list[index]
if isinstance(item, expected_type):
return item
return None
class RAGStorage(BaseRAGStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
search efficiency.
Notes:
- TODO: Add type hints to EmbeddingFunction in next typing PR.
"""
app: ClientAPI | None = None
embedder_config: EmbeddingFunction[Any] | None = None # type: ignore
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
):
self,
type: str,
allow_reset: bool = True,
embedder_config: Any = None,
crew: Any = None,
path: str | None = None,
) -> None:
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
@@ -33,16 +76,29 @@ class RAGStorage(BaseRAGStorage):
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type
self._original_embedder_config = (
embedder_config # Store for later use in _set_embedder_config
)
self.allow_reset = allow_reset
self.path = path
self._initialize_app()
def _set_embedder_config(self):
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _set_embedder_config(self) -> None:
"""Sets the embedder_config using EmbeddingConfigurator.
def _initialize_app(self):
Notes:
- TODO: remove the type: ignore on next typing pr.
"""
configurator = EmbeddingConfigurator() # type: ignore
# Pass the original embedder_config from __init__, not self.embedder_config
if hasattr(self, "_original_embedder_config"):
self.embedder_config = configurator.configure_embedder(
self._original_embedder_config
)
else:
self.embedder_config = configurator.configure_embedder()
def _initialize_app(self) -> None:
from chromadb.config import Settings
# Suppress deprecation warnings from chromadb, which are not relevant to us
@@ -71,7 +127,8 @@ class RAGStorage(BaseRAGStorage):
"""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _build_storage_file_name(self, type: str, file_name: str) -> str:
@staticmethod
def _build_storage_file_name(type: str, file_name: str) -> str:
"""
Ensures file name does not exceed max allowed by OS
"""
@@ -85,7 +142,7 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
try:
@@ -97,9 +154,9 @@ class RAGStorage(BaseRAGStorage):
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
if not hasattr(self, "app"):
self._initialize_app()
@@ -110,26 +167,68 @@ class RAGStorage(BaseRAGStorage):
response = self.collection.query(query_texts=query, n_results=limit)
results = []
for i in range(len(response["ids"][0])):
result = {
"id": response["ids"][0][i],
"metadata": response["metadatas"][0][i],
"context": response["documents"][0][i],
"score": response["distances"][0][i],
}
if result["score"] >= score_threshold:
results.append(result)
if (
response
and "ids" in response
and response["ids"]
and len(response["ids"]) > 0
):
ids_list = (
response["ids"][0]
if isinstance(response["ids"][0], list)
else response["ids"]
)
for i in range(len(ids_list)):
# Handle metadatas
meta_item = _extract_chromadb_response_item(
response.get("metadatas"), i, dict
)
metadata: dict[str, Any] = meta_item if meta_item else {}
# Handle documents
doc_item = _extract_chromadb_response_item(
response.get("documents"), i, str
)
context = doc_item if doc_item else ""
# Handle distances
dist_item = _extract_chromadb_response_item(
response.get("distances"), i, (int, float)
)
score = dist_item if dist_item is not None else 1.0
result = {
"id": ids_list[i],
"metadata": metadata,
"context": context,
"score": score,
}
# Check score threshold - distances are smaller when more similar
if isinstance(score, (int, float)) and score <= score_threshold:
results.append(result)
return results
except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}")
return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
def _generate_embedding(
self, text: str, metadata: dict[str, Any] | None = None
) -> Any:
"""Generates and stores the embedding for the given text and metadata.
Args:
text: The text to generate an embedding for.
metadata: Optional metadata associated with the text.
Notes:
- Need to constrain the typing in the base class, this result isn't used
"""
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
self.collection.add(
return self.collection.add(
documents=[text],
metadatas=[metadata or {}],
ids=[str(uuid.uuid4())],
@@ -151,7 +250,8 @@ class RAGStorage(BaseRAGStorage):
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_default_embedding_function(self):
@staticmethod
def _create_default_embedding_function() -> EmbeddingFunction[Any]:
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable
from collections.abc import Callable
from typing import Any, Optional
from pydantic import Field
@@ -13,7 +14,7 @@ class ConditionalTask(Task):
Note: This cannot be the only task you have in your crew and cannot be the first since its needs context from the previous task.
"""
condition: Callable[[TaskOutput], bool] = Field(
condition: Optional[Callable[[TaskOutput], bool]] = Field(
default=None,
description="Maximum number of retries for an agent to execute a task when an error occurs.",
)
@@ -21,8 +22,8 @@ class ConditionalTask(Task):
def __init__(
self,
condition: Callable[[Any], bool],
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.condition = condition
@@ -36,9 +37,11 @@ class ConditionalTask(Task):
Returns:
bool: True if the task should be executed, False otherwise.
"""
if self.condition is None:
return False
return self.condition(context)
def get_skipped_task_output(self):
def get_skipped_task_output(self) -> TaskOutput:
return TaskOutput(
description=self.description,
raw="",

View File

@@ -5,11 +5,12 @@ import json
import logging
import os
import platform
import threading
import warnings
from collections.abc import Callable
from contextlib import contextmanager
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Callable, Optional
import threading
from typing import TYPE_CHECKING, Any, Optional
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
@@ -32,7 +33,7 @@ logger = logging.getLogger(__name__)
@contextmanager
def suppress_warnings():
def suppress_warnings() -> Any:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
yield
@@ -44,7 +45,7 @@ if TYPE_CHECKING:
class SafeOTLPSpanExporter(OTLPSpanExporter):
def export(self, spans) -> SpanExportResult:
def export(self, spans: Any) -> SpanExportResult:
try:
return super().export(spans)
except Exception as e:
@@ -68,18 +69,18 @@ class Telemetry:
_instance = None
_lock = threading.Lock()
def __new__(cls):
def __new__(cls) -> Telemetry:
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(Telemetry, cls).__new__(cls)
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if hasattr(self, '_initialized') and self._initialized:
if hasattr(self, "_initialized") and self._initialized:
return
self.ready: bool = False
self.trace_set: bool = False
self._initialized: bool = True
@@ -124,7 +125,7 @@ class Telemetry:
"""Check if telemetry operations should be executed."""
return self.ready and not self._is_telemetry_disabled()
def set_tracer(self):
def set_tracer(self) -> None:
if self.ready and not self.trace_set:
try:
with suppress_warnings():
@@ -143,10 +144,10 @@ class Telemetry:
except Exception:
pass
def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None):
def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None) -> None:
"""Records the creation of a crew."""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Created")
self._add_attribute(
@@ -351,7 +352,7 @@ class Telemetry:
def task_started(self, crew: Crew, task: Task) -> Span | None:
"""Records task started in a crew."""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
created_span = tracer.start_span("Task Created")
@@ -438,7 +439,7 @@ class Telemetry:
self._safe_telemetry_operation(operation)
return None
def task_ended(self, span: Span, task: Task, crew: Crew):
def task_ended(self, span: Span, task: Task, crew: Crew) -> None:
"""Records the completion of a task execution in a crew.
Args:
@@ -450,7 +451,7 @@ class Telemetry:
If share_crew is enabled, this will also record the task output
"""
def operation():
def operation() -> Any:
# Ensure fingerprint data is present on completion span
if hasattr(task, "fingerprint") and task.fingerprint:
self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str)
@@ -467,7 +468,7 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int):
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int) -> None:
"""Records when a tool is used repeatedly, which might indicate an issue.
Args:
@@ -476,7 +477,7 @@ class Telemetry:
attempts (int): Number of attempts made with this tool
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Tool Repeated Usage")
self._add_attribute(
@@ -493,7 +494,9 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def tool_usage(self, llm: Any, tool_name: str, attempts: int, agent: Any = None):
def tool_usage(
self, llm: Any, tool_name: str, attempts: int, agent: Any = None
) -> None:
"""Records the usage of a tool by an agent.
Args:
@@ -503,7 +506,7 @@ class Telemetry:
agent (Any, optional): The agent using the tool
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Tool Usage")
self._add_attribute(
@@ -531,7 +534,7 @@ class Telemetry:
def tool_usage_error(
self, llm: Any, agent: Any = None, tool_name: Optional[str] = None
):
) -> None:
"""Records when a tool usage results in an error.
Args:
@@ -540,7 +543,7 @@ class Telemetry:
tool_name (str, optional): Name of the tool that caused the error
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Tool Usage Error")
self._add_attribute(
@@ -569,7 +572,7 @@ class Telemetry:
def individual_test_result_span(
self, crew: Crew, quality: float, exec_time: int, model_name: str
):
) -> None:
"""Records individual test results for a crew execution.
Args:
@@ -579,7 +582,7 @@ class Telemetry:
model_name (str): Name of the model used
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Individual Test Result")
@@ -604,7 +607,7 @@ class Telemetry:
iterations: int,
inputs: dict[str, Any] | None,
model_name: str,
):
) -> None:
"""Records the execution of a test suite for a crew.
Args:
@@ -614,7 +617,7 @@ class Telemetry:
model_name (str): Name of the model used in testing
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Test Execution")
@@ -638,10 +641,10 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def deploy_signup_error_span(self):
def deploy_signup_error_span(self) -> None:
"""Records when an error occurs during the deployment signup process."""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Deploy Signup Error")
span.set_status(Status(StatusCode.OK))
@@ -649,14 +652,14 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def start_deployment_span(self, uuid: Optional[str] = None):
def start_deployment_span(self, uuid: Optional[str] = None) -> None:
"""Records the start of a deployment process.
Args:
uuid (Optional[str]): Unique identifier for the deployment
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Start Deployment")
if uuid:
@@ -666,10 +669,10 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def create_crew_deployment_span(self):
def create_crew_deployment_span(self) -> None:
"""Records the creation of a new crew deployment."""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Create Crew Deployment")
span.set_status(Status(StatusCode.OK))
@@ -677,7 +680,9 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def get_crew_logs_span(self, uuid: Optional[str], log_type: str = "deployment"):
def get_crew_logs_span(
self, uuid: Optional[str], log_type: str = "deployment"
) -> None:
"""Records the retrieval of crew logs.
Args:
@@ -685,7 +690,7 @@ class Telemetry:
log_type (str, optional): Type of logs being retrieved. Defaults to "deployment".
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Get Crew Logs")
self._add_attribute(span, "log_type", log_type)
@@ -696,14 +701,14 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def remove_crew_span(self, uuid: Optional[str] = None):
def remove_crew_span(self, uuid: Optional[str] = None) -> None:
"""Records the removal of a crew.
Args:
uuid (Optional[str]): Unique identifier for the crew being removed
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Remove Crew")
if uuid:
@@ -713,13 +718,13 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def crew_execution_span(self, crew: Crew, inputs: dict[str, Any] | None):
def crew_execution_span(self, crew: Crew, inputs: dict[str, Any] | None) -> None:
"""Records the complete execution of a crew.
This is only collected if the user has opted-in to share the crew.
"""
self.crew_creation(crew, inputs)
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Crew Execution")
self._add_attribute(
@@ -787,11 +792,12 @@ class Telemetry:
if crew.share_crew:
self._safe_telemetry_operation(operation)
return operation()
result = operation()
return result # type: ignore[no-any-return]
return None
def end_crew(self, crew, final_string_output):
def operation():
def end_crew(self, crew: Any, final_string_output: str) -> None:
def operation() -> Any:
self._add_attribute(
crew._execution_span,
"crewai_version",
@@ -820,22 +826,22 @@ class Telemetry:
if crew.share_crew:
self._safe_telemetry_operation(operation)
def _add_attribute(self, span, key, value):
def _add_attribute(self, span: Any, key: str, value: Any) -> None:
"""Add an attribute to a span."""
def operation():
def operation() -> Any:
return span.set_attribute(key, value)
self._safe_telemetry_operation(operation)
def flow_creation_span(self, flow_name: str):
def flow_creation_span(self, flow_name: str) -> None:
"""Records the creation of a new flow.
Args:
flow_name (str): Name of the flow being created
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Flow Creation")
self._add_attribute(span, "flow_name", flow_name)
@@ -844,7 +850,7 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def flow_plotting_span(self, flow_name: str, node_names: list[str]):
def flow_plotting_span(self, flow_name: str, node_names: list[str]) -> None:
"""Records flow visualization/plotting activity.
Args:
@@ -852,7 +858,7 @@ class Telemetry:
node_names (list[str]): List of node names in the flow
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Flow Plotting")
self._add_attribute(span, "flow_name", flow_name)
@@ -862,7 +868,7 @@ class Telemetry:
self._safe_telemetry_operation(operation)
def flow_execution_span(self, flow_name: str, node_names: list[str]):
def flow_execution_span(self, flow_name: str, node_names: list[str]) -> None:
"""Records the execution of a flow.
Args:
@@ -870,7 +876,7 @@ class Telemetry:
node_names (list[str]): List of nodes being executed in the flow
"""
def operation():
def operation() -> Any:
tracer = trace.get_tracer("crewai.telemetry")
span = tracer.start_span("Flow Execution")
self._add_attribute(span, "flow_name", flow_name)

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from crewai.agents.cache import CacheHandler
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.tools.structured_tool import CrewStructuredTool
@@ -13,15 +13,19 @@ class CacheTools(BaseModel):
default_factory=CacheHandler,
)
def tool(self):
def tool(self) -> CrewStructuredTool:
return CrewStructuredTool.from_function(
func=self.hit_cache,
name=self.name,
description="Reads directly from the cache",
)
def hit_cache(self, key):
def hit_cache(self, key: str) -> str:
import json
split = key.split("tool:")
tool = split[1].split("|input:")[0].strip()
tool_input = split[1].split("|input:")[1].strip()
return self.cache_handler.read(tool, tool_input)
tool_input_str = split[1].split("|input:")[1].strip()
tool_input = json.loads(tool_input_str) if tool_input_str else None
result = self.cache_handler.read(tool, tool_input)
return result if result is not None else ""

View File

@@ -5,12 +5,20 @@ import time
from difflib import SequenceMatcher
from json import JSONDecodeError
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import json5
from json_repair import repair_json
from json_repair import repair_json # type: ignore[import-untyped]
from crewai.agents.tools_handler import ToolsHandler
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import (
ToolSelectionErrorEvent,
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
ToolValidateInputErrorEvent,
)
from crewai.task import Task
from crewai.telemetry import Telemetry
from crewai.tools.structured_tool import CrewStructuredTool
@@ -20,14 +28,6 @@ from crewai.utilities.agent_utils import (
get_tool_names,
render_text_description_and_args,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import (
ToolSelectionErrorEvent,
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
ToolValidateInputErrorEvent,
)
if TYPE_CHECKING:
from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -69,12 +69,12 @@ class ToolUsage:
def __init__(
self,
tools_handler: Optional[ToolsHandler],
tools: List[CrewStructuredTool],
tools: list[CrewStructuredTool],
task: Optional[Task],
function_calling_llm: Any,
agent: Optional[Union["BaseAgent", "LiteAgent"]] = None,
action: Any = None,
fingerprint_context: Optional[Dict[str, str]] = None,
fingerprint_context: Optional[dict[str, str]] = None,
) -> None:
self._i18n: I18N = agent.i18n if agent else I18N()
self._printer: Printer = Printer()
@@ -100,12 +100,12 @@ class ToolUsage:
self._max_parsing_attempts = 2
self._remember_format_after_usages = 4
def parse_tool_calling(self, tool_string: str):
def parse_tool_calling(self, tool_string: str) -> Any:
"""Parse the tool string and return the tool calling."""
return self._tool_calling(tool_string)
def use(
self, calling: Union[ToolCalling, InstructorToolCalling], tool_string: str
self, calling: ToolCalling | InstructorToolCalling, tool_string: str
) -> str:
if isinstance(calling, ToolUsageErrorException):
error = calling.message
@@ -147,11 +147,25 @@ class ToolUsage:
self,
tool_string: str,
tool: CrewStructuredTool,
calling: Union[ToolCalling, InstructorToolCalling],
calling: ToolCalling | InstructorToolCalling,
) -> str:
if self._check_tool_repeated_usage(calling=calling): # type: ignore # _check_tool_repeated_usage of "ToolUsage" does not return a value (it only ever returns None)
"""Use a tool with the given calling information.
Args:
tool_string: The string representation of the tool call.
tool: The CrewStructuredTool instance to use.
calling: The tool calling information.
Returns:
The formatted result of the tool usage.
Notes:
TODO: Investigate why BaseAgent/LiteAgent don't have fingerprint attribute.
Currently using hasattr check as a workaround (lines 179-180).
"""
if self._check_tool_repeated_usage(calling=calling):
try:
result = self._i18n.errors("task_repeated_usage").format(
repeated_usage_msg = self._i18n.errors("task_repeated_usage").format(
tool_names=self.tools_names
)
self._telemetry.tool_repeated_usage(
@@ -159,8 +173,8 @@ class ToolUsage:
tool_name=tool.name,
attempts=self._run_attempts,
)
result = self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None)
return result # type: ignore # Fix the return type of this function
repeated_usage_result = self._format_result(result=repeated_usage_msg)
return repeated_usage_result
except Exception:
if self.task:
@@ -176,7 +190,7 @@ class ToolUsage:
"agent": self.agent,
}
if self.agent.fingerprint:
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
event_data.update(self.agent.fingerprint)
if self.task:
event_data["task_name"] = self.task.name or self.task.description
@@ -185,12 +199,12 @@ class ToolUsage:
started_at = time.time()
from_cache = False
result = None # type: ignore
result: str | None = None
if self.tools_handler and self.tools_handler.cache:
result = self.tools_handler.cache.read(
tool=calling.tool_name, input=calling.arguments
) # type: ignore
tool=calling.tool_name, input_data=calling.arguments
)
from_cache = result is not None
available_tool = next(
@@ -229,7 +243,7 @@ class ToolUsage:
try:
acceptable_args = tool.args_schema.model_json_schema()[
"properties"
].keys() # type: ignore
].keys()
arguments = {
k: v
for k, v in calling.arguments.items()
@@ -264,19 +278,20 @@ class ToolUsage:
self._printer.print(
content=f"\n\n{error_message}\n", color="red"
)
return error # type: ignore # No return value expected
return error
if self.task:
self.task.increment_tools_errors()
return self.use(calling=calling, tool_string=tool_string) # type: ignore # No return value expected
return self.use(calling=calling, tool_string=tool_string)
if self.tools_handler:
should_cache = True
if (
hasattr(available_tool, "cache_function")
and available_tool.cache_function # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
available_tool
and hasattr(available_tool, "cache_function")
and available_tool.cache_function
):
should_cache = available_tool.cache_function( # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
should_cache = available_tool.cache_function(
calling.arguments, result
)
@@ -288,8 +303,8 @@ class ToolUsage:
tool_name=tool.name,
attempts=self._run_attempts,
)
result = self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None)
data = {
result = self._format_result(result=result)
data: dict[str, Any] = {
"result": result,
"tool_name": tool.name,
"tool_args": calling.arguments,
@@ -308,7 +323,7 @@ class ToolUsage:
and available_tool.result_as_answer # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
):
result_as_answer = available_tool.result_as_answer # type: ignore # Item "None" of "Any | None" has no attribute "result_as_answer"
data["result_as_answer"] = result_as_answer # type: ignore
data["result_as_answer"] = result_as_answer
if self.agent and hasattr(self.agent, "tools_results"):
self.agent.tools_results.append(data)
@@ -346,7 +361,7 @@ class ToolUsage:
return result
def _check_tool_repeated_usage(
self, calling: Union[ToolCalling, InstructorToolCalling]
self, calling: ToolCalling | InstructorToolCalling
) -> bool:
if not self.tools_handler:
return False
@@ -393,7 +408,7 @@ class ToolUsage:
return tool
if self.task:
self.task.increment_tools_errors()
tool_selection_data: Dict[str, Any] = {
tool_selection_data: dict[str, Any] = {
"agent_key": getattr(self.agent, "key", None) if self.agent else None,
"agent_role": getattr(self.agent, "role", None) if self.agent else None,
"tool_name": tool_name,
@@ -430,7 +445,7 @@ class ToolUsage:
def _function_calling(
self, tool_string: str
) -> Union[ToolCalling, InstructorToolCalling]:
) -> ToolCalling | InstructorToolCalling:
model = (
InstructorToolCalling
if self.function_calling_llm.supports_function_calling()
@@ -459,7 +474,7 @@ class ToolUsage:
def _original_tool_calling(
self, tool_string: str, raise_error: bool = False
) -> Union[ToolCalling, InstructorToolCalling, ToolUsageErrorException]:
) -> ToolCalling | InstructorToolCalling | ToolUsageErrorException:
tool_name = self.action.tool
tool = self._select_tool(tool_name)
try:
@@ -488,7 +503,7 @@ class ToolUsage:
def _tool_calling(
self, tool_string: str
) -> Union[ToolCalling, InstructorToolCalling, ToolUsageErrorException]:
) -> ToolCalling | InstructorToolCalling | ToolUsageErrorException:
try:
try:
return self._original_tool_calling(tool_string, raise_error=True)
@@ -505,12 +520,12 @@ class ToolUsage:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{e}\n", color="red")
return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling")
return ToolUsageErrorException(
f"{self._i18n.errors('tool_usage_error').format(error=e)}\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
)
return self._tool_calling(tool_string)
def _validate_tool_input(self, tool_input: Optional[str]) -> Dict[str, Any]:
def _validate_tool_input(self, tool_input: Optional[str]) -> dict[str, Any]:
if tool_input is None:
return {}
@@ -564,7 +579,7 @@ class ToolUsage:
# If all parsing attempts fail, raise an error
raise Exception(error_message)
def _emit_validate_input_error(self, final_error: str):
def _emit_validate_input_error(self, final_error: str) -> None:
tool_selection_data = {
"agent_key": getattr(self.agent, "key", None) if self.agent else None,
"agent_role": getattr(self.agent, "role", None) if self.agent else None,
@@ -586,7 +601,7 @@ class ToolUsage:
def on_tool_error(
self,
tool: Any,
tool_calling: Union[ToolCalling, InstructorToolCalling],
tool_calling: ToolCalling | InstructorToolCalling,
e: Exception,
) -> None:
event_data = self._prepare_event_data(tool, tool_calling)
@@ -595,7 +610,7 @@ class ToolUsage:
def on_tool_use_finished(
self,
tool: Any,
tool_calling: Union[ToolCalling, InstructorToolCalling],
tool_calling: ToolCalling | InstructorToolCalling,
from_cache: bool,
started_at: float,
result: Any,
@@ -616,8 +631,21 @@ class ToolUsage:
crewai_event_bus.emit(self, ToolUsageFinishedEvent(**event_data))
def _prepare_event_data(
self, tool: Any, tool_calling: Union[ToolCalling, InstructorToolCalling]
) -> dict:
self, tool: Any, tool_calling: ToolCalling | InstructorToolCalling
) -> dict[str, Any]:
"""Prepare event data for tool usage events.
Args:
tool: The tool being used.
tool_calling: The tool calling information containing arguments.
Returns:
A dictionary containing event data for tool usage tracking.
Notes:
TODO: Create a better type representation for the return value,
possibly using TypedDict or a dataclass for stronger typing.
"""
event_data = {
"run_attempts": self._run_attempts,
"delegations": self.task.delegations if self.task else 0,
@@ -641,7 +669,7 @@ class ToolUsage:
return event_data
def _add_fingerprint_metadata(self, arguments: dict) -> dict:
def _add_fingerprint_metadata(self, arguments: dict[str, Any]) -> dict[str, Any]:
"""Add fingerprint metadata to tool arguments if available.
Args:

View File

@@ -1,10 +1,10 @@
from typing import List
from typing import Any, cast
from pydantic import BaseModel, Field
from crewai.utilities import Converter
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.task_events import TaskEvaluationEvent
from crewai.utilities import Converter
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
from crewai.utilities.training_converter import TrainingConverter
@@ -13,23 +13,23 @@ class Entity(BaseModel):
name: str = Field(description="The name of the entity.")
type: str = Field(description="The type of the entity.")
description: str = Field(description="Description of the entity.")
relationships: List[str] = Field(description="Relationships of the entity.")
relationships: list[str] = Field(description="Relationships of the entity.")
class TaskEvaluation(BaseModel):
suggestions: List[str] = Field(
suggestions: list[str] = Field(
description="Suggestions to improve future similar tasks."
)
quality: float = Field(
description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task."
)
entities: List[Entity] = Field(
entities: list[Entity] = Field(
description="Entities extracted from the task output."
)
class TrainingTaskEvaluation(BaseModel):
suggestions: List[str] = Field(
suggestions: list[str] = Field(
description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions."
)
quality: float = Field(
@@ -41,11 +41,11 @@ class TrainingTaskEvaluation(BaseModel):
class TaskEvaluator:
def __init__(self, original_agent):
def __init__(self, original_agent: Any) -> None:
self.llm = original_agent.llm
self.original_agent = original_agent
def evaluate(self, task, output) -> TaskEvaluation:
def evaluate(self, task: Any, output: Any) -> TaskEvaluation:
crewai_event_bus.emit(
self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task)
)
@@ -73,10 +73,10 @@ class TaskEvaluator:
instructions=instructions,
)
return converter.to_pydantic()
return cast(TaskEvaluation, converter.to_pydantic())
def evaluate_training_data(
self, training_data: dict, agent_id: str
self, training_data: dict[str, Any], agent_id: str
) -> TrainingTaskEvaluation:
"""
Evaluate the training data based on the llm output, human feedback, and improved output.
@@ -143,4 +143,4 @@ class TaskEvaluator:
)
pydantic_result = converter.to_pydantic()
return pydantic_result
return cast(TrainingTaskEvaluation, pydantic_result)

View File

@@ -2,33 +2,39 @@ import json
import os
import pickle
from datetime import datetime
from typing import Union
from typing import Any, Union
class FileHandler:
"""Handler for file operations supporting both JSON and text-based logging.
Args:
file_path (Union[bool, str]): Path to the log file or boolean flag
"""
def __init__(self, file_path: Union[bool, str]):
def __init__(self, file_path: bool | str):
self._initialize_path(file_path)
def _initialize_path(self, file_path: Union[bool, str]):
def _initialize_path(self, file_path: bool | str) -> None:
if file_path is True: # File path is boolean True
self._path = os.path.join(os.curdir, "logs.txt")
elif isinstance(file_path, str): # File path is a string
if file_path.endswith((".json", ".txt")):
self._path = file_path # No modification if the file ends with .json or .txt
self._path = (
file_path # No modification if the file ends with .json or .txt
)
else:
self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt
self._path = (
file_path + ".txt"
) # Append .txt if the file doesn't end with .json or .txt
else:
raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid
def log(self, **kwargs):
raise ValueError(
"file_path must be a string or boolean."
) # Handle the case where file_path isn't valid
def log(self, **kwargs: Any) -> None:
try:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_entry = {"timestamp": now, **kwargs}
@@ -45,20 +51,25 @@ class FileHandler:
except (json.JSONDecodeError, FileNotFoundError):
# If no valid JSON or file doesn't exist, start with an empty list
existing_data = [log_entry]
with open(self._path, "w", encoding="utf-8") as write_file:
json.dump(existing_data, write_file, indent=4)
write_file.write("\n")
else:
# Append log in plain text format
message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n"
message = (
f"{now}: "
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
+ "\n"
)
with open(self._path, "a", encoding="utf-8") as file:
file.write(message)
except Exception as e:
raise ValueError(f"Failed to log message: {str(e)}")
class PickleHandler:
def __init__(self, file_name: str) -> None:
"""
@@ -79,7 +90,7 @@ class PickleHandler:
"""
self.save({})
def save(self, data) -> None:
def save(self, data: Any) -> None:
"""
Save the data to the specified file using pickle.
@@ -89,12 +100,12 @@ class PickleHandler:
with open(self.file_path, "wb") as file:
pickle.dump(data, file)
def load(self) -> dict:
def load(self) -> Any:
"""
Load the data from the specified file using pickle.
Returns:
- dict: The data loaded from the file.
- The data loaded from the file.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return {} # Return an empty dictionary if the file does not exist or is empty

View File

@@ -1,78 +1,105 @@
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Protocol, TypedDict, runtime_checkable
from typing_extensions import Required
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
class LLMParams(TypedDict, total=False):
"""TypedDict defining LLM parameters we extract from LLMLike objects."""
model: Required[str]
temperature: float
max_tokens: int
logprobs: int
timeout: float
api_key: str
base_url: str
api_base: str
@runtime_checkable
class LLMLike(Protocol):
"""Protocol for objects that can be converted to an LLM instance."""
model: str | None
model_name: str | None
deployment_name: str | None
temperature: float | None
max_tokens: int | None
logprobs: int | None
timeout: float | None
api_key: str | None
base_url: str | None
api_base: str | None
def create_default_llm() -> LLM:
"""Creates a default LLM instance using environment variables or fallback defaults.
Returns:
A default LLM instance configured from environment or using defaults.
Raises:
ValueError: If LLM creation fails.
"""
result = _llm_via_environment_or_fallback()
if result is None:
raise ValueError("Failed to create default LLM instance")
return result
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[LLM | BaseLLM]:
llm_value: str | BaseLLM | LLMLike,
) -> BaseLLM:
"""
Creates or returns an LLM instance based on the given llm_value.
Args:
llm_value (str | BaseLLM | Any | None):
llm_value (str | BaseLLM | LLMLike):
- str: The model name (e.g., "gpt-4").
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
- LLMLike: Object with LLM-compatible attributes (model_name, temperature, etc.)
Returns:
A BaseLLM instance if successful, or None if something fails.
"""
A BaseLLM instance.
# 1) If llm_value is already a BaseLLM or LLM object, return it directly
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
Raises:
ValueError: If LLM creation fails.
"""
if isinstance(llm_value, BaseLLM):
return llm_value
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str):
try:
created_llm = LLM(model=llm_value)
return created_llm
except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
return None
return LLM(model=llm_value)
# 3) If llm_value is None, parse environment variables or use default
if llm_value is None:
return _llm_via_environment_or_fallback()
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
try:
# Extract attributes with explicit types
model = (
obj_attrs = set(dir(llm_value))
llm_kwargs = {
param: getattr(llm_value, param)
for param in LLMParams.__annotations__
if param != "model"
and param in obj_attrs
and getattr(llm_value, param) is not None
}
llm_kwargs["model"] = (
getattr(llm_value, "model", None)
or getattr(llm_value, "model_name", None)
or getattr(llm_value, "deployment_name", None)
or str(llm_value)
)
temperature: Optional[float] = getattr(llm_value, "temperature", None)
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None)
logprobs: Optional[int] = getattr(llm_value, "logprobs", None)
timeout: Optional[float] = getattr(llm_value, "timeout", None)
api_key: Optional[str] = getattr(llm_value, "api_key", None)
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
created_llm = LLM(
model=model,
temperature=temperature,
max_tokens=max_tokens,
logprobs=logprobs,
timeout=timeout,
api_key=api_key,
base_url=base_url,
api_base=api_base,
)
return created_llm
return LLM(**llm_kwargs)
except Exception as e:
print(f"Error instantiating LLM from unknown object type: {e}")
return None
raise ValueError(f"Error instantiating LLM from object: {e}")
def _llm_via_environment_or_fallback() -> Optional[LLM]:
def _llm_via_environment_or_fallback() -> LLM | None:
"""
Helper function: if llm_value is None, we load environment variables or fallback default model.
"""
@@ -85,24 +112,24 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
# Initialize parameters with correct types
model: str = model_name
temperature: Optional[float] = None
max_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
logprobs: Optional[int] = None
timeout: Optional[float] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
logit_bias: Optional[Dict[int, float]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
top_logprobs: Optional[int] = None
callbacks: List[Any] = []
temperature: float | None = None
max_tokens: int | None = None
max_completion_tokens: int | None = None
logprobs: int | None = None
timeout: float | None = None
api_key: str | None = None
base_url: str | None = None
api_version: str | None = None
presence_penalty: float | None = None
frequency_penalty: float | None = None
top_p: float | None = None
n: int | None = None
stop: str | list[str] | None = None
logit_bias: dict[int, float] | None = None
response_format: dict[str, Any] | None = None
seed: int | None = None
top_logprobs: int | None = None
callbacks: list[Any] = []
# Optional base URL from env
base_url = (
@@ -120,7 +147,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
base_url = api_base
# Initialize llm_params dictionary
llm_params: Dict[str, Any] = {
llm_params: dict[str, Any] = {
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,

View File

@@ -10,7 +10,7 @@ class Logger(BaseModel):
_printer: Printer = PrivateAttr(default_factory=Printer)
default_color: str = Field(default="bold_yellow")
def log(self, level, message, color=None):
def log(self, level: str, message: str, color: str | None = None) -> None:
if color is None:
color = self.default_color
if self.verbose:

View File

@@ -20,18 +20,18 @@ class RPMController(BaseModel):
_shutdown_flag: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def reset_counter(self):
def reset_counter(self) -> "RPMController":
if self.max_rpm is not None:
if not self._shutdown_flag:
self._lock = threading.Lock()
self._reset_request_count()
return self
def check_or_wait(self):
def check_or_wait(self) -> bool:
if self.max_rpm is None:
return True
def _check_and_increment():
def _check_and_increment() -> bool:
if self.max_rpm is not None and self._current_rpm < self.max_rpm:
self._current_rpm += 1
return True
@@ -50,17 +50,17 @@ class RPMController(BaseModel):
else:
return _check_and_increment()
def stop_rpm_counter(self):
def stop_rpm_counter(self) -> None:
if self._timer:
self._timer.cancel()
self._timer = None
def _wait_for_next_minute(self):
def _wait_for_next_minute(self) -> None:
time.sleep(60)
self._current_rpm = 0
def _reset_request_count(self):
def _reset():
def _reset_request_count(self) -> None:
def _reset() -> None:
self._current_rpm = 0
if not self._shutdown_flag:
self._timer = threading.Timer(60.0, self._reset_request_count)

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from pydantic import BaseModel, Field
@@ -16,10 +16,10 @@ class ExecutionLog(BaseModel):
task_id: str
expected_output: Optional[str] = None
output: Dict[str, Any]
output: dict[str, Any]
timestamp: datetime = Field(default_factory=datetime.now)
task_index: int
inputs: Dict[str, Any] = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
was_replayed: bool = False
def __getitem__(self, key: str) -> Any:
@@ -33,7 +33,7 @@ class TaskOutputStorageHandler:
def __init__(self) -> None:
self.storage = KickoffTaskOutputsSQLiteStorage()
def update(self, task_index: int, log: Dict[str, Any]):
def update(self, task_index: int, log: dict[str, Any]) -> None:
saved_outputs = self.load()
if saved_outputs is None:
raise ValueError("Logs cannot be None")
@@ -56,16 +56,16 @@ class TaskOutputStorageHandler:
def add(
self,
task: Task,
output: Dict[str, Any],
output: dict[str, Any],
task_index: int,
inputs: Dict[str, Any] | None = None,
inputs: dict[str, Any] | None = None,
was_replayed: bool = False,
):
) -> None:
inputs = inputs or {}
self.storage.add(task, output, task_index, was_replayed, inputs)
def reset(self):
def reset(self) -> None:
self.storage.delete_all()
def load(self) -> Optional[List[Dict[str, Any]]]:
def load(self) -> Optional[list[dict[str, Any]]]:
return self.storage.load()

View File

@@ -7,21 +7,21 @@ from unittest.mock import MagicMock, patch
import pytest
from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM
from crewai.process import Process
from crewai.tools import tool
from crewai.tools.tool_calling import InstructorToolCalling
from crewai.tools.tool_usage import ToolUsage
from crewai.utilities import RPMController
from crewai.utilities.errors import AgentRepositoryError
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.process import Process
def test_agent_llm_creation_with_env_vars():
@@ -287,8 +287,8 @@ def test_cache_hitting():
output = agent.execute_task(task1)
output = agent.execute_task(task2)
assert cache_handler._cache == {
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
'multiplier-{"first_number": 2, "second_number": 6}': 12,
'multiplier-{"first_number": 3, "second_number": 3}': 9,
}
task = Task(
@@ -300,9 +300,9 @@ def test_cache_hitting():
assert output == "36"
assert cache_handler._cache == {
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
"multiplier-{'first_number': 12, 'second_number': 3}": 36,
'multiplier-{"first_number": 2, "second_number": 6}': 12,
'multiplier-{"first_number": 3, "second_number": 3}': 9,
'multiplier-{"first_number": 12, "second_number": 3}': 36,
}
received_events = []
@@ -322,7 +322,7 @@ def test_cache_hitting():
output = agent.execute_task(task)
assert output == "0"
read.assert_called_with(
tool="multiplier", input={"first_number": 2, "second_number": 6}
tool="multiplier", input_data={"first_number": 2, "second_number": 6}
)
assert len(received_events) == 1
assert isinstance(received_events[0], ToolUsageFinishedEvent)
@@ -559,9 +559,9 @@ def test_agent_repeated_tool_usage(capsys):
expected_message = (
"I tried reusing the same input, I must stop using this action input."
)
assert (
expected_message in output
), f"Expected message not found in output. Output was: {output}"
assert expected_message in output, (
f"Expected message not found in output. Output was: {output}"
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -602,9 +602,9 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
has_max_iterations = "maximum iterations reached" in output_lower
has_final_answer = "final answer" in output_lower or "42" in captured.out
assert (
has_repeated_usage_message or (has_max_iterations and has_final_answer)
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -1116,7 +1116,9 @@ def test_not_using_system_prompt():
use_system_prompt=False,
)
agent.create_agent_executor()
# Create a dummy task for testing
task = Task(description="Test task", expected_output="Test output", agent=agent)
agent.create_agent_executor(task)
assert not agent.agent_executor.prompt.get("user")
assert not agent.agent_executor.prompt.get("system")
@@ -1128,7 +1130,9 @@ def test_using_system_prompt():
backstory="I am the master of {role}",
)
agent.create_agent_executor()
# Create a dummy task for testing
task = Task(description="Test task", expected_output="Test output", agent=agent)
agent.create_agent_executor(task)
assert agent.agent_executor.prompt.get("user")
assert agent.agent_executor.prompt.get("system")
@@ -2312,9 +2316,9 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
# Mock embedchain initialization to prevent race conditions in parallel CI execution
with patch("embedchain.client.Client.setup"):
from crewai_tools import (
SerperDevTool,
FileReadTool,
EnterpriseActionTool,
FileReadTool,
SerperDevTool,
)
mock_get_response = MagicMock()

View File

@@ -1,15 +1,17 @@
import pytest
from unittest.mock import ANY
from collections import defaultdict
from unittest.mock import ANY
import pytest
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.events.types.memory_events import (
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
)
@pytest.fixture
@@ -98,7 +100,7 @@ def test_long_term_memory_search_events(long_term_memory):
test_query = "test query"
long_term_memory.search(test_query, latest_n=5)
long_term_memory.search(test_query, limit=5)
assert len(events["MemoryQueryStartedEvent"]) == 1
assert len(events["MemoryQueryCompletedEvent"]) == 1
@@ -151,7 +153,7 @@ def test_save_and_search(long_term_memory):
metadata={"task": "test_task", "quality": 0.5},
)
long_term_memory.save(memory)
find = long_term_memory.search("test_task", latest_n=5)[0]
find = long_term_memory.search("test_task", limit=5)[0]
assert find["score"] == 0.5
assert find["datetime"] == "test_datetime"
assert find["metadata"]["agent"] == "test_agent"

View File

@@ -2,23 +2,41 @@
import hashlib
import json
from collections import defaultdict
from concurrent.futures import Future
from unittest import mock
from unittest.mock import ANY, MagicMock, patch
from collections import defaultdict
import pydantic_core
import pytest
from crewai.agent import Agent
from crewai.agents import CacheHandler
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.crew_events import (
CrewTestCompletedEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainStartedEvent,
)
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.flow import Flow, start
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.process import Process
@@ -27,28 +45,9 @@ from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.types.usage_metrics import UsageMetrics
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.crew_events import (
CrewTestCompletedEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainStartedEvent,
)
from crewai.utilities.rpm_controller import RPMController
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
from crewai.events.types.memory_events import (
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryRetrievalStartedEvent,
MemoryRetrievalCompletedEvent,
)
from crewai.memory.external.external_memory import ExternalMemory
@pytest.fixture
def ceo():
@@ -570,8 +569,6 @@ def test_crew_with_delegating_agents(ceo, writer):
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer):
from typing import Type
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
@@ -584,7 +581,7 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
class TestTool(BaseTool):
name: str = "Test Tool"
description: str = "A test tool that just returns the input"
args_schema: Type[BaseModel] = TestToolInput
args_schema: type[BaseModel] = TestToolInput
def _run(self, query: str) -> str:
return f"Processed: {query}"
@@ -622,18 +619,16 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
_, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"]
assert any(
isinstance(tool, TestTool) for tool in tools
), "TestTool should be present"
assert any(
"delegate" in tool.name.lower() for tool in tools
), "Delegation tool should be present"
assert any(isinstance(tool, TestTool) for tool in tools), (
"TestTool should be present"
)
assert any("delegate" in tool.name.lower() for tool in tools), (
"Delegation tool should be present"
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer):
from typing import Type
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
@@ -646,7 +641,7 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
class TestTool(BaseTool):
name: str = "Test Tool"
description: str = "A test tool that just returns the input"
args_schema: Type[BaseModel] = TestToolInput
args_schema: type[BaseModel] = TestToolInput
def _run(self, query: str) -> str:
return f"Processed: {query}"
@@ -686,18 +681,16 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
_, kwargs = mock_execute_sync.call_args
tools = kwargs["tools"]
assert any(
isinstance(tool, TestTool) for tool in new_ceo.tools
), "TestTool should be present"
assert any(
"delegate" in tool.name.lower() for tool in tools
), "Delegation tool should be present"
assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), (
"TestTool should be present"
)
assert any("delegate" in tool.name.lower() for tool in tools), (
"Delegation tool should be present"
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_task_tools_override_agent_tools(researcher):
from typing import Type
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
@@ -710,7 +703,7 @@ def test_task_tools_override_agent_tools(researcher):
class TestTool(BaseTool):
name: str = "Test Tool"
description: str = "A test tool that just returns the input"
args_schema: Type[BaseModel] = TestToolInput
args_schema: type[BaseModel] = TestToolInput
def _run(self, query: str) -> str:
return f"Processed: {query}"
@@ -718,7 +711,7 @@ def test_task_tools_override_agent_tools(researcher):
class AnotherTestTool(BaseTool):
name: str = "Another Test Tool"
description: str = "Another test tool"
args_schema: Type[BaseModel] = TestToolInput
args_schema: type[BaseModel] = TestToolInput
def _run(self, query: str) -> str:
return f"Another processed: {query}"
@@ -754,7 +747,6 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
"""
Test that task tools override agent tools while preserving delegation tools when allow_delegation=True
"""
from typing import Type
from pydantic import BaseModel, Field
@@ -766,7 +758,7 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
class TestTool(BaseTool):
name: str = "Test Tool"
description: str = "A test tool that just returns the input"
args_schema: Type[BaseModel] = TestToolInput
args_schema: type[BaseModel] = TestToolInput
def _run(self, query: str) -> str:
return f"Processed: {query}"
@@ -774,7 +766,7 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
class AnotherTestTool(BaseTool):
name: str = "Another Test Tool"
description: str = "Another test tool"
args_schema: Type[BaseModel] = TestToolInput
args_schema: type[BaseModel] = TestToolInput
def _run(self, query: str) -> str:
return f"Another processed: {query}"
@@ -815,17 +807,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
used_tools = kwargs["tools"]
# Confirm AnotherTestTool is present but TestTool is not
assert any(
isinstance(tool, AnotherTestTool) for tool in used_tools
), "AnotherTestTool should be present"
assert not any(
isinstance(tool, TestTool) for tool in used_tools
), "TestTool should not be present among used tools"
assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), (
"AnotherTestTool should be present"
)
assert not any(isinstance(tool, TestTool) for tool in used_tools), (
"TestTool should not be present among used tools"
)
# Confirm delegation tool(s) are present
assert any(
"delegate" in tool.name.lower() for tool in used_tools
), "Delegation tool should be present"
assert any("delegate" in tool.name.lower() for tool in used_tools), (
"Delegation tool should be present"
)
# Finally, make sure the agent's original tools remain unchanged
assert len(researcher_with_delegation.tools) == 1
@@ -912,13 +904,13 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
crew.kickoff()
assert read.call_count == 2, "read was not called exactly twice"
# Filter the mock calls to only include the ones with 'tool' and 'input' keywords
# Filter the mock calls to only include the ones with 'tool' and 'input_data' keywords
cache_calls = [
call
for call in read.call_args_list
if len(call.kwargs) == 2
and "tool" in call.kwargs
and "input" in call.kwargs
and "input_data" in call.kwargs
]
# Check if we have the expected number of cache calls
@@ -926,12 +918,12 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
# Check if both calls were made with the expected arguments
expected_call = call(
tool="multiplier", input={"first_number": 2, "second_number": 6}
tool="multiplier", input_data={"first_number": 2, "second_number": 6}
)
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
assert (
cache_calls[1] == expected_call
), f"Second call mismatch: {cache_calls[1]}"
assert cache_calls[1] == expected_call, (
f"Second call mismatch: {cache_calls[1]}"
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -1674,9 +1666,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
# Verify that exactly one tool was used and it was a CodeInterpreterTool
assert len(used_tools) == 1, "Should have exactly one tool"
assert isinstance(
used_tools[0], CodeInterpreterTool
), "Tool should be CodeInterpreterTool"
assert isinstance(used_tools[0], CodeInterpreterTool), (
"Tool should be CodeInterpreterTool"
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -2237,7 +2229,7 @@ def test_tools_with_custom_caching():
# Verify that one of those calls was with the even number that should be cached
add_to_cache.assert_any_call(
tool="multiplcation_tool",
input={"first_number": 2, "second_number": 6},
input_data={"first_number": 2, "second_number": 6},
output=12,
)
@@ -3815,16 +3807,15 @@ def test_fetch_inputs():
expected_placeholders = {"role_detail", "topic", "field"}
actual_placeholders = crew.fetch_inputs()
assert (
actual_placeholders == expected_placeholders
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
assert actual_placeholders == expected_placeholders, (
f"Expected {expected_placeholders}, but got {actual_placeholders}"
)
def test_task_tools_preserve_code_execution_tools():
"""
Test that task tools don't override code execution tools when allow_code_execution=True
"""
from typing import Type
# Mock embedchain initialization to prevent race conditions in parallel CI execution
with patch("embedchain.client.Client.setup"):
@@ -3841,7 +3832,7 @@ def test_task_tools_preserve_code_execution_tools():
class TestTool(BaseTool):
name: str = "Test Tool"
description: str = "A test tool that just returns the input"
args_schema: Type[BaseModel] = TestToolInput
args_schema: type[BaseModel] = TestToolInput
def _run(self, query: str) -> str:
return f"Processed: {query}"
@@ -3892,20 +3883,20 @@ def test_task_tools_preserve_code_execution_tools():
used_tools = kwargs["tools"]
# Verify all expected tools are present
assert any(
isinstance(tool, TestTool) for tool in used_tools
), "Task's TestTool should be present"
assert any(
isinstance(tool, CodeInterpreterTool) for tool in used_tools
), "CodeInterpreterTool should be present"
assert any(
"delegate" in tool.name.lower() for tool in used_tools
), "Delegation tool should be present"
assert any(isinstance(tool, TestTool) for tool in used_tools), (
"Task's TestTool should be present"
)
assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), (
"CodeInterpreterTool should be present"
)
assert any("delegate" in tool.name.lower() for tool in used_tools), (
"Delegation tool should be present"
)
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
assert (
len(used_tools) == 4
), "Should have TestTool, CodeInterpreter, and 2 delegation tools"
assert len(used_tools) == 4, (
"Should have TestTool, CodeInterpreter, and 2 delegation tools"
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -3949,9 +3940,9 @@ def test_multimodal_flag_adds_multimodal_tools():
used_tools = kwargs["tools"]
# Check that the multimodal tool was added
assert any(
isinstance(tool, AddImageTool) for tool in used_tools
), "AddImageTool should be present when agent is multimodal"
assert any(isinstance(tool, AddImageTool) for tool in used_tools), (
"AddImageTool should be present when agent is multimodal"
)
# Verify we have exactly one tool (just the AddImageTool)
assert len(used_tools) == 1, "Should only have the AddImageTool"
@@ -4215,9 +4206,9 @@ def test_crew_guardrail_feedback_in_context():
assert len(execution_contexts) > 1, "Task should have been executed multiple times"
# Verify that the second execution included the guardrail feedback
assert (
"Output must contain the keyword 'IMPORTANT'" in execution_contexts[1]
), "Guardrail feedback should be included in retry context"
assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], (
"Guardrail feedback should be included in retry context"
)
# Verify final output meets guardrail requirements
assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
@@ -4232,13 +4223,11 @@ def test_before_kickoff_callback():
@CrewBase
class TestCrewClass:
from typing import List
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.project import CrewBase, agent, before_kickoff, crew, task
agents: List[BaseAgent]
tasks: List[Task]
agents: list[BaseAgent]
tasks: list[Task]
agents_config = None
tasks_config = None
@@ -4433,46 +4422,46 @@ def test_crew_copy_with_memory():
try:
crew_copy = crew.copy()
assert hasattr(
crew_copy, "_short_term_memory"
), "Copied crew should have _short_term_memory"
assert (
crew_copy._short_term_memory is not None
), "Copied _short_term_memory should not be None"
assert (
id(crew_copy._short_term_memory) != original_short_term_id
), "Copied _short_term_memory should be a new object"
assert hasattr(crew_copy, "_short_term_memory"), (
"Copied crew should have _short_term_memory"
)
assert crew_copy._short_term_memory is not None, (
"Copied _short_term_memory should not be None"
)
assert id(crew_copy._short_term_memory) != original_short_term_id, (
"Copied _short_term_memory should be a new object"
)
assert hasattr(
crew_copy, "_long_term_memory"
), "Copied crew should have _long_term_memory"
assert (
crew_copy._long_term_memory is not None
), "Copied _long_term_memory should not be None"
assert (
id(crew_copy._long_term_memory) != original_long_term_id
), "Copied _long_term_memory should be a new object"
assert hasattr(crew_copy, "_long_term_memory"), (
"Copied crew should have _long_term_memory"
)
assert crew_copy._long_term_memory is not None, (
"Copied _long_term_memory should not be None"
)
assert id(crew_copy._long_term_memory) != original_long_term_id, (
"Copied _long_term_memory should be a new object"
)
assert hasattr(
crew_copy, "_entity_memory"
), "Copied crew should have _entity_memory"
assert (
crew_copy._entity_memory is not None
), "Copied _entity_memory should not be None"
assert (
id(crew_copy._entity_memory) != original_entity_id
), "Copied _entity_memory should be a new object"
assert hasattr(crew_copy, "_entity_memory"), (
"Copied crew should have _entity_memory"
)
assert crew_copy._entity_memory is not None, (
"Copied _entity_memory should not be None"
)
assert id(crew_copy._entity_memory) != original_entity_id, (
"Copied _entity_memory should be a new object"
)
if original_external_id:
assert hasattr(
crew_copy, "_external_memory"
), "Copied crew should have _external_memory"
assert (
crew_copy._external_memory is not None
), "Copied _external_memory should not be None"
assert (
id(crew_copy._external_memory) != original_external_id
), "Copied _external_memory should be a new object"
assert hasattr(crew_copy, "_external_memory"), (
"Copied crew should have _external_memory"
)
assert crew_copy._external_memory is not None, (
"Copied _external_memory should not be None"
)
assert id(crew_copy._external_memory) != original_external_id, (
"Copied _external_memory should be a new object"
)
else:
assert (
not hasattr(crew_copy, "_external_memory")

View File

@@ -7,10 +7,8 @@ from pydantic import Field
from crewai.agent import Agent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.crew import Crew
from crewai.flow.flow import Flow, listen, start
from crewai.llm import LLM
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
@@ -24,9 +22,6 @@ from crewai.events.types.crew_events import (
CrewTestResultEvent,
CrewTestStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
from crewai.events.types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
@@ -47,7 +42,12 @@ from crewai.events.types.task_events import (
)
from crewai.events.types.tool_usage_events import (
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
)
from crewai.flow.flow import Flow, listen, start
from crewai.llm import LLM
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
@pytest.fixture(scope="module")
@@ -327,9 +327,9 @@ def test_agent_emits_execution_error_event(base_agent, base_task):
error_message = "Error happening while sending prompt to model."
base_agent.max_retry_limit = 0
with patch.object(
CrewAgentExecutor, "invoke", wraps=base_agent.agent_executor.invoke
) as invoke_mock:
# Patch the invoke method on the CrewAgentExecutor class directly
with patch.object(CrewAgentExecutor, "invoke") as invoke_mock:
invoke_mock.side_effect = Exception(error_message)
with pytest.raises(Exception):