mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-03 21:28:29 +00:00
Compare commits
36 Commits
0.193.0
...
gl/fix/cac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af84ba2272 | ||
|
|
e93d597721 | ||
|
|
a414e7f2a7 | ||
|
|
fbcd8bcd83 | ||
|
|
5f776bbb0a | ||
|
|
909b2fd0ef | ||
|
|
929f9dadb4 | ||
|
|
4ef4632a8c | ||
|
|
c246df3cb2 | ||
|
|
4fd40e7857 | ||
|
|
25204c6cb8 | ||
|
|
b44776c367 | ||
|
|
843801f554 | ||
|
|
2faa13ddcb | ||
|
|
e385b45667 | ||
|
|
f03567d463 | ||
|
|
e9f4ac070b | ||
|
|
bcee792390 | ||
|
|
221bfcccce | ||
|
|
4812986f58 | ||
|
|
23c60befd8 | ||
|
|
8dd3493e9c | ||
|
|
9306d889a7 | ||
|
|
8354cdf061 | ||
|
|
2ba48dd82a | ||
|
|
0bab041531 | ||
|
|
eed2ffde5f | ||
|
|
b6e7311d2d | ||
|
|
90ca02b9dc | ||
|
|
06d5c3f170 | ||
|
|
b94fbd3d3a | ||
|
|
43880b49a6 | ||
|
|
bdfc38ba32 | ||
|
|
94029017c3 | ||
|
|
89df777887 | ||
|
|
d1fbf24d9e |
@@ -11,4 +11,6 @@ repos:
|
||||
rev: v1.17.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: ["--config-file", "pyproject.toml"]
|
||||
args: ["--strict", "--exclude", "src/crewai/cli/templates"]
|
||||
files: ^src/
|
||||
exclude: ^tests/
|
||||
|
||||
@@ -1,29 +1,50 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from pydantic import (
|
||||
BeforeValidator,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
computed_field,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.security import Fingerprint
|
||||
from crewai.task import Task
|
||||
@@ -38,25 +59,7 @@ from crewai.utilities.agent_utils import (
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.llm_utils import create_default_llm, create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
@@ -87,6 +90,8 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
_llm: BaseLLM = PrivateAttr()
|
||||
_function_calling_llm: BaseLLM | None = PrivateAttr(default=None)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
@@ -101,10 +106,11 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.",
|
||||
default_factory=create_default_llm,
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
@@ -151,7 +157,7 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||
)
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
embedder: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
@@ -171,7 +177,7 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
guardrail: Optional[Callable[[Any], tuple[bool, Any]] | str] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
@@ -180,20 +186,36 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v):
|
||||
@classmethod
|
||||
def validate_from_repository(cls, v: Any) -> Any:
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
|
||||
@field_validator("function_calling_llm", mode="after")
|
||||
@classmethod
|
||||
def validate_function_calling_llm(cls, v: Any) -> BaseLLM | None:
|
||||
if not v or isinstance(v, BaseLLM):
|
||||
return v
|
||||
return create_llm(v)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
def post_init_setup(self) -> Self:
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
self.llm = create_llm(self.llm)
|
||||
if self.function_calling_llm and not isinstance(
|
||||
self.function_calling_llm, BaseLLM
|
||||
):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
# Validate and set the private LLM attributes
|
||||
if isinstance(self.llm, BaseLLM):
|
||||
self._llm = self.llm
|
||||
elif self.llm is None:
|
||||
self._llm = create_default_llm()
|
||||
else:
|
||||
self._llm = create_llm(self.llm)
|
||||
|
||||
if self.function_calling_llm:
|
||||
if isinstance(self.function_calling_llm, BaseLLM):
|
||||
self._function_calling_llm = self.function_calling_llm
|
||||
else:
|
||||
self._function_calling_llm = create_llm(self.function_calling_llm)
|
||||
|
||||
if not self.agent_executor:
|
||||
self._setup_agent_executor()
|
||||
@@ -203,12 +225,12 @@ class Agent(BaseAgent):
|
||||
|
||||
return self
|
||||
|
||||
def _setup_agent_executor(self):
|
||||
def _setup_agent_executor(self) -> None:
|
||||
if not self.cache_handler:
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: Optional[dict[str, Any]] = None) -> None:
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
@@ -245,8 +267,8 @@ class Agent(BaseAgent):
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
) -> Any:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
Args:
|
||||
@@ -417,7 +439,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
tools = tools or self.tools or []
|
||||
self.create_agent_executor(tools=tools, task=task)
|
||||
self.create_agent_executor(task=task, tools=tools)
|
||||
|
||||
if self.crew and self.crew._train:
|
||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||
@@ -492,7 +514,7 @@ class Agent(BaseAgent):
|
||||
# If there was any tool in self.tools_results that had result_as_answer
|
||||
# set to True, return the results of the last tool that had
|
||||
# result_as_answer set to True
|
||||
for tool_result in self.tools_results: # type: ignore # Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable)
|
||||
for tool_result in self.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
crewai_event_bus.emit(
|
||||
@@ -501,7 +523,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return result
|
||||
|
||||
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> str:
|
||||
def _execute_with_timeout(self, task_prompt: str, task: Task, timeout: int) -> Any:
|
||||
"""Execute a task with a timeout.
|
||||
|
||||
Args:
|
||||
@@ -534,7 +556,7 @@ class Agent(BaseAgent):
|
||||
future.cancel()
|
||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
||||
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> Any:
|
||||
"""Execute a task without a timeout.
|
||||
|
||||
Args:
|
||||
@@ -544,6 +566,9 @@ class Agent(BaseAgent):
|
||||
Returns:
|
||||
The output of the agent.
|
||||
"""
|
||||
assert self.agent_executor is not None, (
|
||||
"Agent executor must be created before execution"
|
||||
)
|
||||
return self.agent_executor.invoke(
|
||||
{
|
||||
"input": task_prompt,
|
||||
@@ -554,14 +579,15 @@ class Agent(BaseAgent):
|
||||
)["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
self, task: Task, tools: Optional[list[BaseTool]] = None
|
||||
) -> None:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the CrewAgentExecutor class.
|
||||
Args:
|
||||
task: Task to execute.
|
||||
tools: Optional list of tools to use.
|
||||
"""
|
||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
prompt = Prompts(
|
||||
@@ -582,7 +608,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
llm=self._llm,
|
||||
task=task,
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
@@ -595,15 +621,15 @@ class Agent(BaseAgent):
|
||||
tools_names=get_tool_names(parsed_tools),
|
||||
tools_description=render_text_description_and_args(parsed_tools),
|
||||
step_callback=self.step_callback,
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
function_calling_llm=self._function_calling_llm,
|
||||
respect_context_window=self.respect_context_window,
|
||||
request_within_rpm_limit=(
|
||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
||||
),
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
litellm_callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
@@ -613,7 +639,7 @@ class Agent(BaseAgent):
|
||||
|
||||
return [AddImageTool()]
|
||||
|
||||
def get_code_execution_tools(self):
|
||||
def get_code_execution_tools(self) -> list[BaseTool]:
|
||||
try:
|
||||
from crewai_tools import CodeInterpreterTool # type: ignore
|
||||
|
||||
@@ -624,8 +650,11 @@ class Agent(BaseAgent):
|
||||
self._logger.log(
|
||||
"info", "Coding tools not available. Install crewai_tools. "
|
||||
)
|
||||
return []
|
||||
|
||||
def get_output_converter(self, llm, text, model, instructions):
|
||||
def get_output_converter(
|
||||
self, llm: BaseLLM, text: str, model: str, instructions: str
|
||||
) -> Converter:
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
def _training_handler(self, task_prompt: str) -> str:
|
||||
@@ -654,7 +683,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
@@ -673,7 +702,7 @@ class Agent(BaseAgent):
|
||||
|
||||
return description
|
||||
|
||||
def _inject_date_to_task(self, task):
|
||||
def _inject_date_to_task(self, task: Task) -> None:
|
||||
"""Inject the current date into the task description if inject_date is enabled."""
|
||||
if self.inject_date:
|
||||
from datetime import datetime
|
||||
@@ -723,7 +752,7 @@ class Agent(BaseAgent):
|
||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
|
||||
@property
|
||||
@@ -736,7 +765,7 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def set_fingerprint(self, fingerprint: Fingerprint):
|
||||
def set_fingerprint(self, fingerprint: Fingerprint) -> None:
|
||||
self.security_config.fingerprint = fingerprint
|
||||
|
||||
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
|
||||
@@ -752,22 +781,8 @@ class Agent(BaseAgent):
|
||||
task_prompt=task_prompt
|
||||
)
|
||||
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
|
||||
if not isinstance(self.llm, BaseLLM):
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Knowledge search query failed: LLM for agent '{self.role}' is not an instance of BaseLLM",
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeQueryFailedEvent(
|
||||
agent=self,
|
||||
error="LLM is not compatible with knowledge search queries",
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
rewritten_query = self.llm.call(
|
||||
rewritten_query = self._llm.call(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
@@ -796,8 +811,8 @@ class Agent(BaseAgent):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: Optional[type[Any]] = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
@@ -819,7 +834,7 @@ class Agent(BaseAgent):
|
||||
role=self.role,
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self.llm,
|
||||
llm=self._llm,
|
||||
tools=self.tools or [],
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
@@ -836,8 +851,8 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: Optional[type[Any]] = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
@@ -857,7 +872,7 @@ class Agent(BaseAgent):
|
||||
role=self.role,
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self.llm,
|
||||
llm=self._llm,
|
||||
tools=self.tools or [],
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.parser import parse, AgentAction, AgentFinish, OutputParserException
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
|
||||
__all__ = ["CacheHandler", "parse", "AgentAction", "AgentFinish", "OutputParserException", "ToolsHandler"]
|
||||
__all__ = [
|
||||
"parse",
|
||||
"AgentAction",
|
||||
"AgentFinish",
|
||||
"OutputParserException",
|
||||
"ToolsHandler",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -10,20 +10,22 @@ from crewai.agents.agent_adapters.langgraph.structured_output_converter import (
|
||||
LangGraphConverterAdapter,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.converter import Converter
|
||||
|
||||
try:
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.checkpoint.memory import ( # type: ignore
|
||||
MemorySaver,
|
||||
)
|
||||
from langgraph.prebuilt import create_react_agent # type: ignore
|
||||
|
||||
LANGGRAPH_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -51,11 +53,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
role: str,
|
||||
goal: str,
|
||||
backstory: str,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
llm: Any = None,
|
||||
max_iterations: int = 10,
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
agent_config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize the LangGraph agent adapter."""
|
||||
if not LANGGRAPH_AVAILABLE:
|
||||
@@ -81,7 +83,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
try:
|
||||
self._memory = MemorySaver()
|
||||
|
||||
converted_tools: List[Any] = self._tool_adapter.tools()
|
||||
converted_tools: list[Any] = self._tool_adapter.tools()
|
||||
if self._agent_config:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
@@ -111,7 +113,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
"""Build a system prompt for the LangGraph agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -124,10 +126,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task using the LangGraph workflow."""
|
||||
self.create_agent_executor(tools)
|
||||
self.create_agent_executor(task, tools)
|
||||
|
||||
self.configure_structured_output(task)
|
||||
|
||||
@@ -197,11 +199,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def create_agent_executor(
|
||||
self, task: Any = None, tools: Optional[list[BaseTool]] = None
|
||||
) -> None:
|
||||
"""Configure the LangGraph agent for execution."""
|
||||
self.configure_tools(tools)
|
||||
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: Optional[list[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the LangGraph agent."""
|
||||
if tools:
|
||||
all_tools = list(self.tools or []) + list(tools or [])
|
||||
@@ -209,7 +213,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
available_tools = self._tool_adapter.tools()
|
||||
self._graph.tools = available_tools
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph."""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
@@ -220,6 +224,6 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
"""Convert output format if needed."""
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for LangGraph."""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
@@ -7,14 +8,15 @@ from crewai.utilities.converter import generate_model_description
|
||||
class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
"""Adapter for handling structured output conversion in LangGraph agents"""
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
def __init__(self, agent_adapter: Any) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
super().__init__(agent_adapter) # type: ignore
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._system_prompt_appendix = None
|
||||
self._output_format: str | None = None
|
||||
self._schema: str | None = None
|
||||
self._system_prompt_appendix: str | None = None
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for LangGraph."""
|
||||
if not (task.output_json or task.output_pydantic):
|
||||
self._output_format = None
|
||||
@@ -41,7 +43,7 @@ Important: Your final answer MUST be provided in the following structured format
|
||||
|
||||
{self._schema}
|
||||
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
The output should be raw JSON that exactly matches the specified schema.
|
||||
"""
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -7,19 +7,19 @@ from crewai.agents.agent_adapters.openai_agents.structured_output_converter impo
|
||||
OpenAIConverterAdapter,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Logger
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Logger
|
||||
|
||||
try:
|
||||
from agents import Agent as OpenAIAgent # type: ignore
|
||||
from agents import Runner, enable_verbose_stdout_logging # type: ignore
|
||||
from agents import Agent as OpenAIAgent # type: ignore[import-not-found]
|
||||
from agents import Runner, enable_verbose_stdout_logging
|
||||
|
||||
from .openai_agent_tool_adapter import OpenAIAgentToolAdapter
|
||||
|
||||
@@ -40,13 +40,14 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
step_callback: Any = Field(default=None)
|
||||
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
|
||||
_converter_adapter: OpenAIConverterAdapter = PrivateAttr()
|
||||
agent_executor: Any = Field(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o-mini",
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
agent_config: Optional[dict] = None,
|
||||
**kwargs,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
agent_config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
@@ -72,7 +73,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
"""Build a system prompt for the OpenAI agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -85,11 +86,11 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
) -> Any:
|
||||
"""Execute a task using the OpenAI Assistant"""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
self.create_agent_executor(tools)
|
||||
self.create_agent_executor(task, tools)
|
||||
|
||||
if self.verbose:
|
||||
enable_verbose_stdout_logging()
|
||||
@@ -109,6 +110,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
task=task,
|
||||
),
|
||||
)
|
||||
assert hasattr(self, "agent_executor"), "agent_executor not initialized"
|
||||
result = self.agent_executor.run_sync(self._openai_agent, task_prompt)
|
||||
final_answer = self.handle_execution_result(result)
|
||||
crewai_event_bus.emit(
|
||||
@@ -131,7 +133,9 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def create_agent_executor(
|
||||
self, task: Any = None, tools: Optional[list[BaseTool]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Configure the OpenAI agent for execution.
|
||||
While OpenAI handles execution differently through Runner,
|
||||
@@ -152,24 +156,24 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
self.agent_executor = Runner
|
||||
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: Optional[list[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
if tools:
|
||||
self._tool_adapter.configure_tools(tools)
|
||||
if self._tool_adapter.converted_tools:
|
||||
self._openai_agent.tools = self._tool_adapter.converted_tools
|
||||
|
||||
def handle_execution_result(self, result: Any) -> str:
|
||||
def handle_execution_result(self, result: Any) -> Any:
|
||||
"""Process OpenAI Assistant execution result converting any structured output to a string"""
|
||||
return self._converter_adapter.post_process_result(result.final_output)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support"""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
@@ -19,14 +20,15 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
_output_model: The Pydantic model for the output
|
||||
"""
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
def __init__(self, agent_adapter: Any) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
super().__init__(agent_adapter) # type: ignore
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._output_model = None
|
||||
self._output_format: str | None = None
|
||||
self._schema: str | None = None
|
||||
self._output_model: Any = None
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
def configure_structured_output(self, task: Any) -> None:
|
||||
"""
|
||||
Configure the structured output for OpenAI agent based on task requirements.
|
||||
|
||||
@@ -75,7 +77,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
|
||||
return f"{base_prompt}\n\n{output_schema}"
|
||||
|
||||
def post_process_result(self, result: str) -> str:
|
||||
def post_process_result(self, result: str) -> Any:
|
||||
"""
|
||||
Post-process the result to ensure it matches the expected format.
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -14,6 +15,7 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
@@ -61,7 +63,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
Methods:
|
||||
execute_task(task: Any, context: Optional[str] = None, tools: Optional[List[BaseTool]] = None) -> str:
|
||||
Abstract method to execute a task.
|
||||
create_agent_executor(tools=None) -> None:
|
||||
create_agent_executor(task, tools=None) -> None:
|
||||
Abstract method to create an agent executor.
|
||||
get_delegation_tools(agents: List["BaseAgent"]):
|
||||
Abstract method to set the agents task tools for handling delegation and question asking to other agents in crew.
|
||||
@@ -79,7 +81,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
Set private attributes.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
__hash__ = object.__hash__
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
@@ -91,7 +93,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Objective of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
config: Optional[dict[str, Any]] = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True
|
||||
)
|
||||
cache: bool = Field(
|
||||
@@ -108,14 +110,14 @@ class BaseAgent(ABC, BaseModel):
|
||||
default=False,
|
||||
description="Enable agent to delegate and ask questions among each other.",
|
||||
)
|
||||
tools: Optional[List[BaseTool]] = Field(
|
||||
tools: Optional[list[BaseTool]] = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
)
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
)
|
||||
agent_executor: InstanceOf = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
agent_executor: Optional[Any] = Field(
|
||||
default=None, description="An instance of the agent executor class."
|
||||
)
|
||||
llm: Any = Field(
|
||||
default=None, description="Language model that will run the agent."
|
||||
@@ -129,7 +131,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
default_factory=ToolsHandler,
|
||||
description="An instance of the ToolsHandler class.",
|
||||
)
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
@@ -138,7 +140,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
default=None, description="Knowledge for the agent."
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: Optional[list[BaseKnowledgeSource]] = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
@@ -150,7 +152,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the agent, including fingerprinting.",
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
callbacks: list[Callable[..., Any]] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
adapted_agent: bool = Field(
|
||||
@@ -163,12 +165,12 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def process_model_config(cls, values):
|
||||
def process_model_config(cls, values: Any) -> Any:
|
||||
return process_config(values, cls)
|
||||
|
||||
@field_validator("tools")
|
||||
@classmethod
|
||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
||||
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
|
||||
"""Validate and process the tools provided to the agent.
|
||||
|
||||
This method ensures that each tool is either an instance of BaseTool
|
||||
@@ -196,7 +198,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
return processed_tools
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_set_attributes(self):
|
||||
def validate_and_set_attributes(self) -> Self:
|
||||
# Validate required fields
|
||||
for field in ["role", "goal", "backstory"]:
|
||||
if getattr(self, field) is None:
|
||||
@@ -228,7 +230,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self):
|
||||
def set_private_attrs(self) -> Self:
|
||||
"""Set private attributes."""
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
@@ -240,7 +242,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
return self
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
def key(self) -> str:
|
||||
source = [
|
||||
self._original_role or self.role,
|
||||
self._original_goal or self.goal,
|
||||
@@ -253,16 +255,18 @@ class BaseAgent(ABC, BaseModel):
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_agent_executor(self, tools=None) -> None:
|
||||
def create_agent_executor(
|
||||
self, task: Any, tools: Optional[list[BaseTool]] = None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
||||
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
|
||||
"""Set the task tools that init BaseAgenTools class."""
|
||||
pass
|
||||
|
||||
@@ -320,7 +324,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
return copied_agent
|
||||
|
||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Interpolate inputs into the agent description and backstory."""
|
||||
if self._original_role is None:
|
||||
self._original_role = self.role
|
||||
@@ -350,7 +354,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
if self.cache:
|
||||
self.cache_handler = cache_handler
|
||||
self.tools_handler.cache = cache_handler
|
||||
self.create_agent_executor()
|
||||
# Executor will be created when a task is executed
|
||||
|
||||
def set_rpm_controller(self, rpm_controller: RPMController) -> None:
|
||||
"""Set the rpm controller for the agent.
|
||||
@@ -360,7 +364,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
"""
|
||||
if not self._rpm_controller:
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
# Executor will be created when a task is executed
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: Optional[dict[str, Any]] = None) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,31 +1,32 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.events.event_listener import event_listener
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.parser import AgentFinish
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class CrewAgentExecutorMixin:
|
||||
crew: "Crew"
|
||||
crew: "Crew | None"
|
||||
agent: "BaseAgent"
|
||||
task: "Task"
|
||||
iterations: int
|
||||
max_iter: int
|
||||
messages: List[Dict[str, str]]
|
||||
messages: list[dict[str, str]]
|
||||
_i18n: I18N
|
||||
_printer: Printer = Printer()
|
||||
|
||||
def _create_short_term_memory(self, output) -> None:
|
||||
def _create_short_term_memory(self, output: "AgentFinish") -> None:
|
||||
"""Create and save a short-term memory item if conditions are met."""
|
||||
if (
|
||||
self.crew
|
||||
@@ -35,7 +36,8 @@ class CrewAgentExecutorMixin:
|
||||
):
|
||||
try:
|
||||
if (
|
||||
hasattr(self.crew, "_short_term_memory")
|
||||
self.crew
|
||||
and hasattr(self.crew, "_short_term_memory")
|
||||
and self.crew._short_term_memory
|
||||
):
|
||||
self.crew._short_term_memory.save(
|
||||
@@ -48,7 +50,7 @@ class CrewAgentExecutorMixin:
|
||||
print(f"Failed to add to short term memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_external_memory(self, output) -> None:
|
||||
def _create_external_memory(self, output: "AgentFinish") -> None:
|
||||
"""Create and save a external-term memory item if conditions are met."""
|
||||
if (
|
||||
self.crew
|
||||
@@ -69,7 +71,7 @@ class CrewAgentExecutorMixin:
|
||||
print(f"Failed to add to external memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_long_term_memory(self, output) -> None:
|
||||
def _create_long_term_memory(self, output: "AgentFinish") -> None:
|
||||
"""Create and save long-term and entity memory items based on evaluation."""
|
||||
if (
|
||||
self.crew
|
||||
|
||||
6
src/crewai/agents/cache/__init__.py
vendored
6
src/crewai/agents/cache/__init__.py
vendored
@@ -1,3 +1,5 @@
|
||||
from .cache_handler import CacheHandler
|
||||
"""Internal caching utilities for agent tool execution.
|
||||
|
||||
__all__ = ["CacheHandler"]
|
||||
This package provides caching mechanisms for storing and retrieving
|
||||
tool execution results to avoid redundant operations.
|
||||
"""
|
||||
|
||||
49
src/crewai/agents/cache/cache_handler.py
vendored
49
src/crewai/agents/cache/cache_handler.py
vendored
@@ -1,15 +1,50 @@
|
||||
from typing import Any, Dict, Optional
|
||||
"""Cache handler for storing and retrieving tool execution results.
|
||||
|
||||
This module provides a caching mechanism for tool outputs in the CrewAI framework,
|
||||
allowing agents to reuse previous tool execution results when the same tool is
|
||||
called with identical arguments.
|
||||
|
||||
Classes:
|
||||
CacheHandler: Manages the caching of tool execution results using an in-memory
|
||||
dictionary with serialized tool arguments as keys.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
class CacheHandler(BaseModel):
|
||||
"""Callback handler for tool usage."""
|
||||
"""Callback handler for tool usage.
|
||||
|
||||
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def add(self, tool, input, output):
|
||||
self._cache[f"{tool}-{input}"] = output
|
||||
Notes:
|
||||
TODO: Make thread-safe, currently not thread-safe.
|
||||
"""
|
||||
|
||||
def read(self, tool, input) -> Optional[str]:
|
||||
return self._cache.get(f"{tool}-{input}")
|
||||
_cache: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def add(self, tool: str, input_data: dict[str, Any] | None, output: str) -> None:
|
||||
"""Add a tool execution result to the cache.
|
||||
|
||||
Args:
|
||||
tool: The name of the tool.
|
||||
input_data: The input arguments for the tool.
|
||||
output: The output from the tool execution.
|
||||
"""
|
||||
cache_key = json.dumps(input_data, sort_keys=True) if input_data else ""
|
||||
self._cache[f"{tool}-{cache_key}"] = output
|
||||
|
||||
def read(self, tool: str, input_data: dict[str, Any] | None) -> str | None:
|
||||
"""Read a tool execution result from the cache.
|
||||
|
||||
Args:
|
||||
tool: The name of the tool.
|
||||
input_data: The input arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The cached output if found, None otherwise.
|
||||
"""
|
||||
cache_key = json.dumps(input_data, sort_keys=True) if input_data else ""
|
||||
return self._cache.get(f"{tool}-{cache_key}")
|
||||
|
||||
@@ -4,8 +4,14 @@ Handles agent execution flow including LLM interactions, tool execution,
|
||||
and memory management.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
@@ -21,6 +27,7 @@ from crewai.events.types.logging_events import (
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities import I18N, Printer
|
||||
@@ -51,9 +58,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: Any,
|
||||
task: Any,
|
||||
crew: Any,
|
||||
llm: BaseLLM,
|
||||
task: Task,
|
||||
crew: Crew | None,
|
||||
agent: BaseAgent,
|
||||
prompt: dict[str, str],
|
||||
max_iter: int,
|
||||
@@ -62,19 +69,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
stop_words: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Any = None,
|
||||
original_tools: list[Any] | None = None,
|
||||
function_calling_llm: Any = None,
|
||||
step_callback: Callable[[AgentAction | AgentFinish], None] | None = None,
|
||||
original_tools: list[BaseTool] | None = None,
|
||||
function_calling_llm: BaseLLM | None = None,
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
litellm_callbacks: list[Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize executor.
|
||||
|
||||
Args:
|
||||
llm: Language model instance.
|
||||
task: Task to execute.
|
||||
crew: Crew instance.
|
||||
crew: Optional Crew instance.
|
||||
agent: Agent to execute.
|
||||
prompt: Prompt templates.
|
||||
max_iter: Maximum iterations.
|
||||
@@ -88,19 +95,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
function_calling_llm: Optional function calling LLM.
|
||||
respect_context_window: Respect context limits.
|
||||
request_within_rpm_limit: RPM limit check function.
|
||||
callbacks: Optional callbacks list.
|
||||
litellm_callbacks: Optional litellm callbacks list.
|
||||
"""
|
||||
self._i18n: I18N = I18N()
|
||||
self.llm: BaseLLM = llm
|
||||
self.llm = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.crew = crew
|
||||
self.crew: Crew | None = crew
|
||||
self.prompt = prompt
|
||||
self.tools = tools
|
||||
self.tools_names = tools_names
|
||||
self.stop = stop_words
|
||||
self.max_iter = max_iter
|
||||
self.callbacks = callbacks or []
|
||||
self.litellm_callbacks = litellm_callbacks or []
|
||||
self._printer: Printer = Printer()
|
||||
self.tools_handler = tools_handler
|
||||
self.original_tools = original_tools or []
|
||||
@@ -123,7 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
)
|
||||
|
||||
def invoke(self, inputs: dict[str, str]) -> dict[str, Any]:
|
||||
def invoke(self, inputs: dict[str, str]) -> dict[str, str]:
|
||||
"""Execute the agent with given inputs.
|
||||
|
||||
Args:
|
||||
@@ -131,6 +138,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
Returns:
|
||||
Dictionary with agent output.
|
||||
|
||||
Raises:
|
||||
AssertionError: If agent fails to reach final answer.
|
||||
Exception: If unknown error occurs during execution.
|
||||
"""
|
||||
if "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
|
||||
@@ -170,6 +181,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
|
||||
Raises:
|
||||
Exception: If litellm error or unknown error occurs.
|
||||
"""
|
||||
formatted_answer = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
@@ -181,7 +195,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
callbacks=self.litellm_callbacks,
|
||||
)
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
@@ -189,7 +203,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
answer = get_llm_response(
|
||||
llm=self.llm,
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
callbacks=self.litellm_callbacks,
|
||||
printer=self._printer,
|
||||
from_task=self.task,
|
||||
)
|
||||
@@ -198,10 +212,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
# Extract agent fingerprint if available
|
||||
fingerprint_context = {}
|
||||
if (
|
||||
self.agent
|
||||
and hasattr(self.agent, "security_config")
|
||||
and hasattr(self.agent.security_config, "fingerprint")
|
||||
if hasattr(self.agent, "security_config") and hasattr(
|
||||
self.agent.security_config, "fingerprint"
|
||||
):
|
||||
fingerprint_context = {
|
||||
"agent_fingerprint": str(
|
||||
@@ -214,8 +226,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
agent_key=self.agent.key,
|
||||
agent_role=self.agent.role,
|
||||
tools_handler=self.tools_handler,
|
||||
task=self.task,
|
||||
agent=self.agent,
|
||||
@@ -247,7 +259,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
callbacks=self.litellm_callbacks,
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
@@ -317,18 +329,13 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
def _show_start_logs(self) -> None:
|
||||
"""Emit agent start event."""
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
AgentLogsStartedEvent(
|
||||
agent_role=self.agent.role,
|
||||
task_description=(
|
||||
getattr(self.task, "description") if self.task else "Not Found"
|
||||
),
|
||||
task_description=self.task.description,
|
||||
verbose=self.agent.verbose
|
||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
||||
or (self.crew.verbose if self.crew else False),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -338,16 +345,13 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Args:
|
||||
formatted_answer: Agent's response to log.
|
||||
"""
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
AgentLogsExecutionEvent(
|
||||
agent_role=self.agent.role,
|
||||
formatted_answer=formatted_answer,
|
||||
verbose=self.agent.verbose
|
||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
||||
or (self.crew.verbose if self.crew else False),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -361,9 +365,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
human_feedback: Optional feedback from human.
|
||||
"""
|
||||
agent_id = str(self.agent.id)
|
||||
train_iteration = (
|
||||
getattr(self.crew, "_train_iteration", None) if self.crew else None
|
||||
)
|
||||
train_iteration = getattr(self.crew, "_train_iteration", None)
|
||||
|
||||
if train_iteration is None or not isinstance(train_iteration, int):
|
||||
self._printer.print(
|
||||
|
||||
@@ -7,18 +7,18 @@ AgentAction or AgentFinish objects.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from json_repair import repair_json
|
||||
from json_repair import repair_json # type: ignore[import-untyped]
|
||||
|
||||
from crewai.agents.constants import (
|
||||
ACTION_INPUT_ONLY_REGEX,
|
||||
ACTION_INPUT_REGEX,
|
||||
ACTION_REGEX,
|
||||
ACTION_INPUT_ONLY_REGEX,
|
||||
FINAL_ANSWER_ACTION,
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
UNABLE_TO_REPAIR_JSON_RESULTS,
|
||||
)
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
_I18N = I18N()
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Tools handler for managing tool execution and caching."""
|
||||
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.tools.cache_tools.cache_tools import CacheTools
|
||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
|
||||
|
||||
class ToolsHandler:
|
||||
@@ -39,6 +39,6 @@ class ToolsHandler:
|
||||
if self.cache and should_cache and calling.tool_name != CacheTools().name:
|
||||
self.cache.add(
|
||||
tool=calling.tool_name,
|
||||
input=calling.arguments,
|
||||
input_data=calling.arguments,
|
||||
output=output,
|
||||
)
|
||||
|
||||
@@ -3,26 +3,18 @@ import json
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Mapping, Set
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -34,15 +26,36 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
@@ -57,29 +70,9 @@ from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.utilities.formatter import (
|
||||
aggregate_raw_outputs_from_task_outputs,
|
||||
aggregate_raw_outputs_from_tasks,
|
||||
@@ -116,9 +109,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
planning: Plan the crew execution and add the plan to the crew.
|
||||
chat_llm: The language model used for orchestrating chat interactions with the crew.
|
||||
security_config: Security configuration for the crew, including fingerprinting.
|
||||
|
||||
Notes:
|
||||
TODO: Improve the embedder type from dict[str, Any] to a more specific TypedDict or dataclass.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
__hash__ = object.__hash__
|
||||
_execution_span: Any = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
@@ -130,7 +126,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
|
||||
_train: Optional[bool] = PrivateAttr(default=False)
|
||||
_train_iteration: Optional[int] = PrivateAttr()
|
||||
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
_inputs: Optional[dict[str, Any]] = PrivateAttr(default=None)
|
||||
_logging_color: str = PrivateAttr(
|
||||
default="bold_purple",
|
||||
)
|
||||
@@ -140,8 +136,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
name: Optional[str] = Field(default="crew")
|
||||
cache: bool = Field(default=True)
|
||||
tasks: List[Task] = Field(default_factory=list)
|
||||
agents: List[BaseAgent] = Field(default_factory=list)
|
||||
tasks: list[Task] = Field(default_factory=list)
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool = Field(
|
||||
@@ -164,7 +160,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="An Instance of the ExternalMemory to be used by the Crew",
|
||||
)
|
||||
embedder: Optional[dict] = Field(
|
||||
embedder: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
@@ -172,16 +168,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
manager_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: Optional[BaseAgent] = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
function_calling_llm: Optional[str | InstanceOf[LLM] | Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
|
||||
config: Optional[Json[dict[str, Any]] | dict[str, Any]] = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: Optional[bool] = Field(default=False)
|
||||
step_callback: Optional[Any] = Field(
|
||||
@@ -192,13 +188,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Callback to be executed after each task for all agents execution.",
|
||||
)
|
||||
before_kickoff_callbacks: List[
|
||||
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
|
||||
before_kickoff_callbacks: list[
|
||||
Callable[[Optional[dict[str, Any]]], Optional[dict[str, Any]]]
|
||||
] = Field(
|
||||
default_factory=list,
|
||||
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
|
||||
)
|
||||
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
default_factory=list,
|
||||
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
|
||||
)
|
||||
@@ -210,7 +206,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Path to the prompt json file to be used for the crew.",
|
||||
)
|
||||
output_log_file: Optional[Union[bool, str]] = Field(
|
||||
output_log_file: Optional[bool | str] = Field(
|
||||
default=None,
|
||||
description="Path to the log file to be saved",
|
||||
)
|
||||
@@ -218,23 +214,23 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
planning_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
default=None,
|
||||
description="Language model that will run the AgentPlanner if planning is True.",
|
||||
)
|
||||
task_execution_output_json_files: Optional[List[str]] = Field(
|
||||
task_execution_output_json_files: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description="List of file paths for task execution JSON files.",
|
||||
)
|
||||
execution_logs: List[Dict[str, Any]] = Field(
|
||||
execution_logs: list[dict[str, Any]] = Field(
|
||||
default=[],
|
||||
description="List of execution logs for tasks",
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: Optional[list[BaseKnowledgeSource]] = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||
)
|
||||
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
chat_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -267,8 +263,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config_type(
|
||||
cls, v: Union[Json, Dict[str, Any]]
|
||||
) -> Union[Json, Dict[str, Any]]:
|
||||
cls, v: Json[dict[str, Any]] | dict[str, Any]
|
||||
) -> Json[dict[str, Any]] | dict[str, Any]:
|
||||
"""Validates that the config is a valid type.
|
||||
Args:
|
||||
v: The config to be validated.
|
||||
@@ -277,10 +273,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
|
||||
# TODO: Improve typing
|
||||
return json.loads(v) if isinstance(v, Json) else v # type: ignore
|
||||
return json.loads(v) if isinstance(v, str) else v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> "Crew":
|
||||
def set_private_attrs(self) -> Self:
|
||||
"""Set private attributes."""
|
||||
|
||||
self._cache_handler = CacheHandler()
|
||||
@@ -300,7 +296,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
def _initialize_default_memories(self):
|
||||
def _initialize_default_memories(self) -> None:
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||
crew=self,
|
||||
@@ -311,7 +307,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_memory(self) -> "Crew":
|
||||
def create_crew_memory(self) -> Self:
|
||||
"""Initialize private memory attributes."""
|
||||
self._external_memory = (
|
||||
# External memory doesn’t support a default value since it was designed to be managed entirely externally
|
||||
@@ -328,7 +324,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_knowledge(self) -> "Crew":
|
||||
def create_crew_knowledge(self) -> Self:
|
||||
"""Create the knowledge for the crew."""
|
||||
if self.knowledge_sources:
|
||||
try:
|
||||
@@ -349,7 +345,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_manager_llm(self):
|
||||
def check_manager_llm(self) -> Self:
|
||||
"""Validates that the language model is set when using hierarchical process."""
|
||||
if self.process == Process.hierarchical:
|
||||
if not self.manager_llm and not self.manager_agent:
|
||||
@@ -371,7 +367,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_config(self):
|
||||
def check_config(self) -> Self:
|
||||
"""Validates that the crew is properly configured with agents and tasks."""
|
||||
if not self.config and not self.tasks and not self.agents:
|
||||
raise PydanticCustomError(
|
||||
@@ -392,20 +388,20 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tasks(self):
|
||||
def validate_tasks(self) -> Self:
|
||||
if self.process == Process.sequential:
|
||||
for task in self.tasks:
|
||||
if task.agent is None:
|
||||
raise PydanticCustomError(
|
||||
"missing_agent_in_task",
|
||||
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
f"Sequential process error: Agent is missing in the task with the following description: {task.description}",
|
||||
{},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_end_with_at_most_one_async_task(self):
|
||||
def validate_end_with_at_most_one_async_task(self) -> Self:
|
||||
"""Validates that the crew ends with at most one asynchronous task."""
|
||||
final_async_task_count = 0
|
||||
|
||||
@@ -426,7 +422,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_must_have_non_conditional_task(self) -> "Crew":
|
||||
def validate_must_have_non_conditional_task(self) -> Self:
|
||||
"""Ensure that a crew has at least one non-conditional task."""
|
||||
if not self.tasks:
|
||||
return self
|
||||
@@ -442,7 +438,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_first_task(self) -> "Crew":
|
||||
def validate_first_task(self) -> Self:
|
||||
"""Ensure the first task is not a ConditionalTask."""
|
||||
if self.tasks and isinstance(self.tasks[0], ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
@@ -453,19 +449,21 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_async_tasks_not_async(self) -> "Crew":
|
||||
def validate_async_tasks_not_async(self) -> Self:
|
||||
"""Ensure that ConditionalTask is not async."""
|
||||
for task in self.tasks:
|
||||
if task.async_execution and isinstance(task, ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
"invalid_async_conditional_task",
|
||||
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
f"Conditional Task: {task.description} , cannot be executed asynchronously.",
|
||||
{},
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_async_task_cannot_include_sequential_async_tasks_in_context(self):
|
||||
def validate_async_task_cannot_include_sequential_async_tasks_in_context(
|
||||
self,
|
||||
) -> Self:
|
||||
"""
|
||||
Validates that if a task is set to be executed asynchronously,
|
||||
it cannot include other asynchronous tasks in its context unless
|
||||
@@ -485,7 +483,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_context_no_future_tasks(self):
|
||||
def validate_context_no_future_tasks(self) -> Self:
|
||||
"""Validates that a task's context does not include future tasks."""
|
||||
task_indices = {id(task): i for i, task in enumerate(self.tasks)}
|
||||
|
||||
@@ -502,7 +500,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
source: List[str] = [agent.key for agent in self.agents] + [
|
||||
source: list[str] = [agent.key for agent in self.agents] + [
|
||||
task.key for task in self.tasks
|
||||
]
|
||||
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
||||
@@ -517,7 +515,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def _setup_from_config(self):
|
||||
def _setup_from_config(self) -> None:
|
||||
assert self.config is not None, "Config should not be None."
|
||||
|
||||
"""Initializes agents and tasks from the provided config."""
|
||||
@@ -530,7 +528,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.agents = [Agent(**agent) for agent in self.config["agents"]]
|
||||
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
|
||||
|
||||
def _create_task(self, task_config: Dict[str, Any]) -> Task:
|
||||
def _create_task(self, task_config: dict[str, Any]) -> Task:
|
||||
"""Creates a task instance from its configuration.
|
||||
|
||||
Args:
|
||||
@@ -559,7 +557,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewTrainingHandler(filename).initialize_file()
|
||||
|
||||
def train(
|
||||
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None
|
||||
self, n_iterations: int, filename: str, inputs: Optional[dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Trains the crew for a given number of iterations."""
|
||||
inputs = inputs or {}
|
||||
@@ -611,7 +609,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
) -> CrewOutput:
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
@@ -643,8 +641,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
for agent in self.agents:
|
||||
agent.i18n = i18n
|
||||
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
|
||||
agent.crew = self # type: ignore[attr-defined]
|
||||
agent.crew = self
|
||||
agent.set_knowledge(crew_embedder=self.embedder)
|
||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
@@ -653,7 +650,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
|
||||
agent.create_agent_executor()
|
||||
# Agent executor will be created when tasks are executed
|
||||
|
||||
if self.planning:
|
||||
self._handle_crew_planning()
|
||||
@@ -682,9 +679,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
|
||||
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input in the list and aggregates results."""
|
||||
results: List[CrewOutput] = []
|
||||
results: list[CrewOutput] = []
|
||||
|
||||
# Initialize the parent crew's usage metrics
|
||||
total_usage_metrics = UsageMetrics()
|
||||
@@ -704,16 +701,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return results
|
||||
|
||||
async def kickoff_async(
|
||||
self, inputs: Optional[Dict[str, Any]] = None
|
||||
self, inputs: Optional[dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution."""
|
||||
inputs = inputs or {}
|
||||
return await asyncio.to_thread(self.kickoff, inputs)
|
||||
|
||||
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
|
||||
async def kickoff_for_each_async(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput]:
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def run_crew(crew, input_data):
|
||||
async def run_crew(crew: Self, input_data: dict[str, Any]) -> CrewOutput:
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
|
||||
tasks = [
|
||||
@@ -732,7 +731,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
def _handle_crew_planning(self):
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
self._logger.log("info", "Planning the crew execution")
|
||||
result = CrewPlanner(
|
||||
@@ -748,7 +747,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
output: TaskOutput,
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
if self._inputs:
|
||||
inputs = self._inputs
|
||||
else:
|
||||
@@ -780,7 +779,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._create_manager_agent()
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
def _create_manager_agent(self):
|
||||
def _create_manager_agent(self) -> None:
|
||||
i18n = I18N(prompt_file=self.prompt_file)
|
||||
if self.manager_agent is not None:
|
||||
self.manager_agent.allow_delegation = True
|
||||
@@ -792,7 +791,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
manager.tools = []
|
||||
raise Exception("Manager agent should not have tools")
|
||||
else:
|
||||
self.manager_llm = create_llm(self.manager_llm)
|
||||
if self.manager_llm is None:
|
||||
from crewai.utilities.llm_utils import create_default_llm
|
||||
|
||||
self.manager_llm = create_default_llm()
|
||||
else:
|
||||
self.manager_llm = create_llm(self.manager_llm)
|
||||
manager = Agent(
|
||||
role=i18n.retrieve("hierarchical_manager_agent", "role"),
|
||||
goal=i18n.retrieve("hierarchical_manager_agent", "goal"),
|
||||
@@ -807,7 +811,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _execute_tasks(
|
||||
self,
|
||||
tasks: List[Task],
|
||||
tasks: list[Task],
|
||||
start_index: Optional[int] = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
@@ -821,8 +825,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
|
||||
task_outputs: List[TaskOutput] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
|
||||
task_outputs: list[TaskOutput] = []
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]] = []
|
||||
last_sync_output: Optional[TaskOutput] = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
@@ -847,7 +851,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
|
||||
cast(list[Tool] | list[BaseTool], tools_for_task),
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
@@ -867,7 +871,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
tools=tools_for_task,
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -879,7 +883,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
tools=tools_for_task,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -893,8 +897,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _handle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: List[TaskOutput],
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
task_outputs: list[TaskOutput],
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Optional[TaskOutput]:
|
||||
@@ -917,8 +921,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return None
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if hasattr(agent, "allow_delegation") and getattr(
|
||||
agent, "allow_delegation", False
|
||||
@@ -948,7 +952,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
|
||||
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
|
||||
return cast(List[BaseTool], tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
|
||||
if self.process == Process.hierarchical:
|
||||
@@ -957,12 +961,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _merge_tools(
|
||||
self,
|
||||
existing_tools: Union[List[Tool], List[BaseTool]],
|
||||
new_tools: Union[List[Tool], List[BaseTool]],
|
||||
) -> List[BaseTool]:
|
||||
existing_tools: list[Tool] | list[BaseTool],
|
||||
new_tools: list[Tool] | list[BaseTool],
|
||||
) -> list[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
if not new_tools:
|
||||
return cast(List[BaseTool], existing_tools)
|
||||
return cast(list[BaseTool], existing_tools)
|
||||
|
||||
# Create mapping of tool names to new tools
|
||||
new_tool_map = {tool.name: tool for tool in new_tools}
|
||||
@@ -973,41 +977,41 @@ class Crew(FlowTrackable, BaseModel):
|
||||
# Add all new tools
|
||||
tools.extend(new_tools)
|
||||
|
||||
return cast(List[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _inject_delegation_tools(
|
||||
self,
|
||||
tools: Union[List[Tool], List[BaseTool]],
|
||||
tools: list[Tool] | list[BaseTool],
|
||||
task_agent: BaseAgent,
|
||||
agents: List[BaseAgent],
|
||||
) -> List[BaseTool]:
|
||||
agents: list[BaseAgent],
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
# Cast delegation_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
return self._merge_tools(tools, delegation_tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_multimodal_tools(
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(agent, "get_multimodal_tools"):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
# Cast multimodal_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_code_execution_tools(
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_delegation_tools(
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
if not tools:
|
||||
@@ -1015,17 +1019,17 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, task.agent, agents_for_delegation
|
||||
)
|
||||
return cast(List[BaseTool], tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
def _log_task_start(self, task: Task, role: str = "None") -> None:
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name, task=task.description, agent=role, status="started"
|
||||
)
|
||||
|
||||
def _update_manager_tools(
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if self.manager_agent:
|
||||
if task.agent:
|
||||
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
||||
@@ -1033,9 +1037,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, self.manager_agent, self.agents
|
||||
)
|
||||
return cast(List[BaseTool], tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
|
||||
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
if not task.context:
|
||||
return ""
|
||||
|
||||
@@ -1057,7 +1061,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
output=output.raw,
|
||||
)
|
||||
|
||||
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
|
||||
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
|
||||
if not task_outputs:
|
||||
raise ValueError("No task outputs available to create crew output.")
|
||||
|
||||
@@ -1088,10 +1092,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _process_async_tasks(
|
||||
self,
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
was_replayed: bool = False,
|
||||
) -> List[TaskOutput]:
|
||||
task_outputs: List[TaskOutput] = []
|
||||
) -> list[TaskOutput]:
|
||||
task_outputs: list[TaskOutput] = []
|
||||
for future_task, future, task_index in futures:
|
||||
task_output = future.result()
|
||||
task_outputs.append(task_output)
|
||||
@@ -1102,7 +1106,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return task_outputs
|
||||
|
||||
def _find_task_index(
|
||||
self, task_id: str, stored_outputs: List[Any]
|
||||
self, task_id: str, stored_outputs: list[Any]
|
||||
) -> Optional[int]:
|
||||
return next(
|
||||
(
|
||||
@@ -1114,7 +1118,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
def replay(
|
||||
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
|
||||
self, task_id: str, inputs: Optional[dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
stored_outputs = self._task_output_handler.load()
|
||||
if not stored_outputs:
|
||||
@@ -1155,15 +1159,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return result
|
||||
|
||||
def query_knowledge(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> Union[List[Dict[str, Any]], None]:
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if self.knowledge:
|
||||
return self.knowledge.query(
|
||||
query, results_limit=results_limit, score_threshold=score_threshold
|
||||
)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> Set[str]:
|
||||
def fetch_inputs(self) -> set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
Scans each task's 'description' + 'expected_output', and each agent's
|
||||
@@ -1172,7 +1176,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: Set[str] = set()
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
for task in self.tasks:
|
||||
@@ -1188,7 +1192,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return required_inputs
|
||||
|
||||
def copy(self):
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
include: Optional[
|
||||
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
|
||||
] = None,
|
||||
exclude: Optional[
|
||||
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
|
||||
] = None,
|
||||
update: Optional[dict[str, Any]] = None,
|
||||
deep: bool = True,
|
||||
) -> "Crew":
|
||||
"""
|
||||
Creates a deep copy of the Crew instance.
|
||||
|
||||
@@ -1219,7 +1234,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
manager_agent = self.manager_agent.copy() if self.manager_agent else None
|
||||
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None
|
||||
|
||||
task_mapping = {}
|
||||
task_mapping: dict[str, Task] = {}
|
||||
|
||||
cloned_tasks = []
|
||||
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
|
||||
@@ -1274,16 +1289,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if not task.callback:
|
||||
task.callback = self.task_callback
|
||||
|
||||
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Interpolates the inputs in the tasks and agents."""
|
||||
[
|
||||
task.interpolate_inputs_and_add_conversation_history(
|
||||
# type: ignore # "interpolate_inputs" of "Task" does not return a value (it only ever returns None)
|
||||
inputs
|
||||
)
|
||||
for task in self.tasks
|
||||
]
|
||||
# type: ignore # "interpolate_inputs" of "Agent" does not return a value (it only ever returns None)
|
||||
for task in self.tasks:
|
||||
task.interpolate_inputs_and_add_conversation_history(inputs)
|
||||
for agent in self.agents:
|
||||
agent.interpolate_inputs(inputs)
|
||||
|
||||
@@ -1307,8 +1316,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
eval_llm: Union[str, InstanceOf[BaseLLM]],
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
eval_llm: str | InstanceOf[BaseLLM],
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
try:
|
||||
@@ -1349,7 +1358,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
|
||||
|
||||
def reset_memories(self, command_type: str) -> None:
|
||||
@@ -1401,7 +1410,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if (system := config.get("system")) is not None:
|
||||
name = config.get("name")
|
||||
try:
|
||||
reset_fn: Callable = cast(Callable, config.get("reset"))
|
||||
reset_fn: Callable[..., None] = cast(
|
||||
Callable[..., None], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
@@ -1430,7 +1441,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise RuntimeError(f"{name} memory system is not initialized")
|
||||
|
||||
try:
|
||||
reset_fn: Callable = cast(Callable, config.get("reset"))
|
||||
reset_fn: Callable[..., None] = cast(
|
||||
Callable[..., None], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
@@ -1441,18 +1454,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _get_memory_systems(self):
|
||||
def _get_memory_systems(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get all available memory systems with their configuration.
|
||||
|
||||
Returns:
|
||||
Dict containing all memory systems with their reset functions and display names.
|
||||
"""
|
||||
|
||||
def default_reset(memory):
|
||||
return memory.reset()
|
||||
def default_reset(memory: Any) -> None:
|
||||
memory.reset()
|
||||
|
||||
def knowledge_reset(memory):
|
||||
return self.reset_knowledge(memory)
|
||||
def knowledge_reset(memory: Any) -> None:
|
||||
self.reset_knowledge(memory)
|
||||
|
||||
# Get knowledge for agents
|
||||
agent_knowledges = [
|
||||
@@ -1506,12 +1519,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
},
|
||||
}
|
||||
|
||||
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
|
||||
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
|
||||
"""Reset crew and agent knowledge storage."""
|
||||
for ks in knowledges:
|
||||
ks.reset()
|
||||
|
||||
def _set_allow_crewai_trigger_context_for_first_task(self):
|
||||
def _set_allow_crewai_trigger_context_for_first_task(self) -> None:
|
||||
crewai_trigger_payload = self._inputs and self._inputs.get(
|
||||
"crewai_trigger_payload"
|
||||
)
|
||||
|
||||
@@ -6,10 +6,10 @@ from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
|
||||
class BaseEventListener(ABC):
|
||||
verbose: bool = False
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.setup_listeners(crewai_event_bus)
|
||||
|
||||
@abstractmethod
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus):
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from blinker import Signal
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_types import EventTypes
|
||||
|
||||
EventT = TypeVar("EventT", bound=BaseEvent)
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class CrewAIEventsBus:
|
||||
@@ -21,21 +23,21 @@ class CrewAIEventsBus:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> Self:
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None: # prevent race condition
|
||||
cls._instance = super(CrewAIEventsBus, cls).__new__(cls)
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""Initialize the event bus internal state"""
|
||||
self._signal = Signal("crewai_event_bus")
|
||||
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
|
||||
self._handlers: dict[type[BaseEvent], list[Callable[[Any, Any], None]]] = {}
|
||||
|
||||
def on(
|
||||
self, event_type: Type[EventT]
|
||||
self, event_type: type[EventT]
|
||||
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
|
||||
"""
|
||||
Decorator to register an event handler for a specific event type.
|
||||
@@ -54,9 +56,7 @@ class CrewAIEventsBus:
|
||||
) -> Callable[[Any, EventT], None]:
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(
|
||||
cast(Callable[[Any, EventT], None], handler)
|
||||
)
|
||||
self._handlers[event_type].append(cast(Callable[[Any, Any], None], handler))
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
@@ -82,17 +82,15 @@ class CrewAIEventsBus:
|
||||
self._signal.send(source, event=event)
|
||||
|
||||
def register_handler(
|
||||
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||
self, event_type: type[BaseEvent], handler: Callable[[Any, Any], None]
|
||||
) -> None:
|
||||
"""Register an event handler for a specific event type"""
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(
|
||||
cast(Callable[[Any, EventTypes], None], handler)
|
||||
)
|
||||
self._handlers[event_type].append(handler)
|
||||
|
||||
@contextmanager
|
||||
def scoped_handlers(self):
|
||||
def scoped_handlers(self) -> Iterator[None]:
|
||||
"""
|
||||
Context manager for temporary event handling scope.
|
||||
Useful for testing or temporary event handling.
|
||||
|
||||
@@ -1,15 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.constants import EMITTER_COLOR
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
@@ -25,34 +42,21 @@ from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
LLMGuardrailCompletedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsStartedEvent,
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.constants import EMITTER_COLOR
|
||||
|
||||
from .listeners.memory_listener import MemoryListener
|
||||
from .types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
@@ -61,38 +65,37 @@ from .types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
|
||||
from .types.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
)
|
||||
|
||||
from .listeners.memory_listener import MemoryListener
|
||||
|
||||
|
||||
class EventListener(BaseEventListener):
|
||||
_instance = None
|
||||
_initialized: bool = False
|
||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
|
||||
execution_spans: dict[Task, Any] = Field(default_factory=dict)
|
||||
next_chunk = 0
|
||||
text_stream = StringIO()
|
||||
knowledge_retrieval_in_progress = False
|
||||
knowledge_query_in_progress = False
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> Self:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
if not hasattr(self, "_initialized") or not self._initialized:
|
||||
super().__init__()
|
||||
self._telemetry = Telemetry()
|
||||
@@ -105,14 +108,14 @@ class EventListener(BaseEventListener):
|
||||
|
||||
# ----------- CREW EVENTS -----------
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
@crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source, event: CrewKickoffStartedEvent):
|
||||
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
|
||||
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
|
||||
self._telemetry.crew_execution_span(source, event.inputs)
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source, event: CrewKickoffCompletedEvent):
|
||||
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
final_string_output = event.output.raw
|
||||
self._telemetry.end_crew(source, final_string_output)
|
||||
@@ -126,7 +129,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event: CrewKickoffFailedEvent):
|
||||
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
|
||||
self.formatter.update_crew_tree(
|
||||
self.formatter.current_crew_tree,
|
||||
event.crew_name or "Crew",
|
||||
@@ -135,23 +138,25 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainStartedEvent)
|
||||
def on_crew_train_started(source, event: CrewTrainStartedEvent):
|
||||
def on_crew_train_started(source: Any, event: CrewTrainStartedEvent) -> None:
|
||||
self.formatter.handle_crew_train_started(
|
||||
event.crew_name or "Crew", str(event.timestamp)
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainCompletedEvent)
|
||||
def on_crew_train_completed(source, event: CrewTrainCompletedEvent):
|
||||
def on_crew_train_completed(
|
||||
source: Any, event: CrewTrainCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_crew_train_completed(
|
||||
event.crew_name or "Crew", str(event.timestamp)
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTrainFailedEvent)
|
||||
def on_crew_train_failed(source, event: CrewTrainFailedEvent):
|
||||
def on_crew_train_failed(source: Any, event: CrewTrainFailedEvent) -> None:
|
||||
self.formatter.handle_crew_train_failed(event.crew_name or "Crew")
|
||||
|
||||
@crewai_event_bus.on(CrewTestResultEvent)
|
||||
def on_crew_test_result(source, event: CrewTestResultEvent):
|
||||
def on_crew_test_result(source: Any, event: CrewTestResultEvent) -> None:
|
||||
self._telemetry.individual_test_result_span(
|
||||
source.crew,
|
||||
event.quality,
|
||||
@@ -162,7 +167,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- TASK EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(TaskStartedEvent)
|
||||
def on_task_started(source, event: TaskStartedEvent):
|
||||
def on_task_started(source: Any, event: TaskStartedEvent) -> None:
|
||||
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
|
||||
self.execution_spans[source] = span
|
||||
# Pass both task ID and task name (if set)
|
||||
@@ -172,7 +177,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
def on_task_completed(source, event: TaskCompletedEvent):
|
||||
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
span = self.execution_spans.get(source)
|
||||
if span:
|
||||
@@ -190,7 +195,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(TaskFailedEvent)
|
||||
def on_task_failed(source, event: TaskFailedEvent):
|
||||
def on_task_failed(source: Any, event: TaskFailedEvent) -> None:
|
||||
span = self.execution_spans.get(source)
|
||||
if span:
|
||||
if source.agent and source.agent.crew:
|
||||
@@ -210,7 +215,9 @@ class EventListener(BaseEventListener):
|
||||
# ----------- AGENT EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionStartedEvent)
|
||||
def on_agent_execution_started(source, event: AgentExecutionStartedEvent):
|
||||
def on_agent_execution_started(
|
||||
source: Any, event: AgentExecutionStartedEvent
|
||||
) -> None:
|
||||
self.formatter.create_agent_branch(
|
||||
self.formatter.current_task_branch,
|
||||
event.agent.role,
|
||||
@@ -218,7 +225,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_execution_completed(source, event: AgentExecutionCompletedEvent):
|
||||
def on_agent_execution_completed(
|
||||
source: Any, event: AgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.update_agent_status(
|
||||
self.formatter.current_agent_branch,
|
||||
event.agent.role,
|
||||
@@ -229,8 +238,8 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def on_lite_agent_execution_started(
|
||||
source, event: LiteAgentExecutionStartedEvent
|
||||
):
|
||||
source: Any, event: LiteAgentExecutionStartedEvent
|
||||
) -> None:
|
||||
"""Handle LiteAgent execution started event."""
|
||||
self.formatter.handle_lite_agent_execution(
|
||||
event.agent_info["role"], status="started", **event.agent_info
|
||||
@@ -238,15 +247,17 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionCompletedEvent)
|
||||
def on_lite_agent_execution_completed(
|
||||
source, event: LiteAgentExecutionCompletedEvent
|
||||
):
|
||||
source: Any, event: LiteAgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
"""Handle LiteAgent execution completed event."""
|
||||
self.formatter.handle_lite_agent_execution(
|
||||
event.agent_info["role"], status="completed", **event.agent_info
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionErrorEvent)
|
||||
def on_lite_agent_execution_error(source, event: LiteAgentExecutionErrorEvent):
|
||||
def on_lite_agent_execution_error(
|
||||
source: Any, event: LiteAgentExecutionErrorEvent
|
||||
) -> None:
|
||||
"""Handle LiteAgent execution error event."""
|
||||
self.formatter.handle_lite_agent_execution(
|
||||
event.agent_info["role"],
|
||||
@@ -258,25 +269,27 @@ class EventListener(BaseEventListener):
|
||||
# ----------- FLOW EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(FlowCreatedEvent)
|
||||
def on_flow_created(source, event: FlowCreatedEvent):
|
||||
def on_flow_created(source: Any, event: FlowCreatedEvent) -> None:
|
||||
self._telemetry.flow_creation_span(event.flow_name)
|
||||
self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def on_flow_started(source, event: FlowStartedEvent):
|
||||
def on_flow_started(source: Any, event: FlowStartedEvent) -> None:
|
||||
self._telemetry.flow_execution_span(
|
||||
event.flow_name, list(source._methods.keys())
|
||||
)
|
||||
self.formatter.start_flow(event.flow_name, str(source.flow_id))
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def on_flow_finished(source, event: FlowFinishedEvent):
|
||||
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
|
||||
self.formatter.update_flow_status(
|
||||
self.formatter.current_flow_tree, event.flow_name, source.flow_id
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def on_method_execution_started(source, event: MethodExecutionStartedEvent):
|
||||
def on_method_execution_started(
|
||||
source: Any, event: MethodExecutionStartedEvent
|
||||
) -> None:
|
||||
self.formatter.update_method_status(
|
||||
self.formatter.current_method_branch,
|
||||
self.formatter.current_flow_tree,
|
||||
@@ -285,7 +298,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def on_method_execution_finished(source, event: MethodExecutionFinishedEvent):
|
||||
def on_method_execution_finished(
|
||||
source: Any, event: MethodExecutionFinishedEvent
|
||||
) -> None:
|
||||
self.formatter.update_method_status(
|
||||
self.formatter.current_method_branch,
|
||||
self.formatter.current_flow_tree,
|
||||
@@ -294,7 +309,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFailedEvent)
|
||||
def on_method_execution_failed(source, event: MethodExecutionFailedEvent):
|
||||
def on_method_execution_failed(
|
||||
source: Any, event: MethodExecutionFailedEvent
|
||||
) -> None:
|
||||
self.formatter.update_method_status(
|
||||
self.formatter.current_method_branch,
|
||||
self.formatter.current_flow_tree,
|
||||
@@ -305,7 +322,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- TOOL USAGE EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(ToolUsageStartedEvent)
|
||||
def on_tool_usage_started(source, event: ToolUsageStartedEvent):
|
||||
def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_started(
|
||||
event.tool_name,
|
||||
@@ -319,7 +336,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def on_tool_usage_finished(source, event: ToolUsageFinishedEvent):
|
||||
def on_tool_usage_finished(source: Any, event: ToolUsageFinishedEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_finished(
|
||||
event.tool_name,
|
||||
@@ -332,7 +349,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
|
||||
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
self.formatter.handle_llm_tool_usage_error(
|
||||
event.tool_name,
|
||||
@@ -349,7 +366,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- LLM EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(source, event: LLMCallStartedEvent):
|
||||
def on_llm_call_started(source: Any, event: LLMCallStartedEvent) -> None:
|
||||
# Capture the returned tool branch and update the current_tool_branch reference
|
||||
thinking_branch = self.formatter.handle_llm_call_started(
|
||||
self.formatter.current_agent_branch,
|
||||
@@ -360,7 +377,7 @@ class EventListener(BaseEventListener):
|
||||
self.formatter.current_tool_branch = thinking_branch
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def on_llm_call_completed(source, event: LLMCallCompletedEvent):
|
||||
def on_llm_call_completed(source: Any, event: LLMCallCompletedEvent) -> None:
|
||||
self.formatter.handle_llm_call_completed(
|
||||
self.formatter.current_tool_branch,
|
||||
self.formatter.current_agent_branch,
|
||||
@@ -368,7 +385,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def on_llm_call_failed(source, event: LLMCallFailedEvent):
|
||||
def on_llm_call_failed(source: Any, event: LLMCallFailedEvent) -> None:
|
||||
self.formatter.handle_llm_call_failed(
|
||||
self.formatter.current_tool_branch,
|
||||
event.error,
|
||||
@@ -376,7 +393,7 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def on_llm_stream_chunk(source, event: LLMStreamChunkEvent):
|
||||
def on_llm_stream_chunk(source: Any, event: LLMStreamChunkEvent) -> None:
|
||||
self.text_stream.write(event.chunk)
|
||||
|
||||
self.text_stream.seek(self.next_chunk)
|
||||
@@ -389,7 +406,9 @@ class EventListener(BaseEventListener):
|
||||
# ----------- LLM GUARDRAIL EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def on_llm_guardrail_started(source, event: LLMGuardrailStartedEvent):
|
||||
def on_llm_guardrail_started(
|
||||
source: Any, event: LLMGuardrailStartedEvent
|
||||
) -> None:
|
||||
guardrail_str = str(event.guardrail)
|
||||
guardrail_name = (
|
||||
guardrail_str[:50] + "..." if len(guardrail_str) > 50 else guardrail_str
|
||||
@@ -398,13 +417,15 @@ class EventListener(BaseEventListener):
|
||||
self.formatter.handle_guardrail_started(guardrail_name, event.retry_count)
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def on_llm_guardrail_completed(source, event: LLMGuardrailCompletedEvent):
|
||||
def on_llm_guardrail_completed(
|
||||
source: Any, event: LLMGuardrailCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_guardrail_completed(
|
||||
event.success, event.error, event.retry_count
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTestStartedEvent)
|
||||
def on_crew_test_started(source, event: CrewTestStartedEvent):
|
||||
def on_crew_test_started(source: Any, event: CrewTestStartedEvent) -> None:
|
||||
cloned_crew = source.copy()
|
||||
self._telemetry.test_execution_span(
|
||||
cloned_crew,
|
||||
@@ -418,20 +439,20 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTestCompletedEvent)
|
||||
def on_crew_test_completed(source, event: CrewTestCompletedEvent):
|
||||
def on_crew_test_completed(source: Any, event: CrewTestCompletedEvent) -> None:
|
||||
self.formatter.handle_crew_test_completed(
|
||||
self.formatter.current_flow_tree,
|
||||
event.crew_name or "Crew",
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(CrewTestFailedEvent)
|
||||
def on_crew_test_failed(source, event: CrewTestFailedEvent):
|
||||
def on_crew_test_failed(source: Any, event: CrewTestFailedEvent) -> None:
|
||||
self.formatter.handle_crew_test_failed(event.crew_name or "Crew")
|
||||
|
||||
@crewai_event_bus.on(KnowledgeRetrievalStartedEvent)
|
||||
def on_knowledge_retrieval_started(
|
||||
source, event: KnowledgeRetrievalStartedEvent
|
||||
):
|
||||
source: Any, event: KnowledgeRetrievalStartedEvent
|
||||
) -> None:
|
||||
if self.knowledge_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -444,8 +465,8 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(KnowledgeRetrievalCompletedEvent)
|
||||
def on_knowledge_retrieval_completed(
|
||||
source, event: KnowledgeRetrievalCompletedEvent
|
||||
):
|
||||
source: Any, event: KnowledgeRetrievalCompletedEvent
|
||||
) -> None:
|
||||
if not self.knowledge_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -457,11 +478,15 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(KnowledgeQueryStartedEvent)
|
||||
def on_knowledge_query_started(source, event: KnowledgeQueryStartedEvent):
|
||||
def on_knowledge_query_started(
|
||||
source: Any, event: KnowledgeQueryStartedEvent
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@crewai_event_bus.on(KnowledgeQueryFailedEvent)
|
||||
def on_knowledge_query_failed(source, event: KnowledgeQueryFailedEvent):
|
||||
def on_knowledge_query_failed(
|
||||
source: Any, event: KnowledgeQueryFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_knowledge_query_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.error,
|
||||
@@ -469,13 +494,15 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(KnowledgeQueryCompletedEvent)
|
||||
def on_knowledge_query_completed(source, event: KnowledgeQueryCompletedEvent):
|
||||
def on_knowledge_query_completed(
|
||||
source: Any, event: KnowledgeQueryCompletedEvent
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@crewai_event_bus.on(KnowledgeSearchQueryFailedEvent)
|
||||
def on_knowledge_search_query_failed(
|
||||
source, event: KnowledgeSearchQueryFailedEvent
|
||||
):
|
||||
source: Any, event: KnowledgeSearchQueryFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_knowledge_search_query_failed(
|
||||
self.formatter.current_agent_branch,
|
||||
event.error,
|
||||
@@ -485,7 +512,9 @@ class EventListener(BaseEventListener):
|
||||
# ----------- REASONING EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningStartedEvent)
|
||||
def on_agent_reasoning_started(source, event: AgentReasoningStartedEvent):
|
||||
def on_agent_reasoning_started(
|
||||
source: Any, event: AgentReasoningStartedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_reasoning_started(
|
||||
self.formatter.current_agent_branch,
|
||||
event.attempt,
|
||||
@@ -493,7 +522,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningCompletedEvent)
|
||||
def on_agent_reasoning_completed(source, event: AgentReasoningCompletedEvent):
|
||||
def on_agent_reasoning_completed(
|
||||
source: Any, event: AgentReasoningCompletedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_reasoning_completed(
|
||||
event.plan,
|
||||
event.ready,
|
||||
@@ -501,7 +532,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningFailedEvent)
|
||||
def on_agent_reasoning_failed(source, event: AgentReasoningFailedEvent):
|
||||
def on_agent_reasoning_failed(
|
||||
source: Any, event: AgentReasoningFailedEvent
|
||||
) -> None:
|
||||
self.formatter.handle_reasoning_failed(
|
||||
event.error,
|
||||
self.formatter.current_crew_tree,
|
||||
@@ -510,7 +543,7 @@ class EventListener(BaseEventListener):
|
||||
# ----------- AGENT LOGGING EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentLogsStartedEvent)
|
||||
def on_agent_logs_started(source, event: AgentLogsStartedEvent):
|
||||
def on_agent_logs_started(source: Any, event: AgentLogsStartedEvent) -> None:
|
||||
self.formatter.handle_agent_logs_started(
|
||||
event.agent_role,
|
||||
event.task_description,
|
||||
@@ -518,7 +551,9 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentLogsExecutionEvent)
|
||||
def on_agent_logs_execution(source, event: AgentLogsExecutionEvent):
|
||||
def on_agent_logs_execution(
|
||||
source: Any, event: AgentLogsExecutionEvent
|
||||
) -> None:
|
||||
self.formatter.handle_agent_logs_execution(
|
||||
event.agent_role,
|
||||
event.formatted_answer,
|
||||
|
||||
@@ -1,25 +1,30 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
class MemoryListener(BaseEventListener):
|
||||
def __init__(self, formatter):
|
||||
def __init__(self, formatter: Any) -> None:
|
||||
super().__init__()
|
||||
self.formatter = formatter
|
||||
self.memory_retrieval_in_progress = False
|
||||
self.memory_save_in_progress = False
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
|
||||
def on_memory_retrieval_started(source, event: MemoryRetrievalStartedEvent):
|
||||
def on_memory_retrieval_started(
|
||||
source: Any, event: MemoryRetrievalStartedEvent
|
||||
) -> None:
|
||||
if self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -31,7 +36,9 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
|
||||
def on_memory_retrieval_completed(source, event: MemoryRetrievalCompletedEvent):
|
||||
def on_memory_retrieval_completed(
|
||||
source: Any, event: MemoryRetrievalCompletedEvent
|
||||
) -> None:
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -44,7 +51,9 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_memory_query_completed(source, event: MemoryQueryCompletedEvent):
|
||||
def on_memory_query_completed(
|
||||
source: Any, event: MemoryQueryCompletedEvent
|
||||
) -> None:
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -56,7 +65,7 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryFailedEvent)
|
||||
def on_memory_query_failed(source, event: MemoryQueryFailedEvent):
|
||||
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
|
||||
if not self.memory_retrieval_in_progress:
|
||||
return
|
||||
|
||||
@@ -68,7 +77,7 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_memory_save_started(source, event: MemorySaveStartedEvent):
|
||||
def on_memory_save_started(source: Any, event: MemorySaveStartedEvent) -> None:
|
||||
if self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
@@ -80,7 +89,9 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_memory_save_completed(source, event: MemorySaveCompletedEvent):
|
||||
def on_memory_save_completed(
|
||||
source: Any, event: MemorySaveCompletedEvent
|
||||
) -> None:
|
||||
if not self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
@@ -94,7 +105,7 @@ class MemoryListener(BaseEventListener):
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveFailedEvent)
|
||||
def on_memory_save_failed(source, event: MemorySaveFailedEvent):
|
||||
def on_memory_save_failed(source: Any, event: MemorySaveFailedEvent) -> None:
|
||||
if not self.memory_save_in_progress:
|
||||
return
|
||||
|
||||
|
||||
@@ -1,28 +1,59 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.listeners.tracing.trace_batch_manager import TraceBatchManager
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
)
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
)
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.events.types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from crewai.events.types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -33,43 +64,9 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
)
|
||||
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
FlowPlotEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
LLMGuardrailCompletedEvent,
|
||||
)
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
|
||||
from .trace_batch_manager import TraceBatchManager
|
||||
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class TraceCollectionListener(BaseEventListener):
|
||||
"""
|
||||
Trace collection listener that orchestrates trace collection
|
||||
@@ -88,7 +85,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
_initialized = False
|
||||
_listeners_setup = False
|
||||
|
||||
def __new__(cls, batch_manager=None):
|
||||
def __new__(cls, batch_manager: Optional[Any] = None) -> Self:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
@@ -101,10 +98,11 @@ class TraceCollectionListener(BaseEventListener):
|
||||
return
|
||||
|
||||
super().__init__()
|
||||
self.batch_manager = batch_manager or TraceBatchManager()
|
||||
self.batch_manager = batch_manager or TraceBatchManager() # type: ignore
|
||||
self._initialized = True
|
||||
|
||||
def _check_authenticated(self) -> bool:
|
||||
@staticmethod
|
||||
def _check_authenticated() -> bool:
|
||||
"""Check if tracing should be enabled"""
|
||||
try:
|
||||
res = bool(get_auth_token())
|
||||
@@ -112,7 +110,8 @@ class TraceCollectionListener(BaseEventListener):
|
||||
except AuthError:
|
||||
return False
|
||||
|
||||
def _get_user_context(self) -> Dict[str, str]:
|
||||
@staticmethod
|
||||
def _get_user_context() -> dict[str, str]:
|
||||
"""Extract user context for tracing"""
|
||||
return {
|
||||
"user_id": os.getenv("CREWAI_USER_ID", "anonymous"),
|
||||
@@ -121,7 +120,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"trace_id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
"""Setup event listeners - delegates to specific handlers"""
|
||||
|
||||
if self._listeners_setup:
|
||||
@@ -133,169 +132,169 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
self._listeners_setup = True
|
||||
|
||||
def _register_flow_event_handlers(self, event_bus):
|
||||
def _register_flow_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for flow events"""
|
||||
|
||||
@event_bus.on(FlowCreatedEvent)
|
||||
def on_flow_created(source, event):
|
||||
def on_flow_created(source: Any, event: Any) -> None:
|
||||
pass
|
||||
|
||||
@event_bus.on(FlowStartedEvent)
|
||||
def on_flow_started(source, event):
|
||||
def on_flow_started(source: Any, event: Any) -> None:
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
self._initialize_flow_batch(source, event)
|
||||
self._handle_trace_event("flow_started", source, event)
|
||||
|
||||
@event_bus.on(MethodExecutionStartedEvent)
|
||||
def on_method_started(source, event):
|
||||
def on_method_started(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("method_execution_started", source, event)
|
||||
|
||||
@event_bus.on(MethodExecutionFinishedEvent)
|
||||
def on_method_finished(source, event):
|
||||
def on_method_finished(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("method_execution_finished", source, event)
|
||||
|
||||
@event_bus.on(MethodExecutionFailedEvent)
|
||||
def on_method_failed(source, event):
|
||||
def on_method_failed(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("method_execution_failed", source, event)
|
||||
|
||||
@event_bus.on(FlowFinishedEvent)
|
||||
def on_flow_finished(source, event):
|
||||
def on_flow_finished(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("flow_finished", source, event)
|
||||
if self.batch_manager.batch_owner_type == "flow":
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(FlowPlotEvent)
|
||||
def on_flow_plot(source, event):
|
||||
def on_flow_plot(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("flow_plot", source, event)
|
||||
|
||||
def _register_context_event_handlers(self, event_bus):
|
||||
def _register_context_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for context events (start/end)"""
|
||||
|
||||
@event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source, event):
|
||||
def on_crew_started(source: Any, event: Any) -> None:
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
self._initialize_crew_batch(source, event)
|
||||
self._handle_trace_event("crew_kickoff_started", source, event)
|
||||
|
||||
@event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source, event):
|
||||
def on_crew_completed(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("crew_kickoff_completed", source, event)
|
||||
if self.batch_manager.batch_owner_type == "crew":
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event):
|
||||
def on_crew_failed(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("crew_kickoff_failed", source, event)
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(TaskStartedEvent)
|
||||
def on_task_started(source, event):
|
||||
def on_task_started(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("task_started", source, event)
|
||||
|
||||
@event_bus.on(TaskCompletedEvent)
|
||||
def on_task_completed(source, event):
|
||||
def on_task_completed(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("task_completed", source, event)
|
||||
|
||||
@event_bus.on(TaskFailedEvent)
|
||||
def on_task_failed(source, event):
|
||||
def on_task_failed(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("task_failed", source, event)
|
||||
|
||||
@event_bus.on(AgentExecutionStartedEvent)
|
||||
def on_agent_started(source, event):
|
||||
def on_agent_started(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("agent_execution_started", source, event)
|
||||
|
||||
@event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_completed(source, event):
|
||||
def on_agent_completed(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("agent_execution_completed", source, event)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def on_lite_agent_started(source, event):
|
||||
def on_lite_agent_started(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("lite_agent_execution_started", source, event)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionCompletedEvent)
|
||||
def on_lite_agent_completed(source, event):
|
||||
def on_lite_agent_completed(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("lite_agent_execution_completed", source, event)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionErrorEvent)
|
||||
def on_lite_agent_error(source, event):
|
||||
def on_lite_agent_error(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("lite_agent_execution_error", source, event)
|
||||
|
||||
@event_bus.on(AgentExecutionErrorEvent)
|
||||
def on_agent_error(source, event):
|
||||
def on_agent_error(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("agent_execution_error", source, event)
|
||||
|
||||
@event_bus.on(LLMGuardrailStartedEvent)
|
||||
def on_guardrail_started(source, event):
|
||||
def on_guardrail_started(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("llm_guardrail_started", source, event)
|
||||
|
||||
@event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def on_guardrail_completed(source, event):
|
||||
def on_guardrail_completed(source: Any, event: Any) -> None:
|
||||
self._handle_trace_event("llm_guardrail_completed", source, event)
|
||||
|
||||
def _register_action_event_handlers(self, event_bus):
|
||||
def _register_action_event_handlers(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Register handlers for action events (LLM calls, tool usage)"""
|
||||
|
||||
@event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(source, event):
|
||||
def on_llm_call_started(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("llm_call_started", source, event)
|
||||
|
||||
@event_bus.on(LLMCallCompletedEvent)
|
||||
def on_llm_call_completed(source, event):
|
||||
def on_llm_call_completed(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("llm_call_completed", source, event)
|
||||
|
||||
@event_bus.on(LLMCallFailedEvent)
|
||||
def on_llm_call_failed(source, event):
|
||||
def on_llm_call_failed(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("llm_call_failed", source, event)
|
||||
|
||||
@event_bus.on(ToolUsageStartedEvent)
|
||||
def on_tool_started(source, event):
|
||||
def on_tool_started(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("tool_usage_started", source, event)
|
||||
|
||||
@event_bus.on(ToolUsageFinishedEvent)
|
||||
def on_tool_finished(source, event):
|
||||
def on_tool_finished(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("tool_usage_finished", source, event)
|
||||
|
||||
@event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_error(source, event):
|
||||
def on_tool_error(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("tool_usage_error", source, event)
|
||||
|
||||
@event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_memory_query_started(source, event):
|
||||
def on_memory_query_started(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("memory_query_started", source, event)
|
||||
|
||||
@event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_memory_query_completed(source, event):
|
||||
def on_memory_query_completed(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("memory_query_completed", source, event)
|
||||
|
||||
@event_bus.on(MemoryQueryFailedEvent)
|
||||
def on_memory_query_failed(source, event):
|
||||
def on_memory_query_failed(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("memory_query_failed", source, event)
|
||||
|
||||
@event_bus.on(MemorySaveStartedEvent)
|
||||
def on_memory_save_started(source, event):
|
||||
def on_memory_save_started(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("memory_save_started", source, event)
|
||||
|
||||
@event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_memory_save_completed(source, event):
|
||||
def on_memory_save_completed(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("memory_save_completed", source, event)
|
||||
|
||||
@event_bus.on(MemorySaveFailedEvent)
|
||||
def on_memory_save_failed(source, event):
|
||||
def on_memory_save_failed(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("memory_save_failed", source, event)
|
||||
|
||||
@event_bus.on(AgentReasoningStartedEvent)
|
||||
def on_agent_reasoning_started(source, event):
|
||||
def on_agent_reasoning_started(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("agent_reasoning_started", source, event)
|
||||
|
||||
@event_bus.on(AgentReasoningCompletedEvent)
|
||||
def on_agent_reasoning_completed(source, event):
|
||||
def on_agent_reasoning_completed(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("agent_reasoning_completed", source, event)
|
||||
|
||||
@event_bus.on(AgentReasoningFailedEvent)
|
||||
def on_agent_reasoning_failed(source, event):
|
||||
def on_agent_reasoning_failed(source: Any, event: Any) -> None:
|
||||
self._handle_action_event("agent_reasoning_failed", source, event)
|
||||
|
||||
def _initialize_crew_batch(self, source: Any, event: Any):
|
||||
def _initialize_crew_batch(self, source: Any, event: Any) -> None:
|
||||
"""Initialize trace batch"""
|
||||
user_context = self._get_user_context()
|
||||
execution_metadata = {
|
||||
@@ -309,7 +308,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_flow_batch(self, source: Any, event: Any):
|
||||
def _initialize_flow_batch(self, source: Any, event: Any) -> None:
|
||||
"""Initialize trace batch for Flow execution"""
|
||||
user_context = self._get_user_context()
|
||||
execution_metadata = {
|
||||
@@ -325,26 +324,24 @@ class TraceCollectionListener(BaseEventListener):
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_batch(
|
||||
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
|
||||
):
|
||||
self, user_context: dict[str, str], execution_metadata: dict[str, Any]
|
||||
) -> None:
|
||||
"""Initialize trace batch if ephemeral"""
|
||||
if not self._check_authenticated():
|
||||
self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=True
|
||||
)
|
||||
else:
|
||||
self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=False
|
||||
)
|
||||
self.batch_manager.initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
|
||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any) -> None:
|
||||
"""Generic handler for context end events"""
|
||||
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
|
||||
self.batch_manager.add_event(trace_event)
|
||||
|
||||
def _handle_action_event(self, event_type: str, source: Any, event: Any):
|
||||
def _handle_action_event(self, event_type: str, source: Any, event: Any) -> None:
|
||||
"""Generic handler for action events (LLM calls, tool usage)"""
|
||||
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
@@ -371,7 +368,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
def _build_event_data(
|
||||
self, event_type: str, event: Any, source: Any
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Build event data"""
|
||||
if event_type not in self.complex_events:
|
||||
return self._safe_serialize_to_dict(event)
|
||||
@@ -426,11 +423,19 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"source": source,
|
||||
}
|
||||
|
||||
# TODO: move to utils
|
||||
@staticmethod
|
||||
def _safe_serialize_to_dict(
|
||||
self, obj, exclude: set[str] | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data."""
|
||||
obj: Any, exclude: set[str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data.
|
||||
|
||||
Args:
|
||||
obj: The object to serialize.
|
||||
exclude: Optional set of attribute names to exclude from serialization.
|
||||
|
||||
Notes:
|
||||
- TODO: refactor to utilities function.
|
||||
"""
|
||||
try:
|
||||
serialized = to_serializable(obj, exclude)
|
||||
if isinstance(serialized, dict):
|
||||
@@ -440,9 +445,20 @@ class TraceCollectionListener(BaseEventListener):
|
||||
except Exception as e:
|
||||
return {"serialization_error": str(e), "object_type": type(obj).__name__}
|
||||
|
||||
# TODO: move to utils
|
||||
def _truncate_messages(self, messages, max_content_length=500, max_messages=5):
|
||||
"""Truncate message content and limit number of messages"""
|
||||
@staticmethod
|
||||
def _truncate_messages(
|
||||
messages: Any, max_content_length: int = 500, max_messages: int = 5
|
||||
) -> Any:
|
||||
"""Truncate message content and limit number of messages
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'content' keys.
|
||||
max_content_length: Max length of each message content.
|
||||
max_messages: Max number of messages to retain.
|
||||
|
||||
Notes:
|
||||
- TODO: refactor to utilities function.
|
||||
"""
|
||||
if not messages or not isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class TaskStartedEvent(BaseEvent):
|
||||
@@ -11,14 +11,15 @@ class TaskStartedEvent(BaseEvent):
|
||||
context: Optional[str]
|
||||
task: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
self.task
|
||||
and hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
@@ -31,14 +32,15 @@ class TaskCompletedEvent(BaseEvent):
|
||||
type: str = "task_completed"
|
||||
task: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
self.task
|
||||
and hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
@@ -51,14 +53,15 @@ class TaskFailedEvent(BaseEvent):
|
||||
type: str = "task_failed"
|
||||
task: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
self.task
|
||||
and hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
@@ -71,14 +74,15 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
evaluation_type: str
|
||||
task: Optional[Any] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
if self.task and hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
self.task
|
||||
and hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import abc
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.llm_utils import create_default_llm, create_llm
|
||||
|
||||
|
||||
class MetricCategory(enum.Enum):
|
||||
GOAL_ALIGNMENT = "goal_alignment"
|
||||
@@ -18,8 +19,8 @@ class MetricCategory(enum.Enum):
|
||||
PARAMETER_EXTRACTION = "parameter_extraction"
|
||||
TOOL_INVOCATION = "tool_invocation"
|
||||
|
||||
def title(self):
|
||||
return self.value.replace('_', ' ').title()
|
||||
def title(self) -> str:
|
||||
return self.value.replace("_", " ").title()
|
||||
|
||||
|
||||
class EvaluationScore(BaseModel):
|
||||
@@ -27,15 +28,13 @@ class EvaluationScore(BaseModel):
|
||||
default=5.0,
|
||||
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
|
||||
ge=0.0,
|
||||
le=10.0
|
||||
le=10.0,
|
||||
)
|
||||
feedback: str = Field(
|
||||
default="",
|
||||
description="Detailed feedback explaining the evaluation score"
|
||||
default="", description="Detailed feedback explaining the evaluation score"
|
||||
)
|
||||
raw_response: str | None = Field(
|
||||
default=None,
|
||||
description="Raw response from the evaluator (e.g., LLM)"
|
||||
default=None, description="Raw response from the evaluator (e.g., LLM)"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -46,7 +45,9 @@ class EvaluationScore(BaseModel):
|
||||
|
||||
class BaseEvaluator(abc.ABC):
|
||||
def __init__(self, llm: BaseLLM | None = None):
|
||||
self.llm: BaseLLM | None = create_llm(llm)
|
||||
self.llm: BaseLLM | None = (
|
||||
create_llm(llm) if llm is not None else create_default_llm()
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@@ -57,7 +58,7 @@ class BaseEvaluator(abc.ABC):
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -67,9 +68,8 @@ class BaseEvaluator(abc.ABC):
|
||||
class AgentEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(description="ID of the evaluated agent")
|
||||
task_id: str = Field(description="ID of the task that was executed")
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Evaluation scores for each metric category"
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Evaluation scores for each metric category"
|
||||
)
|
||||
|
||||
|
||||
@@ -81,33 +81,23 @@ class AggregationStrategy(Enum):
|
||||
|
||||
|
||||
class AgentAggregatedEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(
|
||||
default="",
|
||||
description="ID of the agent"
|
||||
)
|
||||
agent_role: str = Field(
|
||||
default="",
|
||||
description="Role of the agent"
|
||||
)
|
||||
agent_id: str = Field(default="", description="ID of the agent")
|
||||
agent_role: str = Field(default="", description="Role of the agent")
|
||||
task_count: int = Field(
|
||||
default=0,
|
||||
description="Number of tasks included in this aggregation"
|
||||
default=0, description="Number of tasks included in this aggregation"
|
||||
)
|
||||
aggregation_strategy: AggregationStrategy = Field(
|
||||
default=AggregationStrategy.SIMPLE_AVERAGE,
|
||||
description="Strategy used for aggregation"
|
||||
description="Strategy used for aggregation",
|
||||
)
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Aggregated metrics across all tasks"
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Aggregated metrics across all tasks"
|
||||
)
|
||||
task_results: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of tasks included in this aggregation"
|
||||
task_results: list[str] = Field(
|
||||
default_factory=list, description="IDs of tasks included in this aggregation"
|
||||
)
|
||||
overall_score: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Overall score for this agent"
|
||||
default=None, description="Overall score for this agent"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -119,7 +109,7 @@ class AgentAggregatedEvaluationResult(BaseModel):
|
||||
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
|
||||
|
||||
if score.feedback:
|
||||
detailed_feedback = "\n ".join(score.feedback.split('\n'))
|
||||
detailed_feedback = "\n ".join(score.feedback.split("\n"))
|
||||
result += f" {detailed_feedback}\n"
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -18,20 +18,20 @@ class Knowledge(BaseModel):
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
embedder: Optional[dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: Optional[dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
**data,
|
||||
):
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
@@ -43,8 +43,8 @@ class Knowledge(BaseModel):
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
def query(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
@@ -62,7 +62,7 @@ class Knowledge(BaseModel):
|
||||
)
|
||||
return results
|
||||
|
||||
def add_sources(self):
|
||||
def add_sources(self) -> None:
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
|
||||
@@ -2,23 +2,56 @@ import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import OneOrMany
|
||||
from chromadb.config import Settings
|
||||
import warnings
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from crewai.utilities.chromadb import create_persistent_client, sanitize_collection_name
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
def _extract_chromadb_response_item(
|
||||
response_data: Any,
|
||||
index: int,
|
||||
expected_type: type[Any] | tuple[type[Any], ...],
|
||||
) -> Any | None:
|
||||
"""Extract an item from ChromaDB response data at the given index.
|
||||
|
||||
Args:
|
||||
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
|
||||
index: The index of the item to extract.
|
||||
expected_type: The expected type(s) of the item.
|
||||
|
||||
Returns:
|
||||
The extracted item if it exists and matches the expected type, otherwise None.
|
||||
"""
|
||||
if response_data is None or not response_data:
|
||||
return None
|
||||
|
||||
# ChromaDB sometimes returns nested lists, handle both cases
|
||||
data_list = (
|
||||
response_data[0]
|
||||
if response_data and isinstance(response_data[0], list)
|
||||
else response_data
|
||||
)
|
||||
|
||||
if index < len(data_list):
|
||||
item = data_list[index]
|
||||
if isinstance(item, expected_type):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
@@ -30,10 +63,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
collection: Optional[chromadb.Collection] = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
app: Optional[ClientAPI] = None
|
||||
embedder: Optional[EmbeddingFunction[Any]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
embedder: Optional[dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
@@ -41,11 +75,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: Optional[dict[str, Any]] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
@@ -56,20 +90,51 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
where=filter,
|
||||
)
|
||||
results = []
|
||||
for i in range(len(fetched["ids"][0])): # type: ignore
|
||||
result = {
|
||||
"id": fetched["ids"][0][i], # type: ignore
|
||||
"metadata": fetched["metadatas"][0][i], # type: ignore
|
||||
"context": fetched["documents"][0][i], # type: ignore
|
||||
"score": fetched["distances"][0][i], # type: ignore
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
if (
|
||||
fetched
|
||||
and "ids" in fetched
|
||||
and fetched["ids"]
|
||||
and len(fetched["ids"]) > 0
|
||||
):
|
||||
ids_list = (
|
||||
fetched["ids"][0]
|
||||
if isinstance(fetched["ids"][0], list)
|
||||
else fetched["ids"]
|
||||
)
|
||||
for i in range(len(ids_list)):
|
||||
# Handle metadatas
|
||||
meta_item = _extract_chromadb_response_item(
|
||||
fetched.get("metadatas"), i, dict
|
||||
)
|
||||
metadata: dict[str, Any] = meta_item if meta_item else {}
|
||||
|
||||
# Handle documents
|
||||
doc_item = _extract_chromadb_response_item(
|
||||
fetched.get("documents"), i, str
|
||||
)
|
||||
context = doc_item if doc_item else ""
|
||||
|
||||
# Handle distances
|
||||
dist_item = _extract_chromadb_response_item(
|
||||
fetched.get("distances"), i, (int, float)
|
||||
)
|
||||
score = dist_item if dist_item is not None else 1.0
|
||||
|
||||
result = {
|
||||
"id": ids_list[i],
|
||||
"metadata": metadata,
|
||||
"context": context,
|
||||
"score": score,
|
||||
}
|
||||
|
||||
# Check score threshold - distances are smaller when more similar
|
||||
if isinstance(score, (int, float)) and score <= score_threshold:
|
||||
results.append(result)
|
||||
return results
|
||||
else:
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
def initialize_knowledge_storage(self) -> None:
|
||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||
# TODO: Remove this once we upgrade chromadb to at least 1.0.8.
|
||||
warnings.filterwarnings(
|
||||
@@ -99,7 +164,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
except Exception:
|
||||
raise Exception("Failed to create or get collection")
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||
if not self.app:
|
||||
self.app = create_persistent_client(
|
||||
@@ -113,9 +178,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def save(
|
||||
self,
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
):
|
||||
documents: list[str],
|
||||
metadata: Optional[dict[str, Any] | list[dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
if not self.collection:
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
@@ -147,7 +212,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
# If we have no metadata at all, set it to None
|
||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata # type: ignore[assignment]
|
||||
)
|
||||
|
||||
self.collection.upsert(
|
||||
@@ -170,7 +235,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
def _create_default_embedding_function(self) -> Any:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
@@ -179,15 +244,18 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||
def _set_embedder_config(self, embedder: Optional[dict[str, Any]] = None) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
embedder (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
|
||||
Notes:
|
||||
- TODO: Improve typing for embedder configuration, remove type: ignore
|
||||
"""
|
||||
self.embedder = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
EmbeddingConfigurator().configure_embedder(embedder) # type: ignore
|
||||
if embedder
|
||||
else self._create_default_embedding_function()
|
||||
)
|
||||
|
||||
@@ -1,50 +1,49 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.parser import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tasks import TaskOutput
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
@@ -62,15 +61,8 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.llm_utils import create_default_llm, create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
@@ -86,11 +78,11 @@ class LiteAgentOutput(BaseModel):
|
||||
description="Pydantic output of the agent", default=None
|
||||
)
|
||||
agent_role: str = Field(description="Role of the agent that produced this output")
|
||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
||||
usage_metrics: Optional[dict[str, Any]] = Field(
|
||||
description="Token usage metrics for this execution", default=None
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert pydantic_output to a dictionary."""
|
||||
if self.pydantic:
|
||||
return self.pydantic.model_dump()
|
||||
@@ -130,10 +122,10 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
)
|
||||
tools: List[BaseTool] = Field(
|
||||
tools: list[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agent's disposal"
|
||||
)
|
||||
|
||||
@@ -159,29 +151,27 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
|
||||
# Output and Formatting Properties
|
||||
response_format: Optional[Type[BaseModel]] = Field(
|
||||
response_format: Optional[type[BaseModel]] = Field(
|
||||
default=None, description="Pydantic model for structured output"
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Whether to print execution details"
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
callbacks: list[Callable[..., Any]] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
|
||||
# Guardrail Properties
|
||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
|
||||
Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
guardrail: Optional[Callable[[LiteAgentOutput], tuple[bool, Any]] | str] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
|
||||
# State and Results
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
)
|
||||
|
||||
@@ -190,20 +180,25 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=None, description="Reference to the agent that created this LiteAgent"
|
||||
)
|
||||
# Private Attributes
|
||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_guardrail: Optional[Callable[[LiteAgentOutput | TaskOutput], tuple[bool, Any]]] = (
|
||||
PrivateAttr(default=None)
|
||||
)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
def setup_llm(self) -> Self:
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
self.llm = create_llm(self.llm)
|
||||
if self.llm is None:
|
||||
self.llm = create_default_llm()
|
||||
else:
|
||||
self.llm = create_llm(self.llm)
|
||||
if not isinstance(self.llm, BaseLLM):
|
||||
raise ValueError(
|
||||
f"Expected LLM instance of type BaseLLM, got {type(self.llm).__name__}"
|
||||
@@ -216,7 +211,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def parse_tools(self):
|
||||
def parse_tools(self) -> Self:
|
||||
"""Parse the tools and convert them to CrewStructuredTool instances."""
|
||||
self._parsed_tools = parse_tools(self.tools)
|
||||
|
||||
@@ -225,7 +220,10 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrail_is_callable(self) -> Self:
|
||||
if callable(self.guardrail):
|
||||
self._guardrail = self.guardrail
|
||||
self._guardrail = cast(
|
||||
Callable[[LiteAgentOutput | TaskOutput], tuple[bool, Any]],
|
||||
self.guardrail,
|
||||
)
|
||||
elif isinstance(self.guardrail, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
@@ -234,15 +232,18 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
f"Guardrail requires LLM instance of type BaseLLM, got {type(self.llm).__name__}"
|
||||
)
|
||||
|
||||
self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm)
|
||||
self._guardrail = cast(
|
||||
Callable[[LiteAgentOutput | TaskOutput], tuple[bool, Any]],
|
||||
LLMGuardrail(description=self.guardrail, llm=self.llm),
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
cls, v: Optional[Union[Callable, str]]
|
||||
) -> Optional[Union[Callable, str]]:
|
||||
cls, v: Optional[Callable[[Any], tuple[bool, Any]] | str]
|
||||
) -> Optional[Callable[[Any], tuple[bool, Any]] | str]:
|
||||
"""Validate that the guardrail function has the correct signature.
|
||||
|
||||
If v is a callable, validate that it has the correct signature.
|
||||
@@ -267,7 +268,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Check return annotation if present
|
||||
if sig.return_annotation is not sig.empty:
|
||||
if sig.return_annotation == Tuple[bool, Any]:
|
||||
if sig.return_annotation == tuple[bool, Any]:
|
||||
return v
|
||||
|
||||
origin = get_origin(sig.return_annotation)
|
||||
@@ -290,7 +291,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
|
||||
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages.
|
||||
|
||||
@@ -338,7 +339,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise e
|
||||
|
||||
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
|
||||
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -428,7 +429,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return output
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages.
|
||||
@@ -475,8 +476,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return base_prompt
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages for the LLM."""
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -582,7 +583,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
|
||||
"""Show logs for the agent's execution."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from crewai.memory import (
|
||||
EntityMemory,
|
||||
@@ -19,8 +21,8 @@ class ContextualMemory:
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
exm: ExternalMemory,
|
||||
agent: Optional["Agent"] = None,
|
||||
task: Optional["Task"] = None,
|
||||
agent: Optional[Agent] = None,
|
||||
task: Optional[Task] = None,
|
||||
):
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
@@ -42,7 +44,7 @@ class ContextualMemory:
|
||||
self.exm.agent = self.agent
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
@@ -52,14 +54,14 @@ class ContextualMemory:
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
context = []
|
||||
context.append(self._fetch_ltm_context(task.description))
|
||||
context.append(self._fetch_stm_context(query))
|
||||
context.append(self._fetch_entity_context(query))
|
||||
context.append(self._fetch_external_context(query))
|
||||
return "\n".join(filter(None, context))
|
||||
context_parts = []
|
||||
context_parts.append(self._fetch_ltm_context(task.description))
|
||||
context_parts.append(self._fetch_stm_context(query))
|
||||
context_parts.append(self._fetch_entity_context(query))
|
||||
context_parts.append(self._fetch_external_context(query))
|
||||
return "\n".join(filter(None, context_parts))
|
||||
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -74,7 +76,7 @@ class ContextualMemory:
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
def _fetch_ltm_context(self, task: str) -> Optional[str]:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -83,21 +85,23 @@ class ContextualMemory:
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
ltm_results = self.ltm.search(task, latest_n=2)
|
||||
ltm_results = self.ltm.search(task, limit=2)
|
||||
if not ltm_results:
|
||||
return None
|
||||
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
formatted_results_str = "\n".join(
|
||||
[f"- {result}" for result in formatted_results]
|
||||
)
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
return f"Historical Data:\n{formatted_results_str}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
def _fetch_entity_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
@@ -107,7 +111,7 @@ class ContextualMemory:
|
||||
|
||||
em_results = self.em.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
[f"- {result['context']}" for result in em_results]
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
@@ -26,7 +26,13 @@ class EntityMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: Any = None,
|
||||
) -> None:
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
@@ -155,7 +161,7 @@ class EntityMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
) -> Any:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from typing import Any, Dict, List
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
@@ -24,72 +24,84 @@ class LongTermMemory(Memory):
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
def __init__(self, storage: Any = None, path: Any = None) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
# Handle both LongTermMemoryItem and regular save calls
|
||||
if isinstance(value, LongTermMemoryItem):
|
||||
item = value
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata.copy()
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Regular save for compatibility with parent class
|
||||
metadata = metadata or {}
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> list[Any]:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
query=query,
|
||||
limit=limit,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
@@ -98,14 +110,16 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
# The storage.load method uses different parameter names
|
||||
# but we'll call it with the aligned names
|
||||
results = self.storage.load(query, limit)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
query=query,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
limit=limit,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
@@ -113,15 +127,17 @@ class LongTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return results if results is not None else []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
query=query,
|
||||
limit=limit,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
@@ -28,7 +28,13 @@ class ShortTermMemory(Memory):
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: Any = None,
|
||||
) -> None:
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
@@ -56,7 +62,7 @@ class ShortTermMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -114,7 +120,7 @@ class ShortTermMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
) -> Any:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -131,7 +137,7 @@ class ShortTermMemory(Memory):
|
||||
try:
|
||||
results = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from collections import defaultdict
|
||||
from mem0 import Memory, MemoryClient
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient # type: ignore[import-not-found]
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
|
||||
MAX_AGENT_ID_LENGTH_MEM0 = 255
|
||||
|
||||
@@ -13,7 +14,10 @@ class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
def __init__(self, type, crew=None, config=None):
|
||||
|
||||
def __init__(
|
||||
self, type: str, crew: Any = None, config: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._validate_type(type)
|
||||
@@ -24,21 +28,21 @@ class Mem0Storage(Storage):
|
||||
self._extract_config_values()
|
||||
self._initialize_memory()
|
||||
|
||||
def _validate_type(self, type):
|
||||
def _validate_type(self, type: str) -> None:
|
||||
supported_types = {"short_term", "long_term", "entities", "external"}
|
||||
if type not in supported_types:
|
||||
raise ValueError(
|
||||
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}"
|
||||
)
|
||||
|
||||
def _extract_config_values(self):
|
||||
def _extract_config_values(self) -> None:
|
||||
self.mem0_run_id = self.config.get("run_id")
|
||||
self.includes = self.config.get("includes")
|
||||
self.excludes = self.config.get("excludes")
|
||||
self.custom_categories = self.config.get("custom_categories")
|
||||
self.infer = self.config.get("infer", True)
|
||||
|
||||
def _initialize_memory(self):
|
||||
def _initialize_memory(self) -> None:
|
||||
api_key = self.config.get("api_key") or os.getenv("MEM0_API_KEY")
|
||||
org_id = self.config.get("org_id")
|
||||
project_id = self.config.get("project_id")
|
||||
@@ -59,7 +63,7 @@ class Mem0Storage(Storage):
|
||||
else Memory()
|
||||
)
|
||||
|
||||
def _create_filter_for_search(self):
|
||||
def _create_filter_for_search(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns:
|
||||
dict: A filter dictionary containing AND conditions for querying data.
|
||||
@@ -86,21 +90,21 @@ class Mem0Storage(Storage):
|
||||
|
||||
return filter
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
user_id = self.config.get("user_id", "")
|
||||
assistant_message = [{"role" : "assistant","content" : value}]
|
||||
assistant_message = [{"role": "assistant", "content": value}]
|
||||
|
||||
base_metadata = {
|
||||
"short_term": "short_term",
|
||||
"long_term": "long_term",
|
||||
"entities": "entity",
|
||||
"external": "external"
|
||||
"external": "external",
|
||||
}
|
||||
|
||||
# Shared base params
|
||||
params: dict[str, Any] = {
|
||||
"metadata": {"type": base_metadata[self.memory_type], **metadata},
|
||||
"infer": self.infer
|
||||
"infer": self.infer,
|
||||
}
|
||||
|
||||
# MemoryClient-specific overrides
|
||||
@@ -121,13 +125,15 @@ class Mem0Storage(Storage):
|
||||
|
||||
self.memory.add(assistant_message, **params)
|
||||
|
||||
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
|
||||
def search(
|
||||
self, query: str, limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[Any]:
|
||||
params = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"version": "v2",
|
||||
"output_format": "v1.1"
|
||||
}
|
||||
"output_format": "v1.1",
|
||||
}
|
||||
|
||||
if user_id := self.config.get("user_id", ""):
|
||||
params["user_id"] = user_id
|
||||
@@ -148,10 +154,10 @@ class Mem0Storage(Storage):
|
||||
# automatically when the crew is created.
|
||||
|
||||
params["filters"] = self._create_filter_for_search()
|
||||
params['threshold'] = score_threshold
|
||||
params["threshold"] = score_threshold
|
||||
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["version"], params['output_format']
|
||||
del params["metadata"], params["version"], params["output_format"]
|
||||
if params.get("run_id"):
|
||||
del params["run_id"]
|
||||
|
||||
@@ -160,10 +166,10 @@ class Mem0Storage(Storage):
|
||||
# This makes it compatible for Contextual Memory to retrieve
|
||||
for result in results["results"]:
|
||||
result["context"] = result["memory"]
|
||||
|
||||
|
||||
return [r for r in results["results"]]
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
if self.memory:
|
||||
self.memory.reset()
|
||||
|
||||
@@ -180,4 +186,6 @@ class Mem0Storage(Storage):
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
|
||||
return sanitize_collection_name(
|
||||
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
|
||||
)
|
||||
|
||||
@@ -2,29 +2,72 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.api import ClientAPI
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
|
||||
from crewai.rag.embeddings.configurator import EmbeddingConfigurator
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities.chromadb import create_persistent_client
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.logger_utils import suppress_logging
|
||||
import warnings
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
def _extract_chromadb_response_item(
|
||||
response_data: Any,
|
||||
index: int,
|
||||
expected_type: type[Any] | tuple[type[Any], ...],
|
||||
) -> Any | None:
|
||||
"""Extract an item from ChromaDB response data at the given index.
|
||||
|
||||
Args:
|
||||
response_data: The response data from ChromaDB query (e.g., documents, metadatas).
|
||||
index: The index of the item to extract.
|
||||
expected_type: The expected type(s) of the item.
|
||||
|
||||
Returns:
|
||||
The extracted item if it exists and matches the expected type, otherwise None.
|
||||
"""
|
||||
if response_data is None or not response_data:
|
||||
return None
|
||||
|
||||
# ChromaDB sometimes returns nested lists, handle both cases
|
||||
data_list = (
|
||||
response_data[0]
|
||||
if response_data and isinstance(response_data[0], list)
|
||||
else response_data
|
||||
)
|
||||
|
||||
if index < len(data_list):
|
||||
item = data_list[index]
|
||||
if isinstance(item, expected_type):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
|
||||
Notes:
|
||||
- TODO: Add type hints to EmbeddingFunction in next typing PR.
|
||||
"""
|
||||
|
||||
app: ClientAPI | None = None
|
||||
embedder_config: EmbeddingFunction[Any] | None = None # type: ignore
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Any = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
@@ -33,16 +76,29 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
|
||||
self._original_embedder_config = (
|
||||
embedder_config # Store for later use in _set_embedder_config
|
||||
)
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self):
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
def _set_embedder_config(self) -> None:
|
||||
"""Sets the embedder_config using EmbeddingConfigurator.
|
||||
|
||||
def _initialize_app(self):
|
||||
Notes:
|
||||
- TODO: remove the type: ignore on next typing pr.
|
||||
"""
|
||||
configurator = EmbeddingConfigurator() # type: ignore
|
||||
# Pass the original embedder_config from __init__, not self.embedder_config
|
||||
if hasattr(self, "_original_embedder_config"):
|
||||
self.embedder_config = configurator.configure_embedder(
|
||||
self._original_embedder_config
|
||||
)
|
||||
else:
|
||||
self.embedder_config = configurator.configure_embedder()
|
||||
|
||||
def _initialize_app(self) -> None:
|
||||
from chromadb.config import Settings
|
||||
|
||||
# Suppress deprecation warnings from chromadb, which are not relevant to us
|
||||
@@ -71,7 +127,8 @@ class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||
@staticmethod
|
||||
def _build_storage_file_name(type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
@@ -85,7 +142,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
try:
|
||||
@@ -97,9 +154,9 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
@@ -110,26 +167,68 @@ class RAGStorage(BaseRAGStorage):
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
|
||||
results = []
|
||||
for i in range(len(response["ids"][0])):
|
||||
result = {
|
||||
"id": response["ids"][0][i],
|
||||
"metadata": response["metadatas"][0][i],
|
||||
"context": response["documents"][0][i],
|
||||
"score": response["distances"][0][i],
|
||||
}
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
if (
|
||||
response
|
||||
and "ids" in response
|
||||
and response["ids"]
|
||||
and len(response["ids"]) > 0
|
||||
):
|
||||
ids_list = (
|
||||
response["ids"][0]
|
||||
if isinstance(response["ids"][0], list)
|
||||
else response["ids"]
|
||||
)
|
||||
for i in range(len(ids_list)):
|
||||
# Handle metadatas
|
||||
meta_item = _extract_chromadb_response_item(
|
||||
response.get("metadatas"), i, dict
|
||||
)
|
||||
metadata: dict[str, Any] = meta_item if meta_item else {}
|
||||
|
||||
# Handle documents
|
||||
doc_item = _extract_chromadb_response_item(
|
||||
response.get("documents"), i, str
|
||||
)
|
||||
context = doc_item if doc_item else ""
|
||||
|
||||
# Handle distances
|
||||
dist_item = _extract_chromadb_response_item(
|
||||
response.get("distances"), i, (int, float)
|
||||
)
|
||||
score = dist_item if dist_item is not None else 1.0
|
||||
|
||||
result = {
|
||||
"id": ids_list[i],
|
||||
"metadata": metadata,
|
||||
"context": context,
|
||||
"score": score,
|
||||
}
|
||||
|
||||
# Check score threshold - distances are smaller when more similar
|
||||
if isinstance(score, (int, float)) and score <= score_threshold:
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Generates and stores the embedding for the given text and metadata.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
metadata: Optional metadata associated with the text.
|
||||
|
||||
Notes:
|
||||
- Need to constrain the typing in the base class, this result isn't used
|
||||
"""
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
|
||||
self.collection.add(
|
||||
return self.collection.add(
|
||||
documents=[text],
|
||||
metadatas=[metadata or {}],
|
||||
ids=[str(uuid.uuid4())],
|
||||
@@ -151,7 +250,8 @@ class RAGStorage(BaseRAGStorage):
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
@staticmethod
|
||||
def _create_default_embedding_function() -> EmbeddingFunction[Any]:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -13,7 +14,7 @@ class ConditionalTask(Task):
|
||||
Note: This cannot be the only task you have in your crew and cannot be the first since its needs context from the previous task.
|
||||
"""
|
||||
|
||||
condition: Callable[[TaskOutput], bool] = Field(
|
||||
condition: Optional[Callable[[TaskOutput], bool]] = Field(
|
||||
default=None,
|
||||
description="Maximum number of retries for an agent to execute a task when an error occurs.",
|
||||
)
|
||||
@@ -21,8 +22,8 @@ class ConditionalTask(Task):
|
||||
def __init__(
|
||||
self,
|
||||
condition: Callable[[Any], bool],
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.condition = condition
|
||||
|
||||
@@ -36,9 +37,11 @@ class ConditionalTask(Task):
|
||||
Returns:
|
||||
bool: True if the task should be executed, False otherwise.
|
||||
"""
|
||||
if self.condition is None:
|
||||
return False
|
||||
return self.condition(context)
|
||||
|
||||
def get_skipped_task_output(self):
|
||||
def get_skipped_task_output(self) -> TaskOutput:
|
||||
return TaskOutput(
|
||||
description=self.description,
|
||||
raw="",
|
||||
|
||||
@@ -5,11 +5,12 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from importlib.metadata import version
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
@@ -32,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
def suppress_warnings() -> Any:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
yield
|
||||
@@ -44,7 +45,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class SafeOTLPSpanExporter(OTLPSpanExporter):
|
||||
def export(self, spans) -> SpanExportResult:
|
||||
def export(self, spans: Any) -> SpanExportResult:
|
||||
try:
|
||||
return super().export(spans)
|
||||
except Exception as e:
|
||||
@@ -68,18 +69,18 @@ class Telemetry:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> Telemetry:
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Telemetry, cls).__new__(cls)
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if hasattr(self, '_initialized') and self._initialized:
|
||||
if hasattr(self, "_initialized") and self._initialized:
|
||||
return
|
||||
|
||||
|
||||
self.ready: bool = False
|
||||
self.trace_set: bool = False
|
||||
self._initialized: bool = True
|
||||
@@ -124,7 +125,7 @@ class Telemetry:
|
||||
"""Check if telemetry operations should be executed."""
|
||||
return self.ready and not self._is_telemetry_disabled()
|
||||
|
||||
def set_tracer(self):
|
||||
def set_tracer(self) -> None:
|
||||
if self.ready and not self.trace_set:
|
||||
try:
|
||||
with suppress_warnings():
|
||||
@@ -143,10 +144,10 @@ class Telemetry:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None):
|
||||
def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None) -> None:
|
||||
"""Records the creation of a crew."""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Crew Created")
|
||||
self._add_attribute(
|
||||
@@ -351,7 +352,7 @@ class Telemetry:
|
||||
def task_started(self, crew: Crew, task: Task) -> Span | None:
|
||||
"""Records task started in a crew."""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
|
||||
created_span = tracer.start_span("Task Created")
|
||||
@@ -438,7 +439,7 @@ class Telemetry:
|
||||
self._safe_telemetry_operation(operation)
|
||||
return None
|
||||
|
||||
def task_ended(self, span: Span, task: Task, crew: Crew):
|
||||
def task_ended(self, span: Span, task: Task, crew: Crew) -> None:
|
||||
"""Records the completion of a task execution in a crew.
|
||||
|
||||
Args:
|
||||
@@ -450,7 +451,7 @@ class Telemetry:
|
||||
If share_crew is enabled, this will also record the task output
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
# Ensure fingerprint data is present on completion span
|
||||
if hasattr(task, "fingerprint") and task.fingerprint:
|
||||
self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str)
|
||||
@@ -467,7 +468,7 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int):
|
||||
def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int) -> None:
|
||||
"""Records when a tool is used repeatedly, which might indicate an issue.
|
||||
|
||||
Args:
|
||||
@@ -476,7 +477,7 @@ class Telemetry:
|
||||
attempts (int): Number of attempts made with this tool
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Tool Repeated Usage")
|
||||
self._add_attribute(
|
||||
@@ -493,7 +494,9 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def tool_usage(self, llm: Any, tool_name: str, attempts: int, agent: Any = None):
|
||||
def tool_usage(
|
||||
self, llm: Any, tool_name: str, attempts: int, agent: Any = None
|
||||
) -> None:
|
||||
"""Records the usage of a tool by an agent.
|
||||
|
||||
Args:
|
||||
@@ -503,7 +506,7 @@ class Telemetry:
|
||||
agent (Any, optional): The agent using the tool
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Tool Usage")
|
||||
self._add_attribute(
|
||||
@@ -531,7 +534,7 @@ class Telemetry:
|
||||
|
||||
def tool_usage_error(
|
||||
self, llm: Any, agent: Any = None, tool_name: Optional[str] = None
|
||||
):
|
||||
) -> None:
|
||||
"""Records when a tool usage results in an error.
|
||||
|
||||
Args:
|
||||
@@ -540,7 +543,7 @@ class Telemetry:
|
||||
tool_name (str, optional): Name of the tool that caused the error
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Tool Usage Error")
|
||||
self._add_attribute(
|
||||
@@ -569,7 +572,7 @@ class Telemetry:
|
||||
|
||||
def individual_test_result_span(
|
||||
self, crew: Crew, quality: float, exec_time: int, model_name: str
|
||||
):
|
||||
) -> None:
|
||||
"""Records individual test results for a crew execution.
|
||||
|
||||
Args:
|
||||
@@ -579,7 +582,7 @@ class Telemetry:
|
||||
model_name (str): Name of the model used
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Crew Individual Test Result")
|
||||
|
||||
@@ -604,7 +607,7 @@ class Telemetry:
|
||||
iterations: int,
|
||||
inputs: dict[str, Any] | None,
|
||||
model_name: str,
|
||||
):
|
||||
) -> None:
|
||||
"""Records the execution of a test suite for a crew.
|
||||
|
||||
Args:
|
||||
@@ -614,7 +617,7 @@ class Telemetry:
|
||||
model_name (str): Name of the model used in testing
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Crew Test Execution")
|
||||
|
||||
@@ -638,10 +641,10 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def deploy_signup_error_span(self):
|
||||
def deploy_signup_error_span(self) -> None:
|
||||
"""Records when an error occurs during the deployment signup process."""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Deploy Signup Error")
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
@@ -649,14 +652,14 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def start_deployment_span(self, uuid: Optional[str] = None):
|
||||
def start_deployment_span(self, uuid: Optional[str] = None) -> None:
|
||||
"""Records the start of a deployment process.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): Unique identifier for the deployment
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Start Deployment")
|
||||
if uuid:
|
||||
@@ -666,10 +669,10 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def create_crew_deployment_span(self):
|
||||
def create_crew_deployment_span(self) -> None:
|
||||
"""Records the creation of a new crew deployment."""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Create Crew Deployment")
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
@@ -677,7 +680,9 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def get_crew_logs_span(self, uuid: Optional[str], log_type: str = "deployment"):
|
||||
def get_crew_logs_span(
|
||||
self, uuid: Optional[str], log_type: str = "deployment"
|
||||
) -> None:
|
||||
"""Records the retrieval of crew logs.
|
||||
|
||||
Args:
|
||||
@@ -685,7 +690,7 @@ class Telemetry:
|
||||
log_type (str, optional): Type of logs being retrieved. Defaults to "deployment".
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Get Crew Logs")
|
||||
self._add_attribute(span, "log_type", log_type)
|
||||
@@ -696,14 +701,14 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def remove_crew_span(self, uuid: Optional[str] = None):
|
||||
def remove_crew_span(self, uuid: Optional[str] = None) -> None:
|
||||
"""Records the removal of a crew.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): Unique identifier for the crew being removed
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Remove Crew")
|
||||
if uuid:
|
||||
@@ -713,13 +718,13 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def crew_execution_span(self, crew: Crew, inputs: dict[str, Any] | None):
|
||||
def crew_execution_span(self, crew: Crew, inputs: dict[str, Any] | None) -> None:
|
||||
"""Records the complete execution of a crew.
|
||||
This is only collected if the user has opted-in to share the crew.
|
||||
"""
|
||||
self.crew_creation(crew, inputs)
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Crew Execution")
|
||||
self._add_attribute(
|
||||
@@ -787,11 +792,12 @@ class Telemetry:
|
||||
|
||||
if crew.share_crew:
|
||||
self._safe_telemetry_operation(operation)
|
||||
return operation()
|
||||
result = operation()
|
||||
return result # type: ignore[no-any-return]
|
||||
return None
|
||||
|
||||
def end_crew(self, crew, final_string_output):
|
||||
def operation():
|
||||
def end_crew(self, crew: Any, final_string_output: str) -> None:
|
||||
def operation() -> Any:
|
||||
self._add_attribute(
|
||||
crew._execution_span,
|
||||
"crewai_version",
|
||||
@@ -820,22 +826,22 @@ class Telemetry:
|
||||
if crew.share_crew:
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def _add_attribute(self, span, key, value):
|
||||
def _add_attribute(self, span: Any, key: str, value: Any) -> None:
|
||||
"""Add an attribute to a span."""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
return span.set_attribute(key, value)
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def flow_creation_span(self, flow_name: str):
|
||||
def flow_creation_span(self, flow_name: str) -> None:
|
||||
"""Records the creation of a new flow.
|
||||
|
||||
Args:
|
||||
flow_name (str): Name of the flow being created
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Flow Creation")
|
||||
self._add_attribute(span, "flow_name", flow_name)
|
||||
@@ -844,7 +850,7 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def flow_plotting_span(self, flow_name: str, node_names: list[str]):
|
||||
def flow_plotting_span(self, flow_name: str, node_names: list[str]) -> None:
|
||||
"""Records flow visualization/plotting activity.
|
||||
|
||||
Args:
|
||||
@@ -852,7 +858,7 @@ class Telemetry:
|
||||
node_names (list[str]): List of node names in the flow
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Flow Plotting")
|
||||
self._add_attribute(span, "flow_name", flow_name)
|
||||
@@ -862,7 +868,7 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_operation(operation)
|
||||
|
||||
def flow_execution_span(self, flow_name: str, node_names: list[str]):
|
||||
def flow_execution_span(self, flow_name: str, node_names: list[str]) -> None:
|
||||
"""Records the execution of a flow.
|
||||
|
||||
Args:
|
||||
@@ -870,7 +876,7 @@ class Telemetry:
|
||||
node_names (list[str]): List of nodes being executed in the flow
|
||||
"""
|
||||
|
||||
def operation():
|
||||
def operation() -> Any:
|
||||
tracer = trace.get_tracer("crewai.telemetry")
|
||||
span = tracer.start_span("Flow Execution")
|
||||
self._add_attribute(span, "flow_name", flow_name)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
@@ -13,15 +13,19 @@ class CacheTools(BaseModel):
|
||||
default_factory=CacheHandler,
|
||||
)
|
||||
|
||||
def tool(self):
|
||||
def tool(self) -> CrewStructuredTool:
|
||||
return CrewStructuredTool.from_function(
|
||||
func=self.hit_cache,
|
||||
name=self.name,
|
||||
description="Reads directly from the cache",
|
||||
)
|
||||
|
||||
def hit_cache(self, key):
|
||||
def hit_cache(self, key: str) -> str:
|
||||
import json
|
||||
|
||||
split = key.split("tool:")
|
||||
tool = split[1].split("|input:")[0].strip()
|
||||
tool_input = split[1].split("|input:")[1].strip()
|
||||
return self.cache_handler.read(tool, tool_input)
|
||||
tool_input_str = split[1].split("|input:")[1].strip()
|
||||
tool_input = json.loads(tool_input_str) if tool_input_str else None
|
||||
result = self.cache_handler.read(tool, tool_input)
|
||||
return result if result is not None else ""
|
||||
|
||||
@@ -5,12 +5,20 @@ import time
|
||||
from difflib import SequenceMatcher
|
||||
from json import JSONDecodeError
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import json5
|
||||
from json_repair import repair_json
|
||||
from json_repair import repair_json # type: ignore[import-untyped]
|
||||
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolSelectionErrorEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry import Telemetry
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
@@ -20,14 +28,6 @@ from crewai.utilities.agent_utils import (
|
||||
get_tool_names,
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolSelectionErrorEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -69,12 +69,12 @@ class ToolUsage:
|
||||
def __init__(
|
||||
self,
|
||||
tools_handler: Optional[ToolsHandler],
|
||||
tools: List[CrewStructuredTool],
|
||||
tools: list[CrewStructuredTool],
|
||||
task: Optional[Task],
|
||||
function_calling_llm: Any,
|
||||
agent: Optional[Union["BaseAgent", "LiteAgent"]] = None,
|
||||
action: Any = None,
|
||||
fingerprint_context: Optional[Dict[str, str]] = None,
|
||||
fingerprint_context: Optional[dict[str, str]] = None,
|
||||
) -> None:
|
||||
self._i18n: I18N = agent.i18n if agent else I18N()
|
||||
self._printer: Printer = Printer()
|
||||
@@ -100,12 +100,12 @@ class ToolUsage:
|
||||
self._max_parsing_attempts = 2
|
||||
self._remember_format_after_usages = 4
|
||||
|
||||
def parse_tool_calling(self, tool_string: str):
|
||||
def parse_tool_calling(self, tool_string: str) -> Any:
|
||||
"""Parse the tool string and return the tool calling."""
|
||||
return self._tool_calling(tool_string)
|
||||
|
||||
def use(
|
||||
self, calling: Union[ToolCalling, InstructorToolCalling], tool_string: str
|
||||
self, calling: ToolCalling | InstructorToolCalling, tool_string: str
|
||||
) -> str:
|
||||
if isinstance(calling, ToolUsageErrorException):
|
||||
error = calling.message
|
||||
@@ -147,11 +147,25 @@ class ToolUsage:
|
||||
self,
|
||||
tool_string: str,
|
||||
tool: CrewStructuredTool,
|
||||
calling: Union[ToolCalling, InstructorToolCalling],
|
||||
calling: ToolCalling | InstructorToolCalling,
|
||||
) -> str:
|
||||
if self._check_tool_repeated_usage(calling=calling): # type: ignore # _check_tool_repeated_usage of "ToolUsage" does not return a value (it only ever returns None)
|
||||
"""Use a tool with the given calling information.
|
||||
|
||||
Args:
|
||||
tool_string: The string representation of the tool call.
|
||||
tool: The CrewStructuredTool instance to use.
|
||||
calling: The tool calling information.
|
||||
|
||||
Returns:
|
||||
The formatted result of the tool usage.
|
||||
|
||||
Notes:
|
||||
TODO: Investigate why BaseAgent/LiteAgent don't have fingerprint attribute.
|
||||
Currently using hasattr check as a workaround (lines 179-180).
|
||||
"""
|
||||
if self._check_tool_repeated_usage(calling=calling):
|
||||
try:
|
||||
result = self._i18n.errors("task_repeated_usage").format(
|
||||
repeated_usage_msg = self._i18n.errors("task_repeated_usage").format(
|
||||
tool_names=self.tools_names
|
||||
)
|
||||
self._telemetry.tool_repeated_usage(
|
||||
@@ -159,8 +173,8 @@ class ToolUsage:
|
||||
tool_name=tool.name,
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
result = self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None)
|
||||
return result # type: ignore # Fix the return type of this function
|
||||
repeated_usage_result = self._format_result(result=repeated_usage_msg)
|
||||
return repeated_usage_result
|
||||
|
||||
except Exception:
|
||||
if self.task:
|
||||
@@ -176,7 +190,7 @@ class ToolUsage:
|
||||
"agent": self.agent,
|
||||
}
|
||||
|
||||
if self.agent.fingerprint:
|
||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
event_data.update(self.agent.fingerprint)
|
||||
if self.task:
|
||||
event_data["task_name"] = self.task.name or self.task.description
|
||||
@@ -185,12 +199,12 @@ class ToolUsage:
|
||||
|
||||
started_at = time.time()
|
||||
from_cache = False
|
||||
result = None # type: ignore
|
||||
result: str | None = None
|
||||
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
result = self.tools_handler.cache.read(
|
||||
tool=calling.tool_name, input=calling.arguments
|
||||
) # type: ignore
|
||||
tool=calling.tool_name, input_data=calling.arguments
|
||||
)
|
||||
from_cache = result is not None
|
||||
|
||||
available_tool = next(
|
||||
@@ -229,7 +243,7 @@ class ToolUsage:
|
||||
try:
|
||||
acceptable_args = tool.args_schema.model_json_schema()[
|
||||
"properties"
|
||||
].keys() # type: ignore
|
||||
].keys()
|
||||
arguments = {
|
||||
k: v
|
||||
for k, v in calling.arguments.items()
|
||||
@@ -264,19 +278,20 @@ class ToolUsage:
|
||||
self._printer.print(
|
||||
content=f"\n\n{error_message}\n", color="red"
|
||||
)
|
||||
return error # type: ignore # No return value expected
|
||||
return error
|
||||
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
return self.use(calling=calling, tool_string=tool_string) # type: ignore # No return value expected
|
||||
return self.use(calling=calling, tool_string=tool_string)
|
||||
|
||||
if self.tools_handler:
|
||||
should_cache = True
|
||||
if (
|
||||
hasattr(available_tool, "cache_function")
|
||||
and available_tool.cache_function # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
|
||||
available_tool
|
||||
and hasattr(available_tool, "cache_function")
|
||||
and available_tool.cache_function
|
||||
):
|
||||
should_cache = available_tool.cache_function( # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
|
||||
should_cache = available_tool.cache_function(
|
||||
calling.arguments, result
|
||||
)
|
||||
|
||||
@@ -288,8 +303,8 @@ class ToolUsage:
|
||||
tool_name=tool.name,
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
result = self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None)
|
||||
data = {
|
||||
result = self._format_result(result=result)
|
||||
data: dict[str, Any] = {
|
||||
"result": result,
|
||||
"tool_name": tool.name,
|
||||
"tool_args": calling.arguments,
|
||||
@@ -308,7 +323,7 @@ class ToolUsage:
|
||||
and available_tool.result_as_answer # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
|
||||
):
|
||||
result_as_answer = available_tool.result_as_answer # type: ignore # Item "None" of "Any | None" has no attribute "result_as_answer"
|
||||
data["result_as_answer"] = result_as_answer # type: ignore
|
||||
data["result_as_answer"] = result_as_answer
|
||||
|
||||
if self.agent and hasattr(self.agent, "tools_results"):
|
||||
self.agent.tools_results.append(data)
|
||||
@@ -346,7 +361,7 @@ class ToolUsage:
|
||||
return result
|
||||
|
||||
def _check_tool_repeated_usage(
|
||||
self, calling: Union[ToolCalling, InstructorToolCalling]
|
||||
self, calling: ToolCalling | InstructorToolCalling
|
||||
) -> bool:
|
||||
if not self.tools_handler:
|
||||
return False
|
||||
@@ -393,7 +408,7 @@ class ToolUsage:
|
||||
return tool
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
tool_selection_data: Dict[str, Any] = {
|
||||
tool_selection_data: dict[str, Any] = {
|
||||
"agent_key": getattr(self.agent, "key", None) if self.agent else None,
|
||||
"agent_role": getattr(self.agent, "role", None) if self.agent else None,
|
||||
"tool_name": tool_name,
|
||||
@@ -430,7 +445,7 @@ class ToolUsage:
|
||||
|
||||
def _function_calling(
|
||||
self, tool_string: str
|
||||
) -> Union[ToolCalling, InstructorToolCalling]:
|
||||
) -> ToolCalling | InstructorToolCalling:
|
||||
model = (
|
||||
InstructorToolCalling
|
||||
if self.function_calling_llm.supports_function_calling()
|
||||
@@ -459,7 +474,7 @@ class ToolUsage:
|
||||
|
||||
def _original_tool_calling(
|
||||
self, tool_string: str, raise_error: bool = False
|
||||
) -> Union[ToolCalling, InstructorToolCalling, ToolUsageErrorException]:
|
||||
) -> ToolCalling | InstructorToolCalling | ToolUsageErrorException:
|
||||
tool_name = self.action.tool
|
||||
tool = self._select_tool(tool_name)
|
||||
try:
|
||||
@@ -488,7 +503,7 @@ class ToolUsage:
|
||||
|
||||
def _tool_calling(
|
||||
self, tool_string: str
|
||||
) -> Union[ToolCalling, InstructorToolCalling, ToolUsageErrorException]:
|
||||
) -> ToolCalling | InstructorToolCalling | ToolUsageErrorException:
|
||||
try:
|
||||
try:
|
||||
return self._original_tool_calling(tool_string, raise_error=True)
|
||||
@@ -505,12 +520,12 @@ class ToolUsage:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(content=f"\n\n{e}\n", color="red")
|
||||
return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling")
|
||||
return ToolUsageErrorException(
|
||||
f"{self._i18n.errors('tool_usage_error').format(error=e)}\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||
)
|
||||
return self._tool_calling(tool_string)
|
||||
|
||||
def _validate_tool_input(self, tool_input: Optional[str]) -> Dict[str, Any]:
|
||||
def _validate_tool_input(self, tool_input: Optional[str]) -> dict[str, Any]:
|
||||
if tool_input is None:
|
||||
return {}
|
||||
|
||||
@@ -564,7 +579,7 @@ class ToolUsage:
|
||||
# If all parsing attempts fail, raise an error
|
||||
raise Exception(error_message)
|
||||
|
||||
def _emit_validate_input_error(self, final_error: str):
|
||||
def _emit_validate_input_error(self, final_error: str) -> None:
|
||||
tool_selection_data = {
|
||||
"agent_key": getattr(self.agent, "key", None) if self.agent else None,
|
||||
"agent_role": getattr(self.agent, "role", None) if self.agent else None,
|
||||
@@ -586,7 +601,7 @@ class ToolUsage:
|
||||
def on_tool_error(
|
||||
self,
|
||||
tool: Any,
|
||||
tool_calling: Union[ToolCalling, InstructorToolCalling],
|
||||
tool_calling: ToolCalling | InstructorToolCalling,
|
||||
e: Exception,
|
||||
) -> None:
|
||||
event_data = self._prepare_event_data(tool, tool_calling)
|
||||
@@ -595,7 +610,7 @@ class ToolUsage:
|
||||
def on_tool_use_finished(
|
||||
self,
|
||||
tool: Any,
|
||||
tool_calling: Union[ToolCalling, InstructorToolCalling],
|
||||
tool_calling: ToolCalling | InstructorToolCalling,
|
||||
from_cache: bool,
|
||||
started_at: float,
|
||||
result: Any,
|
||||
@@ -616,8 +631,21 @@ class ToolUsage:
|
||||
crewai_event_bus.emit(self, ToolUsageFinishedEvent(**event_data))
|
||||
|
||||
def _prepare_event_data(
|
||||
self, tool: Any, tool_calling: Union[ToolCalling, InstructorToolCalling]
|
||||
) -> dict:
|
||||
self, tool: Any, tool_calling: ToolCalling | InstructorToolCalling
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare event data for tool usage events.
|
||||
|
||||
Args:
|
||||
tool: The tool being used.
|
||||
tool_calling: The tool calling information containing arguments.
|
||||
|
||||
Returns:
|
||||
A dictionary containing event data for tool usage tracking.
|
||||
|
||||
Notes:
|
||||
TODO: Create a better type representation for the return value,
|
||||
possibly using TypedDict or a dataclass for stronger typing.
|
||||
"""
|
||||
event_data = {
|
||||
"run_attempts": self._run_attempts,
|
||||
"delegations": self.task.delegations if self.task else 0,
|
||||
@@ -641,7 +669,7 @@ class ToolUsage:
|
||||
|
||||
return event_data
|
||||
|
||||
def _add_fingerprint_metadata(self, arguments: dict) -> dict:
|
||||
def _add_fingerprint_metadata(self, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add fingerprint metadata to tool arguments if available.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import List
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities import Converter
|
||||
from crewai.events.types.task_events import TaskEvaluationEvent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.task_events import TaskEvaluationEvent
|
||||
from crewai.utilities import Converter
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
from crewai.utilities.training_converter import TrainingConverter
|
||||
|
||||
@@ -13,23 +13,23 @@ class Entity(BaseModel):
|
||||
name: str = Field(description="The name of the entity.")
|
||||
type: str = Field(description="The type of the entity.")
|
||||
description: str = Field(description="Description of the entity.")
|
||||
relationships: List[str] = Field(description="Relationships of the entity.")
|
||||
relationships: list[str] = Field(description="Relationships of the entity.")
|
||||
|
||||
|
||||
class TaskEvaluation(BaseModel):
|
||||
suggestions: List[str] = Field(
|
||||
suggestions: list[str] = Field(
|
||||
description="Suggestions to improve future similar tasks."
|
||||
)
|
||||
quality: float = Field(
|
||||
description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task."
|
||||
)
|
||||
entities: List[Entity] = Field(
|
||||
entities: list[Entity] = Field(
|
||||
description="Entities extracted from the task output."
|
||||
)
|
||||
|
||||
|
||||
class TrainingTaskEvaluation(BaseModel):
|
||||
suggestions: List[str] = Field(
|
||||
suggestions: list[str] = Field(
|
||||
description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions."
|
||||
)
|
||||
quality: float = Field(
|
||||
@@ -41,11 +41,11 @@ class TrainingTaskEvaluation(BaseModel):
|
||||
|
||||
|
||||
class TaskEvaluator:
|
||||
def __init__(self, original_agent):
|
||||
def __init__(self, original_agent: Any) -> None:
|
||||
self.llm = original_agent.llm
|
||||
self.original_agent = original_agent
|
||||
|
||||
def evaluate(self, task, output) -> TaskEvaluation:
|
||||
def evaluate(self, task: Any, output: Any) -> TaskEvaluation:
|
||||
crewai_event_bus.emit(
|
||||
self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task)
|
||||
)
|
||||
@@ -73,10 +73,10 @@ class TaskEvaluator:
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
return converter.to_pydantic()
|
||||
return cast(TaskEvaluation, converter.to_pydantic())
|
||||
|
||||
def evaluate_training_data(
|
||||
self, training_data: dict, agent_id: str
|
||||
self, training_data: dict[str, Any], agent_id: str
|
||||
) -> TrainingTaskEvaluation:
|
||||
"""
|
||||
Evaluate the training data based on the llm output, human feedback, and improved output.
|
||||
@@ -143,4 +143,4 @@ class TaskEvaluator:
|
||||
)
|
||||
|
||||
pydantic_result = converter.to_pydantic()
|
||||
return pydantic_result
|
||||
return cast(TrainingTaskEvaluation, pydantic_result)
|
||||
|
||||
@@ -2,33 +2,39 @@ import json
|
||||
import os
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
|
||||
class FileHandler:
|
||||
"""Handler for file operations supporting both JSON and text-based logging.
|
||||
|
||||
|
||||
Args:
|
||||
file_path (Union[bool, str]): Path to the log file or boolean flag
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: Union[bool, str]):
|
||||
def __init__(self, file_path: bool | str):
|
||||
self._initialize_path(file_path)
|
||||
|
||||
def _initialize_path(self, file_path: Union[bool, str]):
|
||||
|
||||
def _initialize_path(self, file_path: bool | str) -> None:
|
||||
if file_path is True: # File path is boolean True
|
||||
self._path = os.path.join(os.curdir, "logs.txt")
|
||||
|
||||
|
||||
elif isinstance(file_path, str): # File path is a string
|
||||
if file_path.endswith((".json", ".txt")):
|
||||
self._path = file_path # No modification if the file ends with .json or .txt
|
||||
self._path = (
|
||||
file_path # No modification if the file ends with .json or .txt
|
||||
)
|
||||
else:
|
||||
self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt
|
||||
|
||||
self._path = (
|
||||
file_path + ".txt"
|
||||
) # Append .txt if the file doesn't end with .json or .txt
|
||||
|
||||
else:
|
||||
raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid
|
||||
|
||||
def log(self, **kwargs):
|
||||
raise ValueError(
|
||||
"file_path must be a string or boolean."
|
||||
) # Handle the case where file_path isn't valid
|
||||
|
||||
def log(self, **kwargs: Any) -> None:
|
||||
try:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = {"timestamp": now, **kwargs}
|
||||
@@ -45,20 +51,25 @@ class FileHandler:
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# If no valid JSON or file doesn't exist, start with an empty list
|
||||
existing_data = [log_entry]
|
||||
|
||||
|
||||
with open(self._path, "w", encoding="utf-8") as write_file:
|
||||
json.dump(existing_data, write_file, indent=4)
|
||||
write_file.write("\n")
|
||||
|
||||
|
||||
else:
|
||||
# Append log in plain text format
|
||||
message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n"
|
||||
message = (
|
||||
f"{now}: "
|
||||
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
|
||||
+ "\n"
|
||||
)
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to log message: {str(e)}")
|
||||
|
||||
|
||||
|
||||
class PickleHandler:
|
||||
def __init__(self, file_name: str) -> None:
|
||||
"""
|
||||
@@ -79,7 +90,7 @@ class PickleHandler:
|
||||
"""
|
||||
self.save({})
|
||||
|
||||
def save(self, data) -> None:
|
||||
def save(self, data: Any) -> None:
|
||||
"""
|
||||
Save the data to the specified file using pickle.
|
||||
|
||||
@@ -89,12 +100,12 @@ class PickleHandler:
|
||||
with open(self.file_path, "wb") as file:
|
||||
pickle.dump(data, file)
|
||||
|
||||
def load(self) -> dict:
|
||||
def load(self) -> Any:
|
||||
"""
|
||||
Load the data from the specified file using pickle.
|
||||
|
||||
Returns:
|
||||
- dict: The data loaded from the file.
|
||||
- The data loaded from the file.
|
||||
"""
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
return {} # Return an empty dictionary if the file does not exist or is empty
|
||||
|
||||
@@ -1,78 +1,105 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Protocol, TypedDict, runtime_checkable
|
||||
|
||||
from typing_extensions import Required
|
||||
|
||||
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
class LLMParams(TypedDict, total=False):
|
||||
"""TypedDict defining LLM parameters we extract from LLMLike objects."""
|
||||
|
||||
model: Required[str]
|
||||
temperature: float
|
||||
max_tokens: int
|
||||
logprobs: int
|
||||
timeout: float
|
||||
api_key: str
|
||||
base_url: str
|
||||
api_base: str
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LLMLike(Protocol):
|
||||
"""Protocol for objects that can be converted to an LLM instance."""
|
||||
|
||||
model: str | None
|
||||
model_name: str | None
|
||||
deployment_name: str | None
|
||||
temperature: float | None
|
||||
max_tokens: int | None
|
||||
logprobs: int | None
|
||||
timeout: float | None
|
||||
api_key: str | None
|
||||
base_url: str | None
|
||||
api_base: str | None
|
||||
|
||||
|
||||
def create_default_llm() -> LLM:
|
||||
"""Creates a default LLM instance using environment variables or fallback defaults.
|
||||
|
||||
Returns:
|
||||
A default LLM instance configured from environment or using defaults.
|
||||
|
||||
Raises:
|
||||
ValueError: If LLM creation fails.
|
||||
"""
|
||||
result = _llm_via_environment_or_fallback()
|
||||
if result is None:
|
||||
raise ValueError("Failed to create default LLM instance")
|
||||
return result
|
||||
|
||||
|
||||
def create_llm(
|
||||
llm_value: Union[str, LLM, Any, None] = None,
|
||||
) -> Optional[LLM | BaseLLM]:
|
||||
llm_value: str | BaseLLM | LLMLike,
|
||||
) -> BaseLLM:
|
||||
"""
|
||||
Creates or returns an LLM instance based on the given llm_value.
|
||||
|
||||
Args:
|
||||
llm_value (str | BaseLLM | Any | None):
|
||||
llm_value (str | BaseLLM | LLMLike):
|
||||
- str: The model name (e.g., "gpt-4").
|
||||
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
|
||||
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||
- None: Use environment-based or fallback default model.
|
||||
- LLMLike: Object with LLM-compatible attributes (model_name, temperature, etc.)
|
||||
|
||||
Returns:
|
||||
A BaseLLM instance if successful, or None if something fails.
|
||||
"""
|
||||
A BaseLLM instance.
|
||||
|
||||
# 1) If llm_value is already a BaseLLM or LLM object, return it directly
|
||||
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
|
||||
Raises:
|
||||
ValueError: If LLM creation fails.
|
||||
"""
|
||||
if isinstance(llm_value, BaseLLM):
|
||||
return llm_value
|
||||
|
||||
# 2) If llm_value is a string (model name)
|
||||
if isinstance(llm_value, str):
|
||||
try:
|
||||
created_llm = LLM(model=llm_value)
|
||||
return created_llm
|
||||
except Exception as e:
|
||||
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
|
||||
return None
|
||||
return LLM(model=llm_value)
|
||||
|
||||
# 3) If llm_value is None, parse environment variables or use default
|
||||
if llm_value is None:
|
||||
return _llm_via_environment_or_fallback()
|
||||
|
||||
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||
try:
|
||||
# Extract attributes with explicit types
|
||||
model = (
|
||||
obj_attrs = set(dir(llm_value))
|
||||
|
||||
llm_kwargs = {
|
||||
param: getattr(llm_value, param)
|
||||
for param in LLMParams.__annotations__
|
||||
if param != "model"
|
||||
and param in obj_attrs
|
||||
and getattr(llm_value, param) is not None
|
||||
}
|
||||
|
||||
llm_kwargs["model"] = (
|
||||
getattr(llm_value, "model", None)
|
||||
or getattr(llm_value, "model_name", None)
|
||||
or getattr(llm_value, "deployment_name", None)
|
||||
or str(llm_value)
|
||||
)
|
||||
temperature: Optional[float] = getattr(llm_value, "temperature", None)
|
||||
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None)
|
||||
logprobs: Optional[int] = getattr(llm_value, "logprobs", None)
|
||||
timeout: Optional[float] = getattr(llm_value, "timeout", None)
|
||||
api_key: Optional[str] = getattr(llm_value, "api_key", None)
|
||||
base_url: Optional[str] = getattr(llm_value, "base_url", None)
|
||||
api_base: Optional[str] = getattr(llm_value, "api_base", None)
|
||||
|
||||
created_llm = LLM(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=logprobs,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
)
|
||||
return created_llm
|
||||
return LLM(**llm_kwargs)
|
||||
except Exception as e:
|
||||
print(f"Error instantiating LLM from unknown object type: {e}")
|
||||
return None
|
||||
raise ValueError(f"Error instantiating LLM from object: {e}")
|
||||
|
||||
|
||||
def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
def _llm_via_environment_or_fallback() -> LLM | None:
|
||||
"""
|
||||
Helper function: if llm_value is None, we load environment variables or fallback default model.
|
||||
"""
|
||||
@@ -85,24 +112,24 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
|
||||
# Initialize parameters with correct types
|
||||
model: str = model_name
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
max_completion_tokens: Optional[int] = None
|
||||
logprobs: Optional[int] = None
|
||||
timeout: Optional[float] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
response_format: Optional[Dict[str, Any]] = None
|
||||
seed: Optional[int] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
callbacks: List[Any] = []
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
logprobs: int | None = None
|
||||
timeout: float | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
api_version: str | None = None
|
||||
presence_penalty: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
top_p: float | None = None
|
||||
n: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
logit_bias: dict[int, float] | None = None
|
||||
response_format: dict[str, Any] | None = None
|
||||
seed: int | None = None
|
||||
top_logprobs: int | None = None
|
||||
callbacks: list[Any] = []
|
||||
|
||||
# Optional base URL from env
|
||||
base_url = (
|
||||
@@ -120,7 +147,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
base_url = api_base
|
||||
|
||||
# Initialize llm_params dictionary
|
||||
llm_params: Dict[str, Any] = {
|
||||
llm_params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
|
||||
@@ -10,7 +10,7 @@ class Logger(BaseModel):
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
default_color: str = Field(default="bold_yellow")
|
||||
|
||||
def log(self, level, message, color=None):
|
||||
def log(self, level: str, message: str, color: str | None = None) -> None:
|
||||
if color is None:
|
||||
color = self.default_color
|
||||
if self.verbose:
|
||||
|
||||
@@ -20,18 +20,18 @@ class RPMController(BaseModel):
|
||||
_shutdown_flag: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def reset_counter(self):
|
||||
def reset_counter(self) -> "RPMController":
|
||||
if self.max_rpm is not None:
|
||||
if not self._shutdown_flag:
|
||||
self._lock = threading.Lock()
|
||||
self._reset_request_count()
|
||||
return self
|
||||
|
||||
def check_or_wait(self):
|
||||
def check_or_wait(self) -> bool:
|
||||
if self.max_rpm is None:
|
||||
return True
|
||||
|
||||
def _check_and_increment():
|
||||
def _check_and_increment() -> bool:
|
||||
if self.max_rpm is not None and self._current_rpm < self.max_rpm:
|
||||
self._current_rpm += 1
|
||||
return True
|
||||
@@ -50,17 +50,17 @@ class RPMController(BaseModel):
|
||||
else:
|
||||
return _check_and_increment()
|
||||
|
||||
def stop_rpm_counter(self):
|
||||
def stop_rpm_counter(self) -> None:
|
||||
if self._timer:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
|
||||
def _wait_for_next_minute(self):
|
||||
def _wait_for_next_minute(self) -> None:
|
||||
time.sleep(60)
|
||||
self._current_rpm = 0
|
||||
|
||||
def _reset_request_count(self):
|
||||
def _reset():
|
||||
def _reset_request_count(self) -> None:
|
||||
def _reset() -> None:
|
||||
self._current_rpm = 0
|
||||
if not self._shutdown_flag:
|
||||
self._timer = threading.Timer(60.0, self._reset_request_count)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -16,10 +16,10 @@ class ExecutionLog(BaseModel):
|
||||
|
||||
task_id: str
|
||||
expected_output: Optional[str] = None
|
||||
output: Dict[str, Any]
|
||||
output: dict[str, Any]
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
task_index: int
|
||||
inputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
was_replayed: bool = False
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
@@ -33,7 +33,7 @@ class TaskOutputStorageHandler:
|
||||
def __init__(self) -> None:
|
||||
self.storage = KickoffTaskOutputsSQLiteStorage()
|
||||
|
||||
def update(self, task_index: int, log: Dict[str, Any]):
|
||||
def update(self, task_index: int, log: dict[str, Any]) -> None:
|
||||
saved_outputs = self.load()
|
||||
if saved_outputs is None:
|
||||
raise ValueError("Logs cannot be None")
|
||||
@@ -56,16 +56,16 @@ class TaskOutputStorageHandler:
|
||||
def add(
|
||||
self,
|
||||
task: Task,
|
||||
output: Dict[str, Any],
|
||||
output: dict[str, Any],
|
||||
task_index: int,
|
||||
inputs: Dict[str, Any] | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
was_replayed: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
inputs = inputs or {}
|
||||
self.storage.add(task, output, task_index, was_replayed, inputs)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.storage.delete_all()
|
||||
|
||||
def load(self) -> Optional[List[Dict[str, Any]]]:
|
||||
def load(self) -> Optional[list[dict[str, Any]]]:
|
||||
return self.storage.load()
|
||||
|
||||
@@ -7,21 +7,21 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import AgentFinish, CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.process import Process
|
||||
from crewai.tools import tool
|
||||
from crewai.tools.tool_calling import InstructorToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities import RPMController
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.process import Process
|
||||
|
||||
|
||||
def test_agent_llm_creation_with_env_vars():
|
||||
@@ -287,8 +287,8 @@ def test_cache_hitting():
|
||||
output = agent.execute_task(task1)
|
||||
output = agent.execute_task(task2)
|
||||
assert cache_handler._cache == {
|
||||
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
|
||||
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
|
||||
'multiplier-{"first_number": 2, "second_number": 6}': 12,
|
||||
'multiplier-{"first_number": 3, "second_number": 3}': 9,
|
||||
}
|
||||
|
||||
task = Task(
|
||||
@@ -300,9 +300,9 @@ def test_cache_hitting():
|
||||
assert output == "36"
|
||||
|
||||
assert cache_handler._cache == {
|
||||
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
|
||||
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
|
||||
"multiplier-{'first_number': 12, 'second_number': 3}": 36,
|
||||
'multiplier-{"first_number": 2, "second_number": 6}': 12,
|
||||
'multiplier-{"first_number": 3, "second_number": 3}': 9,
|
||||
'multiplier-{"first_number": 12, "second_number": 3}': 36,
|
||||
}
|
||||
received_events = []
|
||||
|
||||
@@ -322,7 +322,7 @@ def test_cache_hitting():
|
||||
output = agent.execute_task(task)
|
||||
assert output == "0"
|
||||
read.assert_called_with(
|
||||
tool="multiplier", input={"first_number": 2, "second_number": 6}
|
||||
tool="multiplier", input_data={"first_number": 2, "second_number": 6}
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], ToolUsageFinishedEvent)
|
||||
@@ -559,9 +559,9 @@ def test_agent_repeated_tool_usage(capsys):
|
||||
expected_message = (
|
||||
"I tried reusing the same input, I must stop using this action input."
|
||||
)
|
||||
assert (
|
||||
expected_message in output
|
||||
), f"Expected message not found in output. Output was: {output}"
|
||||
assert expected_message in output, (
|
||||
f"Expected message not found in output. Output was: {output}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -602,9 +602,9 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
has_max_iterations = "maximum iterations reached" in output_lower
|
||||
has_final_answer = "final answer" in output_lower or "42" in captured.out
|
||||
|
||||
assert (
|
||||
has_repeated_usage_message or (has_max_iterations and has_final_answer)
|
||||
), f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
assert has_repeated_usage_message or (has_max_iterations and has_final_answer), (
|
||||
f"Expected repeated tool usage handling or proper max iteration handling. Output was: {captured.out[:500]}..."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -1116,7 +1116,9 @@ def test_not_using_system_prompt():
|
||||
use_system_prompt=False,
|
||||
)
|
||||
|
||||
agent.create_agent_executor()
|
||||
# Create a dummy task for testing
|
||||
task = Task(description="Test task", expected_output="Test output", agent=agent)
|
||||
agent.create_agent_executor(task)
|
||||
assert not agent.agent_executor.prompt.get("user")
|
||||
assert not agent.agent_executor.prompt.get("system")
|
||||
|
||||
@@ -1128,7 +1130,9 @@ def test_using_system_prompt():
|
||||
backstory="I am the master of {role}",
|
||||
)
|
||||
|
||||
agent.create_agent_executor()
|
||||
# Create a dummy task for testing
|
||||
task = Task(description="Test task", expected_output="Test output", agent=agent)
|
||||
agent.create_agent_executor(task)
|
||||
assert agent.agent_executor.prompt.get("user")
|
||||
assert agent.agent_executor.prompt.get("system")
|
||||
|
||||
@@ -2312,9 +2316,9 @@ def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
from crewai_tools import (
|
||||
SerperDevTool,
|
||||
FileReadTool,
|
||||
EnterpriseActionTool,
|
||||
FileReadTool,
|
||||
SerperDevTool,
|
||||
)
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import pytest
|
||||
from unittest.mock import ANY
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -98,7 +100,7 @@ def test_long_term_memory_search_events(long_term_memory):
|
||||
|
||||
test_query = "test query"
|
||||
|
||||
long_term_memory.search(test_query, latest_n=5)
|
||||
long_term_memory.search(test_query, limit=5)
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
@@ -151,7 +153,7 @@ def test_save_and_search(long_term_memory):
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
find = long_term_memory.search("test_task", latest_n=5)[0]
|
||||
find = long_term_memory.search("test_task", limit=5)[0]
|
||||
assert find["score"] == 0.5
|
||||
assert find["datetime"] == "test_datetime"
|
||||
assert find["metadata"]["agent"] == "test_agent"
|
||||
|
||||
@@ -2,23 +2,41 @@
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
from collections import defaultdict
|
||||
|
||||
import pydantic_core
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
@@ -27,28 +45,9 @@ from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ceo():
|
||||
@@ -570,8 +569,6 @@ def test_crew_with_delegating_agents(ceo, writer):
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer):
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -584,7 +581,7 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -622,18 +619,16 @@ def test_crew_with_delegating_agents_should_not_override_task_tools(ceo, writer)
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
tools = kwargs["tools"]
|
||||
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in tools
|
||||
), "TestTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in tools
|
||||
), "Delegation tool should be present"
|
||||
assert any(isinstance(tool, TestTool) for tool in tools), (
|
||||
"TestTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer):
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -646,7 +641,7 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -686,18 +681,16 @@ def test_crew_with_delegating_agents_should_not_override_agent_tools(ceo, writer
|
||||
_, kwargs = mock_execute_sync.call_args
|
||||
tools = kwargs["tools"]
|
||||
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in new_ceo.tools
|
||||
), "TestTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in tools
|
||||
), "Delegation tool should be present"
|
||||
assert any(isinstance(tool, TestTool) for tool in new_ceo.tools), (
|
||||
"TestTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_task_tools_override_agent_tools(researcher):
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -710,7 +703,7 @@ def test_task_tools_override_agent_tools(researcher):
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -718,7 +711,7 @@ def test_task_tools_override_agent_tools(researcher):
|
||||
class AnotherTestTool(BaseTool):
|
||||
name: str = "Another Test Tool"
|
||||
description: str = "Another test tool"
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Another processed: {query}"
|
||||
@@ -754,7 +747,6 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
"""
|
||||
Test that task tools override agent tools while preserving delegation tools when allow_delegation=True
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -766,7 +758,7 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -774,7 +766,7 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
class AnotherTestTool(BaseTool):
|
||||
name: str = "Another Test Tool"
|
||||
description: str = "Another test tool"
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Another processed: {query}"
|
||||
@@ -815,17 +807,17 @@ def test_task_tools_override_agent_tools_with_allow_delegation(researcher, write
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Confirm AnotherTestTool is present but TestTool is not
|
||||
assert any(
|
||||
isinstance(tool, AnotherTestTool) for tool in used_tools
|
||||
), "AnotherTestTool should be present"
|
||||
assert not any(
|
||||
isinstance(tool, TestTool) for tool in used_tools
|
||||
), "TestTool should not be present among used tools"
|
||||
assert any(isinstance(tool, AnotherTestTool) for tool in used_tools), (
|
||||
"AnotherTestTool should be present"
|
||||
)
|
||||
assert not any(isinstance(tool, TestTool) for tool in used_tools), (
|
||||
"TestTool should not be present among used tools"
|
||||
)
|
||||
|
||||
# Confirm delegation tool(s) are present
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in used_tools
|
||||
), "Delegation tool should be present"
|
||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
|
||||
# Finally, make sure the agent's original tools remain unchanged
|
||||
assert len(researcher_with_delegation.tools) == 1
|
||||
@@ -912,13 +904,13 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
|
||||
crew.kickoff()
|
||||
assert read.call_count == 2, "read was not called exactly twice"
|
||||
|
||||
# Filter the mock calls to only include the ones with 'tool' and 'input' keywords
|
||||
# Filter the mock calls to only include the ones with 'tool' and 'input_data' keywords
|
||||
cache_calls = [
|
||||
call
|
||||
for call in read.call_args_list
|
||||
if len(call.kwargs) == 2
|
||||
and "tool" in call.kwargs
|
||||
and "input" in call.kwargs
|
||||
and "input_data" in call.kwargs
|
||||
]
|
||||
|
||||
# Check if we have the expected number of cache calls
|
||||
@@ -926,12 +918,12 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
|
||||
|
||||
# Check if both calls were made with the expected arguments
|
||||
expected_call = call(
|
||||
tool="multiplier", input={"first_number": 2, "second_number": 6}
|
||||
tool="multiplier", input_data={"first_number": 2, "second_number": 6}
|
||||
)
|
||||
assert cache_calls[0] == expected_call, f"First call mismatch: {cache_calls[0]}"
|
||||
assert (
|
||||
cache_calls[1] == expected_call
|
||||
), f"Second call mismatch: {cache_calls[1]}"
|
||||
assert cache_calls[1] == expected_call, (
|
||||
f"Second call mismatch: {cache_calls[1]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -1674,9 +1666,9 @@ def test_code_execution_flag_adds_code_tool_upon_kickoff():
|
||||
|
||||
# Verify that exactly one tool was used and it was a CodeInterpreterTool
|
||||
assert len(used_tools) == 1, "Should have exactly one tool"
|
||||
assert isinstance(
|
||||
used_tools[0], CodeInterpreterTool
|
||||
), "Tool should be CodeInterpreterTool"
|
||||
assert isinstance(used_tools[0], CodeInterpreterTool), (
|
||||
"Tool should be CodeInterpreterTool"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -2237,7 +2229,7 @@ def test_tools_with_custom_caching():
|
||||
# Verify that one of those calls was with the even number that should be cached
|
||||
add_to_cache.assert_any_call(
|
||||
tool="multiplcation_tool",
|
||||
input={"first_number": 2, "second_number": 6},
|
||||
input_data={"first_number": 2, "second_number": 6},
|
||||
output=12,
|
||||
)
|
||||
|
||||
@@ -3815,16 +3807,15 @@ def test_fetch_inputs():
|
||||
expected_placeholders = {"role_detail", "topic", "field"}
|
||||
actual_placeholders = crew.fetch_inputs()
|
||||
|
||||
assert (
|
||||
actual_placeholders == expected_placeholders
|
||||
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||
assert actual_placeholders == expected_placeholders, (
|
||||
f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||
)
|
||||
|
||||
|
||||
def test_task_tools_preserve_code_execution_tools():
|
||||
"""
|
||||
Test that task tools don't override code execution tools when allow_code_execution=True
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
# Mock embedchain initialization to prevent race conditions in parallel CI execution
|
||||
with patch("embedchain.client.Client.setup"):
|
||||
@@ -3841,7 +3832,7 @@ def test_task_tools_preserve_code_execution_tools():
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A test tool that just returns the input"
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
args_schema: type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
@@ -3892,20 +3883,20 @@ def test_task_tools_preserve_code_execution_tools():
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Verify all expected tools are present
|
||||
assert any(
|
||||
isinstance(tool, TestTool) for tool in used_tools
|
||||
), "Task's TestTool should be present"
|
||||
assert any(
|
||||
isinstance(tool, CodeInterpreterTool) for tool in used_tools
|
||||
), "CodeInterpreterTool should be present"
|
||||
assert any(
|
||||
"delegate" in tool.name.lower() for tool in used_tools
|
||||
), "Delegation tool should be present"
|
||||
assert any(isinstance(tool, TestTool) for tool in used_tools), (
|
||||
"Task's TestTool should be present"
|
||||
)
|
||||
assert any(isinstance(tool, CodeInterpreterTool) for tool in used_tools), (
|
||||
"CodeInterpreterTool should be present"
|
||||
)
|
||||
assert any("delegate" in tool.name.lower() for tool in used_tools), (
|
||||
"Delegation tool should be present"
|
||||
)
|
||||
|
||||
# Verify the total number of tools (TestTool + CodeInterpreter + 2 delegation tools)
|
||||
assert (
|
||||
len(used_tools) == 4
|
||||
), "Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
||||
assert len(used_tools) == 4, (
|
||||
"Should have TestTool, CodeInterpreter, and 2 delegation tools"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -3949,9 +3940,9 @@ def test_multimodal_flag_adds_multimodal_tools():
|
||||
used_tools = kwargs["tools"]
|
||||
|
||||
# Check that the multimodal tool was added
|
||||
assert any(
|
||||
isinstance(tool, AddImageTool) for tool in used_tools
|
||||
), "AddImageTool should be present when agent is multimodal"
|
||||
assert any(isinstance(tool, AddImageTool) for tool in used_tools), (
|
||||
"AddImageTool should be present when agent is multimodal"
|
||||
)
|
||||
|
||||
# Verify we have exactly one tool (just the AddImageTool)
|
||||
assert len(used_tools) == 1, "Should only have the AddImageTool"
|
||||
@@ -4215,9 +4206,9 @@ def test_crew_guardrail_feedback_in_context():
|
||||
assert len(execution_contexts) > 1, "Task should have been executed multiple times"
|
||||
|
||||
# Verify that the second execution included the guardrail feedback
|
||||
assert (
|
||||
"Output must contain the keyword 'IMPORTANT'" in execution_contexts[1]
|
||||
), "Guardrail feedback should be included in retry context"
|
||||
assert "Output must contain the keyword 'IMPORTANT'" in execution_contexts[1], (
|
||||
"Guardrail feedback should be included in retry context"
|
||||
)
|
||||
|
||||
# Verify final output meets guardrail requirements
|
||||
assert "IMPORTANT" in result.raw, "Final output should contain required keyword"
|
||||
@@ -4232,13 +4223,11 @@ def test_before_kickoff_callback():
|
||||
|
||||
@CrewBase
|
||||
class TestCrewClass:
|
||||
from typing import List
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.project import CrewBase, agent, before_kickoff, crew, task
|
||||
|
||||
agents: List[BaseAgent]
|
||||
tasks: List[Task]
|
||||
agents: list[BaseAgent]
|
||||
tasks: list[Task]
|
||||
|
||||
agents_config = None
|
||||
tasks_config = None
|
||||
@@ -4433,46 +4422,46 @@ def test_crew_copy_with_memory():
|
||||
try:
|
||||
crew_copy = crew.copy()
|
||||
|
||||
assert hasattr(
|
||||
crew_copy, "_short_term_memory"
|
||||
), "Copied crew should have _short_term_memory"
|
||||
assert (
|
||||
crew_copy._short_term_memory is not None
|
||||
), "Copied _short_term_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._short_term_memory) != original_short_term_id
|
||||
), "Copied _short_term_memory should be a new object"
|
||||
assert hasattr(crew_copy, "_short_term_memory"), (
|
||||
"Copied crew should have _short_term_memory"
|
||||
)
|
||||
assert crew_copy._short_term_memory is not None, (
|
||||
"Copied _short_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._short_term_memory) != original_short_term_id, (
|
||||
"Copied _short_term_memory should be a new object"
|
||||
)
|
||||
|
||||
assert hasattr(
|
||||
crew_copy, "_long_term_memory"
|
||||
), "Copied crew should have _long_term_memory"
|
||||
assert (
|
||||
crew_copy._long_term_memory is not None
|
||||
), "Copied _long_term_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._long_term_memory) != original_long_term_id
|
||||
), "Copied _long_term_memory should be a new object"
|
||||
assert hasattr(crew_copy, "_long_term_memory"), (
|
||||
"Copied crew should have _long_term_memory"
|
||||
)
|
||||
assert crew_copy._long_term_memory is not None, (
|
||||
"Copied _long_term_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._long_term_memory) != original_long_term_id, (
|
||||
"Copied _long_term_memory should be a new object"
|
||||
)
|
||||
|
||||
assert hasattr(
|
||||
crew_copy, "_entity_memory"
|
||||
), "Copied crew should have _entity_memory"
|
||||
assert (
|
||||
crew_copy._entity_memory is not None
|
||||
), "Copied _entity_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._entity_memory) != original_entity_id
|
||||
), "Copied _entity_memory should be a new object"
|
||||
assert hasattr(crew_copy, "_entity_memory"), (
|
||||
"Copied crew should have _entity_memory"
|
||||
)
|
||||
assert crew_copy._entity_memory is not None, (
|
||||
"Copied _entity_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._entity_memory) != original_entity_id, (
|
||||
"Copied _entity_memory should be a new object"
|
||||
)
|
||||
|
||||
if original_external_id:
|
||||
assert hasattr(
|
||||
crew_copy, "_external_memory"
|
||||
), "Copied crew should have _external_memory"
|
||||
assert (
|
||||
crew_copy._external_memory is not None
|
||||
), "Copied _external_memory should not be None"
|
||||
assert (
|
||||
id(crew_copy._external_memory) != original_external_id
|
||||
), "Copied _external_memory should be a new object"
|
||||
assert hasattr(crew_copy, "_external_memory"), (
|
||||
"Copied crew should have _external_memory"
|
||||
)
|
||||
assert crew_copy._external_memory is not None, (
|
||||
"Copied _external_memory should not be None"
|
||||
)
|
||||
assert id(crew_copy._external_memory) != original_external_id, (
|
||||
"Copied _external_memory should be a new object"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
not hasattr(crew_copy, "_external_memory")
|
||||
|
||||
@@ -7,10 +7,8 @@ from pydantic import Field
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
@@ -24,9 +22,6 @@ from crewai.events.types.crew_events import (
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.types.tool_usage_events import ToolUsageFinishedEvent
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
@@ -47,7 +42,12 @@ from crewai.events.types.task_events import (
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -327,9 +327,9 @@ def test_agent_emits_execution_error_event(base_agent, base_task):
|
||||
|
||||
error_message = "Error happening while sending prompt to model."
|
||||
base_agent.max_retry_limit = 0
|
||||
with patch.object(
|
||||
CrewAgentExecutor, "invoke", wraps=base_agent.agent_executor.invoke
|
||||
) as invoke_mock:
|
||||
|
||||
# Patch the invoke method on the CrewAgentExecutor class directly
|
||||
with patch.object(CrewAgentExecutor, "invoke") as invoke_mock:
|
||||
invoke_mock.side_effect = Exception(error_message)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
|
||||
Reference in New Issue
Block a user