mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
fix: improve type annotations across codebase
This commit is contained in:
@@ -1,29 +1,42 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.security import Fingerprint
|
||||
from crewai.task import Task
|
||||
@@ -38,24 +51,6 @@ from crewai.utilities.agent_utils import (
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -101,10 +96,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
function_calling_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
@@ -151,7 +146,7 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||
)
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
embedder: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
@@ -171,7 +166,7 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
guardrail: Optional[Callable[[Any], tuple[bool, Any]] | str] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
@@ -180,13 +175,14 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v):
|
||||
@classmethod
|
||||
def validate_from_repository(cls, v: Any) -> Any:
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
def post_init_setup(self) -> Self:
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
self.llm = create_llm(self.llm)
|
||||
@@ -203,12 +199,12 @@ class Agent(BaseAgent):
|
||||
|
||||
return self
|
||||
|
||||
def _setup_agent_executor(self):
|
||||
def _setup_agent_executor(self) -> None:
|
||||
if not self.cache_handler:
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: Optional[dict[str, Any]] = None) -> None:
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
@@ -245,7 +241,7 @@ class Agent(BaseAgent):
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tools: Optional[list[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
@@ -492,7 +488,7 @@ class Agent(BaseAgent):
|
||||
# If there was any tool in self.tools_results that had result_as_answer
|
||||
# set to True, return the results of the last tool that had
|
||||
# result_as_answer set to True
|
||||
for tool_result in self.tools_results: # type: ignore # Item "None" of "list[Any] | None" has no attribute "__iter__" (not iterable)
|
||||
for tool_result in self.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
crewai_event_bus.emit(
|
||||
@@ -554,14 +550,14 @@ class Agent(BaseAgent):
|
||||
)["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
self, tools: Optional[list[BaseTool]] = None, task: Optional[Task] = None
|
||||
) -> None:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the CrewAgentExecutor class.
|
||||
"""
|
||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
prompt = Prompts(
|
||||
@@ -603,7 +599,7 @@ class Agent(BaseAgent):
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
@@ -613,7 +609,7 @@ class Agent(BaseAgent):
|
||||
|
||||
return [AddImageTool()]
|
||||
|
||||
def get_code_execution_tools(self):
|
||||
def get_code_execution_tools(self) -> list[BaseTool]:
|
||||
try:
|
||||
from crewai_tools import CodeInterpreterTool # type: ignore
|
||||
|
||||
@@ -625,7 +621,9 @@ class Agent(BaseAgent):
|
||||
"info", "Coding tools not available. Install crewai_tools. "
|
||||
)
|
||||
|
||||
def get_output_converter(self, llm, text, model, instructions):
|
||||
def get_output_converter(
|
||||
self, llm: BaseLLM, text: str, model: str, instructions: str
|
||||
) -> Converter:
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
def _training_handler(self, task_prompt: str) -> str:
|
||||
@@ -654,7 +652,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
@@ -673,7 +671,7 @@ class Agent(BaseAgent):
|
||||
|
||||
return description
|
||||
|
||||
def _inject_date_to_task(self, task):
|
||||
def _inject_date_to_task(self, task: str) -> str:
|
||||
"""Inject the current date into the task description if inject_date is enabled."""
|
||||
if self.inject_date:
|
||||
from datetime import datetime
|
||||
@@ -723,7 +721,7 @@ class Agent(BaseAgent):
|
||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
|
||||
@property
|
||||
@@ -736,7 +734,7 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def set_fingerprint(self, fingerprint: Fingerprint):
|
||||
def set_fingerprint(self, fingerprint: Fingerprint) -> None:
|
||||
self.security_config.fingerprint = fingerprint
|
||||
|
||||
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
|
||||
@@ -796,8 +794,8 @@ class Agent(BaseAgent):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: Optional[type[Any]] = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
@@ -836,8 +834,8 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: Optional[type[Any]] = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
|
||||
@@ -3,26 +3,18 @@ import json
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Mapping, Set
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -34,15 +26,36 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
@@ -57,29 +70,9 @@ from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
|
||||
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
)
|
||||
from crewai.utilities.formatter import (
|
||||
aggregate_raw_outputs_from_task_outputs,
|
||||
aggregate_raw_outputs_from_tasks,
|
||||
@@ -116,9 +109,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
planning: Plan the crew execution and add the plan to the crew.
|
||||
chat_llm: The language model used for orchestrating chat interactions with the crew.
|
||||
security_config: Security configuration for the crew, including fingerprinting.
|
||||
|
||||
Notes:
|
||||
TODO: Improve the embedder type from dict[str, Any] to a more specific TypedDict or dataclass.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
__hash__ = object.__hash__
|
||||
_execution_span: Any = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
@@ -130,7 +126,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr()
|
||||
_train: Optional[bool] = PrivateAttr(default=False)
|
||||
_train_iteration: Optional[int] = PrivateAttr()
|
||||
_inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
_inputs: Optional[dict[str, Any]] = PrivateAttr(default=None)
|
||||
_logging_color: str = PrivateAttr(
|
||||
default="bold_purple",
|
||||
)
|
||||
@@ -140,8 +136,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
name: Optional[str] = Field(default="crew")
|
||||
cache: bool = Field(default=True)
|
||||
tasks: List[Task] = Field(default_factory=list)
|
||||
agents: List[BaseAgent] = Field(default_factory=list)
|
||||
tasks: list[Task] = Field(default_factory=list)
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool = Field(
|
||||
@@ -164,7 +160,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="An Instance of the ExternalMemory to be used by the Crew",
|
||||
)
|
||||
embedder: Optional[dict] = Field(
|
||||
embedder: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
@@ -172,16 +168,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
manager_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: Optional[BaseAgent] = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
function_calling_llm: Optional[str | InstanceOf[LLM] | Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
|
||||
config: Optional[Json[dict[str, Any]] | dict[str, Any]] = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: Optional[bool] = Field(default=False)
|
||||
step_callback: Optional[Any] = Field(
|
||||
@@ -192,13 +188,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Callback to be executed after each task for all agents execution.",
|
||||
)
|
||||
before_kickoff_callbacks: List[
|
||||
Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]]
|
||||
before_kickoff_callbacks: list[
|
||||
Callable[[Optional[dict[str, Any]]], Optional[dict[str, Any]]]
|
||||
] = Field(
|
||||
default_factory=list,
|
||||
description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.",
|
||||
)
|
||||
after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
default_factory=list,
|
||||
description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.",
|
||||
)
|
||||
@@ -210,7 +206,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Path to the prompt json file to be used for the crew.",
|
||||
)
|
||||
output_log_file: Optional[Union[bool, str]] = Field(
|
||||
output_log_file: Optional[bool | str] = Field(
|
||||
default=None,
|
||||
description="Path to the log file to be saved",
|
||||
)
|
||||
@@ -218,23 +214,23 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
planning_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
default=None,
|
||||
description="Language model that will run the AgentPlanner if planning is True.",
|
||||
)
|
||||
task_execution_output_json_files: Optional[List[str]] = Field(
|
||||
task_execution_output_json_files: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description="List of file paths for task execution JSON files.",
|
||||
)
|
||||
execution_logs: List[Dict[str, Any]] = Field(
|
||||
execution_logs: list[dict[str, Any]] = Field(
|
||||
default=[],
|
||||
description="List of execution logs for tasks",
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: Optional[list[BaseKnowledgeSource]] = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||
)
|
||||
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
chat_llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -267,8 +263,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config_type(
|
||||
cls, v: Union[Json, Dict[str, Any]]
|
||||
) -> Union[Json, Dict[str, Any]]:
|
||||
cls, v: Json[dict[str, Any]] | dict[str, Any]
|
||||
) -> Json[dict[str, Any]] | dict[str, Any]:
|
||||
"""Validates that the config is a valid type.
|
||||
Args:
|
||||
v: The config to be validated.
|
||||
@@ -277,10 +273,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
|
||||
# TODO: Improve typing
|
||||
return json.loads(v) if isinstance(v, Json) else v # type: ignore
|
||||
return json.loads(v) if isinstance(v, Json) else v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> "Crew":
|
||||
def set_private_attrs(self) -> Self:
|
||||
"""Set private attributes."""
|
||||
|
||||
self._cache_handler = CacheHandler()
|
||||
@@ -300,7 +296,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
def _initialize_default_memories(self):
|
||||
def _initialize_default_memories(self) -> None:
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||
crew=self,
|
||||
@@ -311,7 +307,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_memory(self) -> "Crew":
|
||||
def create_crew_memory(self) -> Self:
|
||||
"""Initialize private memory attributes."""
|
||||
self._external_memory = (
|
||||
# External memory doesn’t support a default value since it was designed to be managed entirely externally
|
||||
@@ -328,7 +324,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_knowledge(self) -> "Crew":
|
||||
def create_crew_knowledge(self) -> Self:
|
||||
"""Create the knowledge for the crew."""
|
||||
if self.knowledge_sources:
|
||||
try:
|
||||
@@ -349,7 +345,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_manager_llm(self):
|
||||
def check_manager_llm(self) -> Self:
|
||||
"""Validates that the language model is set when using hierarchical process."""
|
||||
if self.process == Process.hierarchical:
|
||||
if not self.manager_llm and not self.manager_agent:
|
||||
@@ -371,7 +367,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_config(self):
|
||||
def check_config(self) -> Self:
|
||||
"""Validates that the crew is properly configured with agents and tasks."""
|
||||
if not self.config and not self.tasks and not self.agents:
|
||||
raise PydanticCustomError(
|
||||
@@ -392,20 +388,20 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tasks(self):
|
||||
def validate_tasks(self) -> Self:
|
||||
if self.process == Process.sequential:
|
||||
for task in self.tasks:
|
||||
if task.agent is None:
|
||||
raise PydanticCustomError(
|
||||
"missing_agent_in_task",
|
||||
f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
f"Sequential process error: Agent is missing in the task with the following description: {task.description}",
|
||||
{},
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_end_with_at_most_one_async_task(self):
|
||||
def validate_end_with_at_most_one_async_task(self) -> Self:
|
||||
"""Validates that the crew ends with at most one asynchronous task."""
|
||||
final_async_task_count = 0
|
||||
|
||||
@@ -426,7 +422,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_must_have_non_conditional_task(self) -> "Crew":
|
||||
def validate_must_have_non_conditional_task(self) -> Self:
|
||||
"""Ensure that a crew has at least one non-conditional task."""
|
||||
if not self.tasks:
|
||||
return self
|
||||
@@ -442,7 +438,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_first_task(self) -> "Crew":
|
||||
def validate_first_task(self) -> Self:
|
||||
"""Ensure the first task is not a ConditionalTask."""
|
||||
if self.tasks and isinstance(self.tasks[0], ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
@@ -453,19 +449,21 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_async_tasks_not_async(self) -> "Crew":
|
||||
def validate_async_tasks_not_async(self) -> Self:
|
||||
"""Ensure that ConditionalTask is not async."""
|
||||
for task in self.tasks:
|
||||
if task.async_execution and isinstance(task, ConditionalTask):
|
||||
raise PydanticCustomError(
|
||||
"invalid_async_conditional_task",
|
||||
f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString"
|
||||
f"Conditional Task: {task.description} , cannot be executed asynchronously.",
|
||||
{},
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_async_task_cannot_include_sequential_async_tasks_in_context(self):
|
||||
def validate_async_task_cannot_include_sequential_async_tasks_in_context(
|
||||
self,
|
||||
) -> Self:
|
||||
"""
|
||||
Validates that if a task is set to be executed asynchronously,
|
||||
it cannot include other asynchronous tasks in its context unless
|
||||
@@ -485,7 +483,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_context_no_future_tasks(self):
|
||||
def validate_context_no_future_tasks(self) -> Self:
|
||||
"""Validates that a task's context does not include future tasks."""
|
||||
task_indices = {id(task): i for i, task in enumerate(self.tasks)}
|
||||
|
||||
@@ -502,7 +500,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
source: List[str] = [agent.key for agent in self.agents] + [
|
||||
source: list[str] = [agent.key for agent in self.agents] + [
|
||||
task.key for task in self.tasks
|
||||
]
|
||||
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
||||
@@ -517,7 +515,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def _setup_from_config(self):
|
||||
def _setup_from_config(self) -> None:
|
||||
assert self.config is not None, "Config should not be None."
|
||||
|
||||
"""Initializes agents and tasks from the provided config."""
|
||||
@@ -530,7 +528,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.agents = [Agent(**agent) for agent in self.config["agents"]]
|
||||
self.tasks = [self._create_task(task) for task in self.config["tasks"]]
|
||||
|
||||
def _create_task(self, task_config: Dict[str, Any]) -> Task:
|
||||
def _create_task(self, task_config: dict[str, Any]) -> Task:
|
||||
"""Creates a task instance from its configuration.
|
||||
|
||||
Args:
|
||||
@@ -559,7 +557,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewTrainingHandler(filename).initialize_file()
|
||||
|
||||
def train(
|
||||
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = None
|
||||
self, n_iterations: int, filename: str, inputs: Optional[dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Trains the crew for a given number of iterations."""
|
||||
inputs = inputs or {}
|
||||
@@ -611,7 +609,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
) -> CrewOutput:
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
@@ -643,7 +641,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
for agent in self.agents:
|
||||
agent.i18n = i18n
|
||||
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
|
||||
agent.crew = self # type: ignore[attr-defined]
|
||||
agent.set_knowledge(crew_embedder=self.embedder)
|
||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||
@@ -682,9 +679,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
|
||||
def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input in the list and aggregates results."""
|
||||
results: List[CrewOutput] = []
|
||||
results: list[CrewOutput] = []
|
||||
|
||||
# Initialize the parent crew's usage metrics
|
||||
total_usage_metrics = UsageMetrics()
|
||||
@@ -704,16 +701,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return results
|
||||
|
||||
async def kickoff_async(
|
||||
self, inputs: Optional[Dict[str, Any]] = None
|
||||
self, inputs: Optional[dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution."""
|
||||
inputs = inputs or {}
|
||||
return await asyncio.to_thread(self.kickoff, inputs)
|
||||
|
||||
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]:
|
||||
async def kickoff_for_each_async(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput]:
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def run_crew(crew, input_data):
|
||||
async def run_crew(crew: Self, input_data: dict[str, Any]) -> CrewOutput:
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
|
||||
tasks = [
|
||||
@@ -732,7 +731,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
def _handle_crew_planning(self):
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
self._logger.log("info", "Planning the crew execution")
|
||||
result = CrewPlanner(
|
||||
@@ -748,7 +747,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
output: TaskOutput,
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
if self._inputs:
|
||||
inputs = self._inputs
|
||||
else:
|
||||
@@ -780,7 +779,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._create_manager_agent()
|
||||
return self._execute_tasks(self.tasks)
|
||||
|
||||
def _create_manager_agent(self):
|
||||
def _create_manager_agent(self) -> None:
|
||||
i18n = I18N(prompt_file=self.prompt_file)
|
||||
if self.manager_agent is not None:
|
||||
self.manager_agent.allow_delegation = True
|
||||
@@ -807,7 +806,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _execute_tasks(
|
||||
self,
|
||||
tasks: List[Task],
|
||||
tasks: list[Task],
|
||||
start_index: Optional[int] = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
@@ -821,8 +820,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
|
||||
task_outputs: List[TaskOutput] = []
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]] = []
|
||||
task_outputs: list[TaskOutput] = []
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]] = []
|
||||
last_sync_output: Optional[TaskOutput] = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
@@ -847,7 +846,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
|
||||
cast(list[Tool] | list[BaseTool], tools_for_task),
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
@@ -867,7 +866,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
tools=tools_for_task,
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -879,7 +878,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
tools=tools_for_task,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -893,8 +892,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _handle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: List[TaskOutput],
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
task_outputs: list[TaskOutput],
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Optional[TaskOutput]:
|
||||
@@ -909,7 +908,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
skipped_task_output = task.get_skipped_task_output()
|
||||
skipped_task_output = cast(TaskOutput, task.get_skipped_task_output())
|
||||
|
||||
if not was_replayed:
|
||||
self._store_execution_log(task, skipped_task_output, task_index)
|
||||
@@ -917,8 +916,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return None
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if hasattr(agent, "allow_delegation") and getattr(
|
||||
agent, "allow_delegation", False
|
||||
@@ -948,7 +947,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
|
||||
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
|
||||
return cast(List[BaseTool], tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
|
||||
if self.process == Process.hierarchical:
|
||||
@@ -957,12 +956,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _merge_tools(
|
||||
self,
|
||||
existing_tools: Union[List[Tool], List[BaseTool]],
|
||||
new_tools: Union[List[Tool], List[BaseTool]],
|
||||
) -> List[BaseTool]:
|
||||
existing_tools: list[Tool] | list[BaseTool],
|
||||
new_tools: list[Tool] | list[BaseTool],
|
||||
) -> list[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
if not new_tools:
|
||||
return cast(List[BaseTool], existing_tools)
|
||||
return cast(list[BaseTool], existing_tools)
|
||||
|
||||
# Create mapping of tool names to new tools
|
||||
new_tool_map = {tool.name: tool for tool in new_tools}
|
||||
@@ -973,41 +972,41 @@ class Crew(FlowTrackable, BaseModel):
|
||||
# Add all new tools
|
||||
tools.extend(new_tools)
|
||||
|
||||
return cast(List[BaseTool], tools)
|
||||
return tools
|
||||
|
||||
def _inject_delegation_tools(
|
||||
self,
|
||||
tools: Union[List[Tool], List[BaseTool]],
|
||||
tools: list[Tool] | list[BaseTool],
|
||||
task_agent: BaseAgent,
|
||||
agents: List[BaseAgent],
|
||||
) -> List[BaseTool]:
|
||||
agents: list[BaseAgent],
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
# Cast delegation_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], delegation_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_multimodal_tools(
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(agent, "get_multimodal_tools"):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
# Cast multimodal_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], multimodal_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_code_execution_tools(
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, agent: BaseAgent, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _add_delegation_tools(
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
if not tools:
|
||||
@@ -1015,17 +1014,17 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, task.agent, agents_for_delegation
|
||||
)
|
||||
return cast(List[BaseTool], tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
def _log_task_start(self, task: Task, role: str = "None") -> None:
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name, task=task.description, agent=role, status="started"
|
||||
)
|
||||
|
||||
def _update_manager_tools(
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
self, task: Task, tools: list[Tool] | list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
if self.manager_agent:
|
||||
if task.agent:
|
||||
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
||||
@@ -1033,9 +1032,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, self.manager_agent, self.agents
|
||||
)
|
||||
return cast(List[BaseTool], tools)
|
||||
return cast(list[BaseTool], tools)
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
|
||||
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
if not task.context:
|
||||
return ""
|
||||
|
||||
@@ -1057,7 +1056,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
output=output.raw,
|
||||
)
|
||||
|
||||
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
|
||||
def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput:
|
||||
if not task_outputs:
|
||||
raise ValueError("No task outputs available to create crew output.")
|
||||
|
||||
@@ -1088,10 +1087,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _process_async_tasks(
|
||||
self,
|
||||
futures: List[Tuple[Task, Future[TaskOutput], int]],
|
||||
futures: list[tuple[Task, Future[TaskOutput], int]],
|
||||
was_replayed: bool = False,
|
||||
) -> List[TaskOutput]:
|
||||
task_outputs: List[TaskOutput] = []
|
||||
) -> list[TaskOutput]:
|
||||
task_outputs: list[TaskOutput] = []
|
||||
for future_task, future, task_index in futures:
|
||||
task_output = future.result()
|
||||
task_outputs.append(task_output)
|
||||
@@ -1102,7 +1101,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return task_outputs
|
||||
|
||||
def _find_task_index(
|
||||
self, task_id: str, stored_outputs: List[Any]
|
||||
self, task_id: str, stored_outputs: list[Any]
|
||||
) -> Optional[int]:
|
||||
return next(
|
||||
(
|
||||
@@ -1114,7 +1113,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
def replay(
|
||||
self, task_id: str, inputs: Optional[Dict[str, Any]] = None
|
||||
self, task_id: str, inputs: Optional[dict[str, Any]] = None
|
||||
) -> CrewOutput:
|
||||
stored_outputs = self._task_output_handler.load()
|
||||
if not stored_outputs:
|
||||
@@ -1155,15 +1154,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return result
|
||||
|
||||
def query_knowledge(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> Union[List[Dict[str, Any]], None]:
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if self.knowledge:
|
||||
return self.knowledge.query(
|
||||
query, results_limit=results_limit, score_threshold=score_threshold
|
||||
)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> Set[str]:
|
||||
def fetch_inputs(self) -> set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
Scans each task's 'description' + 'expected_output', and each agent's
|
||||
@@ -1172,7 +1171,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: Set[str] = set()
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
for task in self.tasks:
|
||||
@@ -1188,7 +1187,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
return required_inputs
|
||||
|
||||
def copy(self):
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
include: Optional[
|
||||
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
|
||||
] = None,
|
||||
exclude: Optional[
|
||||
Set[int] | Set[str] | Mapping[int, Any] | Mapping[str, Any]
|
||||
] = None,
|
||||
update: Optional[dict[str, Any]] = None,
|
||||
deep: bool = True,
|
||||
) -> "Crew":
|
||||
"""
|
||||
Creates a deep copy of the Crew instance.
|
||||
|
||||
@@ -1219,7 +1229,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
manager_agent = self.manager_agent.copy() if self.manager_agent else None
|
||||
manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None
|
||||
|
||||
task_mapping = {}
|
||||
task_mapping: dict[str, Task] = {}
|
||||
|
||||
cloned_tasks = []
|
||||
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
|
||||
@@ -1274,16 +1284,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if not task.callback:
|
||||
task.callback = self.task_callback
|
||||
|
||||
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
def _interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Interpolates the inputs in the tasks and agents."""
|
||||
[
|
||||
task.interpolate_inputs_and_add_conversation_history(
|
||||
# type: ignore # "interpolate_inputs" of "Task" does not return a value (it only ever returns None)
|
||||
inputs
|
||||
)
|
||||
for task in self.tasks
|
||||
]
|
||||
# type: ignore # "interpolate_inputs" of "Agent" does not return a value (it only ever returns None)
|
||||
for task in self.tasks:
|
||||
task.interpolate_inputs_and_add_conversation_history(inputs)
|
||||
for agent in self.agents:
|
||||
agent.interpolate_inputs(inputs)
|
||||
|
||||
@@ -1307,8 +1311,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
eval_llm: Union[str, InstanceOf[BaseLLM]],
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
eval_llm: str | InstanceOf[BaseLLM],
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
try:
|
||||
@@ -1349,7 +1353,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
|
||||
|
||||
def reset_memories(self, command_type: str) -> None:
|
||||
@@ -1401,7 +1405,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if (system := config.get("system")) is not None:
|
||||
name = config.get("name")
|
||||
try:
|
||||
reset_fn: Callable = cast(Callable, config.get("reset"))
|
||||
reset_fn: Callable[..., None] = cast(
|
||||
Callable[..., None], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
@@ -1430,7 +1436,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise RuntimeError(f"{name} memory system is not initialized")
|
||||
|
||||
try:
|
||||
reset_fn: Callable = cast(Callable, config.get("reset"))
|
||||
reset_fn: Callable[..., None] = cast(
|
||||
Callable[..., None], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
@@ -1441,18 +1449,18 @@ class Crew(FlowTrackable, BaseModel):
|
||||
f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _get_memory_systems(self):
|
||||
def _get_memory_systems(self) -> dict[str, dict[str, Any]]:
|
||||
"""Get all available memory systems with their configuration.
|
||||
|
||||
Returns:
|
||||
Dict containing all memory systems with their reset functions and display names.
|
||||
"""
|
||||
|
||||
def default_reset(memory):
|
||||
return memory.reset()
|
||||
def default_reset(memory: Any) -> None:
|
||||
memory.reset()
|
||||
|
||||
def knowledge_reset(memory):
|
||||
return self.reset_knowledge(memory)
|
||||
def knowledge_reset(memory: Any) -> None:
|
||||
self.reset_knowledge(memory)
|
||||
|
||||
# Get knowledge for agents
|
||||
agent_knowledges = [
|
||||
@@ -1506,12 +1514,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
},
|
||||
}
|
||||
|
||||
def reset_knowledge(self, knowledges: List[Knowledge]) -> None:
|
||||
def reset_knowledge(self, knowledges: list[Knowledge]) -> None:
|
||||
"""Reset crew and agent knowledge storage."""
|
||||
for ks in knowledges:
|
||||
ks.reset()
|
||||
|
||||
def _set_allow_crewai_trigger_context_for_first_task(self):
|
||||
def _set_allow_crewai_trigger_context_for_first_task(self) -> None:
|
||||
crewai_trigger_payload = self._inputs and self._inputs.get(
|
||||
"crewai_trigger_payload"
|
||||
)
|
||||
|
||||
@@ -1,35 +1,25 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
@@ -39,12 +29,19 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
@@ -62,14 +59,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
@@ -86,11 +76,11 @@ class LiteAgentOutput(BaseModel):
|
||||
description="Pydantic output of the agent", default=None
|
||||
)
|
||||
agent_role: str = Field(description="Role of the agent that produced this output")
|
||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
||||
usage_metrics: Optional[dict[str, Any]] = Field(
|
||||
description="Token usage metrics for this execution", default=None
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert pydantic_output to a dictionary."""
|
||||
if self.pydantic:
|
||||
return self.pydantic.model_dump()
|
||||
@@ -130,10 +120,10 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
llm: Optional[str | InstanceOf[BaseLLM] | Any] = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
)
|
||||
tools: List[BaseTool] = Field(
|
||||
tools: list[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agent's disposal"
|
||||
)
|
||||
|
||||
@@ -159,29 +149,27 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
|
||||
# Output and Formatting Properties
|
||||
response_format: Optional[Type[BaseModel]] = Field(
|
||||
response_format: Optional[type[BaseModel]] = Field(
|
||||
default=None, description="Pydantic model for structured output"
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Whether to print execution details"
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
callbacks: list[Callable[..., Any]] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
|
||||
# Guardrail Properties
|
||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
|
||||
Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
guardrail: Optional[Callable[[LiteAgentOutput], tuple[bool, Any]] | str] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
|
||||
# State and Results
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
)
|
||||
|
||||
@@ -190,18 +178,20 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=None, description="Reference to the agent that created this LiteAgent"
|
||||
)
|
||||
# Private Attributes
|
||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_guardrail: Optional[Callable[[LiteAgentOutput], tuple[bool, Any]]] = PrivateAttr(
|
||||
default=None
|
||||
)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
def setup_llm(self) -> Self:
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
self.llm = create_llm(self.llm)
|
||||
if not isinstance(self.llm, BaseLLM):
|
||||
@@ -216,7 +206,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def parse_tools(self):
|
||||
def parse_tools(self) -> Self:
|
||||
"""Parse the tools and convert them to CrewStructuredTool instances."""
|
||||
self._parsed_tools = parse_tools(self.tools)
|
||||
|
||||
@@ -241,8 +231,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
cls, v: Optional[Union[Callable, str]]
|
||||
) -> Optional[Union[Callable, str]]:
|
||||
cls, v: Optional[Callable[[Any], tuple[bool, Any]] | str]
|
||||
) -> Optional[Callable[[Any], tuple[bool, Any]] | str]:
|
||||
"""Validate that the guardrail function has the correct signature.
|
||||
|
||||
If v is a callable, validate that it has the correct signature.
|
||||
@@ -267,7 +257,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Check return annotation if present
|
||||
if sig.return_annotation is not sig.empty:
|
||||
if sig.return_annotation == Tuple[bool, Any]:
|
||||
if sig.return_annotation == tuple[bool, Any]:
|
||||
return v
|
||||
|
||||
origin = get_origin(sig.return_annotation)
|
||||
@@ -290,7 +280,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
|
||||
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages.
|
||||
|
||||
@@ -338,7 +328,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise e
|
||||
|
||||
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
|
||||
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -428,7 +418,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return output
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages.
|
||||
@@ -475,8 +465,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return base_prompt
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages for the LLM."""
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -582,7 +572,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
|
||||
"""Show logs for the agent's execution."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.utilities.chromadb import sanitize_collection_name
|
||||
|
||||
MAX_AGENT_ID_LENGTH_MEM0 = 255
|
||||
|
||||
@@ -13,6 +14,7 @@ class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, type, crew=None, config=None):
|
||||
super().__init__()
|
||||
|
||||
@@ -86,21 +88,21 @@ class Mem0Storage(Storage):
|
||||
|
||||
return filter
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
user_id = self.config.get("user_id", "")
|
||||
assistant_message = [{"role" : "assistant","content" : value}]
|
||||
assistant_message = [{"role": "assistant", "content": value}]
|
||||
|
||||
base_metadata = {
|
||||
"short_term": "short_term",
|
||||
"long_term": "long_term",
|
||||
"entities": "entity",
|
||||
"external": "external"
|
||||
"external": "external",
|
||||
}
|
||||
|
||||
# Shared base params
|
||||
params: dict[str, Any] = {
|
||||
"metadata": {"type": base_metadata[self.memory_type], **metadata},
|
||||
"infer": self.infer
|
||||
"infer": self.infer,
|
||||
}
|
||||
|
||||
# MemoryClient-specific overrides
|
||||
@@ -121,13 +123,15 @@ class Mem0Storage(Storage):
|
||||
|
||||
self.memory.add(assistant_message, **params)
|
||||
|
||||
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
|
||||
def search(
|
||||
self, query: str, limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[Any]:
|
||||
params = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"version": "v2",
|
||||
"output_format": "v1.1"
|
||||
}
|
||||
"output_format": "v1.1",
|
||||
}
|
||||
|
||||
if user_id := self.config.get("user_id", ""):
|
||||
params["user_id"] = user_id
|
||||
@@ -148,10 +152,10 @@ class Mem0Storage(Storage):
|
||||
# automatically when the crew is created.
|
||||
|
||||
params["filters"] = self._create_filter_for_search()
|
||||
params['threshold'] = score_threshold
|
||||
params["threshold"] = score_threshold
|
||||
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["version"], params['output_format']
|
||||
del params["metadata"], params["version"], params["output_format"]
|
||||
if params.get("run_id"):
|
||||
del params["run_id"]
|
||||
|
||||
@@ -160,10 +164,10 @@ class Mem0Storage(Storage):
|
||||
# This makes it compatible for Contextual Memory to retrieve
|
||||
for result in results["results"]:
|
||||
result["context"] = result["memory"]
|
||||
|
||||
|
||||
return [r for r in results["results"]]
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
if self.memory:
|
||||
self.memory.reset()
|
||||
|
||||
@@ -180,4 +184,6 @@ class Mem0Storage(Storage):
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)
|
||||
return sanitize_collection_name(
|
||||
name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0
|
||||
)
|
||||
|
||||
@@ -13,14 +13,14 @@ class CacheTools(BaseModel):
|
||||
default_factory=CacheHandler,
|
||||
)
|
||||
|
||||
def tool(self):
|
||||
def tool(self) -> CrewStructuredTool:
|
||||
return CrewStructuredTool.from_function(
|
||||
func=self.hit_cache,
|
||||
name=self.name,
|
||||
description="Reads directly from the cache",
|
||||
)
|
||||
|
||||
def hit_cache(self, key):
|
||||
def hit_cache(self, key: str) -> str:
|
||||
split = key.split("tool:")
|
||||
tool = split[1].split("|input:")[0].strip()
|
||||
tool_input = split[1].split("|input:")[1].strip()
|
||||
|
||||
@@ -5,10 +5,10 @@ import time
|
||||
from difflib import SequenceMatcher
|
||||
from json import JSONDecodeError
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import json5
|
||||
from json_repair import repair_json
|
||||
from json_repair import repair_json # type: ignore[import-untyped]
|
||||
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -100,7 +100,7 @@ class ToolUsage:
|
||||
self._max_parsing_attempts = 2
|
||||
self._remember_format_after_usages = 4
|
||||
|
||||
def parse_tool_calling(self, tool_string: str):
|
||||
def parse_tool_calling(self, tool_string: str) -> Any:
|
||||
"""Parse the tool string and return the tool calling."""
|
||||
return self._tool_calling(tool_string)
|
||||
|
||||
@@ -149,7 +149,21 @@ class ToolUsage:
|
||||
tool: CrewStructuredTool,
|
||||
calling: ToolCalling | InstructorToolCalling,
|
||||
) -> str:
|
||||
if self._check_tool_repeated_usage(calling=calling): # type: ignore # _check_tool_repeated_usage of "ToolUsage" does not return a value (it only ever returns None)
|
||||
"""Use a tool with the given calling information.
|
||||
|
||||
Args:
|
||||
tool_string: The string representation of the tool call.
|
||||
tool: The CrewStructuredTool instance to use.
|
||||
calling: The tool calling information.
|
||||
|
||||
Returns:
|
||||
The formatted result of the tool usage.
|
||||
|
||||
Notes:
|
||||
TODO: Investigate why BaseAgent/LiteAgent don't have fingerprint attribute.
|
||||
Currently using hasattr check as a workaround (lines 179-180).
|
||||
"""
|
||||
if self._check_tool_repeated_usage(calling=calling):
|
||||
try:
|
||||
result = self._i18n.errors("task_repeated_usage").format(
|
||||
tool_names=self.tools_names
|
||||
@@ -159,8 +173,8 @@ class ToolUsage:
|
||||
tool_name=tool.name,
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
result = self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None)
|
||||
return result # type: ignore # Fix the return type of this function
|
||||
result = self._format_result(result=result)
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
if self.task:
|
||||
@@ -176,7 +190,7 @@ class ToolUsage:
|
||||
"agent": self.agent,
|
||||
}
|
||||
|
||||
if self.agent.fingerprint:
|
||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
event_data.update(self.agent.fingerprint)
|
||||
if self.task:
|
||||
event_data["task_name"] = self.task.name or self.task.description
|
||||
@@ -264,19 +278,19 @@ class ToolUsage:
|
||||
self._printer.print(
|
||||
content=f"\n\n{error_message}\n", color="red"
|
||||
)
|
||||
return error # type: ignore # No return value expected
|
||||
return error
|
||||
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
return self.use(calling=calling, tool_string=tool_string) # type: ignore # No return value expected
|
||||
return self.use(calling=calling, tool_string=tool_string)
|
||||
|
||||
if self.tools_handler:
|
||||
should_cache = True
|
||||
if (
|
||||
hasattr(available_tool, "cache_function")
|
||||
and available_tool.cache_function # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
|
||||
and available_tool.cache_function
|
||||
):
|
||||
should_cache = available_tool.cache_function( # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
|
||||
should_cache = available_tool.cache_function(
|
||||
calling.arguments, result
|
||||
)
|
||||
|
||||
@@ -288,7 +302,7 @@ class ToolUsage:
|
||||
tool_name=tool.name,
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
result = self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None)
|
||||
result = self._format_result(result=result)
|
||||
data = {
|
||||
"result": result,
|
||||
"tool_name": tool.name,
|
||||
@@ -505,7 +519,7 @@ class ToolUsage:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(content=f"\n\n{e}\n", color="red")
|
||||
return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling")
|
||||
return ToolUsageErrorException(
|
||||
f"{self._i18n.errors('tool_usage_error').format(error=e)}\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||
)
|
||||
return self._tool_calling(tool_string)
|
||||
@@ -564,7 +578,7 @@ class ToolUsage:
|
||||
# If all parsing attempts fail, raise an error
|
||||
raise Exception(error_message)
|
||||
|
||||
def _emit_validate_input_error(self, final_error: str):
|
||||
def _emit_validate_input_error(self, final_error: str) -> None:
|
||||
tool_selection_data = {
|
||||
"agent_key": getattr(self.agent, "key", None) if self.agent else None,
|
||||
"agent_role": getattr(self.agent, "role", None) if self.agent else None,
|
||||
@@ -617,7 +631,20 @@ class ToolUsage:
|
||||
|
||||
def _prepare_event_data(
|
||||
self, tool: Any, tool_calling: ToolCalling | InstructorToolCalling
|
||||
) -> dict:
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare event data for tool usage events.
|
||||
|
||||
Args:
|
||||
tool: The tool being used.
|
||||
tool_calling: The tool calling information containing arguments.
|
||||
|
||||
Returns:
|
||||
A dictionary containing event data for tool usage tracking.
|
||||
|
||||
Notes:
|
||||
TODO: Create a better type representation for the return value,
|
||||
possibly using TypedDict or a dataclass for stronger typing.
|
||||
"""
|
||||
event_data = {
|
||||
"run_attempts": self._run_attempts,
|
||||
"delegations": self.task.delegations if self.task else 0,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -16,10 +16,10 @@ class ExecutionLog(BaseModel):
|
||||
|
||||
task_id: str
|
||||
expected_output: Optional[str] = None
|
||||
output: Dict[str, Any]
|
||||
output: dict[str, Any]
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
task_index: int
|
||||
inputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
was_replayed: bool = False
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
@@ -33,7 +33,7 @@ class TaskOutputStorageHandler:
|
||||
def __init__(self) -> None:
|
||||
self.storage = KickoffTaskOutputsSQLiteStorage()
|
||||
|
||||
def update(self, task_index: int, log: Dict[str, Any]):
|
||||
def update(self, task_index: int, log: dict[str, Any]):
|
||||
saved_outputs = self.load()
|
||||
if saved_outputs is None:
|
||||
raise ValueError("Logs cannot be None")
|
||||
@@ -56,16 +56,16 @@ class TaskOutputStorageHandler:
|
||||
def add(
|
||||
self,
|
||||
task: Task,
|
||||
output: Dict[str, Any],
|
||||
output: dict[str, Any],
|
||||
task_index: int,
|
||||
inputs: Dict[str, Any] | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
was_replayed: bool = False,
|
||||
):
|
||||
inputs = inputs or {}
|
||||
self.storage.add(task, output, task_index, was_replayed, inputs)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.storage.delete_all()
|
||||
|
||||
def load(self) -> Optional[List[Dict[str, Any]]]:
|
||||
def load(self) -> Optional[list[dict[str, Any]]]:
|
||||
return self.storage.load()
|
||||
|
||||
Reference in New Issue
Block a user