fix: improve type annotations across codebase

This commit is contained in:
Greyson LaLonde
2025-09-03 22:29:41 -04:00
parent 43880b49a6
commit b94fbd3d3a
7 changed files with 318 additions and 289 deletions

View File

@@ -1,29 +1,42 @@
import shutil
import subprocess
import time
from collections.abc import Callable, Sequence
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
)
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.lite_agent import LiteAgent, LiteAgentOutput
from crewai.llm import BaseLLM
from crewai.llms.base_llm import BaseLLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.security import Fingerprint
from crewai.task import Task
@@ -38,24 +51,6 @@ from crewai.utilities.agent_utils import (
)
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryRetrievalStartedEvent,
MemoryRetrievalCompletedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -101,10 +96,10 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
llm: str | InstanceOf[BaseLLM] | Any = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
function_calling_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
description="Language model that will run the agent.", default=None
)
system_template: Optional[str] = Field(
@@ -151,7 +146,7 @@ class Agent(BaseAgent):
default=None,
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
)
embedder: Optional[Dict[str, Any]] = Field(
embedder: Optional[dict[str, Any]] = Field(
default=None,
description="Embedder configuration for the agent.",
)
@@ -171,7 +166,7 @@ class Agent(BaseAgent):
default=None,
description="The Agent's role to be used from your repository.",
)
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
guardrail: Optional[Callable[[Any], tuple[bool, Any]] | str] = Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
@@ -180,13 +175,14 @@ class Agent(BaseAgent):
)
@model_validator(mode="before")
def validate_from_repository(cls, v):
@classmethod
def validate_from_repository(cls, v: Any) -> Any:
if v is not None and (from_repository := v.get("from_repository")):
return load_agent_from_repository(from_repository) | v
return v
@model_validator(mode="after")
def post_init_setup(self):
def post_init_setup(self) -> Self:
self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm)
@@ -203,12 +199,12 @@ class Agent(BaseAgent):
return self
def _setup_agent_executor(self):
def _setup_agent_executor(self) -> None:
if not self.cache_handler:
self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler)
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
def set_knowledge(self, crew_embedder: Optional[dict[str, Any]] = None) -> None:
try:
if self.embedder is None and crew_embedder:
self.embedder = crew_embedder
@@ -245,7 +241,7 @@ class Agent(BaseAgent):
self,
task: Task,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
tools: Optional[list[BaseTool]] = None,
) -> str:
"""Execute a task with the agent.
@@ -492,7 +488,7 @@ class Agent(BaseAgent):
# If there was any tool in self.tools_results that had result_as_answer
# set to True, return the results of the last tool that had
# result_as_answer set to True
for tool_result in self.tools_results: # type: ignore # Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable)
for tool_result in self.tools_results:
if tool_result.get("result_as_answer", False):
result = tool_result["result"]
crewai_event_bus.emit(
@@ -554,14 +550,14 @@ class Agent(BaseAgent):
)["output"]
def create_agent_executor(
self, tools: Optional[List[BaseTool]] = None, task=None
self, tools: Optional[list[BaseTool]] = None, task: Optional[Task] = None
) -> None:
"""Create an agent executor for the agent.
Returns:
An instance of the CrewAgentExecutor class.
"""
raw_tools: List[BaseTool] = tools or self.tools or []
raw_tools: list[BaseTool] = tools or self.tools or []
parsed_tools = parse_tools(raw_tools)
prompt = Prompts(
@@ -603,7 +599,7 @@ class Agent(BaseAgent):
callbacks=[TokenCalcHandler(self._token_process)],
)
def get_delegation_tools(self, agents: List[BaseAgent]):
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools()
return tools
@@ -613,7 +609,7 @@ class Agent(BaseAgent):
return [AddImageTool()]
def get_code_execution_tools(self):
def get_code_execution_tools(self) -> list[BaseTool]:
try:
from crewai_tools import CodeInterpreterTool # type: ignore
@@ -625,7 +621,9 @@ class Agent(BaseAgent):
"info", "Coding tools not available. Install crewai_tools. "
)
def get_output_converter(self, llm, text, model, instructions):
def get_output_converter(
self, llm: BaseLLM, text: str, model: str, instructions: str
) -> Converter:
return Converter(llm=llm, text=text, model=model, instructions=instructions)
def _training_handler(self, task_prompt: str) -> str:
@@ -654,7 +652,7 @@ class Agent(BaseAgent):
)
return task_prompt
def _render_text_description(self, tools: List[Any]) -> str:
def _render_text_description(self, tools: list[Any]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:
@@ -673,7 +671,7 @@ class Agent(BaseAgent):
return description
def _inject_date_to_task(self, task):
def _inject_date_to_task(self, task: str) -> str:
"""Inject the current date into the task description if inject_date is enabled."""
if self.inject_date:
from datetime import datetime
@@ -723,7 +721,7 @@ class Agent(BaseAgent):
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
)
def __repr__(self):
def __repr__(self) -> str:
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
@property
@@ -736,7 +734,7 @@ class Agent(BaseAgent):
"""
return self.security_config.fingerprint
def set_fingerprint(self, fingerprint: Fingerprint):
def set_fingerprint(self, fingerprint: Fingerprint) -> None:
self.security_config.fingerprint = fingerprint
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
@@ -796,8 +794,8 @@ class Agent(BaseAgent):
def kickoff(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: Optional[type[Any]] = None,
) -> LiteAgentOutput:
"""
Execute the agent with the given messages using a LiteAgent instance.
@@ -836,8 +834,8 @@ class Agent(BaseAgent):
async def kickoff_async(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: Optional[type[Any]] = None,
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages using a LiteAgent instance.

View File

@@ -3,26 +3,18 @@ import json
import re
import uuid
import warnings
from collections.abc import Callable, Mapping, Set
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from crewai.utilities.crew.models import CrewContext
from pydantic import (
UUID4,
BaseModel,
@@ -34,15 +26,36 @@ from pydantic import (
model_validator,
)
from pydantic_core import PydanticCustomError
from typing_extensions import Self
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
@@ -57,29 +70,9 @@ from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.crew.models import CrewContext
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
TraceCollectionListener,
)
from crewai.events.listeners.tracing.utils import (
is_tracing_enabled,
)
from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks,
@@ -116,9 +109,12 @@ class Crew(FlowTrackable, BaseModel):
planning: Plan the crew execution and add the plan to the crew.
chat_llm: The language model used for orchestrating chat interactions with the crew.
security_config: Security configuration for the crew, including fingerprinting.
Notes:
TODO: Improve the embedder type from dict[str, Any] to a more specific TypedDict or dataclass.
"""
__hash__ = object.__hash__ # type: ignore
__hash__ = object.__hash__
_execution_span: Any = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr()
@@ -130,7 +126,7 @@ class Crew(FlowTrackable, BaseModel):
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
_train: Optional[bool] = PrivateAttr(default=False)
_train_iteration: Optional[int] = PrivateAttr()
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_inputs: Optional[dict[str, Any]] = PrivateAttr(default=None)
_logging_color: str = PrivateAttr(
default="bold_purple",
)
@@ -140,8 +136,8 @@ class Crew(FlowTrackable, BaseModel):
name: Optional[str] = Field(default="crew")
cache: bool = Field(default=True)
tasks: List[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list)
tasks: list[Task] = Field(default_factory=list)
agents: list[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False)
memory: bool = Field(
@@ -164,7 +160,7 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="An Instance of the ExternalMemory to be used by the Crew",
)
embedder: Optional[dict] = Field(
embedder: Optional[dict[str, Any]] = Field(
default=None,
description="Configuration for the embedder to be used for the crew.",
)
@@ -172,16 +168,16 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
manager_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: Optional[BaseAgent] = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: Optional[str | InstanceOf[LLM] | Any] = Field(
description="Language model that will run the agent.", default=None
)
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
config: Optional[Json[dict[str, Any]] | dict[str, Any]] = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
share_crew: Optional[bool] = Field(default=False)
step_callback: Optional[Any] = Field(
@@ -192,13 +188,13 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Callback to be executed after each task for all agents execution.",
)
before_kickoff_callbacks: List[
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
before_kickoff_callbacks: list[
Callable[[Optional[dict[str, Any]]], Optional[dict[str, Any]]]
] = Field(
default_factory=list,
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
)
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
default_factory=list,
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
)
@@ -210,7 +206,7 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Path to the prompt json file to be used for the crew.",
)
output_log_file: Optional[Union[bool, str]] = Field(
output_log_file: Optional[bool | str] = Field(
default=None,
description="Path to the log file to be saved",
)
@@ -218,23 +214,23 @@ class Crew(FlowTrackable, BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
planning_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
default=None,
description="Language model that will run the AgentPlanner if planning is True.",
)
task_execution_output_json_files: Optional[List[str]] = Field(
task_execution_output_json_files: Optional[list[str]] = Field(
default=None,
description="List of file paths for task execution JSON files.",
)
execution_logs: List[Dict[str, Any]] = Field(
execution_logs: list[dict[str, Any]] = Field(
default=[],
description="List of execution logs for tasks",
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
knowledge_sources: Optional[list[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
)
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
chat_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
@@ -267,8 +263,8 @@ class Crew(FlowTrackable, BaseModel):
@field_validator("config", mode="before")
@classmethod
def check_config_type(
cls, v: Union[Json, Dict[str, Any]]
) -> Union[Json, Dict[str, Any]]:
cls, v: Json[dict[str, Any]] | dict[str, Any]
) -> Json[dict[str, Any]] | dict[str, Any]:
"""Validates that the config is a valid type.
Args:
v: The config to be validated.
@@ -277,10 +273,10 @@ class Crew(FlowTrackable, BaseModel):
"""
# TODO: Improve typing
return json.loads(v) if isinstance(v, Json) else v # type: ignore
return json.loads(v) if isinstance(v, Json) else v
@model_validator(mode="after")
def set_private_attrs(self) -> "Crew":
def set_private_attrs(self) -> Self:
"""Set private attributes."""
self._cache_handler = CacheHandler()
@@ -300,7 +296,7 @@ class Crew(FlowTrackable, BaseModel):
return self
def _initialize_default_memories(self):
def _initialize_default_memories(self) -> None:
self._long_term_memory = self._long_term_memory or LongTermMemory()
self._short_term_memory = self._short_term_memory or ShortTermMemory(
crew=self,
@@ -311,7 +307,7 @@ class Crew(FlowTrackable, BaseModel):
)
@model_validator(mode="after")
def create_crew_memory(self) -> "Crew":
def create_crew_memory(self) -> Self:
"""Initialize private memory attributes."""
self._external_memory = (
# External memory doesnt support a default value since it was designed to be managed entirely externally
@@ -328,7 +324,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def create_crew_knowledge(self) -> "Crew":
def create_crew_knowledge(self) -> Self:
"""Create the knowledge for the crew."""
if self.knowledge_sources:
try:
@@ -349,7 +345,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def check_manager_llm(self):
def check_manager_llm(self) -> Self:
"""Validates that the language model is set when using hierarchical process."""
if self.process == Process.hierarchical:
if not self.manager_llm and not self.manager_agent:
@@ -371,7 +367,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def check_config(self):
def check_config(self) -> Self:
"""Validates that the crew is properly configured with agents and tasks."""
if not self.config and not self.tasks and not self.agents:
raise PydanticCustomError(
@@ -392,20 +388,20 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_tasks(self):
def validate_tasks(self) -> Self:
if self.process == Process.sequential:
for task in self.tasks:
if task.agent is None:
raise PydanticCustomError(
"missing_agent_in_task",
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
f"Sequential process error: Agent is missing in the task with the following description: {task.description}",
{},
)
return self
@model_validator(mode="after")
def validate_end_with_at_most_one_async_task(self):
def validate_end_with_at_most_one_async_task(self) -> Self:
"""Validates that the crew ends with at most one asynchronous task."""
final_async_task_count = 0
@@ -426,7 +422,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_must_have_non_conditional_task(self) -> "Crew":
def validate_must_have_non_conditional_task(self) -> Self:
"""Ensure that a crew has at least one non-conditional task."""
if not self.tasks:
return self
@@ -442,7 +438,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_first_task(self) -> "Crew":
def validate_first_task(self) -> Self:
"""Ensure the first task is not a ConditionalTask."""
if self.tasks and isinstance(self.tasks[0], ConditionalTask):
raise PydanticCustomError(
@@ -453,19 +449,21 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_async_tasks_not_async(self) -> "Crew":
def validate_async_tasks_not_async(self) -> Self:
"""Ensure that ConditionalTask is not async."""
for task in self.tasks:
if task.async_execution and isinstance(task, ConditionalTask):
raise PydanticCustomError(
"invalid_async_conditional_task",
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
f"Conditional Task: {task.description} , cannot be executed asynchronously.",
{},
)
return self
@model_validator(mode="after")
def validate_async_task_cannot_include_sequential_async_tasks_in_context(self):
def validate_async_task_cannot_include_sequential_async_tasks_in_context(
self,
) -> Self:
"""
Validates that if a task is set to be executed asynchronously,
it cannot include other asynchronous tasks in its context unless
@@ -485,7 +483,7 @@ class Crew(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def validate_context_no_future_tasks(self):
def validate_context_no_future_tasks(self) -> Self:
"""Validates that a task's context does not include future tasks."""
task_indices = {id(task): i for i, task in enumerate(self.tasks)}
@@ -502,7 +500,7 @@ class Crew(FlowTrackable, BaseModel):
@property
def key(self) -> str:
source: List[str] = [agent.key for agent in self.agents] + [
source: list[str] = [agent.key for agent in self.agents] + [
task.key for task in self.tasks
]
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
@@ -517,7 +515,7 @@ class Crew(FlowTrackable, BaseModel):
"""
return self.security_config.fingerprint
def _setup_from_config(self):
def _setup_from_config(self) -> None:
assert self.config is not None, "Config should not be None."
"""Initializes agents and tasks from the provided config."""
@@ -530,7 +528,7 @@ class Crew(FlowTrackable, BaseModel):
self.agents = [Agent(**agent) for agent in self.config["agents"]]
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
def _create_task(self, task_config: Dict[str, Any]) -> Task:
def _create_task(self, task_config: dict[str, Any]) -> Task:
"""Creates a task instance from its configuration.
Args:
@@ -559,7 +557,7 @@ class Crew(FlowTrackable, BaseModel):
CrewTrainingHandler(filename).initialize_file()
def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None
self, n_iterations: int, filename: str, inputs: Optional[dict[str, Any]] = None
) -> None:
"""Trains the crew for a given number of iterations."""
inputs = inputs or {}
@@ -611,7 +609,7 @@ class Crew(FlowTrackable, BaseModel):
def kickoff(
self,
inputs: Optional[Dict[str, Any]] = None,
inputs: Optional[dict[str, Any]] = None,
) -> CrewOutput:
ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
@@ -643,7 +641,6 @@ class Crew(FlowTrackable, BaseModel):
for agent in self.agents:
agent.i18n = i18n
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
agent.crew = self # type: ignore[attr-defined]
agent.set_knowledge(crew_embedder=self.embedder)
# TODO: Create an AgentFunctionCalling protocol for future refactoring
@@ -682,9 +679,9 @@ class Crew(FlowTrackable, BaseModel):
finally:
detach(token)
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
"""Executes the Crew's workflow for each input in the list and aggregates results."""
results: List[CrewOutput] = []
results: list[CrewOutput] = []
# Initialize the parent crew's usage metrics
total_usage_metrics = UsageMetrics()
@@ -704,16 +701,18 @@ class Crew(FlowTrackable, BaseModel):
return results
async def kickoff_async(
self, inputs: Optional[Dict[str, Any]] = None
self, inputs: Optional[dict[str, Any]] = None
) -> CrewOutput:
"""Asynchronous kickoff method to start the crew execution."""
inputs = inputs or {}
return await asyncio.to_thread(self.kickoff, inputs)
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
async def kickoff_for_each_async(
self, inputs: list[dict[str, Any]]
) -> list[CrewOutput]:
crew_copies = [self.copy() for _ in inputs]
async def run_crew(crew, input_data):
async def run_crew(crew: Self, input_data: dict[str, Any]) -> CrewOutput:
return await crew.kickoff_async(inputs=input_data)
tasks = [
@@ -732,7 +731,7 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
return results
def _handle_crew_planning(self):
def _handle_crew_planning(self) -> None:
"""Handles the Crew planning."""
self._logger.log("info", "Planning the crew execution")
result = CrewPlanner(
@@ -748,7 +747,7 @@ class Crew(FlowTrackable, BaseModel):
output: TaskOutput,
task_index: int,
was_replayed: bool = False,
):
) -> None:
if self._inputs:
inputs = self._inputs
else:
@@ -780,7 +779,7 @@ class Crew(FlowTrackable, BaseModel):
self._create_manager_agent()
return self._execute_tasks(self.tasks)
def _create_manager_agent(self):
def _create_manager_agent(self) -> None:
i18n = I18N(prompt_file=self.prompt_file)
if self.manager_agent is not None:
self.manager_agent.allow_delegation = True
@@ -807,7 +806,7 @@ class Crew(FlowTrackable, BaseModel):
def _execute_tasks(
self,
tasks: List[Task],
tasks: list[Task],
start_index: Optional[int] = 0,
was_replayed: bool = False,
) -> CrewOutput:
@@ -821,8 +820,8 @@ class Crew(FlowTrackable, BaseModel):
CrewOutput: Final output of the crew
"""
task_outputs: List[TaskOutput] = []
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
task_outputs: list[TaskOutput] = []
futures: list[tuple[Task, Future[TaskOutput], int]] = []
last_sync_output: Optional[TaskOutput] = None
for task_index, task in enumerate(tasks):
@@ -847,7 +846,7 @@ class Crew(FlowTrackable, BaseModel):
tools_for_task = self._prepare_tools(
agent_to_use,
task,
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
cast(list[Tool] | list[BaseTool], tools_for_task),
)
self._log_task_start(task, agent_to_use.role)
@@ -867,7 +866,7 @@ class Crew(FlowTrackable, BaseModel):
future = task.execute_async(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=tools_for_task,
)
futures.append((task, future, task_index))
else:
@@ -879,7 +878,7 @@ class Crew(FlowTrackable, BaseModel):
task_output = task.execute_sync(
agent=agent_to_use,
context=context,
tools=cast(List[BaseTool], tools_for_task),
tools=tools_for_task,
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
@@ -893,8 +892,8 @@ class Crew(FlowTrackable, BaseModel):
def _handle_conditional_task(
self,
task: ConditionalTask,
task_outputs: List[TaskOutput],
futures: List[Tuple[Task, Future[TaskOutput], int]],
task_outputs: list[TaskOutput],
futures: list[tuple[Task, Future[TaskOutput], int]],
task_index: int,
was_replayed: bool,
) -> Optional[TaskOutput]:
@@ -909,7 +908,7 @@ class Crew(FlowTrackable, BaseModel):
f"Skipping conditional task: {task.description}",
color="yellow",
)
skipped_task_output = task.get_skipped_task_output()
skipped_task_output = cast(TaskOutput, task.get_skipped_task_output())
if not was_replayed:
self._store_execution_log(task, skipped_task_output, task_index)
@@ -917,8 +916,8 @@ class Crew(FlowTrackable, BaseModel):
return None
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
# Add delegation tools if agent allows delegation
if hasattr(agent, "allow_delegation") and getattr(
agent, "allow_delegation", False
@@ -948,7 +947,7 @@ class Crew(FlowTrackable, BaseModel):
tools = self._add_multimodal_tools(agent, tools)
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
if self.process == Process.hierarchical:
@@ -957,12 +956,12 @@ class Crew(FlowTrackable, BaseModel):
def _merge_tools(
self,
existing_tools: Union[List[Tool], List[BaseTool]],
new_tools: Union[List[Tool], List[BaseTool]],
) -> List[BaseTool]:
existing_tools: list[Tool] | list[BaseTool],
new_tools: list[Tool] | list[BaseTool],
) -> list[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
if not new_tools:
return cast(List[BaseTool], existing_tools)
return cast(list[BaseTool], existing_tools)
# Create mapping of tool names to new tools
new_tool_map = {tool.name: tool for tool in new_tools}
@@ -973,41 +972,41 @@ class Crew(FlowTrackable, BaseModel):
# Add all new tools
tools.extend(new_tools)
return cast(List[BaseTool], tools)
return tools
def _inject_delegation_tools(
self,
tools: Union[List[Tool], List[BaseTool]],
tools: list[Tool] | list[BaseTool],
task_agent: BaseAgent,
agents: List[BaseAgent],
) -> List[BaseTool]:
agents: list[BaseAgent],
) -> list[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents)
# Cast delegation_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], delegation_tools))
return cast(list[BaseTool], tools)
def _add_multimodal_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
return cast(list[BaseTool], tools)
def _add_code_execution_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
return cast(List[BaseTool], tools)
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
return cast(list[BaseTool], tools)
def _add_delegation_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
@@ -1015,17 +1014,17 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, task.agent, agents_for_delegation
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _log_task_start(self, task: Task, role: str = "None"):
def _log_task_start(self, task: Task, role: str = "None") -> None:
if self.output_log_file:
self._file_handler.log(
task_name=task.name, task=task.description, agent=role, status="started"
)
def _update_manager_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
self, task: Task, tools: list[Tool] | list[BaseTool]
) -> list[BaseTool]:
if self.manager_agent:
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
@@ -1033,9 +1032,9 @@ class Crew(FlowTrackable, BaseModel):
tools = self._inject_delegation_tools(
tools, self.manager_agent, self.agents
)
return cast(List[BaseTool], tools)
return cast(list[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
if not task.context:
return ""
@@ -1057,7 +1056,7 @@ class Crew(FlowTrackable, BaseModel):
output=output.raw,
)
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
if not task_outputs:
raise ValueError("No task outputs available to create crew output.")
@@ -1088,10 +1087,10 @@ class Crew(FlowTrackable, BaseModel):
def _process_async_tasks(
self,
futures: List[Tuple[Task, Future[TaskOutput], int]],
futures: list[tuple[Task, Future[TaskOutput], int]],
was_replayed: bool = False,
) -> List[TaskOutput]:
task_outputs: List[TaskOutput] = []
) -> list[TaskOutput]:
task_outputs: list[TaskOutput] = []
for future_task, future, task_index in futures:
task_output = future.result()
task_outputs.append(task_output)
@@ -1102,7 +1101,7 @@ class Crew(FlowTrackable, BaseModel):
return task_outputs
def _find_task_index(
self, task_id: str, stored_outputs: List[Any]
self, task_id: str, stored_outputs: list[Any]
) -> Optional[int]:
return next(
(
@@ -1114,7 +1113,7 @@ class Crew(FlowTrackable, BaseModel):
)
def replay(
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
self, task_id: str, inputs: Optional[dict[str, Any]] = None
) -> CrewOutput:
stored_outputs = self._task_output_handler.load()
if not stored_outputs:
@@ -1155,15 +1154,15 @@ class Crew(FlowTrackable, BaseModel):
return result
def query_knowledge(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> Union[List[Dict[str, Any]], None]:
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[dict[str, Any]] | None:
if self.knowledge:
return self.knowledge.query(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> Set[str]:
def fetch_inputs(self) -> set[str]:
"""
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
Scans each task's 'description' + 'expected_output', and each agent's
@@ -1172,7 +1171,7 @@ class Crew(FlowTrackable, BaseModel):
Returns a set of all discovered placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: Set[str] = set()
required_inputs: set[str] = set()
# Scan tasks for inputs
for task in self.tasks:
@@ -1188,7 +1187,18 @@ class Crew(FlowTrackable, BaseModel):
return required_inputs
def copy(self):
def copy(
self,
*,
include: Optional[
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
] = None,
exclude: Optional[
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
] = None,
update: Optional[dict[str, Any]] = None,
deep: bool = True,
) -> "Crew":
"""
Creates a deep copy of the Crew instance.
@@ -1219,7 +1229,7 @@ class Crew(FlowTrackable, BaseModel):
manager_agent = self.manager_agent.copy() if self.manager_agent else None
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None
task_mapping = {}
task_mapping: dict[str, Task] = {}
cloned_tasks = []
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
@@ -1274,16 +1284,10 @@ class Crew(FlowTrackable, BaseModel):
if not task.callback:
task.callback = self.task_callback
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolates the inputs in the tasks and agents."""
[
task.interpolate_inputs_and_add_conversation_history(
# type: ignore # "interpolate_inputs" of "Task" does not return a value (it only ever returns None)
inputs
)
for task in self.tasks
]
# type: ignore # "interpolate_inputs" of "Agent" does not return a value (it only ever returns None)
for task in self.tasks:
task.interpolate_inputs_and_add_conversation_history(inputs)
for agent in self.agents:
agent.interpolate_inputs(inputs)
@@ -1307,8 +1311,8 @@ class Crew(FlowTrackable, BaseModel):
def test(
self,
n_iterations: int,
eval_llm: Union[str, InstanceOf[BaseLLM]],
inputs: Optional[Dict[str, Any]] = None,
eval_llm: str | InstanceOf[BaseLLM],
inputs: Optional[dict[str, Any]] = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
try:
@@ -1349,7 +1353,7 @@ class Crew(FlowTrackable, BaseModel):
)
raise
def __repr__(self):
def __repr__(self) -> str:
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
def reset_memories(self, command_type: str) -> None:
@@ -1401,7 +1405,9 @@ class Crew(FlowTrackable, BaseModel):
if (system := config.get("system")) is not None:
name = config.get("name")
try:
reset_fn: Callable = cast(Callable, config.get("reset"))
reset_fn: Callable[..., None] = cast(
Callable[..., None], config.get("reset")
)
reset_fn(system)
self._logger.log(
"info",
@@ -1430,7 +1436,9 @@ class Crew(FlowTrackable, BaseModel):
raise RuntimeError(f"{name} memory system is not initialized")
try:
reset_fn: Callable = cast(Callable, config.get("reset"))
reset_fn: Callable[..., None] = cast(
Callable[..., None], config.get("reset")
)
reset_fn(system)
self._logger.log(
"info",
@@ -1441,18 +1449,18 @@ class Crew(FlowTrackable, BaseModel):
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
) from e
def _get_memory_systems(self):
def _get_memory_systems(self) -> dict[str, dict[str, Any]]:
"""Get all available memory systems with their configuration.
Returns:
Dict containing all memory systems with their reset functions and display names.
"""
def default_reset(memory):
return memory.reset()
def default_reset(memory: Any) -> None:
memory.reset()
def knowledge_reset(memory):
return self.reset_knowledge(memory)
def knowledge_reset(memory: Any) -> None:
self.reset_knowledge(memory)
# Get knowledge for agents
agent_knowledges = [
@@ -1506,12 +1514,12 @@ class Crew(FlowTrackable, BaseModel):
},
}
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
"""Reset crew and agent knowledge storage."""
for ks in knowledges:
ks.reset()
def _set_allow_crewai_trigger_context_for_first_task(self):
def _set_allow_crewai_trigger_context_for_first_task(self) -> None:
crewai_trigger_payload = self._inputs and self._inputs.get(
"crewai_trigger_payload"
)

View File

@@ -1,35 +1,25 @@
import asyncio
import inspect
import uuid
from collections.abc import Callable
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
get_args,
get_origin,
)
try:
from typing import Self
except ImportError:
from typing_extensions import Self
from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
PrivateAttr,
model_validator,
field_validator,
model_validator,
)
from typing_extensions import Self
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
@@ -39,12 +29,19 @@ from crewai.agents.parser import (
AgentFinish,
OutputParserException,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.flow.flow_trackable import FlowTrackable
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities import I18N
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.agent_utils import (
enforce_rpm_limit,
format_message_for_llm,
@@ -62,14 +59,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
)
from crewai.utilities.converter import generate_model_description
from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.events.types.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer
from crewai.utilities.token_counter_callback import TokenCalcHandler
@@ -86,11 +76,11 @@ class LiteAgentOutput(BaseModel):
description="Pydantic output of the agent", default=None
)
agent_role: str = Field(description="Role of the agent that produced this output")
usage_metrics: Optional[Dict[str, Any]] = Field(
usage_metrics: Optional[dict[str, Any]] = Field(
description="Token usage metrics for this execution", default=None
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert pydantic_output to a dictionary."""
if self.pydantic:
return self.pydantic.model_dump()
@@ -130,10 +120,10 @@ class LiteAgent(FlowTrackable, BaseModel):
role: str = Field(description="Role of the agent")
goal: str = Field(description="Goal of the agent")
backstory: str = Field(description="Backstory of the agent")
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
default=None, description="Language model that will run the agent"
)
tools: List[BaseTool] = Field(
tools: list[BaseTool] = Field(
default_factory=list, description="Tools at agent's disposal"
)
@@ -159,29 +149,27 @@ class LiteAgent(FlowTrackable, BaseModel):
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
# Output and Formatting Properties
response_format: Optional[Type[BaseModel]] = Field(
response_format: Optional[type[BaseModel]] = Field(
default=None, description="Pydantic model for structured output"
)
verbose: bool = Field(
default=False, description="Whether to print execution details"
)
callbacks: List[Callable] = Field(
callbacks: list[Callable[..., Any]] = Field(
default=[], description="Callbacks to be used for the agent"
)
# Guardrail Properties
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
guardrail: Optional[Callable[[LiteAgentOutput], tuple[bool, Any]] | str] = Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
# State and Results
tools_results: List[Dict[str, Any]] = Field(
tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
)
@@ -190,18 +178,20 @@ class LiteAgent(FlowTrackable, BaseModel):
default=None, description="Reference to the agent that created this LiteAgent"
)
# Private Attributes
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
_iterations: int = PrivateAttr(default=0)
_printer: Printer = PrivateAttr(default_factory=Printer)
_guardrail: Optional[Callable] = PrivateAttr(default=None)
_guardrail: Optional[Callable[[LiteAgentOutput], tuple[bool, Any]]] = PrivateAttr(
default=None
)
_guardrail_retry_count: int = PrivateAttr(default=0)
@model_validator(mode="after")
def setup_llm(self):
def setup_llm(self) -> Self:
"""Set up the LLM and other components after initialization."""
self.llm = create_llm(self.llm)
if not isinstance(self.llm, BaseLLM):
@@ -216,7 +206,7 @@ class LiteAgent(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def parse_tools(self):
def parse_tools(self) -> Self:
"""Parse the tools and convert them to CrewStructuredTool instances."""
self._parsed_tools = parse_tools(self.tools)
@@ -241,8 +231,8 @@ class LiteAgent(FlowTrackable, BaseModel):
@field_validator("guardrail", mode="before")
@classmethod
def validate_guardrail_function(
cls, v: Optional[Union[Callable, str]]
) -> Optional[Union[Callable, str]]:
cls, v: Optional[Callable[[Any], tuple[bool, Any]] | str]
) -> Optional[Callable[[Any], tuple[bool, Any]] | str]:
"""Validate that the guardrail function has the correct signature.
If v is a callable, validate that it has the correct signature.
@@ -267,7 +257,7 @@ class LiteAgent(FlowTrackable, BaseModel):
# Check return annotation if present
if sig.return_annotation is not sig.empty:
if sig.return_annotation == Tuple[bool, Any]:
if sig.return_annotation == tuple[bool, Any]:
return v
origin = get_origin(sig.return_annotation)
@@ -290,7 +280,7 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Return the original role for compatibility with tool interfaces."""
return self.role
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
"""
Execute the agent with the given messages.
@@ -338,7 +328,7 @@ class LiteAgent(FlowTrackable, BaseModel):
)
raise e
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
# Emit event for agent execution start
crewai_event_bus.emit(
self,
@@ -428,7 +418,7 @@ class LiteAgent(FlowTrackable, BaseModel):
return output
async def kickoff_async(
self, messages: Union[str, List[Dict[str, str]]]
self, messages: str | list[dict[str, str]]
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages.
@@ -475,8 +465,8 @@ class LiteAgent(FlowTrackable, BaseModel):
return base_prompt
def _format_messages(
self, messages: Union[str, List[Dict[str, str]]]
) -> List[Dict[str, str]]:
self, messages: str | list[dict[str, str]]
) -> list[dict[str, str]]:
"""Format messages for the LLM."""
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
@@ -582,7 +572,7 @@ class LiteAgent(FlowTrackable, BaseModel):
self._show_logs(formatted_answer)
return formatted_answer
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
"""Show logs for the agent's execution."""
crewai_event_bus.emit(
self,

View File

@@ -1,10 +1,11 @@
import os
from typing import Any, Dict, List
from collections import defaultdict
from typing import Any
from mem0 import Memory, MemoryClient
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.memory.storage.interface import Storage
from crewai.utilities.chromadb import sanitize_collection_name
MAX_AGENT_ID_LENGTH_MEM0 = 255
@@ -13,6 +14,7 @@ class Mem0Storage(Storage):
"""
Extends Storage to handle embedding and searching across entities using Mem0.
"""
def __init__(self, type, crew=None, config=None):
super().__init__()
@@ -86,21 +88,21 @@ class Mem0Storage(Storage):
return filter
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
user_id = self.config.get("user_id", "")
assistant_message = [{"role" : "assistant","content" : value}]
assistant_message = [{"role": "assistant", "content": value}]
base_metadata = {
"short_term": "short_term",
"long_term": "long_term",
"entities": "entity",
"external": "external"
"external": "external",
}
# Shared base params
params: dict[str, Any] = {
"metadata": {"type": base_metadata[self.memory_type], **metadata},
"infer": self.infer
"infer": self.infer,
}
# MemoryClient-specific overrides
@@ -121,13 +123,15 @@ class Mem0Storage(Storage):
self.memory.add(assistant_message, **params)
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
def search(
self, query: str, limit: int = 3, score_threshold: float = 0.35
) -> list[Any]:
params = {
"query": query,
"limit": limit,
"version": "v2",
"output_format": "v1.1"
}
"output_format": "v1.1",
}
if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id
@@ -148,10 +152,10 @@ class Mem0Storage(Storage):
# automatically when the crew is created.
params["filters"] = self._create_filter_for_search()
params['threshold'] = score_threshold
params["threshold"] = score_threshold
if isinstance(self.memory, Memory):
del params["metadata"], params["version"], params['output_format']
del params["metadata"], params["version"], params["output_format"]
if params.get("run_id"):
del params["run_id"]
@@ -160,10 +164,10 @@ class Mem0Storage(Storage):
# This makes it compatible for Contextual Memory to retrieve
for result in results["results"]:
result["context"] = result["memory"]
return [r for r in results["results"]]
def reset(self):
def reset(self) -> None:
if self.memory:
self.memory.reset()
@@ -180,4 +184,6 @@ class Mem0Storage(Storage):
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
return sanitize_collection_name(
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
)

View File

@@ -13,14 +13,14 @@ class CacheTools(BaseModel):
default_factory=CacheHandler,
)
def tool(self):
def tool(self) -> CrewStructuredTool:
return CrewStructuredTool.from_function(
func=self.hit_cache,
name=self.name,
description="Reads directly from the cache",
)
def hit_cache(self, key):
def hit_cache(self, key: str) -> str:
split = key.split("tool:")
tool = split[1].split("|input:")[0].strip()
tool_input = split[1].split("|input:")[1].strip()

View File

@@ -5,10 +5,10 @@ import time
from difflib import SequenceMatcher
from json import JSONDecodeError
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import json5
from json_repair import repair_json
from json_repair import repair_json # type: ignore[import-untyped]
from crewai.agents.tools_handler import ToolsHandler
from crewai.events.event_bus import crewai_event_bus
@@ -100,7 +100,7 @@ class ToolUsage:
self._max_parsing_attempts = 2
self._remember_format_after_usages = 4
def parse_tool_calling(self, tool_string: str):
def parse_tool_calling(self, tool_string: str) -> Any:
"""Parse the tool string and return the tool calling."""
return self._tool_calling(tool_string)
@@ -149,7 +149,21 @@ class ToolUsage:
tool: CrewStructuredTool,
calling: ToolCalling | InstructorToolCalling,
) -> str:
if self._check_tool_repeated_usage(calling=calling): # type: ignore # _check_tool_repeated_usage of "ToolUsage" does not return a value (it only ever returns None)
"""Use a tool with the given calling information.
Args:
tool_string: The string representation of the tool call.
tool: The CrewStructuredTool instance to use.
calling: The tool calling information.
Returns:
The formatted result of the tool usage.
Notes:
TODO: Investigate why BaseAgent/LiteAgent don't have fingerprint attribute.
Currently using hasattr check as a workaround (lines 179-180).
"""
if self._check_tool_repeated_usage(calling=calling):
try:
result = self._i18n.errors("task_repeated_usage").format(
tool_names=self.tools_names
@@ -159,8 +173,8 @@ class ToolUsage:
tool_name=tool.name,
attempts=self._run_attempts,
)
result = self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None)
return result # type: ignore # Fix the return type of this function
result = self._format_result(result=result)
return result
except Exception:
if self.task:
@@ -176,7 +190,7 @@ class ToolUsage:
"agent": self.agent,
}
if self.agent.fingerprint:
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
event_data.update(self.agent.fingerprint)
if self.task:
event_data["task_name"] = self.task.name or self.task.description
@@ -264,19 +278,19 @@ class ToolUsage:
self._printer.print(
content=f"\n\n{error_message}\n", color="red"
)
return error # type: ignore # No return value expected
return error
if self.task:
self.task.increment_tools_errors()
return self.use(calling=calling, tool_string=tool_string) # type: ignore # No return value expected
return self.use(calling=calling, tool_string=tool_string)
if self.tools_handler:
should_cache = True
if (
hasattr(available_tool, "cache_function")
and available_tool.cache_function # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
and available_tool.cache_function
):
should_cache = available_tool.cache_function( # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
should_cache = available_tool.cache_function(
calling.arguments, result
)
@@ -288,7 +302,7 @@ class ToolUsage:
tool_name=tool.name,
attempts=self._run_attempts,
)
result = self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None)
result = self._format_result(result=result)
data = {
"result": result,
"tool_name": tool.name,
@@ -505,7 +519,7 @@ class ToolUsage:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{e}\n", color="red")
return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling")
return ToolUsageErrorException(
f"{self._i18n.errors('tool_usage_error').format(error=e)}\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
)
return self._tool_calling(tool_string)
@@ -564,7 +578,7 @@ class ToolUsage:
# If all parsing attempts fail, raise an error
raise Exception(error_message)
def _emit_validate_input_error(self, final_error: str):
def _emit_validate_input_error(self, final_error: str) -> None:
tool_selection_data = {
"agent_key": getattr(self.agent, "key", None) if self.agent else None,
"agent_role": getattr(self.agent, "role", None) if self.agent else None,
@@ -617,7 +631,20 @@ class ToolUsage:
def _prepare_event_data(
self, tool: Any, tool_calling: ToolCalling | InstructorToolCalling
) -> dict:
) -> dict[str, Any]:
"""Prepare event data for tool usage events.
Args:
tool: The tool being used.
tool_calling: The tool calling information containing arguments.
Returns:
A dictionary containing event data for tool usage tracking.
Notes:
TODO: Create a better type representation for the return value,
possibly using TypedDict or a dataclass for stronger typing.
"""
event_data = {
"run_attempts": self._run_attempts,
"delegations": self.task.delegations if self.task else 0,

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from pydantic import BaseModel, Field
@@ -16,10 +16,10 @@ class ExecutionLog(BaseModel):
task_id: str
expected_output: Optional[str] = None
output: Dict[str, Any]
output: dict[str, Any]
timestamp: datetime = Field(default_factory=datetime.now)
task_index: int
inputs: Dict[str, Any] = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
was_replayed: bool = False
def __getitem__(self, key: str) -> Any:
@@ -33,7 +33,7 @@ class TaskOutputStorageHandler:
def __init__(self) -> None:
self.storage = KickoffTaskOutputsSQLiteStorage()
def update(self, task_index: int, log: Dict[str, Any]):
def update(self, task_index: int, log: dict[str, Any]):
saved_outputs = self.load()
if saved_outputs is None:
raise ValueError("Logs cannot be None")
@@ -56,16 +56,16 @@ class TaskOutputStorageHandler:
def add(
self,
task: Task,
output: Dict[str, Any],
output: dict[str, Any],
task_index: int,
inputs: Dict[str, Any] | None = None,
inputs: dict[str, Any] | None = None,
was_replayed: bool = False,
):
inputs = inputs or {}
self.storage.add(task, output, task_index, was_replayed, inputs)
def reset(self):
def reset(self) -> None:
self.storage.delete_all()
def load(self) -> Optional[List[Dict[str, Any]]]:
def load(self) -> Optional[list[dict[str, Any]]]:
return self.storage.load()