From ad1ea46bbb959b6f6842d2de72c62d3418beeea8 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 12 May 2025 13:30:50 +0000 Subject: [PATCH] Apply automatic linting fixes to src directory Co-Authored-By: Joe Moura --- src/crewai/agent.py | 153 ++++--- .../agent_adapters/base_agent_adapter.py | 12 +- .../agent_adapters/base_converter_adapter.py | 5 +- .../agent_adapters/base_tool_adapter.py | 16 +- .../langgraph/langgraph_adapter.py | 39 +- .../langgraph/langgraph_tool_adapter.py | 18 +- .../langgraph/structured_output_converter.py | 14 +- .../openai_agents/openai_adapter.py | 77 ++-- .../openai_agent_tool_adapter.py | 32 +- .../structured_output_converter.py | 20 +- src/crewai/agents/agent_builder/base_agent.py | 105 +++-- .../base_agent_executor_mixin.py | 19 +- .../utilities/base_output_converter.py | 8 +- src/crewai/agents/cache/cache_handler.py | 8 +- src/crewai/agents/crew_agent_executor.py | 125 +++--- src/crewai/agents/parser.py | 41 +- src/crewai/agents/tools_handler.py | 13 +- src/crewai/cli/add_crew_to_flow.py | 12 +- src/crewai/cli/authentication/main.py | 19 +- src/crewai/cli/authentication/token.py | 2 +- src/crewai/cli/authentication/utils.py | 31 +- src/crewai/cli/cli.py | 88 ++-- src/crewai/cli/command.py | 12 +- src/crewai/cli/config.py | 15 +- src/crewai/cli/constants.py | 14 +- src/crewai/cli/create_crew.py | 14 +- src/crewai/cli/create_flow.py | 6 +- src/crewai/cli/crew_chat.py | 108 +++-- src/crewai/cli/deploy/main.py | 111 ++--- src/crewai/cli/evaluate_crew.py | 7 +- src/crewai/cli/git.py | 19 +- src/crewai/cli/install_crew.py | 6 +- src/crewai/cli/kickoff_flow.py | 4 +- src/crewai/cli/plot_flow.py | 4 +- src/crewai/cli/plus_api.py | 21 +- src/crewai/cli/provider.py | 50 +-- src/crewai/cli/replay_from_task.py | 4 +- src/crewai/cli/reset_memories_command.py | 22 +- src/crewai/cli/run_crew.py | 12 +- src/crewai/cli/tools/main.py | 45 +- src/crewai/cli/train_crew.py | 10 +- src/crewai/cli/update_crew.py | 12 +- src/crewai/cli/utils.py | 98 ++-- src/crewai/cli/version.py | 2 +- src/crewai/crew.py | 417 +++++++++--------- src/crewai/crews/crew_output.py | 27 +- src/crewai/flow/flow.py | 251 ++++++----- src/crewai/flow/flow_trackable.py | 3 +- src/crewai/flow/flow_visualizer.py | 107 +++-- src/crewai/flow/html_template_handler.py | 23 +- src/crewai/flow/path_utils.py | 75 ++-- src/crewai/flow/persistence/base.py | 27 +- src/crewai/flow/persistence/decorators.py | 110 ++--- src/crewai/flow/persistence/sqlite.py | 25 +- src/crewai/flow/utils.py | 108 ++--- src/crewai/flow/visualization_utils.py | 64 ++- .../knowledge/embedder/base_embedder.py | 27 +- src/crewai/knowledge/embedder/fastembed.py | 42 +- src/crewai/knowledge/knowledge.py | 50 ++- src/crewai/knowledge/knowledge_config.py | 1 + .../source/base_file_knowledge_source.py | 48 +- .../knowledge/source/base_knowledge_source.py | 26 +- .../knowledge/source/crew_docling_source.py | 51 ++- .../knowledge/source/csv_knowledge_source.py | 10 +- .../source/excel_knowledge_source.py | 52 +-- .../knowledge/source/json_knowledge_source.py | 15 +- .../knowledge/source/pdf_knowledge_source.py | 13 +- .../source/string_knowledge_source.py | 12 +- .../source/text_file_knowledge_source.py | 12 +- .../storage/base_knowledge_storage.py | 13 +- .../knowledge/storage/knowledge_storage.py | 62 +-- src/crewai/knowledge/utils/knowledge_utils.py | 4 +- src/crewai/lite_agent.py | 114 ++--- src/crewai/llm.py | 249 ++++++----- src/crewai/llms/base_llm.py | 25 +- src/crewai/llms/third_party/ai_suite.py | 20 +- .../memory/contextual/contextual_memory.py | 39 +- src/crewai/memory/entity/entity_memory.py | 14 +- .../memory/entity/entity_memory_item.py | 2 +- src/crewai/memory/external/external_memory.py | 23 +- .../memory/external/external_memory_item.py | 8 +- .../memory/long_term/long_term_memory.py | 9 +- .../memory/long_term/long_term_memory_item.py | 8 +- src/crewai/memory/memory.py | 20 +- .../memory/short_term/short_term_memory.py | 21 +- .../short_term/short_term_memory_item.py | 8 +- src/crewai/memory/storage/base_rag_storage.py | 34 +- src/crewai/memory/storage/interface.py | 10 +- .../storage/kickoff_task_outputs_storage.py | 35 +- .../memory/storage/ltm_sqlite_storage.py | 25 +- src/crewai/memory/storage/mem0_storage.py | 39 +- src/crewai/memory/storage/rag_storage.py | 44 +- src/crewai/memory/user/user_memory.py | 20 +- src/crewai/memory/user/user_memory_item.py | 4 +- src/crewai/process.py | 4 +- src/crewai/project/annotations.py | 16 +- src/crewai/project/crew_base.py | 66 +-- src/crewai/security/fingerprint.py | 90 ++-- src/crewai/security/security_config.py | 65 ++- src/crewai/task.py | 218 ++++----- src/crewai/tasks/conditional_task.py | 12 +- src/crewai/tasks/guardrail_result.py | 22 +- src/crewai/tasks/llm_guardrail.py | 24 +- src/crewai/tasks/task_output.py | 29 +- src/crewai/telemetry/telemetry.py | 148 ++++--- .../tools/agent_tools/add_image_tool.py | 9 +- src/crewai/tools/agent_tools/agent_tools.py | 6 +- .../tools/agent_tools/ask_question_tool.py | 5 +- .../tools/agent_tools/base_agent_tools.py | 29 +- .../tools/agent_tools/delegate_work_tool.py | 7 +- src/crewai/tools/base_tool.py | 43 +- src/crewai/tools/structured_tool.py | 54 ++- src/crewai/tools/tool_calling.py | 12 +- src/crewai/tools/tool_usage.py | 134 +++--- src/crewai/types/crew_chat.py | 21 +- src/crewai/types/usage_metrics.py | 18 +- src/crewai/utilities/agent_utils.py | 123 +++--- src/crewai/utilities/chromadb.py | 13 +- src/crewai/utilities/config.py | 10 +- src/crewai/utilities/converter.py | 102 ++--- src/crewai/utilities/crew_json_encoder.py | 9 +- .../utilities/crew_pydantic_output_parser.py | 4 +- .../utilities/embedding_configurator.py | 35 +- src/crewai/utilities/errors.py | 5 +- .../evaluators/crew_evaluator_handler.py | 29 +- .../utilities/evaluators/task_evaluator.py | 41 +- src/crewai/utilities/events/agent_events.py | 36 +- .../utilities/events/base_event_listener.py | 3 +- src/crewai/utilities/events/base_events.py | 14 +- src/crewai/utilities/events/crew_events.py | 36 +- .../utilities/events/crewai_event_bus.py | 40 +- src/crewai/utilities/events/event_listener.py | 109 +++-- src/crewai/utilities/events/flow_events.py | 28 +- .../utilities/events/knowledge_events.py | 5 +- src/crewai/utilities/events/llm_events.py | 29 +- .../utilities/events/llm_guardrail_events.py | 14 +- src/crewai/utilities/events/task_events.py | 28 +- .../events/third_party/agentops_listener.py | 16 +- .../utilities/events/tool_usage_events.py | 29 +- .../events/utils/console_formatter.py | 178 ++++---- .../context_window_exceeding_exception.py | 4 +- src/crewai/utilities/file_handler.py | 56 +-- src/crewai/utilities/formatter.py | 10 +- src/crewai/utilities/i18n.py | 21 +- src/crewai/utilities/internal_instructor.py | 17 +- src/crewai/utilities/llm_utils.py | 98 ++-- src/crewai/utilities/logger.py | 4 +- src/crewai/utilities/parser.py | 14 +- src/crewai/utilities/paths.py | 7 +- src/crewai/utilities/planning_handler.py | 20 +- src/crewai/utilities/printer.py | 53 ++- src/crewai/utilities/prompts.py | 32 +- .../utilities/pydantic_schema_parser.py | 39 +- src/crewai/utilities/rpm_controller.py | 28 +- src/crewai/utilities/serialization.py | 26 +- src/crewai/utilities/string_utils.py | 21 +- .../utilities/task_output_storage_handler.py | 26 +- .../utilities/token_counter_callback.py | 16 +- src/crewai/utilities/tool_utils.py | 42 +- src/crewai/utilities/training_handler.py | 14 +- 160 files changed, 3218 insertions(+), 3197 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index dc637967f..f7bad9098 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -1,6 +1,7 @@ import shutil import subprocess -from typing import Any, Dict, List, Literal, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Literal from pydantic import Field, InstanceOf, PrivateAttr, model_validator @@ -67,40 +68,41 @@ class Agent(BaseAgent): step_callback: Callback to be executed after each step of the agent execution. knowledge_sources: Knowledge sources for the agent. embedder: Embedder configuration for the agent. + """ _times_executed: int = PrivateAttr(default=0) - max_execution_time: Optional[int] = Field( + max_execution_time: int | None = Field( default=None, description="Maximum execution time for an agent to execute a task", ) agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str") agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str") - step_callback: Optional[Any] = Field( + step_callback: Any | None = Field( default=None, description="Callback to be executed after each step of the agent execution.", ) - use_system_prompt: Optional[bool] = Field( + use_system_prompt: bool | None = Field( default=True, description="Use system prompt for the agent.", ) - llm: Union[str, InstanceOf[BaseLLM], Any] = Field( - description="Language model that will run the agent.", default=None + llm: str | InstanceOf[BaseLLM] | Any = Field( + description="Language model that will run the agent.", default=None, ) - function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( - description="Language model that will run the agent.", default=None + function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field( + description="Language model that will run the agent.", default=None, ) - system_template: Optional[str] = Field( - default=None, description="System format for the agent." + system_template: str | None = Field( + default=None, description="System format for the agent.", ) - prompt_template: Optional[str] = Field( - default=None, description="Prompt format for the agent." + prompt_template: str | None = Field( + default=None, description="Prompt format for the agent.", ) - response_template: Optional[str] = Field( - default=None, description="Response format for the agent." + response_template: str | None = Field( + default=None, description="Response format for the agent.", ) - allow_code_execution: Optional[bool] = Field( - default=False, description="Enable code execution for the agent." + allow_code_execution: bool | None = Field( + default=False, description="Enable code execution for the agent.", ) respect_context_window: bool = Field( default=True, @@ -118,19 +120,19 @@ class Agent(BaseAgent): default="safe", description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).", ) - embedder: Optional[Dict[str, Any]] = Field( + embedder: dict[str, Any] | None = Field( default=None, description="Embedder configuration for the agent.", ) - agent_knowledge_context: Optional[str] = Field( + agent_knowledge_context: str | None = Field( default=None, description="Knowledge context for the agent.", ) - crew_knowledge_context: Optional[str] = Field( + crew_knowledge_context: str | None = Field( default=None, description="Knowledge context for the crew.", ) - knowledge_search_query: Optional[str] = Field( + knowledge_search_query: str | None = Field( default=None, description="Knowledge search query for the agent dynamically generated by the agent.", ) @@ -141,7 +143,7 @@ class Agent(BaseAgent): self.llm = create_llm(self.llm) if self.function_calling_llm and not isinstance( - self.function_calling_llm, BaseLLM + self.function_calling_llm, BaseLLM, ): self.function_calling_llm = create_llm(self.function_calling_llm) @@ -153,12 +155,12 @@ class Agent(BaseAgent): return self - def _setup_agent_executor(self): + def _setup_agent_executor(self) -> None: if not self.cache_handler: self.cache_handler = CacheHandler() self.set_cache_handler(self.cache_handler) - def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): + def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None: try: if self.embedder is None and crew_embedder: self.embedder = crew_embedder @@ -174,7 +176,8 @@ class Agent(BaseAgent): storage=self.knowledge_storage or None, ) except (TypeError, ValueError) as e: - raise ValueError(f"Invalid Knowledge Configuration: {str(e)}") + msg = f"Invalid Knowledge Configuration: {e!s}" + raise ValueError(msg) def _is_any_available_memory(self) -> bool: """Check if any memory is available.""" @@ -196,8 +199,8 @@ class Agent(BaseAgent): def execute_task( self, task: Task, - context: Optional[str] = None, - tools: Optional[List[BaseTool]] = None, + context: str | None = None, + tools: list[BaseTool] | None = None, ) -> str: """Execute a task with the agent. @@ -213,6 +216,7 @@ class Agent(BaseAgent): TimeoutError: If execution exceeds the maximum execution time. ValueError: If the max execution time is not a positive integer. RuntimeError: If the agent execution fails for other reasons. + """ if self.tools_handler: self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling") @@ -228,18 +232,18 @@ class Agent(BaseAgent): # schema = json.dumps(task.output_json, indent=2) schema = generate_model_description(task.output_json) task_prompt += "\n" + self.i18n.slice( - "formatted_task_instructions" + "formatted_task_instructions", ).format(output_format=schema) elif task.output_pydantic: schema = generate_model_description(task.output_pydantic) task_prompt += "\n" + self.i18n.slice( - "formatted_task_instructions" + "formatted_task_instructions", ).format(output_format=schema) if context: task_prompt = self.i18n.slice("task_with_context").format( - task=task_prompt, context=context + task=task_prompt, context=context, ) if self._is_any_available_memory(): @@ -267,25 +271,25 @@ class Agent(BaseAgent): ) try: self.knowledge_search_query = self._get_knowledge_search_query( - task_prompt + task_prompt, ) if self.knowledge_search_query: agent_knowledge_snippets = self.knowledge.query( - [self.knowledge_search_query], **knowledge_config + [self.knowledge_search_query], **knowledge_config, ) if agent_knowledge_snippets: self.agent_knowledge_context = extract_knowledge_context( - agent_knowledge_snippets + agent_knowledge_snippets, ) if self.agent_knowledge_context: task_prompt += self.agent_knowledge_context if self.crew: knowledge_snippets = self.crew.query_knowledge( - [self.knowledge_search_query], **knowledge_config + [self.knowledge_search_query], **knowledge_config, ) if knowledge_snippets: self.crew_knowledge_context = extract_knowledge_context( - knowledge_snippets + knowledge_snippets, ) if self.crew_knowledge_context: task_prompt += self.crew_knowledge_context @@ -342,11 +346,12 @@ class Agent(BaseAgent): not isinstance(self.max_execution_time, int) or self.max_execution_time <= 0 ): + msg = "Max Execution time must be a positive integer greater than zero" raise ValueError( - "Max Execution time must be a positive integer greater than zero" + msg, ) result = self._execute_with_timeout( - task_prompt, task, self.max_execution_time + task_prompt, task, self.max_execution_time, ) else: result = self._execute_without_timeout(task_prompt, task) @@ -361,7 +366,7 @@ class Agent(BaseAgent): error=str(e), ), ) - raise e + raise except Exception as e: if e.__class__.__module__.startswith("litellm"): # Do not retry on litellm errors @@ -373,7 +378,7 @@ class Agent(BaseAgent): error=str(e), ), ) - raise e + raise self._times_executed += 1 if self._times_executed > self.max_retry_limit: crewai_event_bus.emit( @@ -384,7 +389,7 @@ class Agent(BaseAgent): error=str(e), ), ) - raise e + raise result = self.execute_task(task, context, tools) if self.max_rpm and self._rpm_controller: @@ -416,24 +421,27 @@ class Agent(BaseAgent): Raises: TimeoutError: If execution exceeds the timeout. RuntimeError: If execution fails for other reasons. + """ import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit( - self._execute_without_timeout, task_prompt=task_prompt, task=task + self._execute_without_timeout, task_prompt=task_prompt, task=task, ) try: return future.result(timeout=timeout) except concurrent.futures.TimeoutError: future.cancel() + msg = f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task." raise TimeoutError( - f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task." + msg, ) except Exception as e: future.cancel() - raise RuntimeError(f"Task execution failed: {str(e)}") + msg = f"Task execution failed: {e!s}" + raise RuntimeError(msg) def _execute_without_timeout(self, task_prompt: str, task: Task) -> str: """Execute a task without a timeout. @@ -444,6 +452,7 @@ class Agent(BaseAgent): Returns: The output of the agent. + """ return self.agent_executor.invoke( { @@ -451,18 +460,19 @@ class Agent(BaseAgent): "tool_names": self.agent_executor.tools_names, "tools": self.agent_executor.tools_description, "ask_for_human_input": task.human_input, - } + }, )["output"] def create_agent_executor( - self, tools: Optional[List[BaseTool]] = None, task=None + self, tools: list[BaseTool] | None = None, task=None, ) -> None: """Create an agent executor for the agent. Returns: An instance of the CrewAgentExecutor class. + """ - raw_tools: List[BaseTool] = tools or self.tools or [] + raw_tools: list[BaseTool] = tools or self.tools or [] parsed_tools = parse_tools(raw_tools) prompt = Prompts( @@ -479,7 +489,7 @@ class Agent(BaseAgent): if self.response_template: stop_words.append( - self.response_template.split("{{ .Response }}")[1].strip() + self.response_template.split("{{ .Response }}")[1].strip(), ) self.agent_executor = CrewAgentExecutor( @@ -504,10 +514,9 @@ class Agent(BaseAgent): callbacks=[TokenCalcHandler(self._token_process)], ) - def get_delegation_tools(self, agents: List[BaseAgent]): + def get_delegation_tools(self, agents: list[BaseAgent]): agent_tools = AgentTools(agents=agents) - tools = agent_tools.tools() - return tools + return agent_tools.tools() def get_multimodal_tools(self) -> Sequence[BaseTool]: from crewai.tools.agent_tools.add_image_tool import AddImageTool @@ -523,7 +532,7 @@ class Agent(BaseAgent): return [CodeInterpreterTool(unsafe_mode=unsafe_mode)] except ModuleNotFoundError: self._logger.log( - "info", "Coding tools not available. Install crewai_tools. " + "info", "Coding tools not available. Install crewai_tools. ", ) def get_output_converter(self, llm, text, model, instructions): @@ -555,7 +564,7 @@ class Agent(BaseAgent): ) return task_prompt - def _render_text_description(self, tools: List[Any]) -> str: + def _render_text_description(self, tools: list[Any]) -> str: """Render the tool name and description in plain text. Output will be in the format of: @@ -565,48 +574,48 @@ class Agent(BaseAgent): search: This tool is used for search calculator: This tool is used for math """ - description = "\n".join( + return "\n".join( [ f"Tool name: {tool.name}\nTool description:\n{tool.description}" for tool in tools - ] + ], ) - return description def _validate_docker_installation(self) -> None: """Check if Docker is installed and running.""" if not shutil.which("docker"): + msg = f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}" raise RuntimeError( - f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}" + msg, ) try: subprocess.run( ["docker", "info"], check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, ) except subprocess.CalledProcessError: + msg = f"Docker is not running. Please start Docker to use code execution with agent: {self.role}" raise RuntimeError( - f"Docker is not running. Please start Docker to use code execution with agent: {self.role}" + msg, ) - def __repr__(self): + def __repr__(self) -> str: return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})" @property def fingerprint(self) -> Fingerprint: - """ - Get the agent's fingerprint. + """Get the agent's fingerprint. Returns: Fingerprint: The agent's fingerprint + """ return self.security_config.fingerprint - def set_fingerprint(self, fingerprint: Fingerprint): + def set_fingerprint(self, fingerprint: Fingerprint) -> None: self.security_config.fingerprint = fingerprint def _get_knowledge_search_query(self, task_prompt: str) -> str | None: @@ -619,7 +628,7 @@ class Agent(BaseAgent): ), ) query = self.i18n.slice("knowledge_search_query").format( - task_prompt=task_prompt + task_prompt=task_prompt, ) rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt") if not isinstance(self.llm, BaseLLM): @@ -644,7 +653,7 @@ class Agent(BaseAgent): "content": rewriter_prompt, }, {"role": "user", "content": query}, - ] + ], ) crewai_event_bus.emit( self, @@ -666,11 +675,10 @@ class Agent(BaseAgent): def kickoff( self, - messages: Union[str, List[Dict[str, str]]], - response_format: Optional[Type[Any]] = None, + messages: str | list[dict[str, str]], + response_format: type[Any] | None = None, ) -> LiteAgentOutput: - """ - Execute the agent with the given messages using a LiteAgent instance. + """Execute the agent with the given messages using a LiteAgent instance. This method is useful when you want to use the Agent configuration but with the simpler and more direct execution flow of LiteAgent. @@ -683,6 +691,7 @@ class Agent(BaseAgent): Returns: LiteAgentOutput: The result of the agent execution. + """ lite_agent = LiteAgent( role=self.role, @@ -703,11 +712,10 @@ class Agent(BaseAgent): async def kickoff_async( self, - messages: Union[str, List[Dict[str, str]]], - response_format: Optional[Type[Any]] = None, + messages: str | list[dict[str, str]], + response_format: type[Any] | None = None, ) -> LiteAgentOutput: - """ - Execute the agent asynchronously with the given messages using a LiteAgent instance. + """Execute the agent asynchronously with the given messages using a LiteAgent instance. This is the async version of the kickoff method. @@ -719,6 +727,7 @@ class Agent(BaseAgent): Returns: LiteAgentOutput: The result of the agent execution. + """ lite_agent = LiteAgent( role=self.role, diff --git a/src/crewai/agents/agent_adapters/base_agent_adapter.py b/src/crewai/agents/agent_adapters/base_agent_adapter.py index 6b8a151d6..604829c51 100644 --- a/src/crewai/agents/agent_adapters/base_agent_adapter.py +++ b/src/crewai/agents/agent_adapters/base_agent_adapter.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import PrivateAttr @@ -16,27 +16,27 @@ class BaseAgentAdapter(BaseAgent, ABC): """ adapted_structured_output: bool = False - _agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None) + _agent_config: dict[str, Any] | None = PrivateAttr(default=None) model_config = {"arbitrary_types_allowed": True} - def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any): + def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any) -> None: super().__init__(adapted_agent=True, **kwargs) self._agent_config = agent_config @abstractmethod - def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None: + def configure_tools(self, tools: list[BaseTool] | None = None) -> None: """Configure and adapt tools for the specific agent implementation. Args: tools: Optional list of BaseTool instances to be configured + """ - pass def configure_structured_output(self, structured_output: Any) -> None: """Configure the structured output for the specific agent implementation. Args: structured_output: The structured output to be configured + """ - pass diff --git a/src/crewai/agents/agent_adapters/base_converter_adapter.py b/src/crewai/agents/agent_adapters/base_converter_adapter.py index 557e627f3..e194eae75 100644 --- a/src/crewai/agents/agent_adapters/base_converter_adapter.py +++ b/src/crewai/agents/agent_adapters/base_converter_adapter.py @@ -8,7 +8,7 @@ class BaseConverterAdapter(ABC): converter adapters must implement for converting structured output. """ - def __init__(self, agent_adapter): + def __init__(self, agent_adapter) -> None: self.agent_adapter = agent_adapter @abstractmethod @@ -16,14 +16,11 @@ class BaseConverterAdapter(ABC): """Configure agents to return structured output. Must support json and pydantic output. """ - pass @abstractmethod def enhance_system_prompt(self, base_prompt: str) -> str: """Enhance the system prompt with structured output instructions.""" - pass @abstractmethod def post_process_result(self, result: str) -> str: """Post-process the result to ensure it matches the expected format: string.""" - pass diff --git a/src/crewai/agents/agent_adapters/base_tool_adapter.py b/src/crewai/agents/agent_adapters/base_tool_adapter.py index f1ee438a8..e1cc68231 100644 --- a/src/crewai/agents/agent_adapters/base_tool_adapter.py +++ b/src/crewai/agents/agent_adapters/base_tool_adapter.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any from crewai.tools.base_tool import BaseTool @@ -12,23 +12,23 @@ class BaseToolAdapter(ABC): different frameworks and platforms. """ - original_tools: List[BaseTool] - converted_tools: List[Any] + original_tools: list[BaseTool] + converted_tools: list[Any] - def __init__(self, tools: Optional[List[BaseTool]] = None): + def __init__(self, tools: list[BaseTool] | None = None) -> None: self.original_tools = tools or [] self.converted_tools = [] @abstractmethod - def configure_tools(self, tools: List[BaseTool]) -> None: + def configure_tools(self, tools: list[BaseTool]) -> None: """Configure and convert tools for the specific implementation. Args: tools: List of BaseTool instances to be configured and converted - """ - pass - def tools(self) -> List[Any]: + """ + + def tools(self) -> list[Any]: """Return all converted tools.""" return self.converted_tools diff --git a/src/crewai/agents/agent_adapters/langgraph/langgraph_adapter.py b/src/crewai/agents/agent_adapters/langgraph/langgraph_adapter.py index ea2e373d2..d5afbebab 100644 --- a/src/crewai/agents/agent_adapters/langgraph/langgraph_adapter.py +++ b/src/crewai/agents/agent_adapters/langgraph/langgraph_adapter.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import Any from pydantic import Field, PrivateAttr @@ -52,16 +52,17 @@ class LangGraphAgentAdapter(BaseAgentAdapter): role: str, goal: str, backstory: str, - tools: Optional[List[BaseTool]] = None, + tools: list[BaseTool] | None = None, llm: Any = None, max_iterations: int = 10, - agent_config: Optional[Dict[str, Any]] = None, + agent_config: dict[str, Any] | None = None, **kwargs, - ): + ) -> None: """Initialize the LangGraph agent adapter.""" if not LANGGRAPH_AVAILABLE: + msg = "LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`" raise ImportError( - "LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`" + msg, ) super().__init__( role=role, @@ -82,7 +83,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter): try: self._memory = MemorySaver() - converted_tools: List[Any] = self._tool_adapter.tools() + converted_tools: list[Any] = self._tool_adapter.tools() if self._agent_config: self._graph = create_react_agent( model=self.llm, @@ -101,18 +102,18 @@ class LangGraphAgentAdapter(BaseAgentAdapter): except ImportError as e: self._logger.log( - "error", f"Failed to import LangGraph dependencies: {str(e)}" + "error", f"Failed to import LangGraph dependencies: {e!s}", ) raise except Exception as e: - self._logger.log("error", f"Error setting up LangGraph agent: {str(e)}") + self._logger.log("error", f"Error setting up LangGraph agent: {e!s}") raise def _build_system_prompt(self) -> str: """Build a system prompt for the LangGraph agent.""" base_prompt = f""" You are {self.role}. - + Your goal is: {self.goal} Your backstory: {self.backstory} @@ -124,8 +125,8 @@ class LangGraphAgentAdapter(BaseAgentAdapter): def execute_task( self, task: Any, - context: Optional[str] = None, - tools: Optional[List[BaseTool]] = None, + context: str | None = None, + tools: list[BaseTool] | None = None, ) -> str: """Execute a task using the LangGraph workflow.""" self.create_agent_executor(tools) @@ -137,7 +138,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter): if context: task_prompt = self.i18n.slice("task_with_context").format( - task=task_prompt, context=context + task=task_prompt, context=context, ) crewai_event_bus.emit( @@ -159,7 +160,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter): "messages": [ ("system", self._build_system_prompt()), ("user", task_prompt), - ] + ], }, config, ) @@ -180,14 +181,14 @@ class LangGraphAgentAdapter(BaseAgentAdapter): crewai_event_bus.emit( self, event=AgentExecutionCompletedEvent( - agent=self, task=task, output=final_answer + agent=self, task=task, output=final_answer, ), ) return final_answer except Exception as e: - self._logger.log("error", f"Error executing LangGraph task: {str(e)}") + self._logger.log("error", f"Error executing LangGraph task: {e!s}") crewai_event_bus.emit( self, event=AgentExecutionErrorEvent( @@ -198,11 +199,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter): ) raise - def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None: + def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None: """Configure the LangGraph agent for execution.""" self.configure_tools(tools) - def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None: + def configure_tools(self, tools: list[BaseTool] | None = None) -> None: """Configure tools for the LangGraph agent.""" if tools: all_tools = list(self.tools or []) + list(tools or []) @@ -210,13 +211,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter): available_tools = self._tool_adapter.tools() self._graph.tools = available_tools - def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]: + def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]: """Implement delegation tools support for LangGraph.""" agent_tools = AgentTools(agents=agents) return agent_tools.tools() def get_output_converter( - self, llm: Any, text: str, model: Any, instructions: str + self, llm: Any, text: str, model: Any, instructions: str, ) -> Any: """Convert output format if needed.""" return Converter(llm=llm, text=text, model=model, instructions=instructions) diff --git a/src/crewai/agents/agent_adapters/langgraph/langgraph_tool_adapter.py b/src/crewai/agents/agent_adapters/langgraph/langgraph_tool_adapter.py index 0bc31d201..d1fcfc166 100644 --- a/src/crewai/agents/agent_adapters/langgraph/langgraph_tool_adapter.py +++ b/src/crewai/agents/agent_adapters/langgraph/langgraph_tool_adapter.py @@ -1,29 +1,25 @@ import inspect -from typing import Any, List, Optional +from typing import Any from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter from crewai.tools.base_tool import BaseTool class LangGraphToolAdapter(BaseToolAdapter): - """Adapts CrewAI tools to LangGraph agent tool compatible format""" + """Adapts CrewAI tools to LangGraph agent tool compatible format.""" - def __init__(self, tools: Optional[List[BaseTool]] = None): + def __init__(self, tools: list[BaseTool] | None = None) -> None: self.original_tools = tools or [] self.converted_tools = [] - def configure_tools(self, tools: List[BaseTool]) -> None: - """ - Configure and convert CrewAI tools to LangGraph-compatible format. + def configure_tools(self, tools: list[BaseTool]) -> None: + """Configure and convert CrewAI tools to LangGraph-compatible format. LangGraph expects tools in langchain_core.tools format. """ from langchain_core.tools import BaseTool, StructuredTool converted_tools = [] - if self.original_tools: - all_tools = tools + self.original_tools - else: - all_tools = tools + all_tools = tools + self.original_tools if self.original_tools else tools for tool in all_tools: if isinstance(tool, BaseTool): converted_tools.append(tool) @@ -57,5 +53,5 @@ class LangGraphToolAdapter(BaseToolAdapter): self.converted_tools = converted_tools - def tools(self) -> List[Any]: + def tools(self) -> list[Any]: return self.converted_tools or [] diff --git a/src/crewai/agents/agent_adapters/langgraph/structured_output_converter.py b/src/crewai/agents/agent_adapters/langgraph/structured_output_converter.py index 79f6dcb15..b1560e074 100644 --- a/src/crewai/agents/agent_adapters/langgraph/structured_output_converter.py +++ b/src/crewai/agents/agent_adapters/langgraph/structured_output_converter.py @@ -5,10 +5,10 @@ from crewai.utilities.converter import generate_model_description class LangGraphConverterAdapter(BaseConverterAdapter): - """Adapter for handling structured output conversion in LangGraph agents""" + """Adapter for handling structured output conversion in LangGraph agents.""" - def __init__(self, agent_adapter): - """Initialize the converter adapter with a reference to the agent adapter""" + def __init__(self, agent_adapter) -> None: + """Initialize the converter adapter with a reference to the agent adapter.""" self.agent_adapter = agent_adapter self._output_format = None self._schema = None @@ -32,7 +32,7 @@ class LangGraphConverterAdapter(BaseConverterAdapter): self._system_prompt_appendix = self._generate_system_prompt_appendix() def _generate_system_prompt_appendix(self) -> str: - """Generate an appendix for the system prompt to enforce structured output""" + """Generate an appendix for the system prompt to enforce structured output.""" if not self._output_format or not self._schema: return "" @@ -41,19 +41,19 @@ Important: Your final answer MUST be provided in the following structured format {self._schema} -DO NOT include any markdown code blocks, backticks, or other formatting around your response. +DO NOT include any markdown code blocks, backticks, or other formatting around your response. The output should be raw JSON that exactly matches the specified schema. """ def enhance_system_prompt(self, original_prompt: str) -> str: - """Add structured output instructions to the system prompt if needed""" + """Add structured output instructions to the system prompt if needed.""" if not self._system_prompt_appendix: return original_prompt return f"{original_prompt}\n{self._system_prompt_appendix}" def post_process_result(self, result: str) -> str: - """Post-process the result to ensure it matches the expected format""" + """Post-process the result to ensure it matches the expected format.""" if not self._output_format: return result diff --git a/src/crewai/agents/agent_adapters/openai_agents/openai_adapter.py b/src/crewai/agents/agent_adapters/openai_agents/openai_adapter.py index ac368c1a3..43a621055 100644 --- a/src/crewai/agents/agent_adapters/openai_agents/openai_adapter.py +++ b/src/crewai/agents/agent_adapters/openai_agents/openai_adapter.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any from pydantic import Field, PrivateAttr @@ -29,13 +29,13 @@ except ImportError: class OpenAIAgentAdapter(BaseAgentAdapter): - """Adapter for OpenAI Assistants""" + """Adapter for OpenAI Assistants.""" model_config = {"arbitrary_types_allowed": True} _openai_agent: "OpenAIAgent" = PrivateAttr() _logger: Logger = PrivateAttr(default_factory=lambda: Logger()) - _active_thread: Optional[str] = PrivateAttr(default=None) + _active_thread: str | None = PrivateAttr(default=None) function_calling_llm: Any = Field(default=None) step_callback: Any = Field(default=None) _tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr() @@ -44,35 +44,35 @@ class OpenAIAgentAdapter(BaseAgentAdapter): def __init__( self, model: str = "gpt-4o-mini", - tools: Optional[List[BaseTool]] = None, - agent_config: Optional[dict] = None, + tools: list[BaseTool] | None = None, + agent_config: dict | None = None, **kwargs, - ): + ) -> None: if not OPENAI_AVAILABLE: + msg = "OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`" raise ImportError( - "OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`" + msg, ) - else: - role = kwargs.pop("role", None) - goal = kwargs.pop("goal", None) - backstory = kwargs.pop("backstory", None) - super().__init__( - role=role, - goal=goal, - backstory=backstory, - tools=tools, - agent_config=agent_config, - **kwargs, - ) - self._tool_adapter = OpenAIAgentToolAdapter(tools=tools) - self.llm = model - self._converter_adapter = OpenAIConverterAdapter(self) + role = kwargs.pop("role", None) + goal = kwargs.pop("goal", None) + backstory = kwargs.pop("backstory", None) + super().__init__( + role=role, + goal=goal, + backstory=backstory, + tools=tools, + agent_config=agent_config, + **kwargs, + ) + self._tool_adapter = OpenAIAgentToolAdapter(tools=tools) + self.llm = model + self._converter_adapter = OpenAIConverterAdapter(self) def _build_system_prompt(self) -> str: """Build a system prompt for the OpenAI agent.""" base_prompt = f""" You are {self.role}. - + Your goal is: {self.goal} Your backstory: {self.backstory} @@ -84,10 +84,10 @@ class OpenAIAgentAdapter(BaseAgentAdapter): def execute_task( self, task: Any, - context: Optional[str] = None, - tools: Optional[List[BaseTool]] = None, + context: str | None = None, + tools: list[BaseTool] | None = None, ) -> str: - """Execute a task using the OpenAI Assistant""" + """Execute a task using the OpenAI Assistant.""" self._converter_adapter.configure_structured_output(task) self.create_agent_executor(tools) @@ -98,7 +98,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter): task_prompt = task.prompt() if context: task_prompt = self.i18n.slice("task_with_context").format( - task=task_prompt, context=context + task=task_prompt, context=context, ) crewai_event_bus.emit( self, @@ -114,13 +114,13 @@ class OpenAIAgentAdapter(BaseAgentAdapter): crewai_event_bus.emit( self, event=AgentExecutionCompletedEvent( - agent=self, task=task, output=final_answer + agent=self, task=task, output=final_answer, ), ) return final_answer except Exception as e: - self._logger.log("error", f"Error executing OpenAI task: {str(e)}") + self._logger.log("error", f"Error executing OpenAI task: {e!s}") crewai_event_bus.emit( self, event=AgentExecutionErrorEvent( @@ -131,9 +131,8 @@ class OpenAIAgentAdapter(BaseAgentAdapter): ) raise - def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None: - """ - Configure the OpenAI agent for execution. + def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None: + """Configure the OpenAI agent for execution. While OpenAI handles execution differently through Runner, we can use this method to set up tools and configurations. """ @@ -152,27 +151,27 @@ class OpenAIAgentAdapter(BaseAgentAdapter): self.agent_executor = Runner - def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None: - """Configure tools for the OpenAI Assistant""" + def configure_tools(self, tools: list[BaseTool] | None = None) -> None: + """Configure tools for the OpenAI Assistant.""" if tools: self._tool_adapter.configure_tools(tools) if self._tool_adapter.converted_tools: self._openai_agent.tools = self._tool_adapter.converted_tools def handle_execution_result(self, result: Any) -> str: - """Process OpenAI Assistant execution result converting any structured output to a string""" + """Process OpenAI Assistant execution result converting any structured output to a string.""" return self._converter_adapter.post_process_result(result.final_output) - def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]: - """Implement delegation tools support""" + def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]: + """Implement delegation tools support.""" agent_tools = AgentTools(agents=agents) - tools = agent_tools.tools() - return tools + return agent_tools.tools() def configure_structured_output(self, task) -> None: """Configure the structured output for the specific agent implementation. Args: structured_output: The structured output to be configured + """ self._converter_adapter.configure_structured_output(task) diff --git a/src/crewai/agents/agent_adapters/openai_agents/openai_agent_tool_adapter.py b/src/crewai/agents/agent_adapters/openai_agents/openai_agent_tool_adapter.py index 92eeb7b00..d83872ed4 100644 --- a/src/crewai/agents/agent_adapters/openai_agents/openai_agent_tool_adapter.py +++ b/src/crewai/agents/agent_adapters/openai_agents/openai_agent_tool_adapter.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, List, Optional +from typing import Any from agents import FunctionTool, Tool @@ -8,42 +8,36 @@ from crewai.tools import BaseTool class OpenAIAgentToolAdapter(BaseToolAdapter): - """Adapter for OpenAI Assistant tools""" + """Adapter for OpenAI Assistant tools.""" - def __init__(self, tools: Optional[List[BaseTool]] = None): + def __init__(self, tools: list[BaseTool] | None = None) -> None: self.original_tools = tools or [] - def configure_tools(self, tools: List[BaseTool]) -> None: - """Configure tools for the OpenAI Assistant""" - if self.original_tools: - all_tools = tools + self.original_tools - else: - all_tools = tools + def configure_tools(self, tools: list[BaseTool]) -> None: + """Configure tools for the OpenAI Assistant.""" + all_tools = tools + self.original_tools if self.original_tools else tools if all_tools: self.converted_tools = self._convert_tools_to_openai_format(all_tools) def _convert_tools_to_openai_format( - self, tools: Optional[List[BaseTool]] - ) -> List[Tool]: - """Convert CrewAI tools to OpenAI Assistant tool format""" + self, tools: list[BaseTool] | None, + ) -> list[Tool]: + """Convert CrewAI tools to OpenAI Assistant tool format.""" if not tools: return [] def sanitize_tool_name(name: str) -> str: - """Convert tool name to match OpenAI's required pattern""" + """Convert tool name to match OpenAI's required pattern.""" import re - sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower() - return sanitized + return re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower() def create_tool_wrapper(tool: BaseTool): - """Create a wrapper function that handles the OpenAI function tool interface""" + """Create a wrapper function that handles the OpenAI function tool interface.""" async def wrapper(context_wrapper: Any, arguments: Any) -> Any: # Get the parameter name from the schema - param_name = list( - tool.args_schema.model_json_schema()["properties"].keys() - )[0] + param_name = next(iter(tool.args_schema.model_json_schema()["properties"].keys())) # Handle different argument types if isinstance(arguments, dict): diff --git a/src/crewai/agents/agent_adapters/openai_agents/structured_output_converter.py b/src/crewai/agents/agent_adapters/openai_agents/structured_output_converter.py index 252374bf0..fa86de3b3 100644 --- a/src/crewai/agents/agent_adapters/openai_agents/structured_output_converter.py +++ b/src/crewai/agents/agent_adapters/openai_agents/structured_output_converter.py @@ -7,8 +7,7 @@ from crewai.utilities.i18n import I18N class OpenAIConverterAdapter(BaseConverterAdapter): - """ - Adapter for handling structured output conversion in OpenAI agents. + """Adapter for handling structured output conversion in OpenAI agents. This adapter enhances the OpenAI agent to handle structured output formats and post-processes the results when needed. @@ -17,21 +16,22 @@ class OpenAIConverterAdapter(BaseConverterAdapter): _output_format: The expected output format (json, pydantic, or None) _schema: The schema description for the expected output _output_model: The Pydantic model for the output + """ - def __init__(self, agent_adapter): - """Initialize the converter adapter with a reference to the agent adapter""" + def __init__(self, agent_adapter) -> None: + """Initialize the converter adapter with a reference to the agent adapter.""" self.agent_adapter = agent_adapter self._output_format = None self._schema = None self._output_model = None def configure_structured_output(self, task) -> None: - """ - Configure the structured output for OpenAI agent based on task requirements. + """Configure the structured output for OpenAI agent based on task requirements. Args: task: The task containing output format requirements + """ # Reset configuration self._output_format = None @@ -55,14 +55,14 @@ class OpenAIConverterAdapter(BaseConverterAdapter): self._output_model = task.output_pydantic def enhance_system_prompt(self, base_prompt: str) -> str: - """ - Enhance the base system prompt with structured output requirements if needed. + """Enhance the base system prompt with structured output requirements if needed. Args: base_prompt: The original system prompt Returns: Enhanced system prompt with output format instructions if needed + """ if not self._output_format: return base_prompt @@ -76,8 +76,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter): return f"{base_prompt}\n\n{output_schema}" def post_process_result(self, result: str) -> str: - """ - Post-process the result to ensure it matches the expected format. + """Post-process the result to ensure it matches the expected format. This method attempts to extract valid JSON from the result if necessary. @@ -86,6 +85,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter): Returns: Processed result conforming to the expected output format + """ if not self._output_format: return result diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index ba2596f63..1c82e0011 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -1,8 +1,9 @@ import uuid from abc import ABC, abstractmethod +from collections.abc import Callable from copy import copy as shallow_copy from hashlib import md5 -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, TypeVar from pydantic import ( UUID4, @@ -14,6 +15,7 @@ from pydantic import ( model_validator, ) from pydantic_core import PydanticCustomError +from typing_extensions import Self from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.cache.cache_handler import CacheHandler @@ -25,7 +27,6 @@ from crewai.security.security_config import SecurityConfig from crewai.tools.base_tool import BaseTool, Tool from crewai.utilities import I18N, Logger, RPMController from crewai.utilities.config import process_config -from crewai.utilities.converter import Converter from crewai.utilities.string_utils import interpolate_only T = TypeVar("T", bound="BaseAgent") @@ -77,30 +78,31 @@ class BaseAgent(ABC, BaseModel): Set the rpm controller for the agent. set_private_attrs() -> "BaseAgent": Set private attributes. + """ __hash__ = object.__hash__ # type: ignore _logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False)) - _rpm_controller: Optional[RPMController] = PrivateAttr(default=None) + _rpm_controller: RPMController | None = PrivateAttr(default=None) _request_within_rpm_limit: Any = PrivateAttr(default=None) - _original_role: Optional[str] = PrivateAttr(default=None) - _original_goal: Optional[str] = PrivateAttr(default=None) - _original_backstory: Optional[str] = PrivateAttr(default=None) + _original_role: str | None = PrivateAttr(default=None) + _original_goal: str | None = PrivateAttr(default=None) + _original_backstory: str | None = PrivateAttr(default=None) _token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) role: str = Field(description="Role of the agent") goal: str = Field(description="Objective of the agent") backstory: str = Field(description="Backstory of the agent") - config: Optional[Dict[str, Any]] = Field( - description="Configuration for the agent", default=None, exclude=True + config: dict[str, Any] | None = Field( + description="Configuration for the agent", default=None, exclude=True, ) cache: bool = Field( - default=True, description="Whether the agent should use a cache for tool usage." + default=True, description="Whether the agent should use a cache for tool usage.", ) verbose: bool = Field( - default=False, description="Verbose mode for the Agent Execution" + default=False, description="Verbose mode for the Agent Execution", ) - max_rpm: Optional[int] = Field( + max_rpm: int | None = Field( default=None, description="Maximum number of requests per minute for the agent execution to be respected.", ) @@ -108,41 +110,41 @@ class BaseAgent(ABC, BaseModel): default=False, description="Enable agent to delegate and ask questions among each other.", ) - tools: Optional[List[BaseTool]] = Field( - default_factory=list, description="Tools at agents' disposal" + tools: list[BaseTool] | None = Field( + default_factory=list, description="Tools at agents' disposal", ) max_iter: int = Field( - default=25, description="Maximum iterations for an agent to execute a task" + default=25, description="Maximum iterations for an agent to execute a task", ) agent_executor: InstanceOf = Field( - default=None, description="An instance of the CrewAgentExecutor class." + default=None, description="An instance of the CrewAgentExecutor class.", ) llm: Any = Field( - default=None, description="Language model that will run the agent." + default=None, description="Language model that will run the agent.", ) crew: Any = Field(default=None, description="Crew to which the agent belongs.") i18n: I18N = Field(default=I18N(), description="Internationalization settings.") - cache_handler: Optional[InstanceOf[CacheHandler]] = Field( - default=None, description="An instance of the CacheHandler class." + cache_handler: InstanceOf[CacheHandler] | None = Field( + default=None, description="An instance of the CacheHandler class.", ) tools_handler: InstanceOf[ToolsHandler] = Field( default_factory=ToolsHandler, description="An instance of the ToolsHandler class.", ) - tools_results: List[Dict[str, Any]] = Field( - default=[], description="Results of the tools used by the agent." + tools_results: list[dict[str, Any]] = Field( + default=[], description="Results of the tools used by the agent.", ) - max_tokens: Optional[int] = Field( - default=None, description="Maximum number of tokens for the agent's execution." + max_tokens: int | None = Field( + default=None, description="Maximum number of tokens for the agent's execution.", ) - knowledge: Optional[Knowledge] = Field( - default=None, description="Knowledge for the agent." + knowledge: Knowledge | None = Field( + default=None, description="Knowledge for the agent.", ) - knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field( + knowledge_sources: list[BaseKnowledgeSource] | None = Field( default=None, description="Knowledge sources for the agent.", ) - knowledge_storage: Optional[Any] = Field( + knowledge_storage: Any | None = Field( default=None, description="Custom knowledge storage for the agent.", ) @@ -150,13 +152,13 @@ class BaseAgent(ABC, BaseModel): default_factory=SecurityConfig, description="Security configuration for the agent, including fingerprinting.", ) - callbacks: List[Callable] = Field( - default=[], description="Callbacks to be used for the agent" + callbacks: list[Callable] = Field( + default=[], description="Callbacks to be used for the agent", ) adapted_agent: bool = Field( - default=False, description="Whether the agent is adapted" + default=False, description="Whether the agent is adapted", ) - knowledge_config: Optional[KnowledgeConfig] = Field( + knowledge_config: KnowledgeConfig | None = Field( default=None, description="Knowledge configuration for the agent such as limits and threshold", ) @@ -168,7 +170,7 @@ class BaseAgent(ABC, BaseModel): @field_validator("tools") @classmethod - def validate_tools(cls, tools: List[Any]) -> List[BaseTool]: + def validate_tools(cls, tools: list[Any]) -> list[BaseTool]: """Validate and process the tools provided to the agent. This method ensures that each tool is either an instance of BaseTool @@ -188,11 +190,14 @@ class BaseAgent(ABC, BaseModel): # Tool has the required attributes, create a Tool instance processed_tools.append(Tool.from_langchain(tool)) else: - raise ValueError( + msg = ( f"Invalid tool type: {type(tool)}. " "Tool must be an instance of BaseTool or " "an object with 'name', 'func', and 'description' attributes." ) + raise ValueError( + msg, + ) return processed_tools @model_validator(mode="after") @@ -200,15 +205,16 @@ class BaseAgent(ABC, BaseModel): # Validate required fields for field in ["role", "goal", "backstory"]: if getattr(self, field) is None: + msg = f"{field} must be provided either directly or through config" raise ValueError( - f"{field} must be provided either directly or through config" + msg, ) # Set private attributes self._logger = Logger(verbose=self.verbose) if self.max_rpm and not self._rpm_controller: self._rpm_controller = RPMController( - max_rpm=self.max_rpm, logger=self._logger + max_rpm=self.max_rpm, logger=self._logger, ) if not self._token_process: self._token_process = TokenProcess() @@ -221,10 +227,11 @@ class BaseAgent(ABC, BaseModel): @field_validator("id", mode="before") @classmethod - def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: + def _deny_user_set_id(cls, v: UUID4 | None) -> None: if v: + msg = "may_not_set_field" raise PydanticCustomError( - "may_not_set_field", "This field is not to be set by the user.", {} + msg, "This field is not to be set by the user.", {}, ) @model_validator(mode="after") @@ -233,7 +240,7 @@ class BaseAgent(ABC, BaseModel): self._logger = Logger(verbose=self.verbose) if self.max_rpm and not self._rpm_controller: self._rpm_controller = RPMController( - max_rpm=self.max_rpm, logger=self._logger + max_rpm=self.max_rpm, logger=self._logger, ) if not self._token_process: self._token_process = TokenProcess() @@ -252,8 +259,8 @@ class BaseAgent(ABC, BaseModel): def execute_task( self, task: Any, - context: Optional[str] = None, - tools: Optional[List[BaseTool]] = None, + context: str | None = None, + tools: list[BaseTool] | None = None, ) -> str: pass @@ -262,11 +269,10 @@ class BaseAgent(ABC, BaseModel): pass @abstractmethod - def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]: + def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]: """Set the task tools that init BaseAgenTools class.""" - pass - def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel" + def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel" """Create a deep copy of the Agent.""" exclude = { "id", @@ -309,7 +315,7 @@ class BaseAgent(ABC, BaseModel): copied_data = self.model_dump(exclude=exclude) copied_data = {k: v for k, v in copied_data.items() if v is not None} - copied_agent = type(self)( + return type(self)( **copied_data, llm=existing_llm, tools=self.tools, @@ -318,9 +324,8 @@ class BaseAgent(ABC, BaseModel): knowledge_storage=copied_knowledge_storage, ) - return copied_agent - def interpolate_inputs(self, inputs: Dict[str, Any]) -> None: + def interpolate_inputs(self, inputs: dict[str, Any]) -> None: """Interpolate inputs into the agent description and backstory.""" if self._original_role is None: self._original_role = self.role @@ -331,13 +336,13 @@ class BaseAgent(ABC, BaseModel): if inputs: self.role = interpolate_only( - input_string=self._original_role, inputs=inputs + input_string=self._original_role, inputs=inputs, ) self.goal = interpolate_only( - input_string=self._original_goal, inputs=inputs + input_string=self._original_goal, inputs=inputs, ) self.backstory = interpolate_only( - input_string=self._original_backstory, inputs=inputs + input_string=self._original_backstory, inputs=inputs, ) def set_cache_handler(self, cache_handler: CacheHandler) -> None: @@ -345,6 +350,7 @@ class BaseAgent(ABC, BaseModel): Args: cache_handler: An instance of the CacheHandler class. + """ self.tools_handler = ToolsHandler() if self.cache: @@ -357,10 +363,11 @@ class BaseAgent(ABC, BaseModel): Args: rpm_controller: An instance of the RPMController class. + """ if not self._rpm_controller: self._rpm_controller = rpm_controller self.create_agent_executor() - def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): + def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None: pass diff --git a/src/crewai/agents/agent_builder/base_agent_executor_mixin.py b/src/crewai/agents/agent_builder/base_agent_executor_mixin.py index c46c46844..1197270c6 100644 --- a/src/crewai/agents/agent_builder/base_agent_executor_mixin.py +++ b/src/crewai/agents/agent_builder/base_agent_executor_mixin.py @@ -1,3 +1,4 @@ +import contextlib import time from typing import TYPE_CHECKING @@ -43,8 +44,7 @@ class CrewAgentExecutorMixin: }, agent=self.agent.role, ) - except Exception as e: - print(f"Failed to add to short term memory: {e}") + except Exception: pass def _create_external_memory(self, output) -> None: @@ -56,7 +56,7 @@ class CrewAgentExecutorMixin: and hasattr(self.crew, "_external_memory") and self.crew._external_memory ): - try: + with contextlib.suppress(Exception): self.crew._external_memory.save( value=output.text, metadata={ @@ -64,9 +64,6 @@ class CrewAgentExecutorMixin: }, agent=self.agent.role, ) - except Exception as e: - print(f"Failed to add to external memory: {e}") - pass def _create_long_term_memory(self, output) -> None: """Create and save long-term and entity memory items based on evaluation.""" @@ -103,15 +100,13 @@ class CrewAgentExecutorMixin: type=entity.type, description=entity.description, relationships="\n".join( - [f"- {r}" for r in entity.relationships] + [f"- {r}" for r in entity.relationships], ), ) self.crew._entity_memory.save(entity_memory) - except AttributeError as e: - print(f"Missing attributes for long term memory: {e}") + except AttributeError: pass - except Exception as e: - print(f"Failed to add to long term memory: {e}") + except Exception: pass elif ( self.crew @@ -126,7 +121,7 @@ class CrewAgentExecutorMixin: def _ask_human_input(self, final_answer: str) -> str: """Prompt human input with mode-appropriate messaging.""" self._printer.print( - content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m" + content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m", ) # Training mode prompt (single iteration) diff --git a/src/crewai/agents/agent_builder/utilities/base_output_converter.py b/src/crewai/agents/agent_builder/utilities/base_output_converter.py index 938a6b29a..7dcbcdacf 100644 --- a/src/crewai/agents/agent_builder/utilities/base_output_converter.py +++ b/src/crewai/agents/agent_builder/utilities/base_output_converter.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field class OutputConverter(BaseModel, ABC): - """ - Abstract base class for converting task results into structured formats. + """Abstract base class for converting task results into structured formats. This class provides a framework for converting unstructured text into either Pydantic models or JSON, tailored for specific agent requirements. @@ -19,6 +18,7 @@ class OutputConverter(BaseModel, ABC): model (Any): The target model for structuring the output. instructions (str): Specific instructions for the conversion process. max_attempts (int): Maximum number of conversion attempts (default: 3). + """ text: str = Field(description="Text to be converted.") @@ -33,9 +33,7 @@ class OutputConverter(BaseModel, ABC): @abstractmethod def to_pydantic(self, current_attempt=1) -> BaseModel: """Convert text to pydantic.""" - pass @abstractmethod def to_json(self, current_attempt=1) -> dict: """Convert text to json.""" - pass diff --git a/src/crewai/agents/cache/cache_handler.py b/src/crewai/agents/cache/cache_handler.py index 09dd76f26..82577bb89 100644 --- a/src/crewai/agents/cache/cache_handler.py +++ b/src/crewai/agents/cache/cache_handler.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, PrivateAttr @@ -6,10 +6,10 @@ from pydantic import BaseModel, PrivateAttr class CacheHandler(BaseModel): """Callback handler for tool usage.""" - _cache: Dict[str, Any] = PrivateAttr(default_factory=dict) + _cache: dict[str, Any] = PrivateAttr(default_factory=dict) - def add(self, tool, input, output): + def add(self, tool, input, output) -> None: self._cache[f"{tool}-{input}"] = output - def read(self, tool, input) -> Optional[str]: + def read(self, tool, input) -> str | None: return self._cache.get(f"{tool}-{input}") diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index 914f837ee..ed66c67f6 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -1,6 +1,5 @@ -import json -import re -from typing import Any, Callable, Dict, List, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin @@ -10,8 +9,6 @@ from crewai.agents.parser import ( OutputParserException, ) from crewai.agents.tools_handler import ToolsHandler -from crewai.llm import BaseLLM -from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.tool_types import ToolResult from crewai.utilities import I18N, Printer @@ -34,6 +31,10 @@ from crewai.utilities.logger import Logger from crewai.utilities.tool_utils import execute_tool_and_check_finality from crewai.utilities.training_handler import CrewTrainingHandler +if TYPE_CHECKING: + from crewai.llm import BaseLLM + from crewai.tools.base_tool import BaseTool + class CrewAgentExecutor(CrewAgentExecutorMixin): _logger: Logger = Logger() @@ -46,18 +47,22 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): agent: BaseAgent, prompt: dict[str, str], max_iter: int, - tools: List[CrewStructuredTool], + tools: list[CrewStructuredTool], tools_names: str, - stop_words: List[str], + stop_words: list[str], tools_description: str, tools_handler: ToolsHandler, step_callback: Any = None, - original_tools: List[Any] = [], + original_tools: list[Any] | None = None, function_calling_llm: Any = None, respect_context_window: bool = False, - request_within_rpm_limit: Optional[Callable[[], bool]] = None, - callbacks: List[Any] = [], - ): + request_within_rpm_limit: Callable[[], bool] | None = None, + callbacks: list[Any] | None = None, + ) -> None: + if callbacks is None: + callbacks = [] + if original_tools is None: + original_tools = [] self._i18n: I18N = I18N() self.llm: BaseLLM = llm self.task = task @@ -79,10 +84,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self.respect_context_window = respect_context_window self.request_within_rpm_limit = request_within_rpm_limit self.ask_for_human_input = False - self.messages: List[Dict[str, str]] = [] + self.messages: list[dict[str, str]] = [] self.iterations = 0 self.log_error_after = 3 - self.tool_name_to_tool_map: Dict[str, Union[CrewStructuredTool, BaseTool]] = { + self.tool_name_to_tool_map: dict[str, CrewStructuredTool | BaseTool] = { tool.name: tool for tool in self.tools } existing_stop = self.llm.stop or [] @@ -90,11 +95,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): set( existing_stop + self.stop if isinstance(existing_stop, list) - else self.stop - ) + else self.stop, + ), ) - def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def invoke(self, inputs: dict[str, str]) -> dict[str, Any]: if "system" in self.prompt: system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs) user_prompt = self._format_prompt(self.prompt.get("user", ""), inputs) @@ -120,9 +125,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): handle_unknown_error(self._printer, e) if e.__class__.__module__.startswith("litellm"): # Do not retry on litellm errors - raise e - else: - raise e + raise + raise if self.ask_for_human_input: formatted_answer = self._handle_human_feedback(formatted_answer) @@ -133,8 +137,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): return {"output": formatted_answer.output} def _invoke_loop(self) -> AgentFinish: - """ - Main loop to invoke the agent's thought process until it reaches a conclusion + """Main loop to invoke the agent's thought process until it reaches a conclusion or the maximum number of iterations is reached. """ formatted_answer = None @@ -170,8 +173,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ): fingerprint_context = { "agent_fingerprint": str( - self.agent.security_config.fingerprint - ) + self.agent.security_config.fingerprint, + ), } tool_result = execute_tool_and_check_finality( @@ -187,7 +190,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): function_calling_llm=self.function_calling_llm, ) formatted_answer = self._handle_agent_action( - formatted_answer, tool_result + formatted_answer, tool_result, ) self._invoke_step_callback(formatted_answer) @@ -205,7 +208,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): except Exception as e: if e.__class__.__module__.startswith("litellm"): # Do not retry on litellm errors - raise e + raise if is_context_length_exceeded(e): handle_context_length( respect_context_window=self.respect_context_window, @@ -216,9 +219,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): i18n=self._i18n, ) continue - else: - handle_unknown_error(self._printer, e) - raise e + handle_unknown_error(self._printer, e) + raise finally: self.iterations += 1 @@ -231,8 +233,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): return formatted_answer def _handle_agent_action( - self, formatted_answer: AgentAction, tool_result: ToolResult - ) -> Union[AgentAction, AgentFinish]: + self, formatted_answer: AgentAction, tool_result: ToolResult, + ) -> AgentAction | AgentFinish: """Handle the AgentAction, execute tools, and process the results.""" # Special case for add_image_tool add_image_tool = self._i18n.tools("add_image") @@ -261,24 +263,26 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): """Append a message to the message list with the given role.""" self.messages.append(format_message_for_llm(text, role=role)) - def _show_start_logs(self): + def _show_start_logs(self) -> None: """Show logs for the start of agent execution.""" if self.agent is None: - raise ValueError("Agent cannot be None") + msg = "Agent cannot be None" + raise ValueError(msg) show_agent_logs( printer=self._printer, agent_role=self.agent.role, task_description=( - getattr(self.task, "description") if self.task else "Not Found" + self.task.description if self.task else "Not Found" ), verbose=self.agent.verbose or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)), ) - def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]): + def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None: """Show logs for the agent's execution.""" if self.agent is None: - raise ValueError("Agent cannot be None") + msg = "Agent cannot be None" + raise ValueError(msg) show_agent_logs( printer=self._printer, agent_role=self.agent.role, @@ -300,11 +304,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): summary = self.llm.call( [ format_message_for_llm( - self._i18n.slice("summarizer_system_message"), role="system" + self._i18n.slice("summarizer_system_message"), role="system", ), format_message_for_llm( self._i18n.slice("summarize_instruction").format( - group=group["content"] + group=group["content"], ), ), ], @@ -316,12 +320,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self.messages = [ format_message_for_llm( - self._i18n.slice("summary").format(merged_summary=merged_summary) - ) + self._i18n.slice("summary").format(merged_summary=merged_summary), + ), ] def _handle_crew_training_output( - self, result: AgentFinish, human_feedback: Optional[str] = None + self, result: AgentFinish, human_feedback: str | None = None, ) -> None: """Handle the process of saving training data.""" agent_id = str(self.agent.id) # type: ignore @@ -348,29 +352,27 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): "initial_output": result.output, "human_feedback": human_feedback, } + # Save improved output + elif train_iteration in agent_training_data: + agent_training_data[train_iteration]["improved_output"] = result.output else: - # Save improved output - if train_iteration in agent_training_data: - agent_training_data[train_iteration]["improved_output"] = result.output - else: - self._printer.print( - content=( - f"No existing training data for agent {agent_id} and iteration " - f"{train_iteration}. Cannot save improved output." - ), - color="red", - ) - return + self._printer.print( + content=( + f"No existing training data for agent {agent_id} and iteration " + f"{train_iteration}. Cannot save improved output." + ), + color="red", + ) + return # Update the training data and save training_data[agent_id] = agent_training_data training_handler.save(training_data) - def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str: + def _format_prompt(self, prompt: str, inputs: dict[str, str]) -> str: prompt = prompt.replace("{input}", inputs["input"]) prompt = prompt.replace("{tool_names}", inputs["tool_names"]) - prompt = prompt.replace("{tools}", inputs["tools"]) - return prompt + return prompt.replace("{tools}", inputs["tools"]) def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish: """Handle human feedback with different flows for training vs regular use. @@ -380,6 +382,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Returns: AgentFinish: The final answer after processing feedback + """ human_feedback = self._ask_human_input(formatted_answer.output) @@ -393,14 +396,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): return bool(self.crew and self.crew._train) def _handle_training_feedback( - self, initial_answer: AgentFinish, feedback: str + self, initial_answer: AgentFinish, feedback: str, ) -> AgentFinish: """Process feedback for training scenarios with single iteration.""" self._handle_crew_training_output(initial_answer, feedback) self.messages.append( format_message_for_llm( - self._i18n.slice("feedback_instructions").format(feedback=feedback) - ) + self._i18n.slice("feedback_instructions").format(feedback=feedback), + ), ) improved_answer = self._invoke_loop() self._handle_crew_training_output(improved_answer) @@ -408,7 +411,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): return improved_answer def _handle_regular_feedback( - self, current_answer: AgentFinish, initial_feedback: str + self, current_answer: AgentFinish, initial_feedback: str, ) -> AgentFinish: """Process feedback for regular use with potential multiple iterations.""" feedback = initial_feedback @@ -428,8 +431,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): """Process a single feedback iteration.""" self.messages.append( format_message_for_llm( - self._i18n.slice("feedback_instructions").format(feedback=feedback) - ) + self._i18n.slice("feedback_instructions").format(feedback=feedback), + ), ) return self._invoke_loop() diff --git a/src/crewai/agents/parser.py b/src/crewai/agents/parser.py index 58605b692..bf9b411c7 100644 --- a/src/crewai/agents/parser.py +++ b/src/crewai/agents/parser.py @@ -1,5 +1,5 @@ import re -from typing import Any, Optional, Union +from typing import Any from json_repair import repair_json @@ -18,7 +18,7 @@ class AgentAction: text: str result: str - def __init__(self, thought: str, tool: str, tool_input: str, text: str): + def __init__(self, thought: str, tool: str, tool_input: str, text: str) -> None: self.thought = thought self.tool = tool self.tool_input = tool_input @@ -30,7 +30,7 @@ class AgentFinish: output: str text: str - def __init__(self, thought: str, output: str, text: str): + def __init__(self, thought: str, output: str, text: str) -> None: self.thought = thought self.output = output self.text = text @@ -39,7 +39,7 @@ class AgentFinish: class OutputParserException(Exception): error: str - def __init__(self, error: str): + def __init__(self, error: str) -> None: self.error = error @@ -67,24 +67,24 @@ class CrewAgentParser: _i18n: I18N = I18N() agent: Any = None - def __init__(self, agent: Optional[Any] = None): + def __init__(self, agent: Any | None = None) -> None: self.agent = agent @staticmethod - def parse_text(text: str) -> Union[AgentAction, AgentFinish]: - """ - Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class. + def parse_text(text: str) -> AgentAction | AgentFinish: + """Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class. Args: text: The text to parse. Returns: Either an AgentAction or AgentFinish based on the parsed content. + """ parser = CrewAgentParser() return parser.parse(text) - def parse(self, text: str) -> Union[AgentAction, AgentFinish]: + def parse(self, text: str) -> AgentAction | AgentFinish: thought = self._extract_thought(text) includes_answer = FINAL_ANSWER_ACTION in text regex = ( @@ -102,7 +102,7 @@ class CrewAgentParser: final_answer = final_answer[:-3].rstrip() return AgentFinish(thought, final_answer, text) - elif action_match: + if action_match: action = action_match.group(1) clean_action = self._clean_action(action) @@ -114,21 +114,21 @@ class CrewAgentParser: return AgentAction(thought, clean_action, safe_tool_input, text) if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL): + msg = f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}" raise OutputParserException( - f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}", + msg, ) - elif not re.search( - r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL + if not re.search( + r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL, ): raise OutputParserException( MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE, ) - else: - format = self._i18n.slice("format_without_tools") - error = f"{format}" - raise OutputParserException( - error, - ) + format = self._i18n.slice("format_without_tools") + error = f"{format}" + raise OutputParserException( + error, + ) def _extract_thought(self, text: str) -> str: thought_index = text.find("\nAction") @@ -138,8 +138,7 @@ class CrewAgentParser: return "" thought = text[:thought_index].strip() # Remove any triple backticks from the thought string - thought = thought.replace("```", "").strip() - return thought + return thought.replace("```", "").strip() def _clean_action(self, text: str) -> str: """Clean action string by removing non-essential formatting characters.""" diff --git a/src/crewai/agents/tools_handler.py b/src/crewai/agents/tools_handler.py index fd4bec7ee..e44a02abf 100644 --- a/src/crewai/agents/tools_handler.py +++ b/src/crewai/agents/tools_handler.py @@ -1,7 +1,8 @@ -from typing import Any, Optional, Union +from typing import Any + +from crewai.tools.cache_tools.cache_tools import CacheTools +from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling -from ..tools.cache_tools.cache_tools import CacheTools -from ..tools.tool_calling import InstructorToolCalling, ToolCalling from .cache.cache_handler import CacheHandler @@ -9,16 +10,16 @@ class ToolsHandler: """Callback handler for tool usage.""" last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling") - cache: Optional[CacheHandler] + cache: CacheHandler | None - def __init__(self, cache: Optional[CacheHandler] = None): + def __init__(self, cache: CacheHandler | None = None) -> None: """Initialize the callback handler.""" self.cache = cache self.last_used_tool = {} # type: ignore # BUG?: same as above def on_tool_use( self, - calling: Union[ToolCalling, InstructorToolCalling], + calling: ToolCalling | InstructorToolCalling, output: str, should_cache: bool = True, ) -> Any: diff --git a/src/crewai/cli/add_crew_to_flow.py b/src/crewai/cli/add_crew_to_flow.py index ef693a22b..125ae3a2d 100644 --- a/src/crewai/cli/add_crew_to_flow.py +++ b/src/crewai/cli/add_crew_to_flow.py @@ -9,9 +9,9 @@ def add_crew_to_flow(crew_name: str) -> None: """Add a new crew to the current flow.""" # Check if pyproject.toml exists in the current directory if not Path("pyproject.toml").exists(): - print("This command must be run from the root of a flow project.") + msg = "This command must be run from the root of a flow project." raise click.ClickException( - "This command must be run from the root of a flow project." + msg, ) # Determine the flow folder based on the current directory @@ -19,8 +19,8 @@ def add_crew_to_flow(crew_name: str) -> None: crews_folder = flow_folder / "src" / flow_folder.name / "crews" if not crews_folder.exists(): - print("Crews folder does not exist in the current flow.") - raise click.ClickException("Crews folder does not exist in the current flow.") + msg = "Crews folder does not exist in the current flow." + raise click.ClickException(msg) # Create the crew within the flow's crews directory create_embedded_crew(crew_name, parent_folder=crews_folder) @@ -39,7 +39,7 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None: if crew_folder.exists(): if not click.confirm( - f"Crew {folder_name} already exists. Do you want to override it?" + f"Crew {folder_name} already exists. Do you want to override it?", ): click.secho("Operation cancelled.", fg="yellow") return @@ -66,5 +66,5 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None: copy_template(src_file, dst_file, crew_name, class_name, folder_name) click.secho( - f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True + f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True, ) diff --git a/src/crewai/cli/authentication/main.py b/src/crewai/cli/authentication/main.py index 5a335e1f2..a79f9e938 100644 --- a/src/crewai/cli/authentication/main.py +++ b/src/crewai/cli/authentication/main.py @@ -1,6 +1,6 @@ import time import webbrowser -from typing import Any, Dict +from typing import Any import requests from rich.console import Console @@ -17,38 +17,37 @@ class AuthenticationCommand: DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code" TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token" - def __init__(self): + def __init__(self) -> None: self.token_manager = TokenManager() def signup(self) -> None: - """Sign up to CrewAI+""" + """Sign up to CrewAI+.""" console.print("Signing Up to CrewAI+ \n", style="bold blue") device_code_data = self._get_device_code() self._display_auth_instructions(device_code_data) return self._poll_for_token(device_code_data) - def _get_device_code(self) -> Dict[str, Any]: + def _get_device_code(self) -> dict[str, Any]: """Get the device code to authenticate the user.""" - device_code_payload = { "client_id": AUTH0_CLIENT_ID, "scope": "openid", "audience": AUTH0_AUDIENCE, } response = requests.post( - url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20 + url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20, ) response.raise_for_status() return response.json() - def _display_auth_instructions(self, device_code_data: Dict[str, str]) -> None: + def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None: """Display the authentication instructions to the user.""" console.print("1. Navigate to: ", device_code_data["verification_uri_complete"]) console.print("2. Enter the following code: ", device_code_data["user_code"]) webbrowser.open(device_code_data["verification_uri_complete"]) - def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None: + def _poll_for_token(self, device_code_data: dict[str, Any]) -> None: """Poll the server for the token.""" token_payload = { "grant_type": "urn:ietf:params:oauth:grant-type:device_code", @@ -81,7 +80,7 @@ class AuthenticationCommand: ) console.print( - "\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n" + "\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n", ) return @@ -92,5 +91,5 @@ class AuthenticationCommand: attempts += 1 console.print( - "Timeout: Failed to get the token. Please try again.", style="bold red" + "Timeout: Failed to get the token. Please try again.", style="bold red", ) diff --git a/src/crewai/cli/authentication/token.py b/src/crewai/cli/authentication/token.py index 30a33b4ba..fcd481860 100644 --- a/src/crewai/cli/authentication/token.py +++ b/src/crewai/cli/authentication/token.py @@ -5,5 +5,5 @@ def get_auth_token() -> str: """Get the authentication token.""" access_token = TokenManager().get_token() if not access_token: - raise Exception() + raise Exception return access_token diff --git a/src/crewai/cli/authentication/utils.py b/src/crewai/cli/authentication/utils.py index 2f5fc183f..8a4c997fa 100644 --- a/src/crewai/cli/authentication/utils.py +++ b/src/crewai/cli/authentication/utils.py @@ -3,7 +3,6 @@ import os import sys from datetime import datetime, timedelta from pathlib import Path -from typing import Optional from auth0.authentication.token_verifier import ( AsymmetricSignatureVerifier, @@ -15,8 +14,7 @@ from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN def validate_token(id_token: str) -> None: - """ - Verify the token and its precedence + """Verify the token and its precedence. :param id_token: """ @@ -24,15 +22,14 @@ def validate_token(id_token: str) -> None: issuer = f"https://{AUTH0_DOMAIN}/" signature_verifier = AsymmetricSignatureVerifier(jwks_url) token_verifier = TokenVerifier( - signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID + signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID, ) token_verifier.verify(id_token) class TokenManager: def __init__(self, file_path: str = "tokens.enc") -> None: - """ - Initialize the TokenManager class. + """Initialize the TokenManager class. :param file_path: The file path to store the encrypted tokens. Default is "tokens.enc". """ @@ -41,8 +38,7 @@ class TokenManager: self.fernet = Fernet(self.key) def _get_or_create_key(self) -> bytes: - """ - Get or create the encryption key. + """Get or create the encryption key. :return: The encryption key. """ @@ -57,8 +53,7 @@ class TokenManager: return new_key def save_tokens(self, access_token: str, expires_in: int) -> None: - """ - Save the access token and its expiration time. + """Save the access token and its expiration time. :param access_token: The access token to save. :param expires_in: The expiration time of the access token in seconds. @@ -71,9 +66,8 @@ class TokenManager: encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) self.save_secure_file(self.file_path, encrypted_data) - def get_token(self) -> Optional[str]: - """ - Get the access token if it is valid and not expired. + def get_token(self) -> str | None: + """Get the access token if it is valid and not expired. :return: The access token if valid and not expired, otherwise None. """ @@ -89,8 +83,7 @@ class TokenManager: return data["access_token"] def get_secure_storage_path(self) -> Path: - """ - Get the secure storage path based on the operating system. + """Get the secure storage path based on the operating system. :return: The secure storage path. """ @@ -112,8 +105,7 @@ class TokenManager: return storage_path def save_secure_file(self, filename: str, content: bytes) -> None: - """ - Save the content to a secure file. + """Save the content to a secure file. :param filename: The name of the file. :param content: The content to save. @@ -127,9 +119,8 @@ class TokenManager: # Set appropriate permissions (read/write for owner only) os.chmod(file_path, 0o600) - def read_secure_file(self, filename: str) -> Optional[bytes]: - """ - Read the content of a secure file. + def read_secure_file(self, filename: str) -> bytes | None: + """Read the content of a secure file. :param filename: The name of the file. :return: The content of the file if it exists, otherwise None. diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index b2d59adbe..d9414bee8 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -1,6 +1,4 @@ -import os from importlib.metadata import version as get_version -from typing import Optional, Tuple import click @@ -28,7 +26,7 @@ from .update_crew import update_crew @click.group() @click.version_option(get_version("crewai")) -def crewai(): +def crewai() -> None: """Top-level command group for crewai.""" @@ -37,7 +35,7 @@ def crewai(): @click.argument("name") @click.option("--provider", type=str, help="The provider to use for the crew") @click.option("--skip_provider", is_flag=True, help="Skip provider validation") -def create(type, name, provider, skip_provider=False): +def create(type, name, provider, skip_provider=False) -> None: """Create a new crew, or flow.""" if type == "crew": create_crew(name, provider, skip_provider) @@ -49,9 +47,9 @@ def create(type, name, provider, skip_provider=False): @crewai.command() @click.option( - "--tools", is_flag=True, help="Show the installed version of crewai tools" + "--tools", is_flag=True, help="Show the installed version of crewai tools", ) -def version(tools): +def version(tools) -> None: """Show the installed version of crewai.""" try: crewai_version = get_version("crewai") @@ -82,7 +80,7 @@ def version(tools): default="trained_agents_data.pkl", help="Path to a custom file for training", ) -def train(n_iterations: int, filename: str): +def train(n_iterations: int, filename: str) -> None: """Train the crew.""" click.echo(f"Training the Crew for {n_iterations} iterations") train_crew(n_iterations, filename) @@ -96,11 +94,11 @@ def train(n_iterations: int, filename: str): help="Replay the crew from this task ID, including all subsequent tasks.", ) def replay(task_id: str) -> None: - """ - Replay the crew execution from a specific task. + """Replay the crew execution from a specific task. Args: task_id (str): The ID of the task to replay from. + """ try: click.echo(f"Replaying the crew from task {task_id}") @@ -111,16 +109,14 @@ def replay(task_id: str) -> None: @crewai.command() def log_tasks_outputs() -> None: - """ - Retrieve your latest crew.kickoff() task outputs. - """ + """Retrieve your latest crew.kickoff() task outputs.""" try: storage = KickoffTaskOutputsSQLiteStorage() tasks = storage.load() if not tasks: click.echo( - "No task outputs found. Only crew kickoff task outputs are logged." + "No task outputs found. Only crew kickoff task outputs are logged.", ) return @@ -153,13 +149,11 @@ def reset_memories( kickoff_outputs: bool, all: bool, ) -> None: - """ - Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved. - """ + """Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.""" try: if not all and not (long or short or entities or knowledge or kickoff_outputs): click.echo( - "Please specify at least one memory type to reset using the appropriate flags." + "Please specify at least one memory type to reset using the appropriate flags.", ) return reset_memories_command(long, short, entities, knowledge, kickoff_outputs, all) @@ -182,71 +176,69 @@ def reset_memories( default="gpt-4o-mini", help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.", ) -def test(n_iterations: int, model: str): +def test(n_iterations: int, model: str) -> None: """Test the crew and evaluate the results.""" click.echo(f"Testing the crew for {n_iterations} iterations with model {model}") evaluate_crew(n_iterations, model) @crewai.command( - context_settings=dict( - ignore_unknown_options=True, - allow_extra_args=True, - ) + context_settings={ + "ignore_unknown_options": True, + "allow_extra_args": True, + }, ) @click.pass_context -def install(context): +def install(context) -> None: """Install the Crew.""" install_crew(context.args) @crewai.command() -def run(): +def run() -> None: """Run the Crew.""" run_crew() @crewai.command() -def update(): +def update() -> None: """Update the pyproject.toml of the Crew project to use uv.""" update_crew() @crewai.command() -def signup(): +def signup() -> None: """Sign Up/Login to CrewAI+.""" AuthenticationCommand().signup() @crewai.command() -def login(): +def login() -> None: """Sign Up/Login to CrewAI+.""" AuthenticationCommand().signup() # DEPLOY CREWAI+ COMMANDS @crewai.group() -def deploy(): +def deploy() -> None: """Deploy the Crew CLI group.""" - pass @crewai.group() -def tool(): +def tool() -> None: """Tool Repository related commands.""" - pass @deploy.command(name="create") @click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt") -def deploy_create(yes: bool): +def deploy_create(yes: bool) -> None: """Create a Crew deployment.""" deploy_cmd = DeployCommand() deploy_cmd.create_crew(yes) @deploy.command(name="list") -def deploy_list(): +def deploy_list() -> None: """List all deployments.""" deploy_cmd = DeployCommand() deploy_cmd.list_crews() @@ -254,7 +246,7 @@ def deploy_list(): @deploy.command(name="push") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") -def deploy_push(uuid: Optional[str]): +def deploy_push(uuid: str | None) -> None: """Deploy the Crew.""" deploy_cmd = DeployCommand() deploy_cmd.deploy(uuid=uuid) @@ -262,7 +254,7 @@ def deploy_push(uuid: Optional[str]): @deploy.command(name="status") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") -def deply_status(uuid: Optional[str]): +def deply_status(uuid: str | None) -> None: """Get the status of a deployment.""" deploy_cmd = DeployCommand() deploy_cmd.get_crew_status(uuid=uuid) @@ -270,7 +262,7 @@ def deply_status(uuid: Optional[str]): @deploy.command(name="logs") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") -def deploy_logs(uuid: Optional[str]): +def deploy_logs(uuid: str | None) -> None: """Get the logs of a deployment.""" deploy_cmd = DeployCommand() deploy_cmd.get_crew_logs(uuid=uuid) @@ -278,7 +270,7 @@ def deploy_logs(uuid: Optional[str]): @deploy.command(name="remove") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") -def deploy_remove(uuid: Optional[str]): +def deploy_remove(uuid: str | None) -> None: """Remove a deployment.""" deploy_cmd = DeployCommand() deploy_cmd.remove_crew(uuid=uuid) @@ -286,14 +278,14 @@ def deploy_remove(uuid: Optional[str]): @tool.command(name="create") @click.argument("handle") -def tool_create(handle: str): +def tool_create(handle: str) -> None: tool_cmd = ToolCommand() tool_cmd.create(handle) @tool.command(name="install") @click.argument("handle") -def tool_install(handle: str): +def tool_install(handle: str) -> None: tool_cmd = ToolCommand() tool_cmd.login() tool_cmd.install(handle) @@ -309,27 +301,26 @@ def tool_install(handle: str): ) @click.option("--public", "is_public", flag_value=True, default=False) @click.option("--private", "is_public", flag_value=False) -def tool_publish(is_public: bool, force: bool): +def tool_publish(is_public: bool, force: bool) -> None: tool_cmd = ToolCommand() tool_cmd.login() tool_cmd.publish(is_public, force) @crewai.group() -def flow(): +def flow() -> None: """Flow related commands.""" - pass @flow.command(name="kickoff") -def flow_run(): +def flow_run() -> None: """Kickoff the Flow.""" click.echo("Running the Flow") kickoff_flow() @flow.command(name="plot") -def flow_plot(): +def flow_plot() -> None: """Plot the Flow.""" click.echo("Plotting the Flow") plot_flow() @@ -337,20 +328,19 @@ def flow_plot(): @flow.command(name="add-crew") @click.argument("crew_name") -def flow_add_crew(crew_name): +def flow_add_crew(crew_name) -> None: """Add a crew to an existing flow.""" click.echo(f"Adding crew {crew_name} to the flow") add_crew_to_flow(crew_name) @crewai.command() -def chat(): - """ - Start a conversation with the Crew, collecting user-supplied inputs, +def chat() -> None: + """Start a conversation with the Crew, collecting user-supplied inputs, and using the Chat LLM to generate responses. """ click.secho( - "\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n", + "\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n", ) run_chat() diff --git a/src/crewai/cli/command.py b/src/crewai/cli/command.py index 2bef8985d..c80a97fa9 100644 --- a/src/crewai/cli/command.py +++ b/src/crewai/cli/command.py @@ -10,13 +10,13 @@ console = Console() class BaseCommand: - def __init__(self): + def __init__(self) -> None: self._telemetry = Telemetry() self._telemetry.set_tracer() class PlusAPIMixin: - def __init__(self, telemetry): + def __init__(self, telemetry) -> None: try: telemetry.set_tracer() self.plus_api_client = PlusAPI(api_key=get_auth_token()) @@ -30,11 +30,11 @@ class PlusAPIMixin: raise SystemExit def _validate_response(self, response: requests.Response) -> None: - """ - Handle and display error messages from API responses. + """Handle and display error messages from API responses. Args: response (requests.Response): The response from the Plus API + """ try: json_response = response.json() @@ -55,13 +55,13 @@ class PlusAPIMixin: for field, messages in json_response.items(): for message in messages: console.print( - f"* [bold red]{field.capitalize()}[/bold red] {message}" + f"* [bold red]{field.capitalize()}[/bold red] {message}", ) raise SystemExit if not response.ok: console.print( - "Request to Enterprise API failed. Details:", style="bold red" + "Request to Enterprise API failed. Details:", style="bold red", ) details = ( json_response.get("error") diff --git a/src/crewai/cli/config.py b/src/crewai/cli/config.py index 8e30767ca..499f1a5b6 100644 --- a/src/crewai/cli/config.py +++ b/src/crewai/cli/config.py @@ -1,6 +1,5 @@ import json from pathlib import Path -from typing import Optional from pydantic import BaseModel, Field @@ -8,16 +7,16 @@ DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json" class Settings(BaseModel): - tool_repository_username: Optional[str] = Field( - None, description="Username for interacting with the Tool Repository" + tool_repository_username: str | None = Field( + None, description="Username for interacting with the Tool Repository", ) - tool_repository_password: Optional[str] = Field( - None, description="Password for interacting with the Tool Repository" + tool_repository_password: str | None = Field( + None, description="Password for interacting with the Tool Repository", ) config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True) - def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data): - """Load Settings from config path""" + def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data) -> None: + """Load Settings from config path.""" config_path.parent.mkdir(parents=True, exist_ok=True) file_data = {} @@ -32,7 +31,7 @@ class Settings(BaseModel): super().__init__(config_path=config_path, **merged_data) def dump(self) -> None: - """Save current settings to settings.json""" + """Save current settings to settings.json.""" if self.config_path.is_file(): with self.config_path.open("r") as f: existing_data = json.load(f) diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index 306f1108b..c729ba2df 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -3,31 +3,31 @@ ENV_VARS = { { "prompt": "Enter your OPENAI API key (press Enter to skip)", "key_name": "OPENAI_API_KEY", - } + }, ], "anthropic": [ { "prompt": "Enter your ANTHROPIC API key (press Enter to skip)", "key_name": "ANTHROPIC_API_KEY", - } + }, ], "gemini": [ { "prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)", "key_name": "GEMINI_API_KEY", - } + }, ], "nvidia_nim": [ { "prompt": "Enter your NVIDIA API key (press Enter to skip)", "key_name": "NVIDIA_NIM_API_KEY", - } + }, ], "groq": [ { "prompt": "Enter your GROQ API key (press Enter to skip)", "key_name": "GROQ_API_KEY", - } + }, ], "watson": [ { @@ -47,7 +47,7 @@ ENV_VARS = { { "default": True, "API_BASE": "http://localhost:11434", - } + }, ], "bedrock": [ { @@ -101,7 +101,7 @@ ENV_VARS = { { "prompt": "Enter your SambaNovaCloud API key (press Enter to skip)", "key_name": "SAMBANOVA_API_KEY", - } + }, ], } diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index c658b0de1..e5bf76b2b 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -24,7 +24,7 @@ def create_folder_structure(name, parent_folder=None): if folder_path.exists(): if not click.confirm( - f"Folder {folder_name} already exists. Do you want to override it?" + f"Folder {folder_name} already exists. Do you want to override it?", ): click.secho("Operation cancelled.", fg="yellow") sys.exit(0) @@ -48,7 +48,7 @@ def create_folder_structure(name, parent_folder=None): return folder_path, folder_name, class_name -def copy_template_files(folder_path, name, class_name, parent_folder): +def copy_template_files(folder_path, name, class_name, parent_folder) -> None: package_dir = Path(__file__).parent templates_dir = package_dir / "templates" / "crew" @@ -89,7 +89,7 @@ def copy_template_files(folder_path, name, class_name, parent_folder): copy_template(src_file, dst_file, name, class_name, folder_path.name) -def create_crew(name, provider=None, skip_provider=False, parent_folder=None): +def create_crew(name, provider=None, skip_provider=False, parent_folder=None) -> None: folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) env_vars = load_env_vars(folder_path) if not skip_provider: @@ -109,7 +109,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): if existing_provider: if not click.confirm( - f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?" + f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?", ): click.secho("Keeping existing provider configuration.", fg="yellow") return @@ -126,11 +126,11 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): if selected_provider: # Valid selection break click.secho( - "No provider selected. Please try again or press 'q' to exit.", fg="red" + "No provider selected. Please try again or press 'q' to exit.", fg="red", ) # Check if the selected provider has predefined models - if selected_provider in MODELS and MODELS[selected_provider]: + if MODELS.get(selected_provider): while True: selected_model = select_model(selected_provider, provider_models) if selected_model is None: # User typed 'q' @@ -167,7 +167,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): click.secho("API keys and model saved to .env file", fg="green") else: click.secho( - "No API keys provided. Skipping .env file creation.", fg="yellow" + "No API keys provided. Skipping .env file creation.", fg="yellow", ) click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green") diff --git a/src/crewai/cli/create_flow.py b/src/crewai/cli/create_flow.py index ec68611b5..2de5a882c 100644 --- a/src/crewai/cli/create_flow.py +++ b/src/crewai/cli/create_flow.py @@ -5,7 +5,7 @@ import click from crewai.telemetry import Telemetry -def create_flow(name): +def create_flow(name) -> None: """Create a new flow.""" folder_name = name.replace(" ", "_").replace("-", "_").lower() class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "") @@ -43,12 +43,12 @@ def create_flow(name): "poem_crew", ] - def process_file(src_file, dst_file): + def process_file(src_file, dst_file) -> None: if src_file.suffix in [".pyc", ".pyo", ".pyd"]: return try: - with open(src_file, "r", encoding="utf-8") as file: + with open(src_file, encoding="utf-8") as file: content = file.read() except Exception as e: click.secho(f"Error processing file {src_file}: {e}", fg="red") diff --git a/src/crewai/cli/crew_chat.py b/src/crewai/cli/crew_chat.py index 1b4e18c78..c76d063a8 100644 --- a/src/crewai/cli/crew_chat.py +++ b/src/crewai/cli/crew_chat.py @@ -5,7 +5,7 @@ import sys import threading import time from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any import click import tomli @@ -22,10 +22,9 @@ MIN_REQUIRED_VERSION = "0.98.0" def check_conversational_crews_version( - crewai_version: str, pyproject_data: dict + crewai_version: str, pyproject_data: dict, ) -> bool: - """ - Check if the installed crewAI version supports conversational crews. + """Check if the installed crewAI version supports conversational crews. Args: crewai_version: The current version of crewAI. @@ -33,6 +32,7 @@ def check_conversational_crews_version( Returns: bool: True if version check passes, False otherwise. + """ try: if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION): @@ -48,9 +48,8 @@ def check_conversational_crews_version( return True -def run_chat(): - """ - Runs an interactive chat loop using the Crew's chat LLM with function calling. +def run_chat() -> None: + """Runs an interactive chat loop using the Crew's chat LLM with function calling. Incorporates crew_name, crew_description, and input fields to build a tool schema. Exits if crew_name or crew_description are missing. """ @@ -84,7 +83,7 @@ def run_chat(): # Call the LLM to generate the introductory message introductory_message = chat_llm.call( - messages=[{"role": "system", "content": system_message}] + messages=[{"role": "system", "content": system_message}], ) finally: # Stop loading indicator @@ -108,15 +107,13 @@ def run_chat(): chat_loop(chat_llm, messages, crew_tool_schema, available_functions) -def show_loading(event: threading.Event): +def show_loading(event: threading.Event) -> None: """Display animated loading dots while processing.""" while not event.is_set(): - print(".", end="", flush=True) time.sleep(1) - print() -def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]: +def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None: """Initializes the chat LLM and handles exceptions.""" try: return create_llm(crew.chat_llm) @@ -157,7 +154,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str: ) -def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any: +def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any: """Creates a wrapper function for running the crew tool with messages.""" def run_crew_tool_with_messages(**kwargs): @@ -166,7 +163,7 @@ def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any: return run_crew_tool_with_messages -def flush_input(): +def flush_input() -> None: """Flush any pending input from the user.""" if platform.system() == "Windows": # Windows platform @@ -181,7 +178,7 @@ def flush_input(): termios.tcflush(sys.stdin, termios.TCIFLUSH) -def chat_loop(chat_llm, messages, crew_tool_schema, available_functions): +def chat_loop(chat_llm, messages, crew_tool_schema, available_functions) -> None: """Main chat loop for interacting with the user.""" while True: try: @@ -190,7 +187,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions): user_input = get_user_input() handle_user_input( - user_input, chat_llm, messages, crew_tool_schema, available_functions + user_input, chat_llm, messages, crew_tool_schema, available_functions, ) except KeyboardInterrupt: @@ -221,9 +218,9 @@ def get_user_input() -> str: def handle_user_input( user_input: str, chat_llm: LLM, - messages: List[Dict[str, str]], - crew_tool_schema: Dict[str, Any], - available_functions: Dict[str, Any], + messages: list[dict[str, str]], + crew_tool_schema: dict[str, Any], + available_functions: dict[str, Any], ) -> None: if user_input.strip().lower() == "exit": click.echo("Exiting chat. Goodbye!") @@ -251,8 +248,7 @@ def handle_user_input( def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict: - """ - Dynamically build a Littellm 'function' schema for the given crew. + """Dynamically build a Littellm 'function' schema for the given crew. crew_name: The name of the crew (used for the function 'name'). crew_inputs: A ChatInputs object containing crew_description @@ -281,9 +277,8 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict: } -def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): - """ - Runs the crew using crew.kickoff(inputs=kwargs) and returns the output. +def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs): + """Runs the crew using crew.kickoff(inputs=kwargs) and returns the output. Args: crew (Crew): The crew instance to run. @@ -295,6 +290,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): Raises: SystemExit: Exits the chat if an error occurs during crew execution. + """ try: # Serialize 'messages' to JSON string before adding to kwargs @@ -304,9 +300,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): crew_output = crew.kickoff(inputs=kwargs) # Convert CrewOutput to a string to send back to the user - result = str(crew_output) + return str(crew_output) - return result except Exception as e: # Exit the chat and show the error message click.secho("An error occurred while running the crew:", fg="red") @@ -314,12 +309,12 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): sys.exit(1) -def load_crew_and_name() -> Tuple[Crew, str]: - """ - Loads the crew by importing the crew class from the user's project. +def load_crew_and_name() -> tuple[Crew, str]: + """Loads the crew by importing the crew class from the user's project. Returns: Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew. + """ # Get the current working directory cwd = Path.cwd() @@ -327,7 +322,8 @@ def load_crew_and_name() -> Tuple[Crew, str]: # Path to the pyproject.toml file pyproject_path = cwd / "pyproject.toml" if not pyproject_path.exists(): - raise FileNotFoundError("pyproject.toml not found in the current directory.") + msg = "pyproject.toml not found in the current directory." + raise FileNotFoundError(msg) # Load the pyproject.toml file using 'tomli' with pyproject_path.open("rb") as f: @@ -351,14 +347,16 @@ def load_crew_and_name() -> Tuple[Crew, str]: try: crew_module = __import__(crew_module_name, fromlist=[crew_class_name]) except ImportError as e: - raise ImportError(f"Failed to import crew module {crew_module_name}: {e}") + msg = f"Failed to import crew module {crew_module_name}: {e}" + raise ImportError(msg) # Get the crew class from the module try: crew_class = getattr(crew_module, crew_class_name) except AttributeError: + msg = f"Crew class {crew_class_name} not found in module {crew_module_name}" raise AttributeError( - f"Crew class {crew_class_name} not found in module {crew_module_name}" + msg, ) # Instantiate the crew @@ -367,8 +365,7 @@ def load_crew_and_name() -> Tuple[Crew, str]: def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs: - """ - Generates the ChatInputs required for the crew by analyzing the tasks and agents. + """Generates the ChatInputs required for the crew by analyzing the tasks and agents. Args: crew (Crew): The crew object containing tasks and agents. @@ -377,6 +374,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput Returns: ChatInputs: An object containing the crew's name, description, and input fields. + """ # Extract placeholders from tasks and agents required_inputs = fetch_required_inputs(crew) @@ -391,22 +389,22 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput crew_description = generate_crew_description_with_ai(crew, chat_llm) return ChatInputs( - crew_name=crew_name, crew_description=crew_description, inputs=input_fields + crew_name=crew_name, crew_description=crew_description, inputs=input_fields, ) -def fetch_required_inputs(crew: Crew) -> Set[str]: - """ - Extracts placeholders from the crew's tasks and agents. +def fetch_required_inputs(crew: Crew) -> set[str]: + """Extracts placeholders from the crew's tasks and agents. Args: crew (Crew): The crew object. Returns: Set[str]: A set of placeholder names. + """ placeholder_pattern = re.compile(r"\{(.+?)\}") - required_inputs: Set[str] = set() + required_inputs: set[str] = set() # Scan tasks for task in crew.tasks: @@ -422,8 +420,7 @@ def fetch_required_inputs(crew: Crew) -> Set[str]: def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str: - """ - Generates an input description using AI based on the context of the crew. + """Generates an input description using AI based on the context of the crew. Args: input_name (str): The name of the input placeholder. @@ -432,6 +429,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> Returns: str: A concise description of the input. + """ # Gather context from tasks and agents where the input is used context_texts = [] @@ -444,10 +442,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> ): # Replace placeholders with input names task_description = placeholder_pattern.sub( - lambda m: m.group(1), task.description or "" + lambda m: m.group(1), task.description or "", ) expected_output = placeholder_pattern.sub( - lambda m: m.group(1), task.expected_output or "" + lambda m: m.group(1), task.expected_output or "", ) context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Expected Output: {expected_output}") @@ -461,7 +459,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") agent_backstory = placeholder_pattern.sub( - lambda m: m.group(1), agent.backstory or "" + lambda m: m.group(1), agent.backstory or "", ) context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Goal: {agent_goal}") @@ -470,7 +468,8 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> context = "\n".join(context_texts) if not context: # If no context is found for the input, raise an exception as per instruction - raise ValueError(f"No context found for input '{input_name}'.") + msg = f"No context found for input '{input_name}'." + raise ValueError(msg) prompt = ( f"Based on the following context, write a concise description (15 words or less) of the input '{input_name}'.\n" @@ -479,14 +478,12 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> f"{context}" ) response = chat_llm.call(messages=[{"role": "user", "content": prompt}]) - description = response.strip() + return response.strip() - return description def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: - """ - Generates a brief description of the crew using AI. + """Generates a brief description of the crew using AI. Args: crew (Crew): The crew object. @@ -494,6 +491,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: Returns: str: A concise description of the crew's purpose (15 words or less). + """ # Gather context from tasks and agents context_texts = [] @@ -502,10 +500,10 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: for task in crew.tasks: # Replace placeholders with input names task_description = placeholder_pattern.sub( - lambda m: m.group(1), task.description or "" + lambda m: m.group(1), task.description or "", ) expected_output = placeholder_pattern.sub( - lambda m: m.group(1), task.expected_output or "" + lambda m: m.group(1), task.expected_output or "", ) context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Expected Output: {expected_output}") @@ -514,7 +512,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") agent_backstory = placeholder_pattern.sub( - lambda m: m.group(1), agent.backstory or "" + lambda m: m.group(1), agent.backstory or "", ) context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Goal: {agent_goal}") @@ -522,7 +520,8 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: context = "\n".join(context_texts) if not context: - raise ValueError("No context found for generating crew description.") + msg = "No context found for generating crew description." + raise ValueError(msg) prompt = ( "Based on the following context, write a concise, action-oriented description (15 words or less) of the crew's purpose.\n" @@ -531,6 +530,5 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: f"{context}" ) response = chat_llm.call(messages=[{"role": "user", "content": prompt}]) - crew_description = response.strip() + return response.strip() - return crew_description diff --git a/src/crewai/cli/deploy/main.py b/src/crewai/cli/deploy/main.py index 486959201..726378bc8 100644 --- a/src/crewai/cli/deploy/main.py +++ b/src/crewai/cli/deploy/main.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any from rich.console import Console @@ -10,34 +10,27 @@ console = Console() class DeployCommand(BaseCommand, PlusAPIMixin): - """ - A class to handle deployment-related operations for CrewAI projects. - """ - - def __init__(self): - """ - Initialize the DeployCommand with project name and API client. - """ + """A class to handle deployment-related operations for CrewAI projects.""" + def __init__(self) -> None: + """Initialize the DeployCommand with project name and API client.""" BaseCommand.__init__(self) PlusAPIMixin.__init__(self, telemetry=self._telemetry) self.project_name = get_project_name(require=True) def _standard_no_param_error_message(self) -> None: - """ - Display a standard error message when no UUID or project name is available. - """ + """Display a standard error message when no UUID or project name is available.""" console.print( "No UUID provided, project pyproject.toml not found or with error.", style="bold red", ) - def _display_deployment_info(self, json_response: Dict[str, Any]) -> None: - """ - Display deployment information. + def _display_deployment_info(self, json_response: dict[str, Any]) -> None: + """Display deployment information. Args: json_response (Dict[str, Any]): The deployment information to display. + """ console.print("Deploying the crew...\n", style="bold blue") for key, value in json_response.items(): @@ -47,24 +40,24 @@ class DeployCommand(BaseCommand, PlusAPIMixin): console.print(" or") console.print(f"crewai deploy status --uuid \"{json_response['uuid']}\"") - def _display_logs(self, log_messages: List[Dict[str, Any]]) -> None: - """ - Display log messages. + def _display_logs(self, log_messages: list[dict[str, Any]]) -> None: + """Display log messages. Args: log_messages (List[Dict[str, Any]]): The log messages to display. + """ for log_message in log_messages: console.print( - f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}" + f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}", ) - def deploy(self, uuid: Optional[str] = None) -> None: - """ - Deploy a crew using either UUID or project name. + def deploy(self, uuid: str | None = None) -> None: + """Deploy a crew using either UUID or project name. Args: uuid (Optional[str]): The UUID of the crew to deploy. + """ self._start_deployment_span = self._telemetry.start_deployment_span(uuid) console.print("Starting deployment...", style="bold blue") @@ -80,9 +73,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin): self._display_deployment_info(response.json()) def create_crew(self, confirm: bool = False) -> None: - """ - Create a new crew deployment. - """ + """Create a new crew deployment.""" self._create_crew_deployment_span = ( self._telemetry.create_crew_deployment_span() ) @@ -110,29 +101,28 @@ class DeployCommand(BaseCommand, PlusAPIMixin): self._display_creation_success(response.json()) def _confirm_input( - self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool + self, env_vars: dict[str, str], remote_repo_url: str, confirm: bool, ) -> None: - """ - Confirm input parameters with the user. + """Confirm input parameters with the user. Args: env_vars (Dict[str, str]): Environment variables. remote_repo_url (str): Remote repository URL. confirm (bool): Whether to confirm input. + """ if not confirm: input(f"Press Enter to continue with the following Env vars: {env_vars}") input( - f"Press Enter to continue with the following remote repository: {remote_repo_url}\n" + f"Press Enter to continue with the following remote repository: {remote_repo_url}\n", ) def _create_payload( self, - env_vars: Dict[str, str], + env_vars: dict[str, str], remote_repo_url: str, - ) -> Dict[str, Any]: - """ - Create the payload for crew creation. + ) -> dict[str, Any]: + """Create the payload for crew creation. Args: remote_repo_url (str): Remote repository URL. @@ -140,25 +130,26 @@ class DeployCommand(BaseCommand, PlusAPIMixin): Returns: Dict[str, Any]: The payload for crew creation. + """ return { "deploy": { "name": self.project_name, "repo_clone_url": remote_repo_url, "env": env_vars, - } + }, } - def _display_creation_success(self, json_response: Dict[str, Any]) -> None: - """ - Display success message after crew creation. + def _display_creation_success(self, json_response: dict[str, Any]) -> None: + """Display success message after crew creation. Args: json_response (Dict[str, Any]): The response containing crew information. + """ console.print("Deployment created successfully!\n", style="bold green") console.print( - f"Name: {self.project_name} ({json_response['uuid']})", style="bold green" + f"Name: {self.project_name} ({json_response['uuid']})", style="bold green", ) console.print(f"Status: {json_response['status']}", style="bold green") console.print("\nTo (re)deploy the crew, run:") @@ -167,9 +158,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin): console.print(f"crewai deploy push --uuid {json_response['uuid']}") def list_crews(self) -> None: - """ - List all available crews. - """ + """List all available crews.""" console.print("Listing all Crews\n", style="bold blue") response = self.plus_api_client.list_crews() @@ -179,31 +168,29 @@ class DeployCommand(BaseCommand, PlusAPIMixin): else: self._display_no_crews_message() - def _display_crews(self, crews_data: List[Dict[str, Any]]) -> None: - """ - Display the list of crews. + def _display_crews(self, crews_data: list[dict[str, Any]]) -> None: + """Display the list of crews. Args: crews_data (List[Dict[str, Any]]): List of crew data to display. + """ for crew_data in crews_data: console.print( - f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]" + f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]", ) def _display_no_crews_message(self) -> None: - """ - Display a message when no crews are available. - """ + """Display a message when no crews are available.""" console.print("You don't have any Crews yet. Let's create one!", style="yellow") console.print(" crewai create crew ", style="green") - def get_crew_status(self, uuid: Optional[str] = None) -> None: - """ - Get the status of a crew. + def get_crew_status(self, uuid: str | None = None) -> None: + """Get the status of a crew. Args: uuid (Optional[str]): The UUID of the crew to check. + """ console.print("Fetching deployment status...", style="bold blue") if uuid: @@ -217,23 +204,23 @@ class DeployCommand(BaseCommand, PlusAPIMixin): self._validate_response(response) self._display_crew_status(response.json()) - def _display_crew_status(self, status_data: Dict[str, str]) -> None: - """ - Display the status of a crew. + def _display_crew_status(self, status_data: dict[str, str]) -> None: + """Display the status of a crew. Args: status_data (Dict[str, str]): The status data to display. + """ console.print(f"Name:\t {status_data['name']}") console.print(f"Status:\t {status_data['status']}") - def get_crew_logs(self, uuid: Optional[str], log_type: str = "deployment") -> None: - """ - Get logs for a crew. + def get_crew_logs(self, uuid: str | None, log_type: str = "deployment") -> None: + """Get logs for a crew. Args: uuid (Optional[str]): The UUID of the crew to get logs for. log_type (str): The type of logs to retrieve (default: "deployment"). + """ self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type) console.print(f"Fetching {log_type} logs...", style="bold blue") @@ -249,12 +236,12 @@ class DeployCommand(BaseCommand, PlusAPIMixin): self._validate_response(response) self._display_logs(response.json()) - def remove_crew(self, uuid: Optional[str]) -> None: - """ - Remove a crew deployment. + def remove_crew(self, uuid: str | None) -> None: + """Remove a crew deployment. Args: uuid (Optional[str]): The UUID of the crew to remove. + """ self._remove_crew_span = self._telemetry.remove_crew_span(uuid) console.print("Removing deployment...", style="bold blue") @@ -269,9 +256,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin): if response.status_code == 204: console.print( - f"Crew '{self.project_name}' removed successfully.", style="green" + f"Crew '{self.project_name}' removed successfully.", style="green", ) else: console.print( - f"Failed to remove crew '{self.project_name}'", style="bold red" + f"Failed to remove crew '{self.project_name}'", style="bold red", ) diff --git a/src/crewai/cli/evaluate_crew.py b/src/crewai/cli/evaluate_crew.py index 374f9f27d..b9cbe6b8a 100644 --- a/src/crewai/cli/evaluate_crew.py +++ b/src/crewai/cli/evaluate_crew.py @@ -4,18 +4,19 @@ import click def evaluate_crew(n_iterations: int, model: str) -> None: - """ - Test and Evaluate the crew by running a command in the UV environment. + """Test and Evaluate the crew by running a command in the UV environment. Args: n_iterations (int): The number of iterations to test the crew. model (str): The model to test the crew with. + """ command = ["uv", "run", "test", str(n_iterations), model] try: if n_iterations <= 0: - raise ValueError("The number of iterations must be a positive integer.") + msg = "The number of iterations must be a positive integer." + raise ValueError(msg) result = subprocess.run(command, capture_output=False, text=True, check=True) diff --git a/src/crewai/cli/git.py b/src/crewai/cli/git.py index 58836e733..7f29fd6ab 100644 --- a/src/crewai/cli/git.py +++ b/src/crewai/cli/git.py @@ -1,16 +1,18 @@ import subprocess -from functools import lru_cache +from functools import cache class Repository: - def __init__(self, path="."): + def __init__(self, path=".") -> None: self.path = path if not self.is_git_installed(): - raise ValueError("Git is not installed or not found in your PATH.") + msg = "Git is not installed or not found in your PATH." + raise ValueError(msg) if not self.is_git_repo(): - raise ValueError(f"{self.path} is not a Git repository.") + msg = f"{self.path} is not a Git repository." + raise ValueError(msg) self.fetch() @@ -18,7 +20,7 @@ class Repository: """Check if Git is installed and available in the system.""" try: subprocess.run( - ["git", "--version"], capture_output=True, check=True, text=True + ["git", "--version"], capture_output=True, check=True, text=True, ) return True except (subprocess.CalledProcessError, FileNotFoundError): @@ -36,7 +38,7 @@ class Repository: encoding="utf-8", ).strip() - @lru_cache(maxsize=None) + @cache def is_git_repo(self) -> bool: """Check if the current directory is a git repository.""" try: @@ -62,10 +64,7 @@ class Repository: def is_synced(self) -> bool: """Return True if the Git repository is fully synced with the remote, False otherwise.""" - if self.has_uncommitted_changes() or self.is_ahead_or_behind(): - return False - else: - return True + return not (self.has_uncommitted_changes() or self.is_ahead_or_behind()) def origin_url(self) -> str | None: """Get the Git repository's remote URL.""" diff --git a/src/crewai/cli/install_crew.py b/src/crewai/cli/install_crew.py index bd0f35879..913f10096 100644 --- a/src/crewai/cli/install_crew.py +++ b/src/crewai/cli/install_crew.py @@ -8,11 +8,9 @@ import click # so if you expect this to support more things you will need to replicate it there # ask @joaomdmoura if you are unsure def install_crew(proxy_options: list[str]) -> None: - """ - Install the crew by running the UV command to lock and install. - """ + """Install the crew by running the UV command to lock and install.""" try: - command = ["uv", "sync"] + proxy_options + command = ["uv", "sync", *proxy_options] subprocess.run(command, check=True, capture_output=False, text=True) except subprocess.CalledProcessError as e: diff --git a/src/crewai/cli/kickoff_flow.py b/src/crewai/cli/kickoff_flow.py index 2123a6c15..2d67b68e1 100644 --- a/src/crewai/cli/kickoff_flow.py +++ b/src/crewai/cli/kickoff_flow.py @@ -4,9 +4,7 @@ import click def kickoff_flow() -> None: - """ - Kickoff the flow by running a command in the UV environment. - """ + """Kickoff the flow by running a command in the UV environment.""" command = ["uv", "run", "kickoff"] try: diff --git a/src/crewai/cli/plot_flow.py b/src/crewai/cli/plot_flow.py index 848c55d69..55e1a51af 100644 --- a/src/crewai/cli/plot_flow.py +++ b/src/crewai/cli/plot_flow.py @@ -4,9 +4,7 @@ import click def plot_flow() -> None: - """ - Plot the flow by running a command in the UV environment. - """ + """Plot the flow by running a command in the UV environment.""" command = ["uv", "run", "plot"] try: diff --git a/src/crewai/cli/plus_api.py b/src/crewai/cli/plus_api.py index 23032ca8f..790a36412 100644 --- a/src/crewai/cli/plus_api.py +++ b/src/crewai/cli/plus_api.py @@ -1,5 +1,4 @@ from os import getenv -from typing import Optional from urllib.parse import urljoin import requests @@ -8,9 +7,7 @@ from crewai.cli.version import get_crewai_version class PlusAPI: - """ - This class exposes methods for working with the CrewAI+ API. - """ + """This class exposes methods for working with the CrewAI+ API.""" TOOLS_RESOURCE = "/crewai_plus/api/v1/tools" CREWS_RESOURCE = "/crewai_plus/api/v1/crews" @@ -42,7 +39,7 @@ class PlusAPI: handle: str, is_public: bool, version: str, - description: Optional[str], + description: str | None, encoded_file: str, ): params = { @@ -56,7 +53,7 @@ class PlusAPI: def deploy_by_name(self, project_name: str) -> requests.Response: return self._make_request( - "POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy" + "POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy", ) def deploy_by_uuid(self, uuid: str) -> requests.Response: @@ -64,29 +61,29 @@ class PlusAPI: def crew_status_by_name(self, project_name: str) -> requests.Response: return self._make_request( - "GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status" + "GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status", ) def crew_status_by_uuid(self, uuid: str) -> requests.Response: return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status") def crew_by_name( - self, project_name: str, log_type: str = "deployment" + self, project_name: str, log_type: str = "deployment", ) -> requests.Response: return self._make_request( - "GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}" + "GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}", ) def crew_by_uuid( - self, uuid: str, log_type: str = "deployment" + self, uuid: str, log_type: str = "deployment", ) -> requests.Response: return self._make_request( - "GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}" + "GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}", ) def delete_crew_by_name(self, project_name: str) -> requests.Response: return self._make_request( - "DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}" + "DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}", ) def delete_crew_by_uuid(self, uuid: str) -> requests.Response: diff --git a/src/crewai/cli/provider.py b/src/crewai/cli/provider.py index 529ca5e26..00aa034ec 100644 --- a/src/crewai/cli/provider.py +++ b/src/crewai/cli/provider.py @@ -10,8 +10,7 @@ from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS def select_choice(prompt_message, choices): - """ - Presents a list of choices to the user and prompts them to select one. + """Presents a list of choices to the user and prompts them to select one. Args: - prompt_message (str): The message to display to the user before presenting the choices. @@ -19,11 +18,11 @@ def select_choice(prompt_message, choices): Returns: - str: The selected choice from the list, or None if the user chooses to quit. - """ + """ provider_models = get_provider_data() if not provider_models: - return + return None click.secho(prompt_message, fg="cyan") for idx, choice in enumerate(choices, start=1): click.secho(f"{idx}. {choice}", fg="cyan") @@ -31,7 +30,7 @@ def select_choice(prompt_message, choices): while True: choice = click.prompt( - "Enter the number of your choice or 'q' to quit", type=str + "Enter the number of your choice or 'q' to quit", type=str, ) if choice.lower() == "q": @@ -51,8 +50,7 @@ def select_choice(prompt_message, choices): def select_provider(provider_models): - """ - Presents a list of providers to the user and prompts them to select one. + """Presents a list of providers to the user and prompts them to select one. Args: - provider_models (dict): A dictionary of provider models. @@ -60,12 +58,13 @@ def select_provider(provider_models): Returns: - str: The selected provider - None: If user explicitly quits + """ predefined_providers = [p.lower() for p in PROVIDERS] all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) provider = select_choice( - "Select a provider to set up:", predefined_providers + ["other"] + "Select a provider to set up:", [*predefined_providers, "other"], ) if provider is None: # User typed 'q' return None @@ -79,8 +78,7 @@ def select_provider(provider_models): def select_model(provider, provider_models): - """ - Presents a list of models for a given provider to the user and prompts them to select one. + """Presents a list of models for a given provider to the user and prompts them to select one. Args: - provider (str): The provider for which to select a model. @@ -88,6 +86,7 @@ def select_model(provider, provider_models): Returns: - str: The selected model, or None if the operation is aborted or an invalid selection is made. + """ predefined_providers = [p.lower() for p in PROVIDERS] @@ -100,15 +99,13 @@ def select_model(provider, provider_models): click.secho(f"No models available for provider '{provider}'.", fg="red") return None - selected_model = select_choice( - f"Select a model to use for {provider.capitalize()}:", available_models + return select_choice( + f"Select a model to use for {provider.capitalize()}:", available_models, ) - return selected_model def load_provider_data(cache_file, cache_expiry): - """ - Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web. + """Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web. Args: - cache_file (Path): The path to the cache file. @@ -116,6 +113,7 @@ def load_provider_data(cache_file, cache_expiry): Returns: - dict or None: The loaded provider data or None if the operation fails. + """ current_time = time.time() if ( @@ -126,7 +124,7 @@ def load_provider_data(cache_file, cache_expiry): if data: return data click.secho( - "Cache is corrupted. Fetching provider data from the web...", fg="yellow" + "Cache is corrupted. Fetching provider data from the web...", fg="yellow", ) else: click.secho( @@ -137,31 +135,31 @@ def load_provider_data(cache_file, cache_expiry): def read_cache_file(cache_file): - """ - Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON. + """Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON. Args: - cache_file (Path): The path to the cache file. Returns: - dict or None: The JSON content of the cache file or None if the JSON is invalid. + """ try: - with open(cache_file, "r") as f: + with open(cache_file) as f: return json.load(f) except json.JSONDecodeError: return None def fetch_provider_data(cache_file): - """ - Fetches provider data from a specified URL and caches it to a file. + """Fetches provider data from a specified URL and caches it to a file. Args: - cache_file (Path): The path to the cache file. Returns: - dict or None: The fetched provider data or None if the operation fails. + """ try: response = requests.get(JSON_URL, stream=True, timeout=60) @@ -178,20 +176,20 @@ def fetch_provider_data(cache_file): def download_data(response): - """ - Downloads data from a given HTTP response and returns the JSON content. + """Downloads data from a given HTTP response and returns the JSON content. Args: - response (requests.Response): The HTTP response object. Returns: - dict: The JSON content of the response. + """ total_size = int(response.headers.get("content-length", 0)) block_size = 8192 data_chunks = [] with click.progressbar( - length=total_size, label="Downloading", show_pos=True + length=total_size, label="Downloading", show_pos=True, ) as progress_bar: for chunk in response.iter_content(block_size): if chunk: @@ -202,11 +200,11 @@ def download_data(response): def get_provider_data(): - """ - Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models. + """Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models. Returns: - dict or None: A dictionary of providers mapped to their models or None if the operation fails. + """ cache_dir = Path.home() / ".crewai" cache_dir.mkdir(exist_ok=True) diff --git a/src/crewai/cli/replay_from_task.py b/src/crewai/cli/replay_from_task.py index 7e34c3394..b2312bb7c 100644 --- a/src/crewai/cli/replay_from_task.py +++ b/src/crewai/cli/replay_from_task.py @@ -4,11 +4,11 @@ import click def replay_task_command(task_id: str) -> None: - """ - Replay the crew execution from a specific task. + """Replay the crew execution from a specific task. Args: task_id (str): The ID of the task to replay from. + """ command = ["uv", "run", "replay", task_id] diff --git a/src/crewai/cli/reset_memories_command.py b/src/crewai/cli/reset_memories_command.py index eaf54ffb7..cd9461ddc 100644 --- a/src/crewai/cli/reset_memories_command.py +++ b/src/crewai/cli/reset_memories_command.py @@ -13,8 +13,7 @@ def reset_memories_command( kickoff_outputs, all, ) -> None: - """ - Reset the crew memories. + """Reset the crew memories. Args: long (bool): Whether to reset the long-term memory. @@ -23,49 +22,50 @@ def reset_memories_command( kickoff_outputs (bool): Whether to reset the latest kickoff task outputs. all (bool): Whether to reset all memories. knowledge (bool): Whether to reset the knowledge. - """ + """ try: if not any([long, short, entity, kickoff_outputs, knowledge, all]): click.echo( - "No memory type specified. Please specify at least one type to reset." + "No memory type specified. Please specify at least one type to reset.", ) return crews = get_crews() if not crews: - raise ValueError("No crew found.") + msg = "No crew found." + raise ValueError(msg) for crew in crews: if all: crew.reset_memories(command_type="all") click.echo( - f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed." + f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed.", ) continue if long: crew.reset_memories(command_type="long") click.echo( - f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset." + f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset.", ) if short: crew.reset_memories(command_type="short") click.echo( - f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset." + f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset.", ) if entity: crew.reset_memories(command_type="entity") click.echo( - f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset." + f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset.", ) if kickoff_outputs: crew.reset_memories(command_type="kickoff_outputs") click.echo( - f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset." + f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset.", ) if knowledge: crew.reset_memories(command_type="knowledge") click.echo( - f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset." + f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset.", ) except subprocess.CalledProcessError as e: diff --git a/src/crewai/cli/run_crew.py b/src/crewai/cli/run_crew.py index 62241a4b5..f28e9d90d 100644 --- a/src/crewai/cli/run_crew.py +++ b/src/crewai/cli/run_crew.py @@ -1,6 +1,5 @@ import subprocess from enum import Enum -from typing import List, Optional import click from packaging import version @@ -15,8 +14,7 @@ class CrewType(Enum): def run_crew() -> None: - """ - Run the crew or flow by running a command in the UV environment. + """Run the crew or flow by running a command in the UV environment. Starting from version 0.103.0, this command can be used to run both standard crews and flows. For flows, it detects the type from pyproject.toml @@ -48,11 +46,11 @@ def run_crew() -> None: def execute_command(crew_type: CrewType) -> None: - """ - Execute the appropriate command based on crew type. + """Execute the appropriate command based on crew type. Args: crew_type: The type of crew to run + """ command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"] @@ -67,12 +65,12 @@ def execute_command(crew_type: CrewType) -> None: def handle_error(error: subprocess.CalledProcessError, crew_type: CrewType) -> None: - """ - Handle subprocess errors with appropriate messaging. + """Handle subprocess errors with appropriate messaging. Args: error: The subprocess error that occurred crew_type: The type of crew that was being run + """ entity_type = "flow" if crew_type == CrewType.FLOW else "crew" click.echo(f"An error occurred while running the {entity_type}: {error}", err=True) diff --git a/src/crewai/cli/tools/main.py b/src/crewai/cli/tools/main.py index 8fbe1948b..ee82125a1 100644 --- a/src/crewai/cli/tools/main.py +++ b/src/crewai/cli/tools/main.py @@ -22,15 +22,13 @@ console = Console() class ToolCommand(BaseCommand, PlusAPIMixin): - """ - A class to handle tool repository related operations for CrewAI projects. - """ + """A class to handle tool repository related operations for CrewAI projects.""" - def __init__(self): + def __init__(self) -> None: BaseCommand.__init__(self) PlusAPIMixin.__init__(self, telemetry=self._telemetry) - def create(self, handle: str): + def create(self, handle: str) -> None: self._ensure_not_in_project() folder_name = handle.replace(" ", "_").replace("-", "_").lower() @@ -40,8 +38,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin): if project_root.exists(): click.secho(f"Folder {folder_name} already exists.", fg="red") raise SystemExit - else: - os.makedirs(project_root) + os.makedirs(project_root) click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True) @@ -56,12 +53,12 @@ class ToolCommand(BaseCommand, PlusAPIMixin): self.login() subprocess.run(["git", "init"], check=True) console.print( - f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]" + f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]", ) finally: os.chdir(old_directory) - def publish(self, is_public: bool, force: bool = False): + def publish(self, is_public: bool, force: bool = False) -> None: if not git.Repository().is_synced() and not force: console.print( "[bold red]Failed to publish tool.[/bold red]\n" @@ -69,9 +66,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin): "* [bold]Commit[/bold] your changes.\n" "* [bold]Push[/bold] to sync with the remote.\n" "* [bold]Pull[/bold] the latest changes from the remote.\n" - "\nOnce your repository is up-to-date, retry publishing the tool." + "\nOnce your repository is up-to-date, retry publishing the tool.", ) - raise SystemExit() + raise SystemExit project_name = get_project_name(require=True) assert isinstance(project_name, str) @@ -90,7 +87,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin): ) tarball_filename = next( - (f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None + (f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None, ) if not tarball_filename: console.print( @@ -123,7 +120,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin): style="bold green", ) - def install(self, handle: str): + def install(self, handle: str) -> None: get_response = self.plus_api_client.get_tool(handle) if get_response.status_code == 404: @@ -132,9 +129,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin): style="bold red", ) raise SystemExit - elif get_response.status_code != 200: + if get_response.status_code != 200: console.print( - "Failed to get tool details. Please try again later.", style="bold red" + "Failed to get tool details. Please try again later.", style="bold red", ) raise SystemExit @@ -142,7 +139,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin): console.print(f"Successfully installed {handle}", style="bold green") - def login(self): + def login(self) -> None: login_response = self.plus_api_client.login_to_tool_repository() if login_response.status_code != 200: @@ -164,10 +161,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin): settings.dump() console.print( - "Successfully authenticated to the tool repository.", style="bold green" + "Successfully authenticated to the tool repository.", style="bold green", ) - def _add_package(self, tool_details): + def _add_package(self, tool_details) -> None: tool_handle = tool_details["handle"] repository_handle = tool_details["repository"]["handle"] repository_url = tool_details["repository"]["url"] @@ -192,16 +189,16 @@ class ToolCommand(BaseCommand, PlusAPIMixin): click.echo(add_package_result.stderr, err=True) raise SystemExit - def _ensure_not_in_project(self): + def _ensure_not_in_project(self) -> None: if os.path.isfile("./pyproject.toml"): console.print( - "[bold red]Oops! It looks like you're inside a project.[/bold red]" + "[bold red]Oops! It looks like you're inside a project.[/bold red]", ) console.print( - "You can't create a new tool while inside an existing project." + "You can't create a new tool while inside an existing project.", ) console.print( - "[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again." + "[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again.", ) raise SystemExit @@ -211,10 +208,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin): env = os.environ.copy() env[f"UV_INDEX_{repository_handle}_USERNAME"] = str( - settings.tool_repository_username or "" + settings.tool_repository_username or "", ) env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str( - settings.tool_repository_password or "" + settings.tool_repository_password or "", ) return env diff --git a/src/crewai/cli/train_crew.py b/src/crewai/cli/train_crew.py index 14a5e1a06..5b6ec6a19 100644 --- a/src/crewai/cli/train_crew.py +++ b/src/crewai/cli/train_crew.py @@ -4,20 +4,22 @@ import click def train_crew(n_iterations: int, filename: str) -> None: - """ - Train the crew by running a command in the UV environment. + """Train the crew by running a command in the UV environment. Args: n_iterations (int): The number of iterations to train the crew. + """ command = ["uv", "run", "train", str(n_iterations), filename] try: if n_iterations <= 0: - raise ValueError("The number of iterations must be a positive integer.") + msg = "The number of iterations must be a positive integer." + raise ValueError(msg) if not filename.endswith(".pkl"): - raise ValueError("The filename must not end with .pkl") + msg = "The filename must not end with .pkl" + raise ValueError(msg) result = subprocess.run(command, capture_output=False, text=True, check=True) diff --git a/src/crewai/cli/update_crew.py b/src/crewai/cli/update_crew.py index e7ed69aa1..e1ea4def8 100644 --- a/src/crewai/cli/update_crew.py +++ b/src/crewai/cli/update_crew.py @@ -11,9 +11,8 @@ def update_crew() -> None: migrate_pyproject("pyproject.toml", "pyproject.toml") -def migrate_pyproject(input_file, output_file): - """ - Migrate the pyproject.toml to the new format. +def migrate_pyproject(input_file, output_file) -> None: + """Migrate the pyproject.toml to the new format. This function is used to migrate the pyproject.toml to the new format. And it will be used to migrate the pyproject.toml to the new format when uv is used. @@ -81,7 +80,7 @@ def migrate_pyproject(input_file, output_file): # Extract the module name from any existing script existing_scripts = new_pyproject["project"]["scripts"] module_name = next( - (value.split(".")[0] for value in existing_scripts.values() if "." in value) + value.split(".")[0] for value in existing_scripts.values() if "." in value ) new_pyproject["project"]["scripts"]["run_crew"] = f"{module_name}.main:run" @@ -93,22 +92,19 @@ def migrate_pyproject(input_file, output_file): # Backup the old pyproject.toml backup_file = "pyproject-old.toml" shutil.copy2(input_file, backup_file) - print(f"Original pyproject.toml backed up as {backup_file}") # Rename the poetry.lock file lock_file = "poetry.lock" lock_backup = "poetry-old.lock" if os.path.exists(lock_file): os.rename(lock_file, lock_backup) - print(f"Original poetry.lock renamed to {lock_backup}") else: - print("No poetry.lock file found to rename.") + pass # Write the new pyproject.toml with open(output_file, "wb") as f: tomli_w.dump(new_pyproject, f) - print(f"Migration complete. New pyproject.toml written to {output_file}") def parse_version(version: str) -> str: diff --git a/src/crewai/cli/utils.py b/src/crewai/cli/utils.py index 74fc414d9..393a2c326 100644 --- a/src/crewai/cli/utils.py +++ b/src/crewai/cli/utils.py @@ -3,7 +3,7 @@ import shutil import sys from functools import reduce from inspect import isfunction, ismethod -from typing import Any, Dict, List, get_type_hints +from typing import Any, get_type_hints import click import tomli @@ -19,9 +19,9 @@ if sys.version_info >= (3, 11): console = Console() -def copy_template(src, dst, name, class_name, folder_name): +def copy_template(src, dst, name, class_name, folder_name) -> None: """Copy a file from src to dst.""" - with open(src, "r") as file: + with open(src) as file: content = file.read() # Interpolate the content @@ -39,8 +39,7 @@ def copy_template(src, dst, name, class_name, folder_name): def read_toml(file_path: str = "pyproject.toml"): """Read the content of a TOML file and return it as a dictionary.""" with open(file_path, "rb") as f: - toml_dict = tomli.load(f) - return toml_dict + return tomli.load(f) def parse_toml(content): @@ -50,59 +49,56 @@ def parse_toml(content): def get_project_name( - pyproject_path: str = "pyproject.toml", require: bool = False + pyproject_path: str = "pyproject.toml", require: bool = False, ) -> str | None: """Get the project name from the pyproject.toml file.""" return _get_project_attribute(pyproject_path, ["project", "name"], require=require) def get_project_version( - pyproject_path: str = "pyproject.toml", require: bool = False + pyproject_path: str = "pyproject.toml", require: bool = False, ) -> str | None: """Get the project version from the pyproject.toml file.""" return _get_project_attribute( - pyproject_path, ["project", "version"], require=require + pyproject_path, ["project", "version"], require=require, ) def get_project_description( - pyproject_path: str = "pyproject.toml", require: bool = False + pyproject_path: str = "pyproject.toml", require: bool = False, ) -> str | None: """Get the project description from the pyproject.toml file.""" return _get_project_attribute( - pyproject_path, ["project", "description"], require=require + pyproject_path, ["project", "description"], require=require, ) def _get_project_attribute( - pyproject_path: str, keys: List[str], require: bool + pyproject_path: str, keys: list[str], require: bool, ) -> Any | None: """Get an attribute from the pyproject.toml file.""" attribute = None try: - with open(pyproject_path, "r") as f: + with open(pyproject_path) as f: pyproject_content = parse_toml(f.read()) dependencies = ( _get_nested_value(pyproject_content, ["project", "dependencies"]) or [] ) if not any(True for dep in dependencies if "crewai" in dep): - raise Exception("crewai is not in the dependencies.") + msg = "crewai is not in the dependencies." + raise Exception(msg) attribute = _get_nested_value(pyproject_content, keys) except FileNotFoundError: - print(f"Error: {pyproject_path} not found.") + pass except KeyError: - print(f"Error: {pyproject_path} is not a valid pyproject.toml file.") - except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore - print( - f"Error: {pyproject_path} is not a valid TOML file." - if sys.version_info >= (3, 11) - else f"Error reading the pyproject.toml file: {e}" - ) - except Exception as e: - print(f"Error reading the pyproject.toml file: {e}") + pass + except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception: # type: ignore + pass + except Exception: + pass if require and not attribute: console.print( @@ -114,7 +110,7 @@ def _get_project_attribute( return attribute -def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any: +def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any: return reduce(dict.__getitem__, keys, data) @@ -122,7 +118,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: """Fetch the environment variables from a .env file and return them as a dictionary.""" try: # Read the .env file - with open(env_file_path, "r") as f: + with open(env_file_path) as f: env_content = f.read() # Parse the .env file content to a dictionary @@ -135,14 +131,14 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: return env_dict except FileNotFoundError: - print(f"Error: {env_file_path} not found.") - except Exception as e: - print(f"Error reading the .env file: {e}") + pass + except Exception: + pass return {} -def tree_copy(source, destination): +def tree_copy(source, destination) -> None: """Copies the entire directory structure from the source to the destination.""" for item in os.listdir(source): source_item = os.path.join(source, item) @@ -153,7 +149,7 @@ def tree_copy(source, destination): shutil.copy2(source_item, destination_item) -def tree_find_and_replace(directory, find, replace): +def tree_find_and_replace(directory, find, replace) -> None: """Recursively searches through a directory, replacing a target string in both file contents and filenames with a specified replacement string. """ @@ -161,7 +157,7 @@ def tree_find_and_replace(directory, find, replace): for filename in files: filepath = os.path.join(path, filename) - with open(filepath, "r") as file: + with open(filepath) as file: contents = file.read() with open(filepath, "w") as file: file.write(contents.replace(find, replace)) @@ -180,19 +176,19 @@ def tree_find_and_replace(directory, find, replace): def load_env_vars(folder_path): - """ - Loads environment variables from a .env file in the specified folder path. + """Loads environment variables from a .env file in the specified folder path. Args: - folder_path (Path): The path to the folder containing the .env file. Returns: - dict: A dictionary of environment variables. + """ env_file_path = folder_path / ".env" env_vars = {} if env_file_path.exists(): - with open(env_file_path, "r") as file: + with open(env_file_path) as file: for line in file: key, _, value = line.strip().partition("=") if key and value: @@ -201,8 +197,7 @@ def load_env_vars(folder_path): def update_env_vars(env_vars, provider, model): - """ - Updates environment variables with the API key for the selected provider and model. + """Updates environment variables with the API key for the selected provider and model. Args: - env_vars (dict): Environment variables dictionary. @@ -211,6 +206,7 @@ def update_env_vars(env_vars, provider, model): Returns: - None + """ api_key_var = ENV_VARS.get( provider, @@ -218,14 +214,14 @@ def update_env_vars(env_vars, provider, model): click.prompt( f"Enter the environment variable name for your {provider.capitalize()} API key", type=str, - ) + ), ], )[0] if api_key_var not in env_vars: try: env_vars[api_key_var] = click.prompt( - f"Enter your {provider.capitalize()} API key", type=str, hide_input=True + f"Enter your {provider.capitalize()} API key", type=str, hide_input=True, ) except click.exceptions.Abort: click.secho("Operation aborted by the user.", fg="red") @@ -238,13 +234,13 @@ def update_env_vars(env_vars, provider, model): return env_vars -def write_env_file(folder_path, env_vars): - """ - Writes environment variables to a .env file in the specified folder. +def write_env_file(folder_path, env_vars) -> None: + """Writes environment variables to a .env file in the specified folder. Args: - folder_path (Path): The path to the folder where the .env file will be written. - env_vars (dict): A dictionary of environment variables to write. + """ env_file_path = folder_path / ".env" with open(env_file_path, "w") as file: @@ -263,7 +259,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: crew_os_path = os.path.join(root, crew_path) try: spec = importlib.util.spec_from_file_location( - "crew_module", crew_os_path + "crew_module", crew_os_path, ) if not spec or not spec.loader: continue @@ -277,19 +273,16 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: try: crew_instances.extend(fetch_crews(module_attr)) - except Exception as e: - print(f"Error processing attribute {attr_name}: {e}") + except Exception: continue - except Exception as exec_error: - print(f"Error executing module: {exec_error}") + except Exception: import traceback - print(f"Traceback: {traceback.format_exc()}") except (ImportError, AttributeError) as e: if require: console.print( - f"Error importing crew from {crew_path}: {str(e)}", + f"Error importing crew from {crew_path}: {e!s}", style="bold red", ) continue @@ -303,7 +296,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: except Exception as e: if require: console.print( - f"Unexpected error while loading crew: {str(e)}", style="bold red" + f"Unexpected error while loading crew: {e!s}", style="bold red", ) raise SystemExit return crew_instances @@ -317,13 +310,12 @@ def get_crew_instance(module_attr) -> Crew | None: ): return module_attr().crew() if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints( - module_attr + module_attr, ).get("return") is Crew: return module_attr() - elif isinstance(module_attr, Crew): + if isinstance(module_attr, Crew): return module_attr - else: - return None + return None def fetch_crews(module_attr) -> list[Crew]: diff --git a/src/crewai/cli/version.py b/src/crewai/cli/version.py index a7c1087a7..3ede9ff38 100644 --- a/src/crewai/cli/version.py +++ b/src/crewai/cli/version.py @@ -2,5 +2,5 @@ import importlib.metadata def get_crewai_version() -> str: - """Get the version number of CrewAI running the CLI""" + """Get the version number of CrewAI running the CLI.""" return importlib.metadata.version("crewai") diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 102f22881..87f92de03 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -3,18 +3,12 @@ import json import re import uuid import warnings +from collections.abc import Callable from concurrent.futures import Future from copy import copy as shallow_copy from hashlib import md5 from typing import ( Any, - Callable, - Dict, - List, - Optional, - Set, - Tuple, - Union, cast, ) @@ -81,8 +75,7 @@ warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd") class Crew(FlowTrackable, BaseModel): - """ - Represents a group of agents, defining how they should collaborate and the tasks they should perform. + """Represents a group of agents, defining how they should collaborate and the tasks they should perform. Attributes: tasks: List of tasks assigned to the crew. @@ -105,6 +98,7 @@ class Crew(FlowTrackable, BaseModel): planning: Plan the crew execution and add the plan to the crew. chat_llm: The language model used for orchestrating chat interactions with the crew. security_config: Security configuration for the crew, including fingerprinting. + """ __hash__ = object.__hash__ # type: ignore @@ -113,130 +107,130 @@ class Crew(FlowTrackable, BaseModel): _logger: Logger = PrivateAttr() _file_handler: FileHandler = PrivateAttr() _cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler()) - _short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr() - _long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr() - _entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr() - _user_memory: Optional[InstanceOf[UserMemory]] = PrivateAttr() - _external_memory: Optional[InstanceOf[ExternalMemory]] = PrivateAttr() - _train: Optional[bool] = PrivateAttr(default=False) - _train_iteration: Optional[int] = PrivateAttr() - _inputs: Optional[Dict[str, Any]] = PrivateAttr(default=None) + _short_term_memory: InstanceOf[ShortTermMemory] | None = PrivateAttr() + _long_term_memory: InstanceOf[LongTermMemory] | None = PrivateAttr() + _entity_memory: InstanceOf[EntityMemory] | None = PrivateAttr() + _user_memory: InstanceOf[UserMemory] | None = PrivateAttr() + _external_memory: InstanceOf[ExternalMemory] | None = PrivateAttr() + _train: bool | None = PrivateAttr(default=False) + _train_iteration: int | None = PrivateAttr() + _inputs: dict[str, Any] | None = PrivateAttr(default=None) _logging_color: str = PrivateAttr( default="bold_purple", ) _task_output_handler: TaskOutputStorageHandler = PrivateAttr( - default_factory=TaskOutputStorageHandler + default_factory=TaskOutputStorageHandler, ) - name: Optional[str] = Field(default=None) + name: str | None = Field(default=None) cache: bool = Field(default=True) - tasks: List[Task] = Field(default_factory=list) - agents: List[BaseAgent] = Field(default_factory=list) + tasks: list[Task] = Field(default_factory=list) + agents: list[BaseAgent] = Field(default_factory=list) process: Process = Field(default=Process.sequential) verbose: bool = Field(default=False) memory: bool = Field( default=False, description="Whether the crew should use memory to store memories of it's execution", ) - memory_config: Optional[Dict[str, Any]] = Field( + memory_config: dict[str, Any] | None = Field( default=None, description="Configuration for the memory to be used for the crew.", ) - short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field( + short_term_memory: InstanceOf[ShortTermMemory] | None = Field( default=None, description="An Instance of the ShortTermMemory to be used by the Crew", ) - long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field( + long_term_memory: InstanceOf[LongTermMemory] | None = Field( default=None, description="An Instance of the LongTermMemory to be used by the Crew", ) - entity_memory: Optional[InstanceOf[EntityMemory]] = Field( + entity_memory: InstanceOf[EntityMemory] | None = Field( default=None, description="An Instance of the EntityMemory to be used by the Crew", ) - user_memory: Optional[InstanceOf[UserMemory]] = Field( + user_memory: InstanceOf[UserMemory] | None = Field( default=None, description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.", ) - external_memory: Optional[InstanceOf[ExternalMemory]] = Field( + external_memory: InstanceOf[ExternalMemory] | None = Field( default=None, description="An Instance of the ExternalMemory to be used by the Crew", ) - embedder: Optional[dict] = Field( + embedder: dict | None = Field( default=None, description="Configuration for the embedder to be used for the crew.", ) - usage_metrics: Optional[UsageMetrics] = Field( + usage_metrics: UsageMetrics | None = Field( default=None, description="Metrics for the LLM usage during all tasks execution.", ) - manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( - description="Language model that will run the agent.", default=None + manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field( + description="Language model that will run the agent.", default=None, ) - manager_agent: Optional[BaseAgent] = Field( - description="Custom agent that will be used as manager.", default=None + manager_agent: BaseAgent | None = Field( + description="Custom agent that will be used as manager.", default=None, ) - function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( - description="Language model that will run the agent.", default=None + function_calling_llm: str | InstanceOf[LLM] | Any | None = Field( + description="Language model that will run the agent.", default=None, ) - config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None) + config: Json | dict[str, Any] | None = Field(default=None) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) - share_crew: Optional[bool] = Field(default=False) - step_callback: Optional[Any] = Field( + share_crew: bool | None = Field(default=False) + step_callback: Any | None = Field( default=None, description="Callback to be executed after each step for all agents execution.", ) - task_callback: Optional[Any] = Field( + task_callback: Any | None = Field( default=None, description="Callback to be executed after each task for all agents execution.", ) - before_kickoff_callbacks: List[ - Callable[[Optional[Dict[str, Any]]], Optional[Dict[str, Any]]] + before_kickoff_callbacks: list[ + Callable[[dict[str, Any] | None], dict[str, Any] | None] ] = Field( default_factory=list, description="List of callbacks to be executed before crew kickoff. It may be used to adjust inputs before the crew is executed.", ) - after_kickoff_callbacks: List[Callable[[CrewOutput], CrewOutput]] = Field( + after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field( default_factory=list, description="List of callbacks to be executed after crew kickoff. It may be used to adjust the output of the crew.", ) - max_rpm: Optional[int] = Field( + max_rpm: int | None = Field( default=None, description="Maximum number of requests per minute for the crew execution to be respected.", ) - prompt_file: Optional[str] = Field( + prompt_file: str | None = Field( default=None, description="Path to the prompt json file to be used for the crew.", ) - output_log_file: Optional[Union[bool, str]] = Field( + output_log_file: bool | str | None = Field( default=None, description="Path to the log file to be saved", ) - planning: Optional[bool] = Field( + planning: bool | None = Field( default=False, description="Plan the crew execution and add the plan to the crew.", ) - planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( + planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field( default=None, description="Language model that will run the AgentPlanner if planning is True.", ) - task_execution_output_json_files: Optional[List[str]] = Field( + task_execution_output_json_files: list[str] | None = Field( default=None, description="List of file paths for task execution JSON files.", ) - execution_logs: List[Dict[str, Any]] = Field( + execution_logs: list[dict[str, Any]] = Field( default=[], description="List of execution logs for tasks", ) - knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field( + knowledge_sources: list[BaseKnowledgeSource] | None = Field( default=None, description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.", ) - chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( + chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field( default=None, description="LLM used to handle chatting with the crew.", ) - knowledge: Optional[Knowledge] = Field( + knowledge: Knowledge | None = Field( default=None, description="Knowledge for the crew.", ) @@ -247,32 +241,34 @@ class Crew(FlowTrackable, BaseModel): @field_validator("id", mode="before") @classmethod - def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: + def _deny_user_set_id(cls, v: UUID4 | None) -> None: """Prevent manual setting of the 'id' field by users.""" if v: + msg = "may_not_set_field" raise PydanticCustomError( - "may_not_set_field", "The 'id' field cannot be set by the user.", {} + msg, "The 'id' field cannot be set by the user.", {}, ) @field_validator("config", mode="before") @classmethod def check_config_type( - cls, v: Union[Json, Dict[str, Any]] - ) -> Union[Json, Dict[str, Any]]: + cls, v: Json | dict[str, Any], + ) -> Json | dict[str, Any]: """Validates that the config is a valid type. + Args: v: The config to be validated. + Returns: The config if it is valid. - """ + """ # TODO: Improve typing return json.loads(v) if isinstance(v, Json) else v # type: ignore @model_validator(mode="after") def set_private_attrs(self) -> "Crew": """Set private attributes.""" - self._cache_handler = CacheHandler() event_listener = EventListener() event_listener.verbose = self.verbose @@ -286,7 +282,7 @@ class Crew(FlowTrackable, BaseModel): return self - def _initialize_user_memory(self): + def _initialize_user_memory(self) -> None: if ( self.memory_config and "user_memory" in self.memory_config @@ -294,20 +290,21 @@ class Crew(FlowTrackable, BaseModel): ): # Check for user_memory in config user_memory_config = self.memory_config["user_memory"] if isinstance( - user_memory_config, dict + user_memory_config, dict, ): # Check if it's a configuration dict self._user_memory = UserMemory(crew=self) else: - raise TypeError("user_memory must be a configuration dictionary") + msg = "user_memory must be a configuration dictionary" + raise TypeError(msg) - def _initialize_default_memories(self): + def _initialize_default_memories(self) -> None: self._long_term_memory = self._long_term_memory or LongTermMemory() self._short_term_memory = self._short_term_memory or ShortTermMemory( crew=self, embedder_config=self.embedder, ) self._entity_memory = self.entity_memory or EntityMemory( - crew=self, embedder_config=self.embedder + crew=self, embedder_config=self.embedder, ) @model_validator(mode="after") @@ -350,7 +347,7 @@ class Crew(FlowTrackable, BaseModel): except Exception as e: self._logger.log( - "warning", f"Failed to init knowledge: {e}", color="yellow" + "warning", f"Failed to init knowledge: {e}", color="yellow", ) return self @@ -359,8 +356,9 @@ class Crew(FlowTrackable, BaseModel): """Validates that the language model is set when using hierarchical process.""" if self.process == Process.hierarchical: if not self.manager_llm and not self.manager_agent: + msg = "missing_manager_llm_or_manager_agent" raise PydanticCustomError( - "missing_manager_llm_or_manager_agent", + msg, "Attribute `manager_llm` or `manager_agent` is required when using hierarchical process.", {}, ) @@ -368,8 +366,9 @@ class Crew(FlowTrackable, BaseModel): if (self.manager_agent is not None) and ( self.agents.count(self.manager_agent) > 0 ): + msg = "manager_agent_in_agents" raise PydanticCustomError( - "manager_agent_in_agents", + msg, "Manager agent should not be included in agents list.", {}, ) @@ -380,8 +379,9 @@ class Crew(FlowTrackable, BaseModel): def check_config(self): """Validates that the crew is properly configured with agents and tasks.""" if not self.config and not self.tasks and not self.agents: + msg = "missing_keys" raise PydanticCustomError( - "missing_keys", + msg, "Either 'agents' and 'tasks' need to be set or 'config'.", {}, ) @@ -402,8 +402,9 @@ class Crew(FlowTrackable, BaseModel): if self.process == Process.sequential: for task in self.tasks: if task.agent is None: + msg = "missing_agent_in_task" raise PydanticCustomError( - "missing_agent_in_task", + msg, f"Sequential process error: Agent is missing in the task with the following description: {task.description}", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString" {}, ) @@ -423,8 +424,9 @@ class Crew(FlowTrackable, BaseModel): break # Stop traversing as soon as a non-async task is encountered if final_async_task_count > 1: + msg = "async_task_count" raise PydanticCustomError( - "async_task_count", + msg, "The crew must end with at most one asynchronous task.", {}, ) @@ -440,8 +442,9 @@ class Crew(FlowTrackable, BaseModel): 1 for task in self.tasks if not isinstance(task, ConditionalTask) ) if non_conditional_count == 0: + msg = "only_conditional_tasks" raise PydanticCustomError( - "only_conditional_tasks", + msg, "Crew must include at least one non-conditional task", {}, ) @@ -451,8 +454,9 @@ class Crew(FlowTrackable, BaseModel): def validate_first_task(self) -> "Crew": """Ensure the first task is not a ConditionalTask.""" if self.tasks and isinstance(self.tasks[0], ConditionalTask): + msg = "invalid_first_task" raise PydanticCustomError( - "invalid_first_task", + msg, "The first task cannot be a ConditionalTask.", {}, ) @@ -463,8 +467,9 @@ class Crew(FlowTrackable, BaseModel): """Ensure that ConditionalTask is not async.""" for task in self.tasks: if task.async_execution and isinstance(task, ConditionalTask): + msg = "invalid_async_conditional_task" raise PydanticCustomError( - "invalid_async_conditional_task", + msg, f"Conditional Task: {task.description} , cannot be executed asynchronously.", # type: ignore # Argument of type "str" cannot be assigned to parameter "message_template" of type "LiteralString" {}, ) @@ -472,8 +477,7 @@ class Crew(FlowTrackable, BaseModel): @model_validator(mode="after") def validate_async_task_cannot_include_sequential_async_tasks_in_context(self): - """ - Validates that if a task is set to be executed asynchronously, + """Validates that if a task is set to be executed asynchronously, it cannot include other asynchronous tasks in its context unless separated by a synchronous task. """ @@ -483,8 +487,9 @@ class Crew(FlowTrackable, BaseModel): if context_task.async_execution: for j in range(i - 1, -1, -1): if self.tasks[j] == context_task: + msg = f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context." raise ValueError( - f"Task '{task.description}' is asynchronous and cannot include other sequential asynchronous tasks in its context." + msg, ) if not self.tasks[j].async_execution: break @@ -501,42 +506,44 @@ class Crew(FlowTrackable, BaseModel): if id(context_task) not in task_indices: continue # Skip context tasks not in the main tasks list if task_indices[id(context_task)] > task_indices[id(task)]: + msg = f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed." raise ValueError( - f"Task '{task.description}' has a context dependency on a future task '{context_task.description}', which is not allowed." + msg, ) return self @property def key(self) -> str: - source: List[str] = [agent.key for agent in self.agents] + [ + source: list[str] = [agent.key for agent in self.agents] + [ task.key for task in self.tasks ] return md5("|".join(source).encode(), usedforsecurity=False).hexdigest() @property def fingerprint(self) -> Fingerprint: - """ - Get the crew's fingerprint. + """Get the crew's fingerprint. Returns: Fingerprint: The crew's fingerprint + """ return self.security_config.fingerprint - def _setup_from_config(self): + def _setup_from_config(self) -> None: assert self.config is not None, "Config should not be None." """Initializes agents and tasks from the provided config.""" if not self.config.get("agents") or not self.config.get("tasks"): + msg = "missing_keys_in_config" raise PydanticCustomError( - "missing_keys_in_config", "Config should have 'agents' and 'tasks'.", {} + msg, "Config should have 'agents' and 'tasks'.", {}, ) self.process = self.config.get("process", self.process) self.agents = [Agent(**agent) for agent in self.config["agents"]] self.tasks = [self._create_task(task) for task in self.config["tasks"]] - def _create_task(self, task_config: Dict[str, Any]) -> Task: + def _create_task(self, task_config: dict[str, Any]) -> Task: """Creates a task instance from its configuration. Args: @@ -544,6 +551,7 @@ class Crew(FlowTrackable, BaseModel): Returns: A task instance. + """ task_agent = next( agt for agt in self.agents if agt.role == task_config["agent"] @@ -565,9 +573,11 @@ class Crew(FlowTrackable, BaseModel): CrewTrainingHandler(filename).initialize_file() def train( - self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {} + self, n_iterations: int, filename: str, inputs: dict[str, Any] | None = None, ) -> None: """Trains the crew for a given number of iterations.""" + if inputs is None: + inputs = {} try: crewai_event_bus.emit( self, @@ -590,10 +600,10 @@ class Crew(FlowTrackable, BaseModel): for agent in train_crew.agents: if training_data.get(str(agent.id)): result = TaskEvaluator(agent).evaluate_training_data( - training_data=training_data, agent_id=str(agent.id) + training_data=training_data, agent_id=str(agent.id), ) CrewTrainingHandler(filename).save_trained_data( - agent_id=str(agent.role), trained_data=result.model_dump() + agent_id=str(agent.role), trained_data=result.model_dump(), ) crewai_event_bus.emit( @@ -616,7 +626,7 @@ class Crew(FlowTrackable, BaseModel): def kickoff( self, - inputs: Optional[Dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, ) -> CrewOutput: try: for before_callback in self.before_kickoff_callbacks: @@ -657,15 +667,16 @@ class Crew(FlowTrackable, BaseModel): if self.planning: self._handle_crew_planning() - metrics: List[UsageMetrics] = [] + metrics: list[UsageMetrics] = [] if self.process == Process.sequential: result = self._run_sequential_process() elif self.process == Process.hierarchical: result = self._run_hierarchical_process() else: + msg = f"The process '{self.process}' is not implemented yet." raise NotImplementedError( - f"The process '{self.process}' is not implemented yet." + msg, ) for after_callback in self.after_kickoff_callbacks: @@ -684,9 +695,9 @@ class Crew(FlowTrackable, BaseModel): ) raise - def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]: + def kickoff_for_each(self, inputs: list[dict[str, Any]]) -> list[CrewOutput]: """Executes the Crew's workflow for each input in the list and aggregates results.""" - results: List[CrewOutput] = [] + results: list[CrewOutput] = [] # Initialize the parent crew's usage metrics total_usage_metrics = UsageMetrics() @@ -705,11 +716,13 @@ class Crew(FlowTrackable, BaseModel): self._task_output_handler.reset() return results - async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = {}) -> CrewOutput: + async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> CrewOutput: """Asynchronous kickoff method to start the crew execution.""" + if inputs is None: + inputs = {} return await asyncio.to_thread(self.kickoff, inputs) - async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[CrewOutput]: + async def kickoff_for_each_async(self, inputs: list[dict]) -> list[CrewOutput]: crew_copies = [self.copy() for _ in inputs] async def run_crew(crew, input_data): @@ -731,14 +744,14 @@ class Crew(FlowTrackable, BaseModel): self._task_output_handler.reset() return results - def _handle_crew_planning(self): + def _handle_crew_planning(self) -> None: """Handles the Crew planning.""" self._logger.log("info", "Planning the crew execution") result = CrewPlanner( - tasks=self.tasks, planning_agent_llm=self.planning_llm + tasks=self.tasks, planning_agent_llm=self.planning_llm, )._handle_crew_planning() - for task, step_plan in zip(self.tasks, result.list_of_plans_per_task): + for task, step_plan in zip(self.tasks, result.list_of_plans_per_task, strict=False): task.description += step_plan.plan def _store_execution_log( @@ -747,11 +760,8 @@ class Crew(FlowTrackable, BaseModel): output: TaskOutput, task_index: int, was_replayed: bool = False, - ): - if self._inputs: - inputs = self._inputs - else: - inputs = {} + ) -> None: + inputs = self._inputs if self._inputs else {} log = { "task": task, @@ -779,17 +789,18 @@ class Crew(FlowTrackable, BaseModel): self._create_manager_agent() return self._execute_tasks(self.tasks) - def _create_manager_agent(self): + def _create_manager_agent(self) -> None: i18n = I18N(prompt_file=self.prompt_file) if self.manager_agent is not None: self.manager_agent.allow_delegation = True manager = self.manager_agent if manager.tools is not None and len(manager.tools) > 0: self._logger.log( - "warning", "Manager agent should not have tools", color="orange" + "warning", "Manager agent should not have tools", color="orange", ) manager.tools = [] - raise Exception("Manager agent should not have tools") + msg = "Manager agent should not have tools" + raise Exception(msg) else: self.manager_llm = create_llm(self.manager_llm) manager = Agent( @@ -806,8 +817,8 @@ class Crew(FlowTrackable, BaseModel): def _execute_tasks( self, - tasks: List[Task], - start_index: Optional[int] = 0, + tasks: list[Task], + start_index: int | None = 0, was_replayed: bool = False, ) -> CrewOutput: """Executes tasks sequentially and returns the final output. @@ -818,11 +829,11 @@ class Crew(FlowTrackable, BaseModel): Returns: CrewOutput: Final output of the crew - """ - task_outputs: List[TaskOutput] = [] - futures: List[Tuple[Task, Future[TaskOutput], int]] = [] - last_sync_output: Optional[TaskOutput] = None + """ + task_outputs: list[TaskOutput] = [] + futures: list[tuple[Task, Future[TaskOutput], int]] = [] + last_sync_output: TaskOutput | None = None for task_index, task in enumerate(tasks): if start_index is not None and task_index < start_index: @@ -836,8 +847,9 @@ class Crew(FlowTrackable, BaseModel): agent_to_use = self._get_agent_to_use(task) if agent_to_use is None: + msg = f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided." raise ValueError( - f"No agent available for task: {task.description}. Ensure that either the task has an assigned agent or a manager agent is provided." + msg, ) # Determine which tools to use - task tools take precedence over agent tools @@ -846,14 +858,14 @@ class Crew(FlowTrackable, BaseModel): tools_for_task = self._prepare_tools( agent_to_use, task, - cast(Union[List[Tool], List[BaseTool]], tools_for_task), + cast("list[Tool] | list[BaseTool]", tools_for_task), ) self._log_task_start(task, agent_to_use.role) if isinstance(task, ConditionalTask): skipped_task_output = self._handle_conditional_task( - task, task_outputs, futures, task_index, was_replayed + task, task_outputs, futures, task_index, was_replayed, ) if skipped_task_output: task_outputs.append(skipped_task_output) @@ -861,12 +873,12 @@ class Crew(FlowTrackable, BaseModel): if task.async_execution: context = self._get_context( - task, [last_sync_output] if last_sync_output else [] + task, [last_sync_output] if last_sync_output else [], ) future = task.execute_async( agent=agent_to_use, context=context, - tools=cast(List[BaseTool], tools_for_task), + tools=cast("list[BaseTool]", tools_for_task), ) futures.append((task, future, task_index)) else: @@ -878,7 +890,7 @@ class Crew(FlowTrackable, BaseModel): task_output = task.execute_sync( agent=agent_to_use, context=context, - tools=cast(List[BaseTool], tools_for_task), + tools=cast("list[BaseTool]", tools_for_task), ) task_outputs.append(task_output) self._process_task_result(task, task_output) @@ -892,11 +904,11 @@ class Crew(FlowTrackable, BaseModel): def _handle_conditional_task( self, task: ConditionalTask, - task_outputs: List[TaskOutput], - futures: List[Tuple[Task, Future[TaskOutput], int]], + task_outputs: list[TaskOutput], + futures: list[tuple[Task, Future[TaskOutput], int]], task_index: int, was_replayed: bool, - ) -> Optional[TaskOutput]: + ) -> TaskOutput | None: if futures: task_outputs = self._process_async_tasks(futures, was_replayed) futures.clear() @@ -916,18 +928,19 @@ class Crew(FlowTrackable, BaseModel): return None def _prepare_tools( - self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, agent: BaseAgent, task: Task, tools: list[Tool] | list[BaseTool], + ) -> list[BaseTool]: # Add delegation tools if agent allows delegation if hasattr(agent, "allow_delegation") and getattr( - agent, "allow_delegation", False + agent, "allow_delegation", False, ): if self.process == Process.hierarchical: if self.manager_agent: tools = self._update_manager_tools(task, tools) else: + msg = "Manager agent is required for hierarchical process." raise ValueError( - "Manager agent is required for hierarchical process." + msg, ) elif agent: @@ -935,7 +948,7 @@ class Crew(FlowTrackable, BaseModel): # Add code execution tools if agent allows code execution if hasattr(agent, "allow_code_execution") and getattr( - agent, "allow_code_execution", False + agent, "allow_code_execution", False, ): tools = self._add_code_execution_tools(agent, tools) @@ -947,21 +960,21 @@ class Crew(FlowTrackable, BaseModel): tools = self._add_multimodal_tools(agent, tools) # Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async - return cast(List[BaseTool], tools) + return cast("list[BaseTool]", tools) - def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]: + def _get_agent_to_use(self, task: Task) -> BaseAgent | None: if self.process == Process.hierarchical: return self.manager_agent return task.agent def _merge_tools( self, - existing_tools: Union[List[Tool], List[BaseTool]], - new_tools: Union[List[Tool], List[BaseTool]], - ) -> List[BaseTool]: + existing_tools: list[Tool] | list[BaseTool], + new_tools: list[Tool] | list[BaseTool], + ) -> list[BaseTool]: """Merge new tools into existing tools list, avoiding duplicates by tool name.""" if not new_tools: - return cast(List[BaseTool], existing_tools) + return cast("list[BaseTool]", existing_tools) # Create mapping of tool names to new tools new_tool_map = {tool.name: tool for tool in new_tools} @@ -972,75 +985,74 @@ class Crew(FlowTrackable, BaseModel): # Add all new tools tools.extend(new_tools) - return cast(List[BaseTool], tools) + return cast("list[BaseTool]", tools) def _inject_delegation_tools( self, - tools: Union[List[Tool], List[BaseTool]], + tools: list[Tool] | list[BaseTool], task_agent: BaseAgent, - agents: List[BaseAgent], - ) -> List[BaseTool]: + agents: list[BaseAgent], + ) -> list[BaseTool]: if hasattr(task_agent, "get_delegation_tools"): delegation_tools = task_agent.get_delegation_tools(agents) # Cast delegation_tools to the expected type for _merge_tools - return self._merge_tools(tools, cast(List[BaseTool], delegation_tools)) - return cast(List[BaseTool], tools) + return self._merge_tools(tools, cast("list[BaseTool]", delegation_tools)) + return cast("list[BaseTool]", tools) def _add_multimodal_tools( - self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, agent: BaseAgent, tools: list[Tool] | list[BaseTool], + ) -> list[BaseTool]: if hasattr(agent, "get_multimodal_tools"): multimodal_tools = agent.get_multimodal_tools() # Cast multimodal_tools to the expected type for _merge_tools - return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools)) - return cast(List[BaseTool], tools) + return self._merge_tools(tools, cast("list[BaseTool]", multimodal_tools)) + return cast("list[BaseTool]", tools) def _add_code_execution_tools( - self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, agent: BaseAgent, tools: list[Tool] | list[BaseTool], + ) -> list[BaseTool]: if hasattr(agent, "get_code_execution_tools"): code_tools = agent.get_code_execution_tools() # Cast code_tools to the expected type for _merge_tools - return self._merge_tools(tools, cast(List[BaseTool], code_tools)) - return cast(List[BaseTool], tools) + return self._merge_tools(tools, cast("list[BaseTool]", code_tools)) + return cast("list[BaseTool]", tools) def _add_delegation_tools( - self, task: Task, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, task: Task, tools: list[Tool] | list[BaseTool], + ) -> list[BaseTool]: agents_for_delegation = [agent for agent in self.agents if agent != task.agent] if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent: if not tools: tools = [] tools = self._inject_delegation_tools( - tools, task.agent, agents_for_delegation + tools, task.agent, agents_for_delegation, ) - return cast(List[BaseTool], tools) + return cast("list[BaseTool]", tools) - def _log_task_start(self, task: Task, role: str = "None"): + def _log_task_start(self, task: Task, role: str = "None") -> None: if self.output_log_file: self._file_handler.log( - task_name=task.name, task=task.description, agent=role, status="started" + task_name=task.name, task=task.description, agent=role, status="started", ) def _update_manager_tools( - self, task: Task, tools: Union[List[Tool], List[BaseTool]] - ) -> List[BaseTool]: + self, task: Task, tools: list[Tool] | list[BaseTool], + ) -> list[BaseTool]: if self.manager_agent: if task.agent: tools = self._inject_delegation_tools(tools, task.agent, [task.agent]) else: tools = self._inject_delegation_tools( - tools, self.manager_agent, self.agents + tools, self.manager_agent, self.agents, ) - return cast(List[BaseTool], tools) + return cast("list[BaseTool]", tools) - def _get_context(self, task: Task, task_outputs: List[TaskOutput]): - context = ( + def _get_context(self, task: Task, task_outputs: list[TaskOutput]): + return ( aggregate_raw_outputs_from_tasks(task.context) if task.context else aggregate_raw_outputs_from_task_outputs(task_outputs) ) - return context def _process_task_result(self, task: Task, output: TaskOutput) -> None: role = task.agent.role if task.agent is not None else "None" @@ -1053,14 +1065,16 @@ class Crew(FlowTrackable, BaseModel): output=output.raw, ) - def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput: + def _create_crew_output(self, task_outputs: list[TaskOutput]) -> CrewOutput: if not task_outputs: - raise ValueError("No task outputs available to create crew output.") + msg = "No task outputs available to create crew output." + raise ValueError(msg) # Filter out empty outputs and get the last valid one as the main output valid_outputs = [t for t in task_outputs if t.raw] if not valid_outputs: - raise ValueError("No valid task outputs available to create crew output.") + msg = "No valid task outputs available to create crew output." + raise ValueError(msg) final_task_output = valid_outputs[-1] final_string_output = final_task_output.raw @@ -1069,7 +1083,7 @@ class Crew(FlowTrackable, BaseModel): crewai_event_bus.emit( self, CrewKickoffCompletedEvent( - crew_name=self.name or "crew", output=final_task_output + crew_name=self.name or "crew", output=final_task_output, ), ) return CrewOutput( @@ -1082,22 +1096,22 @@ class Crew(FlowTrackable, BaseModel): def _process_async_tasks( self, - futures: List[Tuple[Task, Future[TaskOutput], int]], + futures: list[tuple[Task, Future[TaskOutput], int]], was_replayed: bool = False, - ) -> List[TaskOutput]: - task_outputs: List[TaskOutput] = [] + ) -> list[TaskOutput]: + task_outputs: list[TaskOutput] = [] for future_task, future, task_index in futures: task_output = future.result() task_outputs.append(task_output) self._process_task_result(future_task, task_output) self._store_execution_log( - future_task, task_output, task_index, was_replayed + future_task, task_output, task_index, was_replayed, ) return task_outputs def _find_task_index( - self, task_id: str, stored_outputs: List[Any] - ) -> Optional[int]: + self, task_id: str, stored_outputs: list[Any], + ) -> int | None: return next( ( index @@ -1108,16 +1122,18 @@ class Crew(FlowTrackable, BaseModel): ) def replay( - self, task_id: str, inputs: Optional[Dict[str, Any]] = None + self, task_id: str, inputs: dict[str, Any] | None = None, ) -> CrewOutput: stored_outputs = self._task_output_handler.load() if not stored_outputs: - raise ValueError(f"Task with id {task_id} not found in the crew's tasks.") + msg = f"Task with id {task_id} not found in the crew's tasks." + raise ValueError(msg) start_index = self._find_task_index(task_id, stored_outputs) if start_index is None: - raise ValueError(f"Task with id {task_id} not found in the crew's tasks.") + msg = f"Task with id {task_id} not found in the crew's tasks." + raise ValueError(msg) replay_inputs = ( inputs if inputs is not None else stored_outputs[start_index]["inputs"] @@ -1145,28 +1161,26 @@ class Crew(FlowTrackable, BaseModel): self.tasks[i].output = task_output self._logging_color = "bold_blue" - result = self._execute_tasks(self.tasks, start_index, True) - return result + return self._execute_tasks(self.tasks, start_index, True) def query_knowledge( - self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 - ) -> Union[List[Dict[str, Any]], None]: + self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35, + ) -> list[dict[str, Any]] | None: if self.knowledge: return self.knowledge.query( - query, results_limit=results_limit, score_threshold=score_threshold + query, results_limit=results_limit, score_threshold=score_threshold, ) return None - def fetch_inputs(self) -> Set[str]: - """ - Gathers placeholders (e.g., {something}) referenced in tasks or agents. + def fetch_inputs(self) -> set[str]: + """Gathers placeholders (e.g., {something}) referenced in tasks or agents. Scans each task's 'description' + 'expected_output', and each agent's 'role', 'goal', and 'backstory'. Returns a set of all discovered placeholder names. """ placeholder_pattern = re.compile(r"\{(.+?)\}") - required_inputs: Set[str] = set() + required_inputs: set[str] = set() # Scan tasks for inputs for task in self.tasks: @@ -1183,13 +1197,12 @@ class Crew(FlowTrackable, BaseModel): return required_inputs def copy(self): - """ - Creates a deep copy of the Crew instance. + """Creates a deep copy of the Crew instance. Returns: Crew: A new instance with copied components - """ + """ exclude = { "id", "_rpm_controller", @@ -1225,7 +1238,7 @@ class Crew(FlowTrackable, BaseModel): cloned_tasks.append(cloned_task) task_mapping[task.key] = cloned_task - for cloned_task, original_task in zip(cloned_tasks, self.tasks): + for cloned_task, original_task in zip(cloned_tasks, self.tasks, strict=False): if original_task.context: cloned_context = [ task_mapping[context_task.key] @@ -1237,11 +1250,11 @@ class Crew(FlowTrackable, BaseModel): copied_data = {k: v for k, v in copied_data.items() if v is not None} if self.short_term_memory: copied_data["short_term_memory"] = self.short_term_memory.model_copy( - deep=True + deep=True, ) if self.long_term_memory: copied_data["long_term_memory"] = self.long_term_memory.model_copy( - deep=True + deep=True, ) if self.entity_memory: copied_data["entity_memory"] = self.entity_memory.model_copy(deep=True) @@ -1253,7 +1266,7 @@ class Crew(FlowTrackable, BaseModel): copied_data.pop("agents", None) copied_data.pop("tasks", None) - copied_crew = Crew( + return Crew( **copied_data, agents=cloned_agents, tasks=cloned_tasks, @@ -1263,20 +1276,19 @@ class Crew(FlowTrackable, BaseModel): manager_llm=manager_llm, ) - return copied_crew def _set_tasks_callbacks(self) -> None: - """Sets callback for every task suing task_callback""" + """Sets callback for every task suing task_callback.""" for task in self.tasks: if not task.callback: task.callback = self.task_callback - def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None: + def _interpolate_inputs(self, inputs: dict[str, Any]) -> None: """Interpolates the inputs in the tasks and agents.""" [ task.interpolate_inputs_and_add_conversation_history( # type: ignore # "interpolate_inputs" of "Task" does not return a value (it only ever returns None) - inputs + inputs, ) for task in self.tasks ] @@ -1304,15 +1316,16 @@ class Crew(FlowTrackable, BaseModel): def test( self, n_iterations: int, - eval_llm: Union[str, InstanceOf[BaseLLM]], - inputs: Optional[Dict[str, Any]] = None, + eval_llm: str | InstanceOf[BaseLLM], + inputs: dict[str, Any] | None = None, ) -> None: """Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.""" try: # Create LLM instance and ensure it's of type LLM for CrewEvaluator llm_instance = create_llm(eval_llm) if not llm_instance: - raise ValueError("Failed to create LLM instance.") + msg = "Failed to create LLM instance." + raise ValueError(msg) crewai_event_bus.emit( self, @@ -1345,7 +1358,7 @@ class Crew(FlowTrackable, BaseModel): ) raise - def __repr__(self): + def __repr__(self) -> str: return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})" def reset_memories(self, command_type: str) -> None: @@ -1359,6 +1372,7 @@ class Crew(FlowTrackable, BaseModel): Raises: ValueError: If an invalid command type is provided. RuntimeError: If memory reset operation fails. + """ VALID_TYPES = frozenset( [ @@ -1369,12 +1383,13 @@ class Crew(FlowTrackable, BaseModel): "kickoff_outputs", "all", "external", - ] + ], ) if command_type not in VALID_TYPES: + msg = f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}" raise ValueError( - f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}" + msg, ) try: @@ -1384,7 +1399,7 @@ class Crew(FlowTrackable, BaseModel): self._reset_specific_memory(command_type) except Exception as e: - error_msg = f"Failed to reset {command_type} memory: {str(e)}" + error_msg = f"Failed to reset {command_type} memory: {e!s}" self._logger.log("error", error_msg) raise RuntimeError(error_msg) from e @@ -1408,8 +1423,9 @@ class Crew(FlowTrackable, BaseModel): f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset", ) except Exception as e: + msg = f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {e!s}" raise RuntimeError( - f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}" + msg, ) from e def _reset_specific_memory(self, memory_type: str) -> None: @@ -1420,6 +1436,7 @@ class Crew(FlowTrackable, BaseModel): Raises: RuntimeError: If the specified memory system fails to reset + """ reset_functions = { "long": (getattr(self, "_long_term_memory", None), "long term"), @@ -1435,7 +1452,8 @@ class Crew(FlowTrackable, BaseModel): memory_system, name = reset_functions[memory_type] if memory_system is None: - raise RuntimeError(f"{name} memory system is not initialized") + msg = f"{name} memory system is not initialized" + raise RuntimeError(msg) try: memory_system.reset() @@ -1444,6 +1462,7 @@ class Crew(FlowTrackable, BaseModel): f"[Crew ({self.name if self.name else self.id})] {name} memory has been reset", ) except Exception as e: + msg = f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {e!s}" raise RuntimeError( - f"[Crew ({self.name if self.name else self.id})] Failed to reset {name} memory: {str(e)}" + msg, ) from e diff --git a/src/crewai/crews/crew_output.py b/src/crewai/crews/crew_output.py index c9a92a0d0..13102956f 100644 --- a/src/crewai/crews/crew_output.py +++ b/src/crewai/crews/crew_output.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,27 +12,28 @@ class CrewOutput(BaseModel): """Class that represents the result of a crew.""" raw: str = Field(description="Raw output of crew", default="") - pydantic: Optional[BaseModel] = Field( - description="Pydantic output of Crew", default=None + pydantic: BaseModel | None = Field( + description="Pydantic output of Crew", default=None, ) - json_dict: Optional[Dict[str, Any]] = Field( - description="JSON dict output of Crew", default=None + json_dict: dict[str, Any] | None = Field( + description="JSON dict output of Crew", default=None, ) tasks_output: list[TaskOutput] = Field( - description="Output of each task", default=[] + description="Output of each task", default=[], ) token_usage: UsageMetrics = Field(description="Processed token summary", default={}) @property - def json(self) -> Optional[str]: + def json(self) -> str | None: if self.tasks_output[-1].output_format != OutputFormat.JSON: + msg = "No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew." raise ValueError( - "No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew." + msg, ) return json.dumps(self.json_dict) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert json_output and pydantic_output to a dictionary.""" output_dict = {} if self.json_dict: @@ -44,12 +45,12 @@ class CrewOutput(BaseModel): def __getitem__(self, key): if self.pydantic and hasattr(self.pydantic, key): return getattr(self.pydantic, key) - elif self.json_dict and key in self.json_dict: + if self.json_dict and key in self.json_dict: return self.json_dict[key] - else: - raise KeyError(f"Key '{key}' not found in CrewOutput.") + msg = f"Key '{key}' not found in CrewOutput." + raise KeyError(msg) - def __str__(self): + def __str__(self) -> str: if self.pydantic: return str(self.pydantic) if self.json_dict: diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 99ae82c96..d6edcd14e 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -2,17 +2,11 @@ import asyncio import copy import inspect import logging +from collections.abc import Callable from typing import ( Any, - Callable, - Dict, Generic, - List, - Optional, - Set, - Type, TypeVar, - Union, cast, ) from uuid import uuid4 @@ -48,14 +42,14 @@ class FlowState(BaseModel): # Type variables with explicit bounds T = TypeVar( - "T", bound=Union[Dict[str, Any], BaseModel] + "T", bound=dict[str, Any] | BaseModel, ) # Generic flow state type parameter StateT = TypeVar( - "StateT", bound=Union[Dict[str, Any], BaseModel] + "StateT", bound=dict[str, Any] | BaseModel, ) # State validation type parameter -def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT: +def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT: """Ensure state matches expected type with proper validation. Args: @@ -68,6 +62,7 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT: Raises: TypeError: If state doesn't match expected type ValueError: If state validation fails + """ """Ensure state matches expected type with proper validation. @@ -84,20 +79,22 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT: """ if expected_type is dict: if not isinstance(state, dict): - raise TypeError(f"Expected dict, got {type(state).__name__}") - return cast(StateT, state) + msg = f"Expected dict, got {type(state).__name__}" + raise TypeError(msg) + return cast("StateT", state) if isinstance(expected_type, type) and issubclass(expected_type, BaseModel): if not isinstance(state, expected_type): + msg = f"Expected {expected_type.__name__}, got {type(state).__name__}" raise TypeError( - f"Expected {expected_type.__name__}, got {type(state).__name__}" + msg, ) - return cast(StateT, state) - raise TypeError(f"Invalid expected_type: {expected_type}") + return cast("StateT", state) + msg = f"Invalid expected_type: {expected_type}" + raise TypeError(msg) -def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: - """ - Marks a method as a flow's starting point. +def start(condition: str | dict | Callable | None = None) -> Callable: + """Marks a method as a flow's starting point. This decorator designates a method as an entry point for the flow execution. It can optionally specify conditions that trigger the start based on other @@ -135,6 +132,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: >>> @start(and_("method1", "method2")) # Start after multiple methods >>> def complex_start(self): ... pass + """ def decorator(func): @@ -154,17 +152,17 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" else: + msg = "Condition must be a method, string, or a result of or_() or and_()" raise ValueError( - "Condition must be a method, string, or a result of or_() or and_()" + msg, ) return func return decorator -def listen(condition: Union[str, dict, Callable]) -> Callable: - """ - Creates a listener that executes when specified conditions are met. +def listen(condition: str | dict | Callable) -> Callable: + """Creates a listener that executes when specified conditions are met. This decorator sets up a method to execute in response to other method executions in the flow. It supports both simple and complex triggering @@ -197,6 +195,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable: >>> @listen(or_("success", "failure")) # Listen to multiple methods >>> def handle_completion(self): ... pass + """ def decorator(func): @@ -214,17 +213,17 @@ def listen(condition: Union[str, dict, Callable]) -> Callable: func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" else: + msg = "Condition must be a method, string, or a result of or_() or and_()" raise ValueError( - "Condition must be a method, string, or a result of or_() or and_()" + msg, ) return func return decorator -def router(condition: Union[str, dict, Callable]) -> Callable: - """ - Creates a routing method that directs flow execution based on conditions. +def router(condition: str | dict | Callable) -> Callable: + """Creates a routing method that directs flow execution based on conditions. This decorator marks a method as a router, which can dynamically determine the next steps in the flow based on its return value. Routers are triggered @@ -262,6 +261,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable: ... if all([self.state.valid, self.state.processed]): ... return CONTINUE ... return STOP + """ def decorator(func): @@ -280,17 +280,17 @@ def router(condition: Union[str, dict, Callable]) -> Callable: func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" else: + msg = "Condition must be a method, string, or a result of or_() or and_()" raise ValueError( - "Condition must be a method, string, or a result of or_() or and_()" + msg, ) return func return decorator -def or_(*conditions: Union[str, dict, Callable]) -> dict: - """ - Combines multiple conditions with OR logic for flow control. +def or_(*conditions: str | dict | Callable) -> dict: + """Combines multiple conditions with OR logic for flow control. Creates a condition that is satisfied when any of the specified conditions are met. This is used with @start, @listen, or @router decorators to create @@ -320,6 +320,7 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict: >>> @listen(or_("success", "timeout")) >>> def handle_completion(self): ... pass + """ methods = [] for condition in conditions: @@ -330,13 +331,13 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict: elif callable(condition): methods.append(getattr(condition, "__name__", repr(condition))) else: - raise ValueError("Invalid condition in or_()") + msg = "Invalid condition in or_()" + raise ValueError(msg) return {"type": "OR", "methods": methods} -def and_(*conditions: Union[str, dict, Callable]) -> dict: - """ - Combines multiple conditions with AND logic for flow control. +def and_(*conditions: str | dict | Callable) -> dict: + """Combines multiple conditions with AND logic for flow control. Creates a condition that is satisfied only when all specified conditions are met. This is used with @start, @listen, or @router decorators to create @@ -366,6 +367,7 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict: >>> @listen(and_("validated", "processed")) >>> def handle_complete_data(self): ... pass + """ methods = [] for condition in conditions: @@ -376,7 +378,8 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict: elif callable(condition): methods.append(getattr(condition, "__name__", repr(condition))) else: - raise ValueError("Invalid condition in and_()") + msg = "Invalid condition in and_()" + raise ValueError(msg) return {"type": "AND", "methods": methods} @@ -416,10 +419,10 @@ class FlowMeta(type): if possible_returns: router_paths[attr_name] = possible_returns - setattr(cls, "_start_methods", start_methods) - setattr(cls, "_listeners", listeners) - setattr(cls, "_routers", routers) - setattr(cls, "_router_paths", router_paths) + cls._start_methods = start_methods + cls._listeners = listeners + cls._routers = routers + cls._router_paths = router_paths return cls @@ -427,17 +430,18 @@ class FlowMeta(type): class Flow(Generic[T], metaclass=FlowMeta): """Base class for all flows. - Type parameter T must be either Dict[str, Any] or a subclass of BaseModel.""" + Type parameter T must be either Dict[str, Any] or a subclass of BaseModel. + """ _printer = Printer() - _start_methods: List[str] = [] - _listeners: Dict[str, tuple[str, List[str]]] = {} - _routers: Set[str] = set() - _router_paths: Dict[str, List[str]] = {} - initial_state: Union[Type[T], T, None] = None + _start_methods: list[str] = [] + _listeners: dict[str, tuple[str, list[str]]] = {} + _routers: set[str] = set() + _router_paths: dict[str, list[str]] = {} + initial_state: type[T] | T | None = None - def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]: + def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]: class _FlowGeneric(cls): # type: ignore _initial_state_T = item # type: ignore @@ -446,7 +450,7 @@ class Flow(Generic[T], metaclass=FlowMeta): def __init__( self, - persistence: Optional[FlowPersistence] = None, + persistence: FlowPersistence | None = None, **kwargs: Any, ) -> None: """Initialize a new Flow instance. @@ -454,13 +458,14 @@ class Flow(Generic[T], metaclass=FlowMeta): Args: persistence: Optional persistence backend for storing flow states **kwargs: Additional state values to initialize or override + """ # Initialize basic instance attributes - self._methods: Dict[str, Callable] = {} - self._method_execution_counts: Dict[str, int] = {} - self._pending_and_listeners: Dict[str, Set[str]] = {} - self._method_outputs: List[Any] = [] # List to store all method outputs - self._persistence: Optional[FlowPersistence] = persistence + self._methods: dict[str, Callable] = {} + self._method_execution_counts: dict[str, int] = {} + self._pending_and_listeners: dict[str, set[str]] = {} + self._method_outputs: list[Any] = [] # List to store all method outputs + self._persistence: FlowPersistence | None = persistence # Initialize state with initial values self._state = self._create_initial_state() @@ -502,58 +507,61 @@ class Flow(Generic[T], metaclass=FlowMeta): Raises: ValueError: If structured state model lacks 'id' field TypeError: If state is neither BaseModel nor dictionary + """ # Handle case where initial_state is None but we have a type parameter if self.initial_state is None and hasattr(self, "_initial_state_T"): - state_type = getattr(self, "_initial_state_T") + state_type = self._initial_state_T if isinstance(state_type, type): if issubclass(state_type, FlowState): # Create instance without id, then set it instance = state_type() if not hasattr(instance, "id"): - setattr(instance, "id", str(uuid4())) - return cast(T, instance) - elif issubclass(state_type, BaseModel): + instance.id = str(uuid4()) + return cast("T", instance) + if issubclass(state_type, BaseModel): # Create a new type that includes the ID field class StateWithId(state_type, FlowState): # type: ignore pass instance = StateWithId() if not hasattr(instance, "id"): - setattr(instance, "id", str(uuid4())) - return cast(T, instance) - elif state_type is dict: - return cast(T, {"id": str(uuid4())}) + instance.id = str(uuid4()) + return cast("T", instance) + if state_type is dict: + return cast("T", {"id": str(uuid4())}) # Handle case where no initial state is provided if self.initial_state is None: - return cast(T, {"id": str(uuid4())}) + return cast("T", {"id": str(uuid4())}) # Handle case where initial_state is a type (class) if isinstance(self.initial_state, type): if issubclass(self.initial_state, FlowState): - return cast(T, self.initial_state()) # Uses model defaults - elif issubclass(self.initial_state, BaseModel): + return cast("T", self.initial_state()) # Uses model defaults + if issubclass(self.initial_state, BaseModel): # Validate that the model has an id field model_fields = getattr(self.initial_state, "model_fields", None) if not model_fields or "id" not in model_fields: - raise ValueError("Flow state model must have an 'id' field") - return cast(T, self.initial_state()) # Uses model defaults - elif self.initial_state is dict: - return cast(T, {"id": str(uuid4())}) + msg = "Flow state model must have an 'id' field" + raise ValueError(msg) + return cast("T", self.initial_state()) # Uses model defaults + if self.initial_state is dict: + return cast("T", {"id": str(uuid4())}) # Handle dictionary instance case if isinstance(self.initial_state, dict): new_state = dict(self.initial_state) # Copy to avoid mutations if "id" not in new_state: new_state["id"] = str(uuid4()) - return cast(T, new_state) + return cast("T", new_state) # Handle BaseModel instance case if isinstance(self.initial_state, BaseModel): - model = cast(BaseModel, self.initial_state) + model = cast("BaseModel", self.initial_state) if not hasattr(model, "id"): - raise ValueError("Flow state model must have an 'id' field") + msg = "Flow state model must have an 'id' field" + raise ValueError(msg) # Create new instance with same values to avoid mutations if hasattr(model, "model_dump"): @@ -570,9 +578,10 @@ class Flow(Generic[T], metaclass=FlowMeta): # Create new instance of the same class model_class = type(model) - return cast(T, model_class(**state_dict)) + return cast("T", model_class(**state_dict)) + msg = f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" raise TypeError( - f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" + msg, ) def _copy_state(self) -> T: @@ -583,7 +592,7 @@ class Flow(Generic[T], metaclass=FlowMeta): return self._state @property - def method_outputs(self) -> List[Any]: + def method_outputs(self) -> list[Any]: """Returns the list of all outputs from executed methods.""" return self._method_outputs @@ -607,6 +616,7 @@ class Flow(Generic[T], metaclass=FlowMeta): flow = MyFlow() print(f"Current flow ID: {flow.flow_id}") # Safely get flow ID ``` + """ try: if not hasattr(self, "_state"): @@ -614,13 +624,13 @@ class Flow(Generic[T], metaclass=FlowMeta): if isinstance(self._state, dict): return str(self._state.get("id", "")) - elif isinstance(self._state, BaseModel): + if isinstance(self._state, BaseModel): return str(getattr(self._state, "id", "")) return "" except (AttributeError, TypeError): return "" # Safely handle any unexpected attribute access issues - def _initialize_state(self, inputs: Dict[str, Any]) -> None: + def _initialize_state(self, inputs: dict[str, Any]) -> None: """Initialize or update flow state with new inputs. Args: @@ -629,6 +639,7 @@ class Flow(Generic[T], metaclass=FlowMeta): Raises: ValueError: If validation fails for structured state TypeError: If state is neither BaseModel nor dictionary + """ if isinstance(self._state, dict): # For dict states, preserve existing fields unless overridden @@ -644,7 +655,7 @@ class Flow(Generic[T], metaclass=FlowMeta): elif isinstance(self._state, BaseModel): # For BaseModel states, preserve existing fields unless overridden try: - model = cast(BaseModel, self._state) + model = cast("BaseModel", self._state) # Get current state as dict if hasattr(model, "model_dump"): current_state = model.model_dump() @@ -662,19 +673,21 @@ class Flow(Generic[T], metaclass=FlowMeta): model_class = type(model) if hasattr(model_class, "model_validate"): # Pydantic v2 - self._state = cast(T, model_class.model_validate(new_state)) + self._state = cast("T", model_class.model_validate(new_state)) elif hasattr(model_class, "parse_obj"): # Pydantic v1 - self._state = cast(T, model_class.parse_obj(new_state)) + self._state = cast("T", model_class.parse_obj(new_state)) else: # Fallback for other BaseModel implementations - self._state = cast(T, model_class(**new_state)) + self._state = cast("T", model_class(**new_state)) except ValidationError as e: - raise ValueError(f"Invalid inputs for structured state: {e}") from e + msg = f"Invalid inputs for structured state: {e}" + raise ValueError(msg) from e else: - raise TypeError("State must be a BaseModel instance or a dictionary.") + msg = "State must be a BaseModel instance or a dictionary." + raise TypeError(msg) - def _restore_state(self, stored_state: Dict[str, Any]) -> None: + def _restore_state(self, stored_state: dict[str, Any]) -> None: """Restore flow state from persistence. Args: @@ -683,11 +696,13 @@ class Flow(Generic[T], metaclass=FlowMeta): Raises: ValueError: If validation fails for structured state TypeError: If state is neither BaseModel nor dictionary + """ # When restoring from persistence, use the stored ID stored_id = stored_state.get("id") if not stored_id: - raise ValueError("Stored state must have an 'id' field") + msg = "Stored state must have an 'id' field" + raise ValueError(msg) if isinstance(self._state, dict): # For dict states, update all fields from stored state @@ -695,22 +710,22 @@ class Flow(Generic[T], metaclass=FlowMeta): self._state.update(stored_state) elif isinstance(self._state, BaseModel): # For BaseModel states, create new instance with stored values - model = cast(BaseModel, self._state) + model = cast("BaseModel", self._state) if hasattr(model, "model_validate"): # Pydantic v2 - self._state = cast(T, type(model).model_validate(stored_state)) + self._state = cast("T", type(model).model_validate(stored_state)) elif hasattr(model, "parse_obj"): # Pydantic v1 - self._state = cast(T, type(model).parse_obj(stored_state)) + self._state = cast("T", type(model).parse_obj(stored_state)) else: # Fallback for other BaseModel implementations - self._state = cast(T, type(model)(**stored_state)) + self._state = cast("T", type(model)(**stored_state)) else: - raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}") + msg = f"State must be dict or BaseModel, got {type(self._state)}" + raise TypeError(msg) - def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: - """ - Start the flow execution in a synchronous context. + def kickoff(self, inputs: dict[str, Any] | None = None) -> Any: + """Start the flow execution in a synchronous context. This method wraps kickoff_async so that all state initialization and event emission is handled in the asynchronous method. @@ -721,9 +736,8 @@ class Flow(Generic[T], metaclass=FlowMeta): return asyncio.run(run_flow()) - async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: - """ - Start the flow execution asynchronously. + async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any: + """Start the flow execution asynchronously. This method performs state restoration (if an 'id' is provided and persistence is available) and updates the flow state with any additional inputs. It then emits the FlowStartedEvent, @@ -735,6 +749,7 @@ class Flow(Generic[T], metaclass=FlowMeta): Returns: The final output from the flow, which is the result of the last executed method. + """ if inputs: # Override the id in the state if it exists in inputs @@ -742,7 +757,7 @@ class Flow(Generic[T], metaclass=FlowMeta): if isinstance(self._state, dict): self._state["id"] = inputs["id"] elif isinstance(self._state, BaseModel): - setattr(self._state, "id", inputs["id"]) + self._state.id = inputs["id"] # If persistence is enabled, attempt to restore the stored state using the provided id. if "id" in inputs and self._persistence is not None: @@ -756,7 +771,7 @@ class Flow(Generic[T], metaclass=FlowMeta): self._restore_state(stored_state) else: self._log_flow_event( - f"No flow state found for UUID: {restore_uuid}", color="red" + f"No flow state found for UUID: {restore_uuid}", color="red", ) # Update state with any additional inputs (ignoring the 'id' key) @@ -774,7 +789,7 @@ class Flow(Generic[T], metaclass=FlowMeta): ), ) self._log_flow_event( - f"Flow started with ID: {self.flow_id}", color="bold_magenta" + f"Flow started with ID: {self.flow_id}", color="bold_magenta", ) if inputs is not None and "id" not in inputs: @@ -800,8 +815,7 @@ class Flow(Generic[T], metaclass=FlowMeta): return final_output async def _execute_start_method(self, start_method_name: str) -> None: - """ - Executes a flow's start method and its triggered listeners. + """Executes a flow's start method and its triggered listeners. This internal method handles the execution of methods marked with @start decorator and manages the subsequent chain of listener executions. @@ -816,14 +830,15 @@ class Flow(Generic[T], metaclass=FlowMeta): - Executes the start method and captures its result - Triggers execution of any listeners waiting on this start method - Part of the flow's initialization sequence + """ result = await self._execute_method( - start_method_name, self._methods[start_method_name] + start_method_name, self._methods[start_method_name], ) await self._execute_listeners(start_method_name, result) async def _execute_method( - self, method_name: str, method: Callable, *args: Any, **kwargs: Any + self, method_name: str, method: Callable, *args: Any, **kwargs: Any, ) -> Any: try: dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | ( @@ -873,11 +888,10 @@ class Flow(Generic[T], metaclass=FlowMeta): error=e, ), ) - raise e + raise async def _execute_listeners(self, trigger_method: str, result: Any) -> None: - """ - Executes all listeners and routers triggered by a method completion. + """Executes all listeners and routers triggered by a method completion. This internal method manages the execution flow by: 1. First executing all triggered routers sequentially @@ -897,6 +911,7 @@ class Flow(Generic[T], metaclass=FlowMeta): - Each router's result becomes a new trigger_method - Normal listeners are executed in parallel for efficiency - Listeners can receive the trigger method's result as a parameter + """ # First, handle routers repeatedly until no router triggers anymore router_results = [] @@ -904,7 +919,7 @@ class Flow(Generic[T], metaclass=FlowMeta): while True: routers_triggered = self._find_triggered_methods( - current_trigger, router_only=True + current_trigger, router_only=True, ) if not routers_triggered: break @@ -920,12 +935,12 @@ class Flow(Generic[T], metaclass=FlowMeta): ) # Now execute normal listeners for all router results and the original trigger - all_triggers = [trigger_method] + router_results + all_triggers = [trigger_method, *router_results] for current_trigger in all_triggers: if current_trigger: # Skip None results listeners_triggered = self._find_triggered_methods( - current_trigger, router_only=False + current_trigger, router_only=False, ) if listeners_triggered: tasks = [ @@ -935,10 +950,9 @@ class Flow(Generic[T], metaclass=FlowMeta): await asyncio.gather(*tasks) def _find_triggered_methods( - self, trigger_method: str, router_only: bool - ) -> List[str]: - """ - Finds all methods that should be triggered based on conditions. + self, trigger_method: str, router_only: bool, + ) -> list[str]: + """Finds all methods that should be triggered based on conditions. This internal method evaluates both OR and AND conditions to determine which methods should be executed next in the flow. @@ -963,6 +977,7 @@ class Flow(Generic[T], metaclass=FlowMeta): * AND: Triggers only when all conditions are met - Maintains state for AND conditions using _pending_and_listeners - Separates router and normal listener evaluation + """ triggered = [] for listener_name, (condition_type, methods) in self._listeners.items(): @@ -992,8 +1007,7 @@ class Flow(Generic[T], metaclass=FlowMeta): return triggered async def _execute_single_listener(self, listener_name: str, result: Any) -> None: - """ - Executes a single listener method with proper event handling. + """Executes a single listener method with proper event handling. This internal method manages the execution of an individual listener, including parameter inspection, event emission, and error handling. @@ -1018,6 +1032,7 @@ class Flow(Generic[T], metaclass=FlowMeta): ------------- Catches and logs any exceptions during execution, preventing individual listener failures from breaking the entire flow. + """ try: method = self._methods[listener_name] @@ -1028,7 +1043,7 @@ class Flow(Generic[T], metaclass=FlowMeta): if method_params: listener_result = await self._execute_method( - listener_name, method, result + listener_name, method, result, ) else: listener_result = await self._execute_method(listener_name, method) @@ -1036,17 +1051,14 @@ class Flow(Generic[T], metaclass=FlowMeta): # Execute listeners (and possibly routers) of this listener await self._execute_listeners(listener_name, listener_result) - except Exception as e: - print( - f"[Flow._execute_single_listener] Error in method {listener_name}: {e}" - ) + except Exception: import traceback traceback.print_exc() raise def _log_flow_event( - self, message: str, color: str = "yellow", level: str = "info" + self, message: str, color: str = "yellow", level: str = "info", ) -> None: """Centralized logging method for flow events. @@ -1064,6 +1076,7 @@ class Flow(Generic[T], metaclass=FlowMeta): Note: This method uses the Printer utility for colored console output and the standard logging module for log level support. + """ self._printer.print(message, color=color) if level == "info": diff --git a/src/crewai/flow/flow_trackable.py b/src/crewai/flow/flow_trackable.py index 64e90630c..d333a535e 100644 --- a/src/crewai/flow/flow_trackable.py +++ b/src/crewai/flow/flow_trackable.py @@ -1,5 +1,4 @@ import inspect -from typing import Optional from pydantic import BaseModel, Field, InstanceOf, model_validator @@ -14,7 +13,7 @@ class FlowTrackable(BaseModel): inspecting the call stack. """ - parent_flow: Optional[InstanceOf[Flow]] = Field( + parent_flow: InstanceOf[Flow] | None = Field( default=None, description="The parent flow of the instance, if it was created inside a flow.", ) diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index a70e91a18..e71d7ac44 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,14 +1,13 @@ # flow_visualizer.py import os -from pathlib import Path from pyvis.network import Network from crewai.flow.config import COLORS, NODE_STYLES from crewai.flow.html_template_handler import HTMLTemplateHandler from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items -from crewai.flow.path_utils import safe_path_join, validate_path_exists +from crewai.flow.path_utils import safe_path_join from crewai.flow.utils import calculate_node_levels from crewai.flow.visualization_utils import ( add_edges, @@ -20,9 +19,8 @@ from crewai.flow.visualization_utils import ( class FlowPlot: """Handles the creation and rendering of flow visualization diagrams.""" - def __init__(self, flow): - """ - Initialize FlowPlot with a flow object. + def __init__(self, flow) -> None: + """Initialize FlowPlot with a flow object. Parameters ---------- @@ -33,21 +31,24 @@ class FlowPlot: ------ ValueError If flow object is invalid or missing required attributes. + """ - if not hasattr(flow, '_methods'): - raise ValueError("Invalid flow object: missing '_methods' attribute") - if not hasattr(flow, '_listeners'): - raise ValueError("Invalid flow object: missing '_listeners' attribute") - if not hasattr(flow, '_start_methods'): - raise ValueError("Invalid flow object: missing '_start_methods' attribute") - + if not hasattr(flow, "_methods"): + msg = "Invalid flow object: missing '_methods' attribute" + raise ValueError(msg) + if not hasattr(flow, "_listeners"): + msg = "Invalid flow object: missing '_listeners' attribute" + raise ValueError(msg) + if not hasattr(flow, "_start_methods"): + msg = "Invalid flow object: missing '_start_methods' attribute" + raise ValueError(msg) + self.flow = flow self.colors = COLORS self.node_styles = NODE_STYLES - def plot(self, filename): - """ - Generate and save an HTML visualization of the flow. + def plot(self, filename) -> None: + """Generate and save an HTML visualization of the flow. Parameters ---------- @@ -62,10 +63,12 @@ class FlowPlot: If file operations fail or visualization cannot be generated. RuntimeError If network visualization generation fails. + """ if not filename or not isinstance(filename, str): - raise ValueError("Filename must be a non-empty string") - + msg = "Filename must be a non-empty string" + raise ValueError(msg) + try: # Initialize network net = Network( @@ -89,58 +92,63 @@ class FlowPlot: "enabled": false } } - """ + """, ) # Calculate levels for nodes try: node_levels = calculate_node_levels(self.flow) except Exception as e: - raise ValueError(f"Failed to calculate node levels: {str(e)}") + msg = f"Failed to calculate node levels: {e!s}" + raise ValueError(msg) # Compute positions try: node_positions = compute_positions(self.flow, node_levels) except Exception as e: - raise ValueError(f"Failed to compute node positions: {str(e)}") + msg = f"Failed to compute node positions: {e!s}" + raise ValueError(msg) # Add nodes to the network try: add_nodes_to_network(net, self.flow, node_positions, self.node_styles) except Exception as e: - raise RuntimeError(f"Failed to add nodes to network: {str(e)}") + msg = f"Failed to add nodes to network: {e!s}" + raise RuntimeError(msg) # Add edges to the network try: add_edges(net, self.flow, node_positions, self.colors) except Exception as e: - raise RuntimeError(f"Failed to add edges to network: {str(e)}") + msg = f"Failed to add edges to network: {e!s}" + raise RuntimeError(msg) # Generate HTML try: network_html = net.generate_html() final_html_content = self._generate_final_html(network_html) except Exception as e: - raise RuntimeError(f"Failed to generate network visualization: {str(e)}") + msg = f"Failed to generate network visualization: {e!s}" + raise RuntimeError(msg) # Save the final HTML content to the file try: with open(f"{filename}.html", "w", encoding="utf-8") as f: f.write(final_html_content) - print(f"Plot saved as {filename}.html") - except IOError as e: - raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}") + except OSError as e: + msg = f"Failed to save flow visualization to {filename}.html: {e!s}" + raise OSError(msg) - except (ValueError, RuntimeError, IOError) as e: - raise e + except (OSError, ValueError, RuntimeError): + raise except Exception as e: - raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}") + msg = f"Unexpected error during flow visualization: {e!s}" + raise RuntimeError(msg) finally: self._cleanup_pyvis_lib() def _generate_final_html(self, network_html): - """ - Generate the final HTML content with network visualization and legend. + """Generate the final HTML content with network visualization and legend. Parameters ---------- @@ -158,9 +166,11 @@ class FlowPlot: If template or logo files cannot be accessed. ValueError If network_html is invalid. + """ if not network_html: - raise ValueError("Invalid network HTML content") + msg = "Invalid network HTML content" + raise ValueError(msg) try: # Extract just the body content from the generated HTML @@ -169,9 +179,11 @@ class FlowPlot: logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir) if not os.path.exists(template_path): - raise IOError(f"Template file not found: {template_path}") + msg = f"Template file not found: {template_path}" + raise OSError(msg) if not os.path.exists(logo_path): - raise IOError(f"Logo file not found: {logo_path}") + msg = f"Logo file not found: {logo_path}" + raise OSError(msg) html_handler = HTMLTemplateHandler(template_path, logo_path) network_body = html_handler.extract_body_content(network_html) @@ -179,16 +191,15 @@ class FlowPlot: # Generate the legend items HTML legend_items = get_legend_items(self.colors) legend_items_html = generate_legend_items_html(legend_items) - final_html_content = html_handler.generate_final_html( - network_body, legend_items_html + return html_handler.generate_final_html( + network_body, legend_items_html, ) - return final_html_content except Exception as e: - raise IOError(f"Failed to generate visualization HTML: {str(e)}") + msg = f"Failed to generate visualization HTML: {e!s}" + raise OSError(msg) - def _cleanup_pyvis_lib(self): - """ - Clean up the generated lib folder from pyvis. + def _cleanup_pyvis_lib(self) -> None: + """Clean up the generated lib folder from pyvis. This method safely removes the temporary lib directory created by pyvis during network visualization generation. @@ -198,15 +209,14 @@ class FlowPlot: if os.path.exists(lib_folder) and os.path.isdir(lib_folder): import shutil shutil.rmtree(lib_folder) - except ValueError as e: - print(f"Error validating lib folder path: {e}") - except Exception as e: - print(f"Error cleaning up lib folder: {e}") + except ValueError: + pass + except Exception: + pass -def plot_flow(flow, filename="flow_plot"): - """ - Convenience function to create and save a flow visualization. +def plot_flow(flow, filename="flow_plot") -> None: + """Convenience function to create and save a flow visualization. Parameters ---------- @@ -221,6 +231,7 @@ def plot_flow(flow, filename="flow_plot"): If flow object or filename is invalid. IOError If file operations fail. + """ visualizer = FlowPlot(flow) visualizer.plot(filename) diff --git a/src/crewai/flow/html_template_handler.py b/src/crewai/flow/html_template_handler.py index f0d2d89ad..85577f499 100644 --- a/src/crewai/flow/html_template_handler.py +++ b/src/crewai/flow/html_template_handler.py @@ -1,16 +1,14 @@ import base64 import re -from pathlib import Path -from crewai.flow.path_utils import safe_path_join, validate_path_exists +from crewai.flow.path_utils import validate_path_exists class HTMLTemplateHandler: """Handles HTML template processing and generation for flow visualization diagrams.""" - def __init__(self, template_path, logo_path): - """ - Initialize HTMLTemplateHandler with validated template and logo paths. + def __init__(self, template_path, logo_path) -> None: + """Initialize HTMLTemplateHandler with validated template and logo paths. Parameters ---------- @@ -23,16 +21,18 @@ class HTMLTemplateHandler: ------ ValueError If template or logo paths are invalid or files don't exist. + """ try: self.template_path = validate_path_exists(template_path, "file") self.logo_path = validate_path_exists(logo_path, "file") except ValueError as e: - raise ValueError(f"Invalid template or logo path: {e}") + msg = f"Invalid template or logo path: {e}" + raise ValueError(msg) def read_template(self): """Read and return the HTML template file contents.""" - with open(self.template_path, "r", encoding="utf-8") as f: + with open(self.template_path, encoding="utf-8") as f: return f.read() def encode_logo(self): @@ -81,13 +81,12 @@ class HTMLTemplateHandler: final_html_content = html_template.replace("{{ title }}", title) final_html_content = final_html_content.replace( - "{{ network_content }}", network_body + "{{ network_content }}", network_body, ) final_html_content = final_html_content.replace( - "{{ logo_svg_base64 }}", logo_svg_base64 + "{{ logo_svg_base64 }}", logo_svg_base64, ) - final_html_content = final_html_content.replace( - "", legend_items_html + return final_html_content.replace( + "", legend_items_html, ) - return final_html_content diff --git a/src/crewai/flow/path_utils.py b/src/crewai/flow/path_utils.py index 09ae8cd3d..04098a686 100644 --- a/src/crewai/flow/path_utils.py +++ b/src/crewai/flow/path_utils.py @@ -1,18 +1,14 @@ -""" -Path utilities for secure file operations in CrewAI flow module. +"""Path utilities for secure file operations in CrewAI flow module. This module provides utilities for secure path handling to prevent directory traversal attacks and ensure paths remain within allowed boundaries. """ -import os from pathlib import Path -from typing import List, Union -def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str: - """ - Safely join path components and ensure the result is within allowed boundaries. +def safe_path_join(*parts: str, root: str | Path | None = None) -> str: + """Safely join path components and ensure the result is within allowed boundaries. Parameters ---------- @@ -31,39 +27,43 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str: ValueError If the resulting path would be outside the root directory or if any path component is invalid. + """ if not parts: - raise ValueError("No path components provided") + msg = "No path components provided" + raise ValueError(msg) try: # Convert all parts to strings and clean them clean_parts = [str(part).strip() for part in parts if part] if not clean_parts: - raise ValueError("No valid path components provided") + msg = "No valid path components provided" + raise ValueError(msg) # Establish root directory root_path = Path(root).resolve() if root else Path.cwd() - + # Join and resolve the full path full_path = Path(root_path, *clean_parts).resolve() - + # Check if the resolved path is within root if not str(full_path).startswith(str(root_path)): + msg = f"Invalid path: Potential directory traversal. Path must be within {root_path}" raise ValueError( - f"Invalid path: Potential directory traversal. Path must be within {root_path}" + msg, ) - + return str(full_path) - + except Exception as e: if isinstance(e, ValueError): raise - raise ValueError(f"Invalid path components: {str(e)}") + msg = f"Invalid path components: {e!s}" + raise ValueError(msg) -def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str: - """ - Validate that a path exists and is of the expected type. +def validate_path_exists(path: str | Path, file_type: str = "file") -> str: + """Validate that a path exists and is of the expected type. Parameters ---------- @@ -81,29 +81,33 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str ------ ValueError If path doesn't exist or is not of expected type. + """ try: path_obj = Path(path).resolve() - + if not path_obj.exists(): - raise ValueError(f"Path does not exist: {path}") - + msg = f"Path does not exist: {path}" + raise ValueError(msg) + if file_type == "file" and not path_obj.is_file(): - raise ValueError(f"Path is not a file: {path}") - elif file_type == "directory" and not path_obj.is_dir(): - raise ValueError(f"Path is not a directory: {path}") - + msg = f"Path is not a file: {path}" + raise ValueError(msg) + if file_type == "directory" and not path_obj.is_dir(): + msg = f"Path is not a directory: {path}" + raise ValueError(msg) + return str(path_obj) - + except Exception as e: if isinstance(e, ValueError): raise - raise ValueError(f"Invalid path: {str(e)}") + msg = f"Invalid path: {e!s}" + raise ValueError(msg) -def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]: - """ - Safely list files in a directory matching a pattern. +def list_files(directory: str | Path, pattern: str = "*") -> list[str]: + """Safely list files in a directory matching a pattern. Parameters ---------- @@ -121,15 +125,18 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]: ------ ValueError If directory is invalid or inaccessible. + """ try: dir_path = Path(directory).resolve() if not dir_path.is_dir(): - raise ValueError(f"Not a directory: {directory}") - + msg = f"Not a directory: {directory}" + raise ValueError(msg) + return [str(p) for p in dir_path.glob(pattern) if p.is_file()] - + except Exception as e: if isinstance(e, ValueError): raise - raise ValueError(f"Error listing files: {str(e)}") + msg = f"Error listing files: {e!s}" + raise ValueError(msg) diff --git a/src/crewai/flow/persistence/base.py b/src/crewai/flow/persistence/base.py index c926f6f34..901aa5b57 100644 --- a/src/crewai/flow/persistence/base.py +++ b/src/crewai/flow/persistence/base.py @@ -1,53 +1,52 @@ """Base class for flow state persistence.""" import abc -from typing import Any, Dict, Optional, Union +from typing import Any from pydantic import BaseModel class FlowPersistence(abc.ABC): """Abstract base class for flow state persistence. - + This class defines the interface that all persistence implementations must follow. It supports both structured (Pydantic BaseModel) and unstructured (dict) states. """ - + @abc.abstractmethod def init_db(self) -> None: """Initialize the persistence backend. - + This method should handle any necessary setup, such as: - Creating tables - Establishing connections - Setting up indexes """ - pass - + @abc.abstractmethod def save_state( self, flow_uuid: str, method_name: str, - state_data: Union[Dict[str, Any], BaseModel] + state_data: dict[str, Any] | BaseModel, ) -> None: """Persist the flow state after method completion. - + Args: flow_uuid: Unique identifier for the flow instance method_name: Name of the method that just completed state_data: Current state data (either dict or Pydantic model) + """ - pass - + @abc.abstractmethod - def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: + def load_state(self, flow_uuid: str) -> dict[str, Any] | None: """Load the most recent state for a given flow UUID. - + Args: flow_uuid: Unique identifier for the flow instance - + Returns: The most recent state as a dictionary, or None if no state exists + """ - pass diff --git a/src/crewai/flow/persistence/decorators.py b/src/crewai/flow/persistence/decorators.py index 7b3bd447c..df087280b 100644 --- a/src/crewai/flow/persistence/decorators.py +++ b/src/crewai/flow/persistence/decorators.py @@ -1,5 +1,4 @@ -""" -Decorators for flow state persistence. +"""Decorators for flow state persistence. Example: ```python @@ -19,18 +18,16 @@ Example: # Asynchronous method implementation await some_async_operation() ``` + """ import asyncio import functools import logging +from collections.abc import Callable from typing import ( Any, - Callable, - Optional, - Type, TypeVar, - Union, cast, ) @@ -48,7 +45,7 @@ LOG_MESSAGES = { "save_state": "Saving flow state to memory for ID: {}", "save_error": "Failed to persist state for method {}: {}", "state_missing": "Flow instance has no state", - "id_missing": "Flow state must have an 'id' field for persistence" + "id_missing": "Flow state must have an 'id' field for persistence", } @@ -74,20 +71,23 @@ class PersistenceDecorator: ValueError: If flow has no state or state lacks an ID RuntimeError: If state persistence fails AttributeError: If flow instance lacks required state attributes + """ try: - state = getattr(flow_instance, 'state', None) + state = getattr(flow_instance, "state", None) if state is None: - raise ValueError("Flow instance has no state") + msg = "Flow instance has no state" + raise ValueError(msg) - flow_uuid: Optional[str] = None + flow_uuid: str | None = None if isinstance(state, dict): - flow_uuid = state.get('id') + flow_uuid = state.get("id") elif isinstance(state, BaseModel): - flow_uuid = getattr(state, 'id', None) + flow_uuid = getattr(state, "id", None) if not flow_uuid: - raise ValueError("Flow state must have an 'id' field for persistence") + msg = "Flow state must have an 'id' field for persistence" + raise ValueError(msg) # Log state saving only if verbose is True if verbose: @@ -103,21 +103,22 @@ class PersistenceDecorator: except Exception as e: error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e)) cls._printer.print(error_msg, color="red") - logger.error(error_msg) - raise RuntimeError(f"State persistence failed: {str(e)}") from e + logger.exception(error_msg) + msg = f"State persistence failed: {e!s}" + raise RuntimeError(msg) from e except AttributeError: error_msg = LOG_MESSAGES["state_missing"] cls._printer.print(error_msg, color="red") - logger.error(error_msg) + logger.exception(error_msg) raise ValueError(error_msg) except (TypeError, ValueError) as e: error_msg = LOG_MESSAGES["id_missing"] cls._printer.print(error_msg, color="red") - logger.error(error_msg) + logger.exception(error_msg) raise ValueError(error_msg) from e -def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False): +def persist(persistence: FlowPersistence | None = None, verbose: bool = False): """Decorator to persist flow state. This decorator can be applied at either the class level or method level. @@ -143,22 +144,23 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False @start() def begin(self): pass + """ - def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]: + def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]: """Decorator that handles both class and method decoration.""" actual_persistence = persistence or SQLiteFlowPersistence() if isinstance(target, type): # Class decoration - original_init = getattr(target, "__init__") + original_init = target.__init__ @functools.wraps(original_init) def new_init(self: Any, *args: Any, **kwargs: Any) -> None: - if 'persistence' not in kwargs: - kwargs['persistence'] = actual_persistence + if "persistence" not in kwargs: + kwargs["persistence"] = actual_persistence original_init(self, *args, **kwargs) - setattr(target, "__init__", new_init) + target.__init__ = new_init # Store original methods to preserve their decorators original_methods = {} @@ -191,7 +193,7 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: if hasattr(method, attr): setattr(wrapped, attr, getattr(method, attr)) - setattr(wrapped, "__is_flow_method__", True) + wrapped.__is_flow_method__ = True # Update the class with the wrapped method setattr(target, name, wrapped) @@ -211,44 +213,42 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: if hasattr(method, attr): setattr(wrapped, attr, getattr(method, attr)) - setattr(wrapped, "__is_flow_method__", True) + wrapped.__is_flow_method__ = True # Update the class with the wrapped method setattr(target, name, wrapped) return target - else: - # Method decoration - method = target - setattr(method, "__is_flow_method__", True) + # Method decoration + method = target + method.__is_flow_method__ = True - if asyncio.iscoroutinefunction(method): - @functools.wraps(method) - async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: - method_coro = method(flow_instance, *args, **kwargs) - if asyncio.iscoroutine(method_coro): - result = await method_coro - else: - result = method_coro - PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose) - return result + if asyncio.iscoroutinefunction(method): + @functools.wraps(method) + async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: + method_coro = method(flow_instance, *args, **kwargs) + if asyncio.iscoroutine(method_coro): + result = await method_coro + else: + result = method_coro + PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose) + return result - for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: - if hasattr(method, attr): - setattr(method_async_wrapper, attr, getattr(method, attr)) - setattr(method_async_wrapper, "__is_flow_method__", True) - return cast(Callable[..., T], method_async_wrapper) - else: - @functools.wraps(method) - def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: - result = method(flow_instance, *args, **kwargs) - PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose) - return result + for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: + if hasattr(method, attr): + setattr(method_async_wrapper, attr, getattr(method, attr)) + method_async_wrapper.__is_flow_method__ = True + return cast("Callable[..., T]", method_async_wrapper) + @functools.wraps(method) + def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: + result = method(flow_instance, *args, **kwargs) + PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose) + return result - for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: - if hasattr(method, attr): - setattr(method_sync_wrapper, attr, getattr(method, attr)) - setattr(method_sync_wrapper, "__is_flow_method__", True) - return cast(Callable[..., T], method_sync_wrapper) + for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: + if hasattr(method, attr): + setattr(method_sync_wrapper, attr, getattr(method, attr)) + method_sync_wrapper.__is_flow_method__ = True + return cast("Callable[..., T]", method_sync_wrapper) return decorator diff --git a/src/crewai/flow/persistence/sqlite.py b/src/crewai/flow/persistence/sqlite.py index 8b2a0f3f2..696517d8b 100644 --- a/src/crewai/flow/persistence/sqlite.py +++ b/src/crewai/flow/persistence/sqlite.py @@ -1,12 +1,10 @@ -""" -SQLite-based implementation of flow state persistence. -""" +"""SQLite-based implementation of flow state persistence.""" import json import sqlite3 from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any from pydantic import BaseModel @@ -23,7 +21,7 @@ class SQLiteFlowPersistence(FlowPersistence): db_path: str - def __init__(self, db_path: Optional[str] = None): + def __init__(self, db_path: str | None = None) -> None: """Initialize SQLite persistence. Args: @@ -32,6 +30,7 @@ class SQLiteFlowPersistence(FlowPersistence): Raises: ValueError: If db_path is invalid + """ from crewai.utilities.paths import db_storage_path @@ -39,7 +38,8 @@ class SQLiteFlowPersistence(FlowPersistence): path = db_path or str(Path(db_storage_path()) / "flow_states.db") if not path: - raise ValueError("Database path must be provided") + msg = "Database path must be provided" + raise ValueError(msg) self.db_path = path # Now mypy knows this is str self.init_db() @@ -56,21 +56,21 @@ class SQLiteFlowPersistence(FlowPersistence): timestamp DATETIME NOT NULL, state_json TEXT NOT NULL ) - """ + """, ) # Add index for faster UUID lookups conn.execute( """ CREATE INDEX IF NOT EXISTS idx_flow_states_uuid ON flow_states(flow_uuid) - """ + """, ) def save_state( self, flow_uuid: str, method_name: str, - state_data: Union[Dict[str, Any], BaseModel], + state_data: dict[str, Any] | BaseModel, ) -> None: """Save the current flow state to SQLite. @@ -78,6 +78,7 @@ class SQLiteFlowPersistence(FlowPersistence): flow_uuid: Unique identifier for the flow instance method_name: Name of the method that just completed state_data: Current state data (either dict or Pydantic model) + """ # Convert state_data to dict, handling both Pydantic and dict cases if isinstance(state_data, BaseModel): @@ -85,8 +86,9 @@ class SQLiteFlowPersistence(FlowPersistence): elif isinstance(state_data, dict): state_dict = state_data else: + msg = f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" raise ValueError( - f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" + msg, ) with sqlite3.connect(self.db_path) as conn: @@ -107,7 +109,7 @@ class SQLiteFlowPersistence(FlowPersistence): ), ) - def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: + def load_state(self, flow_uuid: str) -> dict[str, Any] | None: """Load the most recent state for a given flow UUID. Args: @@ -115,6 +117,7 @@ class SQLiteFlowPersistence(FlowPersistence): Returns: The most recent state as a dictionary, or None if no state exists + """ with sqlite3.connect(self.db_path) as conn: cursor = conn.execute( diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py index 81f3c1041..40b430402 100644 --- a/src/crewai/flow/utils.py +++ b/src/crewai/flow/utils.py @@ -1,33 +1,32 @@ -""" -Utility functions for flow visualization and dependency analysis. +"""Utility functions for flow visualization and dependency analysis. This module provides core functionality for analyzing and manipulating flow structures, including node level calculation, ancestor tracking, and return value analysis. Functions in this module are primarily used by the visualization system to create accurate and informative flow diagrams. -Example +Example: ------- >>> flow = Flow() >>> node_levels = calculate_node_levels(flow) >>> ancestors = build_ancestor_dict(flow) + """ import ast import inspect import textwrap from collections import defaultdict, deque -from typing import Any, Deque, Dict, List, Optional, Set, Union +from typing import Any -def get_possible_return_constants(function: Any) -> Optional[List[str]]: +def get_possible_return_constants(function: Any) -> list[str] | None: try: source = inspect.getsource(function) except OSError: # Can't get source code return None - except Exception as e: - print(f"Error retrieving source code for function {function.__name__}: {e}") + except Exception: return None try: @@ -35,24 +34,18 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]: source = textwrap.dedent(source) # Parse the source code into an AST code_ast = ast.parse(source) - except IndentationError as e: - print(f"IndentationError while parsing source code of {function.__name__}: {e}") - print(f"Source code:\n{source}") + except IndentationError: return None - except SyntaxError as e: - print(f"SyntaxError while parsing source code of {function.__name__}: {e}") - print(f"Source code:\n{source}") + except SyntaxError: return None - except Exception as e: - print(f"Unexpected error while parsing source code of {function.__name__}: {e}") - print(f"Source code:\n{source}") + except Exception: return None return_values = set() dict_definitions = {} class DictionaryAssignmentVisitor(ast.NodeVisitor): - def visit_Assign(self, node): + def visit_Assign(self, node) -> None: # Check if this assignment is assigning a dictionary literal to a variable if isinstance(node.value, ast.Dict) and len(node.targets) == 1: target = node.targets[0] @@ -69,10 +62,10 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]: self.generic_visit(node) class ReturnVisitor(ast.NodeVisitor): - def visit_Return(self, node): + def visit_Return(self, node) -> None: # Direct string return if isinstance(node.value, ast.Constant) and isinstance( - node.value.value, str + node.value.value, str, ): return_values.add(node.value.value) # Dictionary-based return, like return paths[result] @@ -94,9 +87,8 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]: return list(return_values) if return_values else None -def calculate_node_levels(flow: Any) -> Dict[str, int]: - """ - Calculate the hierarchical level of each node in the flow. +def calculate_node_levels(flow: Any) -> dict[str, int]: + """Calculate the hierarchical level of each node in the flow. Performs a breadth-first traversal of the flow graph to assign levels to nodes, starting with start methods at level 0. @@ -117,11 +109,12 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: - Each subsequent connected node is assigned level = parent_level + 1 - Handles both OR and AND conditions for listeners - Processes router paths separately + """ - levels: Dict[str, int] = {} - queue: Deque[str] = deque() - visited: Set[str] = set() - pending_and_listeners: Dict[str, Set[str]] = {} + levels: dict[str, int] = {} + queue: deque[str] = deque() + visited: set[str] = set() + pending_and_listeners: dict[str, set[str]] = {} # Make all start methods at level 0 for method_name, method in flow._methods.items(): @@ -172,9 +165,8 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: return levels -def count_outgoing_edges(flow: Any) -> Dict[str, int]: - """ - Count the number of outgoing edges for each method in the flow. +def count_outgoing_edges(flow: Any) -> dict[str, int]: + """Count the number of outgoing edges for each method in the flow. Parameters ---------- @@ -185,6 +177,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]: ------- Dict[str, int] Dictionary mapping method names to their outgoing edge count. + """ counts = {} for method_name in flow._methods: @@ -197,9 +190,8 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]: return counts -def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: - """ - Build a dictionary mapping each node to its ancestor nodes. +def build_ancestor_dict(flow: Any) -> dict[str, set[str]]: + """Build a dictionary mapping each node to its ancestor nodes. Parameters ---------- @@ -210,9 +202,10 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: ------- Dict[str, Set[str]] Dictionary mapping each node to a set of its ancestor nodes. + """ - ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods} - visited: Set[str] = set() + ancestors: dict[str, set[str]] = {node: set() for node in flow._methods} + visited: set[str] = set() for node in flow._methods: if node not in visited: dfs_ancestors(node, ancestors, visited, flow) @@ -220,10 +213,9 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: def dfs_ancestors( - node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any + node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any, ) -> None: - """ - Perform depth-first search to build ancestor relationships. + """Perform depth-first search to build ancestor relationships. Parameters ---------- @@ -240,6 +232,7 @@ def dfs_ancestors( ----- This function modifies the ancestors dictionary in-place to build the complete ancestor graph. + """ if node in visited: return @@ -265,10 +258,9 @@ def dfs_ancestors( def is_ancestor( - node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]] + node: str, ancestor_candidate: str, ancestors: dict[str, set[str]], ) -> bool: - """ - Check if one node is an ancestor of another. + """Check if one node is an ancestor of another. Parameters ---------- @@ -283,13 +275,13 @@ def is_ancestor( ------- bool True if ancestor_candidate is an ancestor of node, False otherwise. + """ return ancestor_candidate in ancestors.get(node, set()) -def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: - """ - Build a dictionary mapping parent nodes to their children. +def build_parent_children_dict(flow: Any) -> dict[str, list[str]]: + """Build a dictionary mapping parent nodes to their children. Parameters ---------- @@ -306,8 +298,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: - Maps listeners to their trigger methods - Maps router methods to their paths and listeners - Children lists are sorted for consistent ordering + """ - parent_children: Dict[str, List[str]] = {} + parent_children: dict[str, list[str]] = {} # Map listeners to their trigger methods for listener_name, (_, trigger_methods) in flow._listeners.items(): @@ -332,10 +325,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: def get_child_index( - parent: str, child: str, parent_children: Dict[str, List[str]] + parent: str, child: str, parent_children: dict[str, list[str]], ) -> int: - """ - Get the index of a child node in its parent's sorted children list. + """Get the index of a child node in its parent's sorted children list. Parameters ---------- @@ -350,27 +342,25 @@ def get_child_index( ------- int Zero-based index of the child in its parent's sorted children list. + """ children = parent_children.get(parent, []) children.sort() return children.index(child) -def process_router_paths(flow, current, current_level, levels, queue): - """ - Handle the router connections for the current node. - """ +def process_router_paths(flow, current, current_level, levels, queue) -> None: + """Handle the router connections for the current node.""" if current in flow._routers: paths = flow._router_paths.get(current, []) for path in paths: for listener_name, ( - condition_type, + _condition_type, trigger_methods, ) in flow._listeners.items(): - if path in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - queue.append(listener_name) + if path in trigger_methods and ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + queue.append(listener_name) diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py index 781677276..d5ce5be2d 100644 --- a/src/crewai/flow/visualization_utils.py +++ b/src/crewai/flow/visualization_utils.py @@ -1,23 +1,23 @@ -""" -Utilities for creating visual representations of flow structures. +"""Utilities for creating visual representations of flow structures. This module provides functions for generating network visualizations of flows, including node placement, edge creation, and visual styling. It handles the conversion of flow structures into visual network graphs with appropriate styling and layout. -Example +Example: ------- >>> flow = Flow() >>> net = Network(directed=True) >>> node_positions = compute_positions(flow, node_levels) >>> add_nodes_to_network(net, flow, node_positions, node_styles) >>> add_edges(net, flow, node_positions, colors) + """ import ast import inspect -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any from .utils import ( build_ancestor_dict, @@ -28,8 +28,7 @@ from .utils import ( def method_calls_crew(method: Any) -> bool: - """ - Check if the method contains a call to `.crew()`. + """Check if the method contains a call to `.crew()`. Parameters ---------- @@ -45,21 +44,22 @@ def method_calls_crew(method: Any) -> bool: ----- Uses AST analysis to detect method calls, specifically looking for attribute access of 'crew'. + """ try: source = inspect.getsource(method) source = inspect.cleandoc(source) tree = ast.parse(source) - except Exception as e: - print(f"Could not parse method {method.__name__}: {e}") + except Exception: return False class CrewCallVisitor(ast.NodeVisitor): """AST visitor to detect .crew() method calls.""" - def __init__(self): + + def __init__(self) -> None: self.found = False - def visit_Call(self, node): + def visit_Call(self, node) -> None: if isinstance(node.func, ast.Attribute): if node.func.attr == "crew": self.found = True @@ -73,11 +73,10 @@ def method_calls_crew(method: Any) -> bool: def add_nodes_to_network( net: Any, flow: Any, - node_positions: Dict[str, Tuple[float, float]], - node_styles: Dict[str, Dict[str, Any]] + node_positions: dict[str, tuple[float, float]], + node_styles: dict[str, dict[str, Any]], ) -> None: - """ - Add nodes to the network visualization with appropriate styling. + """Add nodes to the network visualization with appropriate styling. Parameters ---------- @@ -97,6 +96,7 @@ def add_nodes_to_network( - Router methods - Crew methods - Regular methods + """ def human_friendly_label(method_name): return method_name.replace("_", " ").title() @@ -123,7 +123,7 @@ def add_nodes_to_network( "multi": "html", "color": node_style.get("font", {}).get("color", "#FFFFFF"), }, - } + }, ) net.add_node( @@ -138,12 +138,11 @@ def add_nodes_to_network( def compute_positions( flow: Any, - node_levels: Dict[str, int], + node_levels: dict[str, int], y_spacing: float = 150, - x_spacing: float = 150 -) -> Dict[str, Tuple[float, float]]: - """ - Compute the (x, y) positions for each node in the flow graph. + x_spacing: float = 150, +) -> dict[str, tuple[float, float]]: + """Compute the (x, y) positions for each node in the flow graph. Parameters ---------- @@ -160,9 +159,10 @@ def compute_positions( ------- Dict[str, Tuple[float, float]] Dictionary mapping node names to their (x, y) coordinates. + """ - level_nodes: Dict[int, List[str]] = {} - node_positions: Dict[str, Tuple[float, float]] = {} + level_nodes: dict[int, list[str]] = {} + node_positions: dict[str, tuple[float, float]] = {} for method_name, level in node_levels.items(): level_nodes.setdefault(level, []).append(method_name) @@ -180,10 +180,10 @@ def compute_positions( def add_edges( net: Any, flow: Any, - node_positions: Dict[str, Tuple[float, float]], - colors: Dict[str, str] + node_positions: dict[str, tuple[float, float]], + colors: dict[str, str], ) -> None: - edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value + edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value """ Add edges to the network visualization with appropriate styling. @@ -245,7 +245,7 @@ def add_edges( "color": edge_color, "width": 2, "arrows": "to", - "dashes": True if is_router_edge or is_and_condition else False, + "dashes": bool(is_router_edge or is_and_condition), "smooth": edge_smooth, } @@ -261,9 +261,7 @@ def add_edges( # If it's a known router edge and the method is known, don't warn. # This means the path is legitimate, just not reflected as nodes here. if not (is_router_edge and method_known): - print( - f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge." - ) + pass # Edges for router return paths for router_method_name, paths in flow._router_paths.items(): @@ -278,7 +276,7 @@ def add_edges( and listener_name in node_positions ): is_cycle_edge = is_ancestor( - router_method_name, listener_name, ancestors + router_method_name, listener_name, ancestors, ) parent_has_multiple_children = ( len(parent_children.get(router_method_name, [])) > 1 @@ -293,7 +291,7 @@ def add_edges( dx = target_pos[0] - source_pos[0] smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" index = get_child_index( - router_method_name, listener_name, parent_children + router_method_name, listener_name, parent_children, ) edge_smooth = { "type": smooth_type, @@ -316,6 +314,4 @@ def add_edges( # Same check here: known router edge and known method? method_known = listener_name in flow._methods if not method_known: - print( - f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge." - ) + pass diff --git a/src/crewai/knowledge/embedder/base_embedder.py b/src/crewai/knowledge/embedder/base_embedder.py index c3252bf43..43f3e6cff 100644 --- a/src/crewai/knowledge/embedder/base_embedder.py +++ b/src/crewai/knowledge/embedder/base_embedder.py @@ -1,55 +1,48 @@ from abc import ABC, abstractmethod -from typing import List import numpy as np class BaseEmbedder(ABC): - """ - Abstract base class for text embedding models - """ + """Abstract base class for text embedding models.""" @abstractmethod - def embed_chunks(self, chunks: List[str]) -> np.ndarray: - """ - Generate embeddings for a list of text chunks + def embed_chunks(self, chunks: list[str]) -> np.ndarray: + """Generate embeddings for a list of text chunks. Args: chunks: List of text chunks to embed Returns: Array of embeddings + """ - pass @abstractmethod - def embed_texts(self, texts: List[str]) -> np.ndarray: - """ - Generate embeddings for a list of texts + def embed_texts(self, texts: list[str]) -> np.ndarray: + """Generate embeddings for a list of texts. Args: texts: List of texts to embed Returns: Array of embeddings + """ - pass @abstractmethod def embed_text(self, text: str) -> np.ndarray: - """ - Generate embedding for a single text + """Generate embedding for a single text. Args: text: Text to embed Returns: Embedding array + """ - pass @property @abstractmethod def dimension(self) -> int: - """Get the dimension of the embeddings""" - pass + """Get the dimension of the embeddings.""" diff --git a/src/crewai/knowledge/embedder/fastembed.py b/src/crewai/knowledge/embedder/fastembed.py index 54db11643..4709171d1 100644 --- a/src/crewai/knowledge/embedder/fastembed.py +++ b/src/crewai/knowledge/embedder/fastembed.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List, Optional, Union import numpy as np @@ -19,75 +18,74 @@ except ImportError: class FastEmbed(BaseEmbedder): - """ - A wrapper class for text embedding models using FastEmbed - """ + """A wrapper class for text embedding models using FastEmbed.""" def __init__( self, model_name: str = "BAAI/bge-small-en-v1.5", - cache_dir: Optional[Union[str, Path]] = None, - ): - """ - Initialize the embedding model + cache_dir: str | Path | None = None, + ) -> None: + """Initialize the embedding model. Args: model_name: Name of the model to use cache_dir: Directory to cache the model gpu: Whether to use GPU acceleration + """ if not FASTEMBED_AVAILABLE: - raise ImportError( + msg = ( "FastEmbed is not installed. Please install it with: " "uv pip install fastembed or uv pip install fastembed-gpu for GPU support" ) + raise ImportError( + msg, + ) self.model = TextEmbedding( model_name=model_name, cache_dir=str(cache_dir) if cache_dir else None, ) - def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]: - """ - Generate embeddings for a list of text chunks + def embed_chunks(self, chunks: list[str]) -> list[np.ndarray]: + """Generate embeddings for a list of text chunks. Args: chunks: List of text chunks to embed Returns: List of embeddings - """ - embeddings = list(self.model.embed(chunks)) - return embeddings - def embed_texts(self, texts: List[str]) -> List[np.ndarray]: """ - Generate embeddings for a list of texts + return list(self.model.embed(chunks)) + + def embed_texts(self, texts: list[str]) -> list[np.ndarray]: + """Generate embeddings for a list of texts. Args: texts: List of texts to embed Returns: List of embeddings + """ - embeddings = list(self.model.embed(texts)) - return embeddings + return list(self.model.embed(texts)) def embed_text(self, text: str) -> np.ndarray: - """ - Generate embedding for a single text + """Generate embedding for a single text. Args: text: Text to embed Returns: Embedding array + """ return self.embed_texts([text])[0] @property def dimension(self) -> int: - """Get the dimension of the embeddings""" + """Get the dimension of the embeddings.""" # Generate a test embedding to get dimensions test_embed = self.embed_text("test") return len(test_embed) diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index 2340dec90..de856d825 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field @@ -10,68 +10,70 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed class Knowledge(BaseModel): - """ - Knowledge is a collection of sources and setup for the vector store to save and query relevant context. + """Knowledge is a collection of sources and setup for the vector store to save and query relevant context. + Args: sources: List[BaseKnowledgeSource] = Field(default_factory=list) storage: Optional[KnowledgeStorage] = Field(default=None) - embedder: Optional[Dict[str, Any]] = None + embedder: Optional[Dict[str, Any]] = None. + """ - sources: List[BaseKnowledgeSource] = Field(default_factory=list) + sources: list[BaseKnowledgeSource] = Field(default_factory=list) model_config = ConfigDict(arbitrary_types_allowed=True) - storage: Optional[KnowledgeStorage] = Field(default=None) - embedder: Optional[Dict[str, Any]] = None - collection_name: Optional[str] = None + storage: KnowledgeStorage | None = Field(default=None) + embedder: dict[str, Any] | None = None + collection_name: str | None = None def __init__( self, collection_name: str, - sources: List[BaseKnowledgeSource], - embedder: Optional[Dict[str, Any]] = None, - storage: Optional[KnowledgeStorage] = None, + sources: list[BaseKnowledgeSource], + embedder: dict[str, Any] | None = None, + storage: KnowledgeStorage | None = None, **data, - ): + ) -> None: super().__init__(**data) if storage: self.storage = storage else: self.storage = KnowledgeStorage( - embedder=embedder, collection_name=collection_name + embedder=embedder, collection_name=collection_name, ) self.sources = sources self.storage.initialize_knowledge_storage() def query( - self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 - ) -> List[Dict[str, Any]]: - """ - Query across all knowledge sources to find the most relevant information. + self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35, + ) -> list[dict[str, Any]]: + """Query across all knowledge sources to find the most relevant information. Returns the top_k most relevant chunks. Raises: ValueError: If storage is not initialized. + """ if self.storage is None: - raise ValueError("Storage is not initialized.") + msg = "Storage is not initialized." + raise ValueError(msg) - results = self.storage.search( + return self.storage.search( query, limit=results_limit, score_threshold=score_threshold, ) - return results - def add_sources(self): + def add_sources(self) -> None: try: for source in self.sources: source.storage = self.storage source.add() - except Exception as e: - raise e + except Exception: + raise def reset(self) -> None: if self.storage: self.storage.reset() else: - raise ValueError("Storage is not initialized.") + msg = "Storage is not initialized." + raise ValueError(msg) diff --git a/src/crewai/knowledge/knowledge_config.py b/src/crewai/knowledge/knowledge_config.py index e84341f6a..5f0556c8a 100644 --- a/src/crewai/knowledge/knowledge_config.py +++ b/src/crewai/knowledge/knowledge_config.py @@ -7,6 +7,7 @@ class KnowledgeConfig(BaseModel): Args: results_limit (int): The number of relevant documents to return. score_threshold (float): The minimum score for a document to be considered relevant. + """ results_limit: int = Field(default=3, description="The number of results to return") diff --git a/src/crewai/knowledge/source/base_file_knowledge_source.py b/src/crewai/knowledge/source/base_file_knowledge_source.py index 4c4b9b337..be761fcc1 100644 --- a/src/crewai/knowledge/source/base_file_knowledge_source.py +++ b/src/crewai/knowledge/source/base_file_knowledge_source.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional, Union from pydantic import Field, field_validator @@ -14,43 +13,43 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): """Base class for knowledge sources that load content from files.""" _logger: Logger = Logger(verbose=True) - file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field( + file_path: Path | list[Path] | str | list[str] | None = Field( default=None, description="[Deprecated] The path to the file. Use file_paths instead.", ) - file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field( - default_factory=list, description="The path to the file" + file_paths: Path | list[Path] | str | list[str] | None = Field( + default_factory=list, description="The path to the file", ) - content: Dict[Path, str] = Field(init=False, default_factory=dict) - storage: Optional[KnowledgeStorage] = Field(default=None) - safe_file_paths: List[Path] = Field(default_factory=list) + content: dict[Path, str] = Field(init=False, default_factory=dict) + storage: KnowledgeStorage | None = Field(default=None) + safe_file_paths: list[Path] = Field(default_factory=list) @field_validator("file_path", "file_paths", mode="before") - def validate_file_path(cls, v, info): + def validate_file_path(self, v, info): """Validate that at least one of file_path or file_paths is provided.""" # Single check if both are None, O(1) instead of nested conditions if ( v is None and info.data.get( - "file_path" if info.field_name == "file_paths" else "file_paths" + "file_path" if info.field_name == "file_paths" else "file_paths", ) is None ): - raise ValueError("Either file_path or file_paths must be provided") + msg = "Either file_path or file_paths must be provided" + raise ValueError(msg) return v - def model_post_init(self, _): + def model_post_init(self, _) -> None: """Post-initialization method to load content.""" self.safe_file_paths = self._process_file_paths() self.validate_content() self.content = self.load_content() @abstractmethod - def load_content(self) -> Dict[Path, str]: + def load_content(self) -> dict[Path, str]: """Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory.""" - pass - def validate_content(self): + def validate_content(self) -> None: """Validate the paths.""" for path in self.safe_file_paths: if not path.exists(): @@ -59,7 +58,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.", color="red", ) - raise FileNotFoundError(f"File not found: {path}") + msg = f"File not found: {path}" + raise FileNotFoundError(msg) if not path.is_file(): self._logger.log( "error", @@ -67,20 +67,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): color="red", ) - def _save_documents(self): + def _save_documents(self) -> None: """Save the documents to the storage.""" if self.storage: self.storage.save(self.chunks) else: - raise ValueError("No storage found to save documents.") + msg = "No storage found to save documents." + raise ValueError(msg) - def convert_to_path(self, path: Union[Path, str]) -> Path: + def convert_to_path(self, path: Path | str) -> Path: """Convert a path to a Path object.""" return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path - def _process_file_paths(self) -> List[Path]: + def _process_file_paths(self) -> list[Path]: """Convert file_path to a list of Path objects.""" - if hasattr(self, "file_path") and self.file_path is not None: self._logger.log( "warning", @@ -90,10 +90,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): self.file_paths = self.file_path if self.file_paths is None: - raise ValueError("Your source must be provided with a file_paths: []") + msg = "Your source must be provided with a file_paths: []" + raise ValueError(msg) # Convert single path to list - path_list: List[Union[Path, str]] = ( + path_list: list[Path | str] = ( [self.file_paths] if isinstance(self.file_paths, (str, Path)) else list(self.file_paths) @@ -102,8 +103,9 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): ) if not path_list: + msg = "file_path/file_paths must be a Path, str, or a list of these types" raise ValueError( - "file_path/file_paths must be a Path, str, or a list of these types" + msg, ) return [self.convert_to_path(path) for path in path_list] diff --git a/src/crewai/knowledge/source/base_knowledge_source.py b/src/crewai/knowledge/source/base_knowledge_source.py index b558a4b9a..e9fc2fbc4 100644 --- a/src/crewai/knowledge/source/base_knowledge_source.py +++ b/src/crewai/knowledge/source/base_knowledge_source.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any import numpy as np from pydantic import BaseModel, ConfigDict, Field @@ -12,41 +12,39 @@ class BaseKnowledgeSource(BaseModel, ABC): chunk_size: int = 4000 chunk_overlap: int = 200 - chunks: List[str] = Field(default_factory=list) - chunk_embeddings: List[np.ndarray] = Field(default_factory=list) + chunks: list[str] = Field(default_factory=list) + chunk_embeddings: list[np.ndarray] = Field(default_factory=list) model_config = ConfigDict(arbitrary_types_allowed=True) - storage: Optional[KnowledgeStorage] = Field(default=None) - metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused - collection_name: Optional[str] = Field(default=None) + storage: KnowledgeStorage | None = Field(default=None) + metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused + collection_name: str | None = Field(default=None) @abstractmethod def validate_content(self) -> Any: """Load and preprocess content from the source.""" - pass @abstractmethod def add(self) -> None: """Process content, chunk it, compute embeddings, and save them.""" - pass - def get_embeddings(self) -> List[np.ndarray]: + def get_embeddings(self) -> list[np.ndarray]: """Return the list of embeddings for the chunks.""" return self.chunk_embeddings - def _chunk_text(self, text: str) -> List[str]: + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size - self.chunk_overlap) ] - def _save_documents(self): - """ - Save the documents to the storage. + def _save_documents(self) -> None: + """Save the documents to the storage. This method should be called after the chunks and embeddings are generated. """ if self.storage: self.storage.save(self.chunks) else: - raise ValueError("No storage found to save documents.") + msg = "No storage found to save documents." + raise ValueError(msg) diff --git a/src/crewai/knowledge/source/crew_docling_source.py b/src/crewai/knowledge/source/crew_docling_source.py index 6ca0ae967..b7dfefe13 100644 --- a/src/crewai/knowledge/source/crew_docling_source.py +++ b/src/crewai/knowledge/source/crew_docling_source.py @@ -1,5 +1,6 @@ +from collections.abc import Iterator from pathlib import Path -from typing import Iterator, List, Optional, Union +from typing import TYPE_CHECKING from urllib.parse import urlparse try: @@ -7,7 +8,6 @@ try: from docling.document_converter import DocumentConverter from docling.exceptions import ConversionError from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker - from docling_core.types.doc.document import DoclingDocument DOCLING_AVAILABLE = True except ImportError: @@ -19,27 +19,33 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.utilities.logger import Logger +if TYPE_CHECKING: + from docling_core.types.doc.document import DoclingDocument + class CrewDoclingSource(BaseKnowledgeSource): """Default Source class for converting documents to markdown or json This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if not DOCLING_AVAILABLE: - raise ImportError( + msg = ( "The docling package is required to use CrewDoclingSource. " "Please install it using: uv add docling" ) + raise ImportError( + msg, + ) super().__init__(*args, **kwargs) _logger: Logger = Logger(verbose=True) - file_path: Optional[List[Union[Path, str]]] = Field(default=None) - file_paths: List[Union[Path, str]] = Field(default_factory=list) - chunks: List[str] = Field(default_factory=list) - safe_file_paths: List[Union[Path, str]] = Field(default_factory=list) - content: List["DoclingDocument"] = Field(default_factory=list) + file_path: list[Path | str] | None = Field(default=None) + file_paths: list[Path | str] = Field(default_factory=list) + chunks: list[str] = Field(default_factory=list) + safe_file_paths: list[Path | str] = Field(default_factory=list) + content: list["DoclingDocument"] = Field(default_factory=list) document_converter: "DocumentConverter" = Field( default_factory=lambda: DocumentConverter( allowed_formats=[ @@ -51,8 +57,8 @@ class CrewDoclingSource(BaseKnowledgeSource): InputFormat.IMAGE, InputFormat.XLSX, InputFormat.PPTX, - ] - ) + ], + ), ) def model_post_init(self, _) -> None: @@ -66,7 +72,7 @@ class CrewDoclingSource(BaseKnowledgeSource): self.safe_file_paths = self.validate_content() self.content = self._load_content() - def _load_content(self) -> List["DoclingDocument"]: + def _load_content(self) -> list["DoclingDocument"]: try: return self._convert_source_to_docling_documents() except ConversionError as e: @@ -75,10 +81,10 @@ class CrewDoclingSource(BaseKnowledgeSource): f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}", "red", ) - raise e + raise except Exception as e: self._logger.log("error", f"Error loading content: {e}") - raise e + raise def add(self) -> None: if self.content is None: @@ -88,7 +94,7 @@ class CrewDoclingSource(BaseKnowledgeSource): self.chunks.extend(list(new_chunks_iterable)) self._save_documents() - def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]: + def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]: conv_results_iter = self.document_converter.convert_all(self.safe_file_paths) return [result.document for result in conv_results_iter] @@ -97,8 +103,8 @@ class CrewDoclingSource(BaseKnowledgeSource): for chunk in chunker.chunk(doc): yield chunk.text - def validate_content(self) -> List[Union[Path, str]]: - processed_paths: List[Union[Path, str]] = [] + def validate_content(self) -> list[Path | str]: + processed_paths: list[Path | str] = [] for path in self.file_paths: if isinstance(path, str): if path.startswith(("http://", "https://")): @@ -106,15 +112,18 @@ class CrewDoclingSource(BaseKnowledgeSource): if self._validate_url(path): processed_paths.append(path) else: - raise ValueError(f"Invalid URL format: {path}") + msg = f"Invalid URL format: {path}" + raise ValueError(msg) except Exception as e: - raise ValueError(f"Invalid URL: {path}. Error: {str(e)}") + msg = f"Invalid URL: {path}. Error: {e!s}" + raise ValueError(msg) else: local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path) if local_path.exists(): processed_paths.append(local_path) else: - raise FileNotFoundError(f"File not found: {local_path}") + msg = f"File not found: {local_path}" + raise FileNotFoundError(msg) else: # this is an instance of Path processed_paths.append(path) @@ -128,7 +137,7 @@ class CrewDoclingSource(BaseKnowledgeSource): result.scheme in ("http", "https"), result.netloc, len(result.netloc.split(".")) >= 2, # Ensure domain has TLD - ] + ], ) except Exception: return False diff --git a/src/crewai/knowledge/source/csv_knowledge_source.py b/src/crewai/knowledge/source/csv_knowledge_source.py index 3bb0714d9..faa91caed 100644 --- a/src/crewai/knowledge/source/csv_knowledge_source.py +++ b/src/crewai/knowledge/source/csv_knowledge_source.py @@ -1,6 +1,5 @@ import csv from pathlib import Path -from typing import Dict, List from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource @@ -8,11 +7,11 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge class CSVKnowledgeSource(BaseFileKnowledgeSource): """A knowledge source that stores and queries CSV file content using embeddings.""" - def load_content(self) -> Dict[Path, str]: + def load_content(self) -> dict[Path, str]: """Load and preprocess CSV file content.""" content_dict = {} for file_path in self.safe_file_paths: - with open(file_path, "r", encoding="utf-8") as csvfile: + with open(file_path, encoding="utf-8") as csvfile: reader = csv.reader(csvfile) content = "" for row in reader: @@ -21,8 +20,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource): return content_dict def add(self) -> None: - """ - Add CSV file content to the knowledge source, chunk it, compute embeddings, + """Add CSV file content to the knowledge source, chunk it, compute embeddings, and save the embeddings. """ content_str = ( @@ -32,7 +30,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() - def _chunk_text(self, text: str) -> List[str]: + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ text[i : i + self.chunk_size] diff --git a/src/crewai/knowledge/source/excel_knowledge_source.py b/src/crewai/knowledge/source/excel_knowledge_source.py index a73afb1df..ada90eae5 100644 --- a/src/crewai/knowledge/source/excel_knowledge_source.py +++ b/src/crewai/knowledge/source/excel_knowledge_source.py @@ -1,6 +1,4 @@ from pathlib import Path -from typing import Dict, Iterator, List, Optional, Union -from urllib.parse import urlparse from pydantic import Field, field_validator @@ -16,34 +14,34 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): _logger: Logger = Logger(verbose=True) - file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field( + file_path: Path | list[Path] | str | list[str] | None = Field( default=None, description="[Deprecated] The path to the file. Use file_paths instead.", ) - file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field( - default_factory=list, description="The path to the file" + file_paths: Path | list[Path] | str | list[str] | None = Field( + default_factory=list, description="The path to the file", ) - chunks: List[str] = Field(default_factory=list) - content: Dict[Path, Dict[str, str]] = Field(default_factory=dict) - safe_file_paths: List[Path] = Field(default_factory=list) + chunks: list[str] = Field(default_factory=list) + content: dict[Path, dict[str, str]] = Field(default_factory=dict) + safe_file_paths: list[Path] = Field(default_factory=list) @field_validator("file_path", "file_paths", mode="before") - def validate_file_path(cls, v, info): + def validate_file_path(self, v, info): """Validate that at least one of file_path or file_paths is provided.""" # Single check if both are None, O(1) instead of nested conditions if ( v is None and info.data.get( - "file_path" if info.field_name == "file_paths" else "file_paths" + "file_path" if info.field_name == "file_paths" else "file_paths", ) is None ): - raise ValueError("Either file_path or file_paths must be provided") + msg = "Either file_path or file_paths must be provided" + raise ValueError(msg) return v - def _process_file_paths(self) -> List[Path]: + def _process_file_paths(self) -> list[Path]: """Convert file_path to a list of Path objects.""" - if hasattr(self, "file_path") and self.file_path is not None: self._logger.log( "warning", @@ -53,10 +51,11 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): self.file_paths = self.file_path if self.file_paths is None: - raise ValueError("Your source must be provided with a file_paths: []") + msg = "Your source must be provided with a file_paths: []" + raise ValueError(msg) # Convert single path to list - path_list: List[Union[Path, str]] = ( + path_list: list[Path | str] = ( [self.file_paths] if isinstance(self.file_paths, (str, Path)) else list(self.file_paths) @@ -65,13 +64,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): ) if not path_list: + msg = "file_path/file_paths must be a Path, str, or a list of these types" raise ValueError( - "file_path/file_paths must be a Path, str, or a list of these types" + msg, ) return [self.convert_to_path(path) for path in path_list] - def validate_content(self): + def validate_content(self) -> None: """Validate the paths.""" for path in self.safe_file_paths: if not path.exists(): @@ -80,7 +80,8 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.", color="red", ) - raise FileNotFoundError(f"File not found: {path}") + msg = f"File not found: {path}" + raise FileNotFoundError(msg) if not path.is_file(): self._logger.log( "error", @@ -100,7 +101,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): self.validate_content() self.content = self._load_content() - def _load_content(self) -> Dict[Path, Dict[str, str]]: + def _load_content(self) -> dict[Path, dict[str, str]]: """Load and preprocess Excel file content from multiple sheets. Each sheet's content is converted to CSV format and stored. @@ -111,6 +112,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): Raises: ImportError: If required dependencies are missing. FileNotFoundError: If the specified Excel file cannot be opened. + """ pd = self._import_dependencies() content_dict = {} @@ -119,14 +121,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): with pd.ExcelFile(file_path) as xl: sheet_dict = { str(sheet_name): str( - pd.read_excel(xl, sheet_name).to_csv(index=False) + pd.read_excel(xl, sheet_name).to_csv(index=False), ) for sheet_name in xl.sheet_names } content_dict[file_path] = sheet_dict return content_dict - def convert_to_path(self, path: Union[Path, str]) -> Path: + def convert_to_path(self, path: Path | str) -> Path: """Convert a path to a Path object.""" return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path @@ -138,13 +140,13 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): return pd except ImportError as e: missing_package = str(e).split()[-1] + msg = f"{missing_package} is not installed. Please install it with: pip install {missing_package}" raise ImportError( - f"{missing_package} is not installed. Please install it with: pip install {missing_package}" + msg, ) def add(self) -> None: - """ - Add Excel file content to the knowledge source, chunk it, compute embeddings, + """Add Excel file content to the knowledge source, chunk it, compute embeddings, and save the embeddings. """ # Convert dictionary values to a single string if content is a dictionary @@ -161,7 +163,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() - def _chunk_text(self, text: str) -> List[str]: + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ text[i : i + self.chunk_size] diff --git a/src/crewai/knowledge/source/json_knowledge_source.py b/src/crewai/knowledge/source/json_knowledge_source.py index b02d438e6..cba067263 100644 --- a/src/crewai/knowledge/source/json_knowledge_source.py +++ b/src/crewai/knowledge/source/json_knowledge_source.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Any, Dict, List +from typing import Any from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource @@ -8,12 +8,12 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge class JSONKnowledgeSource(BaseFileKnowledgeSource): """A knowledge source that stores and queries JSON file content using embeddings.""" - def load_content(self) -> Dict[Path, str]: + def load_content(self) -> dict[Path, str]: """Load and preprocess JSON file content.""" - content: Dict[Path, str] = {} + content: dict[Path, str] = {} for path in self.safe_file_paths: path = self.convert_to_path(path) - with open(path, "r", encoding="utf-8") as json_file: + with open(path, encoding="utf-8") as json_file: data = json.load(json_file) content[path] = self._json_to_text(data) return content @@ -29,12 +29,11 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource): for item in data: text += f"{indent}- {self._json_to_text(item, level + 1)}\n" else: - text += f"{str(data)}" + text += f"{data!s}" return text def add(self) -> None: - """ - Add JSON file content to the knowledge source, chunk it, compute embeddings, + """Add JSON file content to the knowledge source, chunk it, compute embeddings, and save the embeddings. """ content_str = ( @@ -44,7 +43,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() - def _chunk_text(self, text: str) -> List[str]: + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ text[i : i + self.chunk_size] diff --git a/src/crewai/knowledge/source/pdf_knowledge_source.py b/src/crewai/knowledge/source/pdf_knowledge_source.py index 38cd67807..441a0403a 100644 --- a/src/crewai/knowledge/source/pdf_knowledge_source.py +++ b/src/crewai/knowledge/source/pdf_knowledge_source.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Dict, List from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource @@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge class PDFKnowledgeSource(BaseFileKnowledgeSource): """A knowledge source that stores and queries PDF file content using embeddings.""" - def load_content(self) -> Dict[Path, str]: + def load_content(self) -> dict[Path, str]: """Load and preprocess PDF file content.""" pdfplumber = self._import_pdfplumber() @@ -31,21 +30,21 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource): return pdfplumber except ImportError: + msg = "pdfplumber is not installed. Please install it with: pip install pdfplumber" raise ImportError( - "pdfplumber is not installed. Please install it with: pip install pdfplumber" + msg, ) def add(self) -> None: - """ - Add PDF file content to the knowledge source, chunk it, compute embeddings, + """Add PDF file content to the knowledge source, chunk it, compute embeddings, and save the embeddings. """ - for _, text in self.content.items(): + for text in self.content.values(): new_chunks = self._chunk_text(text) self.chunks.extend(new_chunks) self._save_documents() - def _chunk_text(self, text: str) -> List[str]: + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ text[i : i + self.chunk_size] diff --git a/src/crewai/knowledge/source/string_knowledge_source.py b/src/crewai/knowledge/source/string_knowledge_source.py index 614303b1f..1b103fb33 100644 --- a/src/crewai/knowledge/source/string_knowledge_source.py +++ b/src/crewai/knowledge/source/string_knowledge_source.py @@ -1,4 +1,3 @@ -from typing import List, Optional from pydantic import Field @@ -9,16 +8,17 @@ class StringKnowledgeSource(BaseKnowledgeSource): """A knowledge source that stores and queries plain text content using embeddings.""" content: str = Field(...) - collection_name: Optional[str] = Field(default=None) + collection_name: str | None = Field(default=None) - def model_post_init(self, _): + def model_post_init(self, _) -> None: """Post-initialization method to validate content.""" self.validate_content() - def validate_content(self): + def validate_content(self) -> None: """Validate string content.""" if not isinstance(self.content, str): - raise ValueError("StringKnowledgeSource only accepts string content") + msg = "StringKnowledgeSource only accepts string content" + raise ValueError(msg) def add(self) -> None: """Add string content to the knowledge source, chunk it, compute embeddings, and save them.""" @@ -26,7 +26,7 @@ class StringKnowledgeSource(BaseKnowledgeSource): self.chunks.extend(new_chunks) self._save_documents() - def _chunk_text(self, text: str) -> List[str]: + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ text[i : i + self.chunk_size] diff --git a/src/crewai/knowledge/source/text_file_knowledge_source.py b/src/crewai/knowledge/source/text_file_knowledge_source.py index ddb1f2516..21e9ea058 100644 --- a/src/crewai/knowledge/source/text_file_knowledge_source.py +++ b/src/crewai/knowledge/source/text_file_knowledge_source.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Dict, List from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource @@ -7,26 +6,25 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge class TextFileKnowledgeSource(BaseFileKnowledgeSource): """A knowledge source that stores and queries text file content using embeddings.""" - def load_content(self) -> Dict[Path, str]: + def load_content(self) -> dict[Path, str]: """Load and preprocess text file content.""" content = {} for path in self.safe_file_paths: path = self.convert_to_path(path) - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: content[path] = f.read() return content def add(self) -> None: - """ - Add text file content to the knowledge source, chunk it, compute embeddings, + """Add text file content to the knowledge source, chunk it, compute embeddings, and save the embeddings. """ - for _, text in self.content.items(): + for text in self.content.values(): new_chunks = self._chunk_text(text) self.chunks.extend(new_chunks) self._save_documents() - def _chunk_text(self, text: str) -> List[str]: + def _chunk_text(self, text: str) -> list[str]: """Utility method to split text into chunks.""" return [ text[i : i + self.chunk_size] diff --git a/src/crewai/knowledge/storage/base_knowledge_storage.py b/src/crewai/knowledge/storage/base_knowledge_storage.py index d4887e85b..b0e18504f 100644 --- a/src/crewai/knowledge/storage/base_knowledge_storage.py +++ b/src/crewai/knowledge/storage/base_knowledge_storage.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any class BaseKnowledgeStorage(ABC): @@ -8,22 +8,19 @@ class BaseKnowledgeStorage(ABC): @abstractmethod def search( self, - query: List[str], + query: list[str], limit: int = 3, - filter: Optional[dict] = None, + filter: dict | None = None, score_threshold: float = 0.35, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Search for documents in the knowledge base.""" - pass @abstractmethod def save( - self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]] + self, documents: list[str], metadata: dict[str, Any] | list[dict[str, Any]], ) -> None: """Save documents to the knowledge base.""" - pass @abstractmethod def reset(self) -> None: """Reset the knowledge base.""" - pass diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index d49cc9876..1fb3f6d76 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -4,12 +4,11 @@ import io import logging import os import shutil -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any import chromadb import chromadb.errors from chromadb.api import ClientAPI -from chromadb.api.types import OneOrMany from chromadb.config import Settings from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage @@ -19,6 +18,9 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.utilities.logger import Logger from crewai.utilities.paths import db_storage_path +if TYPE_CHECKING: + from chromadb.api.types import OneOrMany + @contextlib.contextmanager def suppress_logging( @@ -38,30 +40,29 @@ def suppress_logging( class KnowledgeStorage(BaseKnowledgeStorage): - """ - Extends Storage to handle embeddings for memory entries, improving + """Extends Storage to handle embeddings for memory entries, improving search efficiency. """ - collection: Optional[chromadb.Collection] = None - collection_name: Optional[str] = "knowledge" - app: Optional[ClientAPI] = None + collection: chromadb.Collection | None = None + collection_name: str | None = "knowledge" + app: ClientAPI | None = None def __init__( self, - embedder: Optional[Dict[str, Any]] = None, - collection_name: Optional[str] = None, - ): + embedder: dict[str, Any] | None = None, + collection_name: str | None = None, + ) -> None: self.collection_name = collection_name self._set_embedder_config(embedder) def search( self, - query: List[str], + query: list[str], limit: int = 3, - filter: Optional[dict] = None, + filter: dict | None = None, score_threshold: float = 0.35, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: with suppress_logging(): if self.collection: fetched = self.collection.query( @@ -80,10 +81,10 @@ class KnowledgeStorage(BaseKnowledgeStorage): if result["score"] >= score_threshold: results.append(result) return results - else: - raise Exception("Collection not initialized") + msg = "Collection not initialized" + raise Exception(msg) - def initialize_knowledge_storage(self): + def initialize_knowledge_storage(self) -> None: base_path = os.path.join(db_storage_path(), "knowledge") chroma_client = chromadb.PersistentClient( path=base_path, @@ -104,11 +105,13 @@ class KnowledgeStorage(BaseKnowledgeStorage): embedding_function=self.embedder, ) else: - raise Exception("Vector Database Client not initialized") + msg = "Vector Database Client not initialized" + raise Exception(msg) except Exception: - raise Exception("Failed to create or get collection") + msg = "Failed to create or get collection" + raise Exception(msg) - def reset(self): + def reset(self) -> None: base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY) if not self.app: self.app = chromadb.PersistentClient( @@ -123,11 +126,12 @@ class KnowledgeStorage(BaseKnowledgeStorage): def save( self, - documents: List[str], - metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, - ): + documents: list[str], + metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + ) -> None: if not self.collection: - raise Exception("Collection not initialized") + msg = "Collection not initialized" + raise Exception(msg) try: # Create a dictionary to store unique documents @@ -156,7 +160,7 @@ class KnowledgeStorage(BaseKnowledgeStorage): filtered_ids.append(doc_id) # If we have no metadata at all, set it to None - final_metadata: Optional[OneOrMany[chromadb.Metadata]] = ( + final_metadata: OneOrMany[chromadb.Metadata] | None = ( None if all(m is None for m in filtered_metadata) else filtered_metadata ) @@ -171,10 +175,13 @@ class KnowledgeStorage(BaseKnowledgeStorage): "Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`", "red", ) - raise ValueError( + msg = ( "Embedding dimension mismatch. Make sure you're using the same embedding model " "across all operations with this collection." "Try resetting the collection using `crewai reset-memories -a`" + ) + raise ValueError( + msg, ) from e except Exception as e: Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red") @@ -186,15 +193,16 @@ class KnowledgeStorage(BaseKnowledgeStorage): ) return OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" + api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small", ) - def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None: + def _set_embedder_config(self, embedder: dict[str, Any] | None = None) -> None: """Set the embedding configuration for the knowledge storage. Args: embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. If None or empty, defaults to the default embedding function. + """ self.embedder = ( EmbeddingConfigurator().configure_embedder(embedder) diff --git a/src/crewai/knowledge/utils/knowledge_utils.py b/src/crewai/knowledge/utils/knowledge_utils.py index bdd8b9a4e..b8b76b8b2 100644 --- a/src/crewai/knowledge/utils/knowledge_utils.py +++ b/src/crewai/knowledge/utils/knowledge_utils.py @@ -1,7 +1,7 @@ -from typing import Any, Dict, List +from typing import Any -def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str: +def extract_knowledge_context(knowledge_snippets: list[dict[str, Any]]) -> str: """Extract knowledge from the task prompt.""" valid_snippets = [ result["context"] diff --git a/src/crewai/lite_agent.py b/src/crewai/lite_agent.py index 4cb46c1f0..6ad330ad5 100644 --- a/src/crewai/lite_agent.py +++ b/src/crewai/lite_agent.py @@ -1,7 +1,7 @@ import asyncio import uuid -from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Type, Union, cast +from collections.abc import Callable +from typing import Any, cast from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator @@ -35,7 +35,7 @@ from crewai.utilities.agent_utils import ( render_text_description_and_args, show_agent_logs, ) -from crewai.utilities.converter import convert_to_model, generate_model_description +from crewai.utilities.converter import generate_model_description from crewai.utilities.events.agent_events import ( LiteAgentExecutionCompletedEvent, LiteAgentExecutionErrorEvent, @@ -60,15 +60,15 @@ class LiteAgentOutput(BaseModel): model_config = {"arbitrary_types_allowed": True} raw: str = Field(description="Raw output of the agent", default="") - pydantic: Optional[BaseModel] = Field( - description="Pydantic output of the agent", default=None + pydantic: BaseModel | None = Field( + description="Pydantic output of the agent", default=None, ) agent_role: str = Field(description="Role of the agent that produced this output") - usage_metrics: Optional[Dict[str, Any]] = Field( - description="Token usage metrics for this execution", default=None + usage_metrics: dict[str, Any] | None = Field( + description="Token usage metrics for this execution", default=None, ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert pydantic_output to a dictionary.""" if self.pydantic: return self.pydantic.model_dump() @@ -82,8 +82,7 @@ class LiteAgentOutput(BaseModel): class LiteAgent(FlowTrackable, BaseModel): - """ - A lightweight agent that can process messages and use tools. + """A lightweight agent that can process messages and use tools. This agent is simpler than the full Agent class, focusing on direct execution rather than task delegation. It's designed to be used for simple interactions @@ -99,6 +98,7 @@ class LiteAgent(FlowTrackable, BaseModel): max_iterations: Maximum number of iterations for tool usage. max_execution_time: Maximum execution time in seconds. response_format: Optional Pydantic model for structured output. + """ model_config = {"arbitrary_types_allowed": True} @@ -107,19 +107,19 @@ class LiteAgent(FlowTrackable, BaseModel): role: str = Field(description="Role of the agent") goal: str = Field(description="Goal of the agent") backstory: str = Field(description="Backstory of the agent") - llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( - default=None, description="Language model that will run the agent" + llm: str | InstanceOf[LLM] | Any | None = Field( + default=None, description="Language model that will run the agent", ) - tools: List[BaseTool] = Field( - default_factory=list, description="Tools at agent's disposal" + tools: list[BaseTool] = Field( + default_factory=list, description="Tools at agent's disposal", ) # Execution Control Properties max_iterations: int = Field( - default=15, description="Maximum number of iterations for tool usage" + default=15, description="Maximum number of iterations for tool usage", ) - max_execution_time: Optional[int] = Field( - default=None, description="Maximum execution time in seconds" + max_execution_time: int | None = Field( + default=None, description="Maximum execution time in seconds", ) respect_context_window: bool = Field( default=True, @@ -129,38 +129,38 @@ class LiteAgent(FlowTrackable, BaseModel): default=True, description="Whether to use stop words to prevent the LLM from using tools", ) - request_within_rpm_limit: Optional[Callable[[], bool]] = Field( + request_within_rpm_limit: Callable[[], bool] | None = Field( default=None, description="Callback to check if the request is within the RPM limit", ) i18n: I18N = Field(default=I18N(), description="Internationalization settings.") # Output and Formatting Properties - response_format: Optional[Type[BaseModel]] = Field( - default=None, description="Pydantic model for structured output" + response_format: type[BaseModel] | None = Field( + default=None, description="Pydantic model for structured output", ) verbose: bool = Field( - default=False, description="Whether to print execution details" + default=False, description="Whether to print execution details", ) - callbacks: List[Callable] = Field( - default=[], description="Callbacks to be used for the agent" + callbacks: list[Callable] = Field( + default=[], description="Callbacks to be used for the agent", ) # State and Results - tools_results: List[Dict[str, Any]] = Field( - default=[], description="Results of the tools used by the agent." + tools_results: list[dict[str, Any]] = Field( + default=[], description="Results of the tools used by the agent.", ) # Reference of Agent - original_agent: Optional[BaseAgent] = Field( - default=None, description="Reference to the agent that created this LiteAgent" + original_agent: BaseAgent | None = Field( + default=None, description="Reference to the agent that created this LiteAgent", ) # Private Attributes - _parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list) + _parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list) _token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess) _cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler) _key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4())) - _messages: List[Dict[str, str]] = PrivateAttr(default_factory=list) + _messages: list[dict[str, str]] = PrivateAttr(default_factory=list) _iterations: int = PrivateAttr(default=0) _printer: Printer = PrivateAttr(default_factory=Printer) @@ -169,7 +169,8 @@ class LiteAgent(FlowTrackable, BaseModel): """Set up the LLM and other components after initialization.""" self.llm = create_llm(self.llm) if not isinstance(self.llm, LLM): - raise ValueError("Unable to create LLM instance") + msg = "Unable to create LLM instance" + raise ValueError(msg) # Initialize callbacks token_callback = TokenCalcHandler(token_cost_process=self._token_process) @@ -194,9 +195,8 @@ class LiteAgent(FlowTrackable, BaseModel): """Return the original role for compatibility with tool interfaces.""" return self.role - def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput: - """ - Execute the agent with the given messages. + def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput: + """Execute the agent with the given messages. Args: messages: Either a string query or a list of message dictionaries. @@ -205,6 +205,7 @@ class LiteAgent(FlowTrackable, BaseModel): Returns: LiteAgentOutput: The result of the agent execution. + """ # Create agent info for event emission agent_info = { @@ -235,18 +236,18 @@ class LiteAgent(FlowTrackable, BaseModel): # Execute the agent using invoke loop agent_finish = self._invoke_loop() - formatted_result: Optional[BaseModel] = None + formatted_result: BaseModel | None = None if self.response_format: try: # Cast to BaseModel to ensure type safety result = self.response_format.model_validate_json( - agent_finish.output + agent_finish.output, ) if isinstance(result, BaseModel): formatted_result = result except Exception as e: self._printer.print( - content=f"Failed to parse output into response format: {str(e)}", + content=f"Failed to parse output into response format: {e!s}", color="yellow", ) @@ -286,13 +287,12 @@ class LiteAgent(FlowTrackable, BaseModel): error=str(e), ), ) - raise e + raise async def kickoff_async( - self, messages: Union[str, List[Dict[str, str]]] + self, messages: str | list[dict[str, str]], ) -> LiteAgentOutput: - """ - Execute the agent asynchronously with the given messages. + """Execute the agent asynchronously with the given messages. Args: messages: Either a string query or a list of message dictionaries. @@ -301,6 +301,7 @@ class LiteAgent(FlowTrackable, BaseModel): Returns: LiteAgentOutput: The result of the agent execution. + """ return await asyncio.to_thread(self.kickoff, messages) @@ -319,7 +320,7 @@ class LiteAgent(FlowTrackable, BaseModel): else: # Use the prompt template for agents without tools base_prompt = self.i18n.slice( - "lite_agent_system_prompt_without_tools" + "lite_agent_system_prompt_without_tools", ).format( role=self.role, backstory=self.backstory, @@ -330,14 +331,14 @@ class LiteAgent(FlowTrackable, BaseModel): if self.response_format: schema = generate_model_description(self.response_format) base_prompt += self.i18n.slice("lite_agent_response_format").format( - response_format=schema + response_format=schema, ) return base_prompt def _format_messages( - self, messages: Union[str, List[Dict[str, str]]] - ) -> List[Dict[str, str]]: + self, messages: str | list[dict[str, str]], + ) -> list[dict[str, str]]: """Format messages for the LLM.""" if isinstance(messages, str): messages = [{"role": "user", "content": messages}] @@ -353,11 +354,11 @@ class LiteAgent(FlowTrackable, BaseModel): return formatted_messages def _invoke_loop(self) -> AgentFinish: - """ - Run the agent's thought process until it reaches a conclusion or max iterations. + """Run the agent's thought process until it reaches a conclusion or max iterations. Returns: AgentFinish: The final result of the agent execution. + """ # Execute the agent loop formatted_answer = None @@ -369,7 +370,7 @@ class LiteAgent(FlowTrackable, BaseModel): printer=self._printer, i18n=self.i18n, messages=self._messages, - llm=cast(LLM, self.llm), + llm=cast("LLM", self.llm), callbacks=self._callbacks, ) @@ -387,7 +388,7 @@ class LiteAgent(FlowTrackable, BaseModel): try: answer = get_llm_response( - llm=cast(LLM, self.llm), + llm=cast("LLM", self.llm), messages=self._messages, callbacks=self._callbacks, printer=self._printer, @@ -407,7 +408,7 @@ class LiteAgent(FlowTrackable, BaseModel): self, event=LLMCallFailedEvent(error=str(e)), ) - raise e + raise formatted_answer = process_llm_response(answer, self.use_stop_words) @@ -421,8 +422,8 @@ class LiteAgent(FlowTrackable, BaseModel): agent_role=self.role, agent=self.original_agent, ) - except Exception as e: - raise e + except Exception: + raise formatted_answer = handle_agent_action_core( formatted_answer=formatted_answer, @@ -443,20 +444,19 @@ class LiteAgent(FlowTrackable, BaseModel): except Exception as e: if e.__class__.__module__.startswith("litellm"): # Do not retry on litellm errors - raise e + raise if is_context_length_exceeded(e): handle_context_length( respect_context_window=self.respect_context_window, printer=self._printer, messages=self._messages, - llm=cast(LLM, self.llm), + llm=cast("LLM", self.llm), callbacks=self._callbacks, i18n=self.i18n, ) continue - else: - handle_unknown_error(self._printer, e) - raise e + handle_unknown_error(self._printer, e) + raise finally: self._iterations += 1 @@ -465,7 +465,7 @@ class LiteAgent(FlowTrackable, BaseModel): self._show_logs(formatted_answer) return formatted_answer - def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]): + def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None: """Show logs for the agent's execution.""" show_agent_logs( printer=self._printer, diff --git a/src/crewai/llm.py b/src/crewai/llm.py index c8c456297..e09be9a96 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -6,17 +6,10 @@ import threading import warnings from collections import defaultdict from contextlib import contextmanager -from types import SimpleNamespace from typing import ( Any, - DefaultDict, - Dict, - List, Literal, - Optional, - Type, TypedDict, - Union, cast, ) @@ -31,7 +24,6 @@ from crewai.utilities.events.llm_events import ( LLMCallType, LLMStreamChunkEvent, ) -from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) @@ -55,7 +47,7 @@ load_dotenv() class FilteredStream: - def __init__(self, original_stream): + def __init__(self, original_stream) -> None: self._original_stream = original_stream self._lock = threading.Lock() @@ -210,7 +202,7 @@ def suppress_warnings(): with warnings.catch_warnings(): warnings.filterwarnings("ignore") warnings.filterwarnings( - "ignore", message="open_text is deprecated*", category=DeprecationWarning + "ignore", message="open_text is deprecated*", category=DeprecationWarning, ) # Redirect stdout and stderr @@ -226,14 +218,14 @@ def suppress_warnings(): class Delta(TypedDict): - content: Optional[str] - role: Optional[str] + content: str | None + role: str | None class StreamingChoices(TypedDict): delta: Delta index: int - finish_reason: Optional[str] + finish_reason: str | None class FunctionArgs(BaseModel): @@ -249,29 +241,31 @@ class LLM(BaseLLM): def __init__( self, model: str, - timeout: Optional[Union[float, int]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - n: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[int, float]] = None, - response_format: Optional[Type[BaseModel]] = None, - seed: Optional[int] = None, - logprobs: Optional[int] = None, - top_logprobs: Optional[int] = None, - base_url: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - callbacks: List[Any] = [], - reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, + timeout: float | None = None, + temperature: float | None = None, + top_p: float | None = None, + n: int | None = None, + stop: str | list[str] | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + presence_penalty: float | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[int, float] | None = None, + response_format: type[BaseModel] | None = None, + seed: int | None = None, + logprobs: int | None = None, + top_logprobs: int | None = None, + base_url: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + api_key: str | None = None, + callbacks: list[Any] | None = None, + reasoning_effort: Literal["none", "low", "medium", "high"] | None = None, stream: bool = False, **kwargs, - ): + ) -> None: + if callbacks is None: + callbacks = [] self.model = model self.timeout = timeout self.temperature = temperature @@ -301,7 +295,7 @@ class LLM(BaseLLM): # Normalize self.stop to always be a List[str] if stop is None: - self.stop: List[str] = [] + self.stop: list[str] = [] elif isinstance(stop, str): self.stop = [stop] else: @@ -318,15 +312,16 @@ class LLM(BaseLLM): Returns: bool: True if the model is from Anthropic, False otherwise. + """ ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/") return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) def _prepare_completion_params( self, - messages: Union[str, List[Dict[str, str]]], - tools: Optional[List[dict]] = None, - ) -> Dict[str, Any]: + messages: str | list[dict[str, str]], + tools: list[dict] | None = None, + ) -> dict[str, Any]: """Prepare parameters for the completion call. Args: @@ -337,6 +332,7 @@ class LLM(BaseLLM): Returns: Dict[str, Any]: Parameters for the completion call + """ # --- 1) Format messages according to provider requirements if isinstance(messages, str): @@ -375,9 +371,9 @@ class LLM(BaseLLM): def _handle_streaming_response( self, - params: Dict[str, Any], - callbacks: Optional[List[Any]] = None, - available_functions: Optional[Dict[str, Any]] = None, + params: dict[str, Any], + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, ) -> str: """Handle a streaming response from the LLM. @@ -391,6 +387,7 @@ class LLM(BaseLLM): Raises: Exception: If no content is received from the streaming response + """ # --- 1) Initialize response tracking full_response = "" @@ -399,8 +396,8 @@ class LLM(BaseLLM): usage_info = None tool_calls = None - accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict( - AccumulatedToolArgs + accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict( + AccumulatedToolArgs, ) # --- 2) Make sure stream is set to True and include usage metrics @@ -424,16 +421,16 @@ class LLM(BaseLLM): choices = chunk["choices"] elif hasattr(chunk, "choices"): # Check if choices is not a type but an actual attribute with value - if not isinstance(getattr(chunk, "choices"), type): - choices = getattr(chunk, "choices") + if not isinstance(chunk.choices, type): + choices = chunk.choices # Try to extract usage information if available if isinstance(chunk, dict) and "usage" in chunk: usage_info = chunk["usage"] elif hasattr(chunk, "usage"): # Check if usage is not a type but an actual attribute with value - if not isinstance(getattr(chunk, "usage"), type): - usage_info = getattr(chunk, "usage") + if not isinstance(chunk.usage, type): + usage_info = chunk.usage if choices and len(choices) > 0: choice = choices[0] @@ -443,7 +440,7 @@ class LLM(BaseLLM): if isinstance(choice, dict) and "delta" in choice: delta = choice["delta"] elif hasattr(choice, "delta"): - delta = getattr(choice, "delta") + delta = choice.delta # Extract content from delta if delta: @@ -453,7 +450,7 @@ class LLM(BaseLLM): chunk_content = delta["content"] # Handle object format elif hasattr(delta, "content"): - chunk_content = getattr(delta, "content") + chunk_content = delta.content # Handle case where content might be None or empty if chunk_content is None and isinstance(delta, dict): @@ -491,21 +488,21 @@ class LLM(BaseLLM): # --- 4) Fallback to non-streaming if no content received if not full_response.strip() and chunk_count == 0: logging.warning( - "No chunks received in streaming response, falling back to non-streaming" + "No chunks received in streaming response, falling back to non-streaming", ) non_streaming_params = params.copy() non_streaming_params["stream"] = False non_streaming_params.pop( - "stream_options", None + "stream_options", None, ) # Remove stream_options for non-streaming call return self._handle_non_streaming_response( - non_streaming_params, callbacks, available_functions + non_streaming_params, callbacks, available_functions, ) # --- 5) Handle empty response with chunks if not full_response.strip() and chunk_count > 0: logging.warning( - f"Received {chunk_count} chunks but no content was extracted" + f"Received {chunk_count} chunks but no content was extracted", ) if last_chunk is not None: try: @@ -514,8 +511,8 @@ class LLM(BaseLLM): if isinstance(last_chunk, dict) and "choices" in last_chunk: choices = last_chunk["choices"] elif hasattr(last_chunk, "choices"): - if not isinstance(getattr(last_chunk, "choices"), type): - choices = getattr(last_chunk, "choices") + if not isinstance(last_chunk.choices, type): + choices = last_chunk.choices if choices and len(choices) > 0: choice = choices[0] @@ -525,30 +522,31 @@ class LLM(BaseLLM): if isinstance(choice, dict) and "message" in choice: message = choice["message"] elif hasattr(choice, "message"): - message = getattr(choice, "message") + message = choice.message if message: content = None if isinstance(message, dict) and "content" in message: content = message["content"] elif hasattr(message, "content"): - content = getattr(message, "content") + content = message.content if content: full_response = content logging.info( - f"Extracted content from last chunk message: {full_response}" + f"Extracted content from last chunk message: {full_response}", ) except Exception as e: logging.debug(f"Error extracting content from last chunk: {e}") logging.debug( - f"Last chunk format: {type(last_chunk)}, content: {last_chunk}" + f"Last chunk format: {type(last_chunk)}, content: {last_chunk}", ) # --- 6) If still empty, raise an error instead of using a default response if not full_response.strip() and len(accumulated_tool_args) == 0: + msg = "No content received from streaming response. Received empty chunks or failed to extract content." raise Exception( - "No content received from streaming response. Received empty chunks or failed to extract content." + msg, ) # --- 7) Check for tool calls in the final response @@ -559,8 +557,8 @@ class LLM(BaseLLM): if isinstance(last_chunk, dict) and "choices" in last_chunk: choices = last_chunk["choices"] elif hasattr(last_chunk, "choices"): - if not isinstance(getattr(last_chunk, "choices"), type): - choices = getattr(last_chunk, "choices") + if not isinstance(last_chunk.choices, type): + choices = last_chunk.choices if choices and len(choices) > 0: choice = choices[0] @@ -569,13 +567,13 @@ class LLM(BaseLLM): if isinstance(choice, dict) and "message" in choice: message = choice["message"] elif hasattr(choice, "message"): - message = getattr(choice, "message") + message = choice.message if message: if isinstance(message, dict) and "tool_calls" in message: tool_calls = message["tool_calls"] elif hasattr(message, "tool_calls"): - tool_calls = getattr(message, "tool_calls") + tool_calls = message.tool_calls except Exception as e: logging.debug(f"Error checking for tool calls: {e}") # --- 8) If no tool calls or no available functions, return the text response directly @@ -605,9 +603,9 @@ class LLM(BaseLLM): # decide whether to summarize the content or abort based on the respect_context_window flag. raise LLMContextLengthExceededException(str(e)) except Exception as e: - logging.error(f"Error in streaming response: {str(e)}") + logging.exception(f"Error in streaming response: {e!s}") if full_response.strip(): - logging.warning(f"Returning partial response despite error: {str(e)}") + logging.warning(f"Returning partial response despite error: {e!s}") self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL) return full_response @@ -617,13 +615,14 @@ class LLM(BaseLLM): self, event=LLMCallFailedEvent(error=str(e)), ) - raise Exception(f"Failed to get streaming response: {str(e)}") + msg = f"Failed to get streaming response: {e!s}" + raise Exception(msg) def _handle_streaming_tool_calls( self, - tool_calls: List[ChatCompletionDeltaToolCall], - accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs], - available_functions: Optional[Dict[str, Any]] = None, + tool_calls: list[ChatCompletionDeltaToolCall], + accumulated_tool_args: defaultdict[int, AccumulatedToolArgs], + available_functions: dict[str, Any] | None = None, ) -> None | str: for tool_call in tool_calls: current_tool_accumulator = accumulated_tool_args[tool_call.index] @@ -662,9 +661,9 @@ class LLM(BaseLLM): def _handle_streaming_callbacks( self, - callbacks: Optional[List[Any]], - usage_info: Optional[Dict[str, Any]], - last_chunk: Optional[Any], + callbacks: list[Any] | None, + usage_info: dict[str, Any] | None, + last_chunk: Any | None, ) -> None: """Handle callbacks with usage info for streaming responses. @@ -672,6 +671,7 @@ class LLM(BaseLLM): callbacks: Optional list of callback functions usage_info: Usage information collected during streaming last_chunk: The last chunk received from the streaming response + """ if callbacks and len(callbacks) > 0: for callback in callbacks: @@ -688,9 +688,9 @@ class LLM(BaseLLM): usage_info = last_chunk["usage"] elif hasattr(last_chunk, "usage"): if not isinstance( - getattr(last_chunk, "usage"), type + last_chunk.usage, type, ): - usage_info = getattr(last_chunk, "usage") + usage_info = last_chunk.usage except Exception as e: logging.debug(f"Error extracting usage info: {e}") @@ -704,9 +704,9 @@ class LLM(BaseLLM): def _handle_non_streaming_response( self, - params: Dict[str, Any], - callbacks: Optional[List[Any]] = None, - available_functions: Optional[Dict[str, Any]] = None, + params: dict[str, Any], + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, ) -> str: """Handle a non-streaming response from the LLM. @@ -717,6 +717,7 @@ class LLM(BaseLLM): Returns: str: The response text + """ # --- 1) Make the completion call try: @@ -731,7 +732,7 @@ class LLM(BaseLLM): raise LLMContextLengthExceededException(str(e)) # --- 2) Extract response message and content - response_message = cast(Choices, cast(ModelResponse, response).choices)[ + response_message = cast("Choices", cast("ModelResponse", response).choices)[ 0 ].message text_response = response_message.content or "" @@ -768,9 +769,9 @@ class LLM(BaseLLM): def _handle_tool_call( self, - tool_calls: List[Any], - available_functions: Optional[Dict[str, Any]] = None, - ) -> Optional[str]: + tool_calls: list[Any], + available_functions: dict[str, Any] | None = None, + ) -> str | None: """Handle a tool call from the LLM. Args: @@ -779,6 +780,7 @@ class LLM(BaseLLM): Returns: Optional[str]: The result of the tool call, or None if no tool call was made + """ # --- 1) Validate tool calls and available functions if not tool_calls or not available_functions: @@ -805,23 +807,23 @@ class LLM(BaseLLM): except Exception as e: # --- 3.4) Handle execution errors fn = available_functions.get( - function_name, lambda: None + function_name, lambda: None, ) # Ensure fn is always a callable - logging.error(f"Error executing function '{function_name}': {e}") + logging.exception(f"Error executing function '{function_name}': {e}") assert hasattr(crewai_event_bus, "emit") crewai_event_bus.emit( self, - event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"), + event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"), ) return None def call( self, - messages: Union[str, List[Dict[str, str]]], - tools: Optional[List[dict]] = None, - callbacks: Optional[List[Any]] = None, - available_functions: Optional[Dict[str, Any]] = None, - ) -> Union[str, Any]: + messages: str | list[dict[str, str]], + tools: list[dict] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + ) -> str | Any: """High-level LLM call method. Args: @@ -844,6 +846,7 @@ class LLM(BaseLLM): TypeError: If messages format is invalid ValueError: If response format is not supported LLMContextLengthExceededException: If input exceeds model's context limit + """ # --- 1) Emit call started event assert hasattr(crewai_event_bus, "emit") @@ -882,12 +885,11 @@ class LLM(BaseLLM): # --- 7) Make the completion call and handle response if self.stream: return self._handle_streaming_response( - params, callbacks, available_functions - ) - else: - return self._handle_non_streaming_response( - params, callbacks, available_functions + params, callbacks, available_functions, ) + return self._handle_non_streaming_response( + params, callbacks, available_functions, + ) except LLMContextLengthExceededException: # Re-raise LLMContextLengthExceededException as it should be handled @@ -900,15 +902,16 @@ class LLM(BaseLLM): self, event=LLMCallFailedEvent(error=str(e)), ) - logging.error(f"LiteLLM call failed: {str(e)}") + logging.exception(f"LiteLLM call failed: {e!s}") raise - def _handle_emit_call_events(self, response: Any, call_type: LLMCallType): + def _handle_emit_call_events(self, response: Any, call_type: LLMCallType) -> None: """Handle the events for the LLM call. Args: response (str): The response from the LLM call. call_type (str): The type of call, either "tool_call" or "llm_call". + """ assert hasattr(crewai_event_bus, "emit") crewai_event_bus.emit( @@ -917,8 +920,8 @@ class LLM(BaseLLM): ) def _format_messages_for_provider( - self, messages: List[Dict[str, str]] - ) -> List[Dict[str, str]]: + self, messages: list[dict[str, str]], + ) -> list[dict[str, str]]: """Format messages according to provider requirements. Args: @@ -931,15 +934,18 @@ class LLM(BaseLLM): Raises: TypeError: If messages is None or contains invalid message format. + """ if messages is None: - raise TypeError("Messages cannot be None") + msg = "Messages cannot be None" + raise TypeError(msg) # Validate message format first for msg in messages: if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: + msg = "Invalid message format. Each message must be a dict with 'role' and 'content' keys" raise TypeError( - "Invalid message format. Each message must be a dict with 'role' and 'content' keys" + msg, ) # Handle O1 models specially @@ -949,7 +955,7 @@ class LLM(BaseLLM): # Convert system messages to assistant messages if msg["role"] == "system": formatted_messages.append( - {"role": "assistant", "content": msg["content"]} + {"role": "assistant", "content": msg["content"]}, ) else: formatted_messages.append(msg) @@ -977,9 +983,8 @@ class LLM(BaseLLM): return messages - def _get_custom_llm_provider(self) -> Optional[str]: - """ - Derives the custom_llm_provider from the model string. + def _get_custom_llm_provider(self) -> str | None: + """Derives the custom_llm_provider from the model string. - For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter". - If the model is "gemini/gemini-1.5-pro", returns "gemini". - If there is no '/', defaults to "openai". @@ -989,8 +994,7 @@ class LLM(BaseLLM): return None def _validate_call_params(self) -> None: - """ - Validate parameters before making a call. Currently this only checks if + """Validate parameters before making a call. Currently this only checks if a response_format is provided and whether the model supports it. The custom_llm_provider is dynamically determined from the model: - E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter" @@ -1002,19 +1006,22 @@ class LLM(BaseLLM): model=self.model, custom_llm_provider=provider, ): - raise ValueError( + msg = ( f"The model {self.model} does not support response_format for provider '{provider}'. " "Please remove response_format or use a supported model." ) + raise ValueError( + msg, + ) def supports_function_calling(self) -> bool: try: provider = self._get_custom_llm_provider() return litellm.utils.supports_function_calling( - self.model, custom_llm_provider=provider + self.model, custom_llm_provider=provider, ) except Exception as e: - logging.error(f"Failed to check function calling support: {str(e)}") + logging.exception(f"Failed to check function calling support: {e!s}") return False def supports_stop_words(self) -> bool: @@ -1022,16 +1029,16 @@ class LLM(BaseLLM): params = get_supported_openai_params(model=self.model) return params is not None and "stop" in params except Exception as e: - logging.error(f"Failed to get supported params: {str(e)}") + logging.exception(f"Failed to get supported params: {e!s}") return False def get_context_window_size(self) -> int: - """ - Returns the context window size, using 75% of the maximum to avoid + """Returns the context window size, using 75% of the maximum to avoid cutting off messages mid-thread. Raises: ValueError: If a model's context window size is outside valid bounds (1024-2097152) + """ if self.context_window_size != 0: return self.context_window_size @@ -1042,21 +1049,21 @@ class LLM(BaseLLM): # Validate all context window sizes for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): if value < MIN_CONTEXT or value > MAX_CONTEXT: + msg = f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}" raise ValueError( - f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}" + msg, ) self.context_window_size = int( - DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO + DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO, ) for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): if self.model.startswith(key): self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) return self.context_window_size - def set_callbacks(self, callbacks: List[Any]): - """ - Attempt to keep a single set of callbacks in litellm by removing old + def set_callbacks(self, callbacks: list[Any]) -> None: + """Attempt to keep a single set of callbacks in litellm by removing old duplicates and adding new ones. """ with suppress_warnings(): @@ -1071,9 +1078,8 @@ class LLM(BaseLLM): litellm.callbacks = callbacks - def set_env_callbacks(self): - """ - Sets the success and failure callbacks for the LiteLLM library from environment variables. + def set_env_callbacks(self) -> None: + """Sets the success and failure callbacks for the LiteLLM library from environment variables. This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS` environment variables, which should contain comma-separated lists of callback names. @@ -1089,6 +1095,7 @@ class LLM(BaseLLM): This will set `litellm.success_callback` to ["langfuse", "langsmith"] and `litellm.failure_callback` to ["langfuse"]. + """ with suppress_warnings(): success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "") diff --git a/src/crewai/llms/base_llm.py b/src/crewai/llms/base_llm.py index c51e8847d..3e93f48aa 100644 --- a/src/crewai/llms/base_llm.py +++ b/src/crewai/llms/base_llm.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any class BaseLLM(ABC): @@ -17,17 +17,18 @@ class BaseLLM(ABC): Attributes: stop (list): A list of stop sequences that the LLM should use to stop generation. This is used by the CrewAgentExecutor and other components. + """ model: str - temperature: Optional[float] = None - stop: Optional[List[str]] = None + temperature: float | None = None + stop: list[str] | None = None def __init__( self, model: str, - temperature: Optional[float] = None, - ): + temperature: float | None = None, + ) -> None: """Initialize the BaseLLM with default attributes. This constructor sets default values for attributes that are expected @@ -43,11 +44,11 @@ class BaseLLM(ABC): @abstractmethod def call( self, - messages: Union[str, List[Dict[str, str]]], - tools: Optional[List[dict]] = None, - callbacks: Optional[List[Any]] = None, - available_functions: Optional[Dict[str, Any]] = None, - ) -> Union[str, Any]: + messages: str | list[dict[str, str]], + tools: list[dict] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + ) -> str | Any: """Call the LLM with the given messages. Args: @@ -70,14 +71,15 @@ class BaseLLM(ABC): ValueError: If the messages format is invalid. TimeoutError: If the LLM request times out. RuntimeError: If the LLM request fails for other reasons. + """ - pass def supports_stop_words(self) -> bool: """Check if the LLM supports stop words. Returns: bool: True if the LLM supports stop words, False otherwise. + """ return True # Default implementation assumes support for stop words @@ -86,6 +88,7 @@ class BaseLLM(ABC): Returns: int: The number of tokens/characters the model can handle. + """ # Default implementation - subclasses should override with model-specific values return 4096 diff --git a/src/crewai/llms/third_party/ai_suite.py b/src/crewai/llms/third_party/ai_suite.py index 78185a081..9ecc8ba29 100644 --- a/src/crewai/llms/third_party/ai_suite.py +++ b/src/crewai/llms/third_party/ai_suite.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any import aisuite as ai @@ -6,17 +6,17 @@ from crewai.llms.base_llm import BaseLLM class AISuiteLLM(BaseLLM): - def __init__(self, model: str, temperature: Optional[float] = None, **kwargs): + def __init__(self, model: str, temperature: float | None = None, **kwargs) -> None: super().__init__(model, temperature, **kwargs) self.client = ai.Client() def call( self, - messages: Union[str, List[Dict[str, str]]], - tools: Optional[List[dict]] = None, - callbacks: Optional[List[Any]] = None, - available_functions: Optional[Dict[str, Any]] = None, - ) -> Union[str, Any]: + messages: str | list[dict[str, str]], + tools: list[dict] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + ) -> str | Any: completion_params = self._prepare_completion_params(messages, tools) response = self.client.chat.completions.create(**completion_params) @@ -24,9 +24,9 @@ class AISuiteLLM(BaseLLM): def _prepare_completion_params( self, - messages: Union[str, List[Dict[str, str]]], - tools: Optional[List[dict]] = None, - ) -> Dict[str, Any]: + messages: str | list[dict[str, str]], + tools: list[dict] | None = None, + ) -> dict[str, Any]: return { "model": self.model, "messages": messages, diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index c88614800..2e41a3817 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any from crewai.memory import ( EntityMemory, @@ -12,13 +12,13 @@ from crewai.memory import ( class ContextualMemory: def __init__( self, - memory_config: Optional[Dict[str, Any]], + memory_config: dict[str, Any] | None, stm: ShortTermMemory, ltm: LongTermMemory, em: EntityMemory, um: UserMemory, exm: ExternalMemory, - ): + ) -> None: if memory_config is not None: self.memory_provider = memory_config.get("provider") else: @@ -30,8 +30,7 @@ class ContextualMemory: self.exm = exm def build_context_for_task(self, task, context) -> str: - """ - Automatically builds a minimal, highly relevant set of contextual information + """Automatically builds a minimal, highly relevant set of contextual information for a given task. """ query = f"{task.description} {context}".strip() @@ -49,11 +48,9 @@ class ContextualMemory: return "\n".join(filter(None, context)) def _fetch_stm_context(self, query) -> str: - """ - Fetches recent relevant insights from STM related to the task's description and expected_output, + """Fetches recent relevant insights from STM related to the task's description and expected_output, formatted as bullet points. """ - if self.stm is None: return "" @@ -62,16 +59,14 @@ class ContextualMemory: [ f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}" for result in stm_results - ] + ], ) return f"Recent Insights:\n{formatted_results}" if stm_results else "" - def _fetch_ltm_context(self, task) -> Optional[str]: - """ - Fetches historical data or insights from LTM that are relevant to the task's description and expected_output, + def _fetch_ltm_context(self, task) -> str | None: + """Fetches historical data or insights from LTM that are relevant to the task's description and expected_output, formatted as bullet points. """ - if self.ltm is None: return "" @@ -90,8 +85,7 @@ class ContextualMemory: return f"Historical Data:\n{formatted_results}" if ltm_results else "" def _fetch_entity_context(self, query) -> str: - """ - Fetches relevant entity information from Entity Memory related to the task's description and expected_output, + """Fetches relevant entity information from Entity Memory related to the task's description and expected_output, formatted as bullet points. """ if self.em is None: @@ -102,19 +96,20 @@ class ContextualMemory: [ f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}" for result in em_results - ] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" + ], # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" ) return f"Entities:\n{formatted_results}" if em_results else "" def _fetch_user_context(self, query: str) -> str: - """ - Fetches and formats relevant user information from User Memory. + """Fetches and formats relevant user information from User Memory. + Args: query (str): The search query to find relevant user memories. + Returns: str: Formatted user memories as bullet points, or an empty string if none found. - """ + """ if self.um is None: return "" @@ -128,12 +123,14 @@ class ContextualMemory: return f"User memories/preferences:\n{formatted_memories}" def _fetch_external_context(self, query: str) -> str: - """ - Fetches and formats relevant information from External Memory. + """Fetches and formats relevant information from External Memory. + Args: query (str): The search query to find relevant information. + Returns: str: Formatted information as bullet points, or an empty string if none found. + """ if self.exm is None: return "" diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 264b64103..b9db97c61 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import PrivateAttr @@ -8,15 +7,14 @@ from crewai.memory.storage.rag_storage import RAGStorage class EntityMemory(Memory): - """ - EntityMemory class for managing structured information about entities + """EntityMemory class for managing structured information about entities and their relationships using SQLite storage. Inherits from the Memory class. """ - _memory_provider: Optional[str] = PrivateAttr() + _memory_provider: str | None = PrivateAttr() - def __init__(self, crew=None, embedder_config=None, storage=None, path=None): + def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None: if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: memory_provider = crew.memory_config.get("provider") else: @@ -26,8 +24,9 @@ class EntityMemory(Memory): try: from crewai.memory.storage.mem0_storage import Mem0Storage except ImportError: + msg = "Mem0 is not installed. Please install it with `pip install mem0ai`." raise ImportError( - "Mem0 is not installed. Please install it with `pip install mem0ai`." + msg, ) storage = Mem0Storage(type="entities", crew=crew) else: @@ -63,4 +62,5 @@ class EntityMemory(Memory): try: self.storage.reset() except Exception as e: - raise Exception(f"An error occurred while resetting the entity memory: {e}") + msg = f"An error occurred while resetting the entity memory: {e}" + raise Exception(msg) diff --git a/src/crewai/memory/entity/entity_memory_item.py b/src/crewai/memory/entity/entity_memory_item.py index 7e1ef1c0e..7c24b027b 100644 --- a/src/crewai/memory/entity/entity_memory_item.py +++ b/src/crewai/memory/entity/entity_memory_item.py @@ -5,7 +5,7 @@ class EntityMemoryItem: type: str, description: str, relationships: str, - ): + ) -> None: self.name = name self.type = type self.description = description diff --git a/src/crewai/memory/external/external_memory.py b/src/crewai/memory/external/external_memory.py index be35f513b..e0f99e353 100644 --- a/src/crewai/memory/external/external_memory.py +++ b/src/crewai/memory/external/external_memory.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any from crewai.memory.external.external_memory_item import ExternalMemoryItem from crewai.memory.memory import Memory @@ -9,41 +9,44 @@ if TYPE_CHECKING: class ExternalMemory(Memory): - def __init__(self, storage: Optional[Storage] = None, **data: Any): + def __init__(self, storage: Storage | None = None, **data: Any) -> None: super().__init__(storage=storage, **data) @staticmethod - def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage": + def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage": from crewai.memory.storage.mem0_storage import Mem0Storage return Mem0Storage(type="external", crew=crew, config=config) @staticmethod - def external_supported_storages() -> Dict[str, Any]: + def external_supported_storages() -> dict[str, Any]: return { "mem0": ExternalMemory._configure_mem0, } @staticmethod - def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage: + def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage: if not embedder_config: - raise ValueError("embedder_config is required") + msg = "embedder_config is required" + raise ValueError(msg) if "provider" not in embedder_config: - raise ValueError("embedder_config must include a 'provider' key") + msg = "embedder_config must include a 'provider' key" + raise ValueError(msg) provider = embedder_config["provider"] supported_storages = ExternalMemory.external_supported_storages() if provider not in supported_storages: - raise ValueError(f"Provider {provider} not supported") + msg = f"Provider {provider} not supported" + raise ValueError(msg) return supported_storages[provider](crew, embedder_config.get("config", {})) def save( self, value: Any, - metadata: Optional[Dict[str, Any]] = None, - agent: Optional[str] = None, + metadata: dict[str, Any] | None = None, + agent: str | None = None, ) -> None: """Saves a value into the external storage.""" item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent) diff --git a/src/crewai/memory/external/external_memory_item.py b/src/crewai/memory/external/external_memory_item.py index c97cccd59..ca12d3385 100644 --- a/src/crewai/memory/external/external_memory_item.py +++ b/src/crewai/memory/external/external_memory_item.py @@ -1,13 +1,13 @@ -from typing import Any, Dict, Optional +from typing import Any class ExternalMemoryItem: def __init__( self, value: Any, - metadata: Optional[Dict[str, Any]] = None, - agent: Optional[str] = None, - ): + metadata: dict[str, Any] | None = None, + agent: str | None = None, + ) -> None: self.value = value self.metadata = metadata self.agent = agent diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index 94aac3a97..3c52ddb8a 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem from crewai.memory.memory import Memory @@ -6,15 +6,14 @@ from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage class LongTermMemory(Memory): - """ - LongTermMemory class for managing cross runs data related to overall crew's + """LongTermMemory class for managing cross runs data related to overall crew's execution and performance. Inherits from the Memory class and utilizes an instance of a class that adheres to the Storage for data storage, specifically working with LongTermMemoryItem instances. """ - def __init__(self, storage=None, path=None): + def __init__(self, storage=None, path=None) -> None: if not storage: storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() super().__init__(storage=storage) @@ -29,7 +28,7 @@ class LongTermMemory(Memory): datetime=item.datetime, ) - def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" + def search(self, task: str, latest_n: int = 3) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load" def reset(self) -> None: diff --git a/src/crewai/memory/long_term/long_term_memory_item.py b/src/crewai/memory/long_term/long_term_memory_item.py index b2164f242..5c8ec7cc0 100644 --- a/src/crewai/memory/long_term/long_term_memory_item.py +++ b/src/crewai/memory/long_term/long_term_memory_item.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any class LongTermMemoryItem: @@ -8,9 +8,9 @@ class LongTermMemoryItem: task: str, expected_output: str, datetime: str, - quality: Optional[Union[int, float]] = None, - metadata: Optional[Dict[str, Any]] = None, - ): + quality: float | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: self.task = task self.agent = agent self.quality = quality diff --git a/src/crewai/memory/memory.py b/src/crewai/memory/memory.py index 20538a186..1aa89df41 100644 --- a/src/crewai/memory/memory.py +++ b/src/crewai/memory/memory.py @@ -1,26 +1,24 @@ -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import BaseModel class Memory(BaseModel): - """ - Base class for memory, now supporting agent tags and generic metadata. - """ + """Base class for memory, now supporting agent tags and generic metadata.""" - embedder_config: Optional[Dict[str, Any]] = None - crew: Optional[Any] = None + embedder_config: dict[str, Any] | None = None + crew: Any | None = None storage: Any - def __init__(self, storage: Any, **data: Any): + def __init__(self, storage: Any, **data: Any) -> None: super().__init__(storage=storage, **data) def save( self, value: Any, - metadata: Optional[Dict[str, Any]] = None, - agent: Optional[str] = None, + metadata: dict[str, Any] | None = None, + agent: str | None = None, ) -> None: metadata = metadata or {} if agent: @@ -33,9 +31,9 @@ class Memory(BaseModel): query: str, limit: int = 3, score_threshold: float = 0.35, - ) -> List[Any]: + ) -> list[Any]: return self.storage.search( - query=query, limit=limit, score_threshold=score_threshold + query=query, limit=limit, score_threshold=score_threshold, ) def set_crew(self, crew: Any) -> "Memory": diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index b7581f400..6833b5c60 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any from pydantic import PrivateAttr @@ -8,17 +8,16 @@ from crewai.memory.storage.rag_storage import RAGStorage class ShortTermMemory(Memory): - """ - ShortTermMemory class for managing transient data related to immediate tasks + """ShortTermMemory class for managing transient data related to immediate tasks and interactions. Inherits from the Memory class and utilizes an instance of a class that adheres to the Storage for data storage, specifically working with MemoryItem instances. """ - _memory_provider: Optional[str] = PrivateAttr() + _memory_provider: str | None = PrivateAttr() - def __init__(self, crew=None, embedder_config=None, storage=None, path=None): + def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None: if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: memory_provider = crew.memory_config.get("provider") else: @@ -28,8 +27,9 @@ class ShortTermMemory(Memory): try: from crewai.memory.storage.mem0_storage import Mem0Storage except ImportError: + msg = "Mem0 is not installed. Please install it with `pip install mem0ai`." raise ImportError( - "Mem0 is not installed. Please install it with `pip install mem0ai`." + msg, ) storage = Mem0Storage(type="short_term", crew=crew) else: @@ -49,8 +49,8 @@ class ShortTermMemory(Memory): def save( self, value: Any, - metadata: Optional[Dict[str, Any]] = None, - agent: Optional[str] = None, + metadata: dict[str, Any] | None = None, + agent: str | None = None, ) -> None: item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent) if self._memory_provider == "mem0": @@ -65,13 +65,14 @@ class ShortTermMemory(Memory): score_threshold: float = 0.35, ): return self.storage.search( - query=query, limit=limit, score_threshold=score_threshold + query=query, limit=limit, score_threshold=score_threshold, ) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters def reset(self) -> None: try: self.storage.reset() except Exception as e: + msg = f"An error occurred while resetting the short-term memory: {e}" raise Exception( - f"An error occurred while resetting the short-term memory: {e}" + msg, ) diff --git a/src/crewai/memory/short_term/short_term_memory_item.py b/src/crewai/memory/short_term/short_term_memory_item.py index 83b7f842f..428ef2e1b 100644 --- a/src/crewai/memory/short_term/short_term_memory_item.py +++ b/src/crewai/memory/short_term/short_term_memory_item.py @@ -1,13 +1,13 @@ -from typing import Any, Dict, Optional +from typing import Any class ShortTermMemoryItem: def __init__( self, data: Any, - agent: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ): + agent: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: self.data = data self.agent = agent self.metadata = metadata if metadata is not None else {} diff --git a/src/crewai/memory/storage/base_rag_storage.py b/src/crewai/memory/storage/base_rag_storage.py index 4ab9acb99..97a4af500 100644 --- a/src/crewai/memory/storage/base_rag_storage.py +++ b/src/crewai/memory/storage/base_rag_storage.py @@ -1,11 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any class BaseRAGStorage(ABC): - """ - Base class for RAG-based Storage implementations. - """ + """Base class for RAG-based Storage implementations.""" app: Any | None = None @@ -13,9 +11,9 @@ class BaseRAGStorage(ABC): self, type: str, allow_reset: bool = True, - embedder_config: Optional[Dict[str, Any]] = None, + embedder_config: dict[str, Any] | None = None, crew: Any = None, - ): + ) -> None: self.type = type self.allow_reset = allow_reset self.embedder_config = embedder_config @@ -25,52 +23,44 @@ class BaseRAGStorage(ABC): def _initialize_agents(self) -> str: if self.crew: return "_".join( - [self._sanitize_role(agent.role) for agent in self.crew.agents] + [self._sanitize_role(agent.role) for agent in self.crew.agents], ) return "" @abstractmethod def _sanitize_role(self, role: str) -> str: """Sanitizes agent roles to ensure valid directory names.""" - pass @abstractmethod - def save(self, value: Any, metadata: Dict[str, Any]) -> None: + def save(self, value: Any, metadata: dict[str, Any]) -> None: """Save a value with metadata to the storage.""" - pass @abstractmethod def search( self, query: str, limit: int = 3, - filter: Optional[dict] = None, + filter: dict | None = None, score_threshold: float = 0.35, - ) -> List[Any]: + ) -> list[Any]: """Search for entries in the storage.""" - pass @abstractmethod def reset(self) -> None: """Reset the storage.""" - pass @abstractmethod def _generate_embedding( - self, text: str, metadata: Optional[Dict[str, Any]] = None + self, text: str, metadata: dict[str, Any] | None = None, ) -> Any: """Generate an embedding for the given text and metadata.""" - pass @abstractmethod def _initialize_app(self): """Initialize the vector db.""" - pass - def setup_config(self, config: Dict[str, Any]): + def setup_config(self, config: dict[str, Any]) -> None: """Setup the config of the storage.""" - pass - def initialize_client(self): - """Initialize the client of the storage. This should setup the app and the db collection""" - pass + def initialize_client(self) -> None: + """Initialize the client of the storage. This should setup the app and the db collection.""" diff --git a/src/crewai/memory/storage/interface.py b/src/crewai/memory/storage/interface.py index 8bec9a14f..9ca647f8c 100644 --- a/src/crewai/memory/storage/interface.py +++ b/src/crewai/memory/storage/interface.py @@ -1,15 +1,15 @@ -from typing import Any, Dict, List +from typing import Any class Storage: - """Abstract base class defining the storage interface""" + """Abstract base class defining the storage interface.""" - def save(self, value: Any, metadata: Dict[str, Any]) -> None: + def save(self, value: Any, metadata: dict[str, Any]) -> None: pass def search( - self, query: str, limit: int, score_threshold: float - ) -> Dict[str, Any] | List[Any]: + self, query: str, limit: int, score_threshold: float, + ) -> dict[str, Any] | list[Any]: return {} def reset(self) -> None: diff --git a/src/crewai/memory/storage/kickoff_task_outputs_storage.py b/src/crewai/memory/storage/kickoff_task_outputs_storage.py index 2a035833d..5a360c297 100644 --- a/src/crewai/memory/storage/kickoff_task_outputs_storage.py +++ b/src/crewai/memory/storage/kickoff_task_outputs_storage.py @@ -2,7 +2,7 @@ import json import logging import sqlite3 from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any from crewai.task import Task from crewai.utilities import Printer @@ -14,12 +14,10 @@ logger = logging.getLogger(__name__) class KickoffTaskOutputsSQLiteStorage: - """ - An updated SQLite storage class for kickoff task outputs storage. - """ + """An updated SQLite storage class for kickoff task outputs storage.""" def __init__( - self, db_path: Optional[str] = None + self, db_path: str | None = None, ) -> None: if db_path is None: # Get the parent directory of the default db path and create our db file there @@ -37,6 +35,7 @@ class KickoffTaskOutputsSQLiteStorage: Raises: DatabaseOperationError: If database initialization fails due to SQLite errors. + """ try: with sqlite3.connect(self.db_path) as conn: @@ -52,22 +51,22 @@ class KickoffTaskOutputsSQLiteStorage: was_replayed BOOLEAN, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) - """ + """, ) conn.commit() except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e) - logger.error(error_msg) + logger.exception(error_msg) raise DatabaseOperationError(error_msg, e) def add( self, task: Task, - output: Dict[str, Any], + output: dict[str, Any], task_index: int, was_replayed: bool = False, - inputs: Dict[str, Any] = {}, + inputs: dict[str, Any] | None = None, ) -> None: """Add a new task output record to the database. @@ -80,7 +79,10 @@ class KickoffTaskOutputsSQLiteStorage: Raises: DatabaseOperationError: If saving the task output fails due to SQLite errors. + """ + if inputs is None: + inputs = {} try: with sqlite3.connect(self.db_path) as conn: conn.execute("BEGIN TRANSACTION") @@ -103,7 +105,7 @@ class KickoffTaskOutputsSQLiteStorage: conn.commit() except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e) - logger.error(error_msg) + logger.exception(error_msg) raise DatabaseOperationError(error_msg, e) def update( @@ -123,6 +125,7 @@ class KickoffTaskOutputsSQLiteStorage: Raises: DatabaseOperationError: If updating the task output fails due to SQLite errors. + """ try: with sqlite3.connect(self.db_path) as conn: @@ -136,7 +139,7 @@ class KickoffTaskOutputsSQLiteStorage: values.append( json.dumps(value, cls=CrewJSONEncoder) if isinstance(value, dict) - else value + else value, ) query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec @@ -149,10 +152,10 @@ class KickoffTaskOutputsSQLiteStorage: logger.warning(f"No row found with task_index {task_index}. No update performed.") except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e) - logger.error(error_msg) + logger.exception(error_msg) raise DatabaseOperationError(error_msg, e) - def load(self) -> List[Dict[str, Any]]: + def load(self) -> list[dict[str, Any]]: """Load all task output records from the database. Returns: @@ -162,6 +165,7 @@ class KickoffTaskOutputsSQLiteStorage: Raises: DatabaseOperationError: If loading task outputs fails due to SQLite errors. + """ try: with sqlite3.connect(self.db_path) as conn: @@ -190,7 +194,7 @@ class KickoffTaskOutputsSQLiteStorage: except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e) - logger.error(error_msg) + logger.exception(error_msg) raise DatabaseOperationError(error_msg, e) def delete_all(self) -> None: @@ -201,6 +205,7 @@ class KickoffTaskOutputsSQLiteStorage: Raises: DatabaseOperationError: If deleting task outputs fails due to SQLite errors. + """ try: with sqlite3.connect(self.db_path) as conn: @@ -210,5 +215,5 @@ class KickoffTaskOutputsSQLiteStorage: conn.commit() except sqlite3.Error as e: error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e) - logger.error(error_msg) + logger.exception(error_msg) raise DatabaseOperationError(error_msg, e) diff --git a/src/crewai/memory/storage/ltm_sqlite_storage.py b/src/crewai/memory/storage/ltm_sqlite_storage.py index 35f54e0e7..7b43f9c46 100644 --- a/src/crewai/memory/storage/ltm_sqlite_storage.py +++ b/src/crewai/memory/storage/ltm_sqlite_storage.py @@ -1,19 +1,17 @@ import json import sqlite3 from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any from crewai.utilities import Printer from crewai.utilities.paths import db_storage_path class LTMSQLiteStorage: - """ - An updated SQLite storage class for LTM data storage. - """ + """An updated SQLite storage class for LTM data storage.""" def __init__( - self, db_path: Optional[str] = None + self, db_path: str | None = None, ) -> None: if db_path is None: # Get the parent directory of the default db path and create our db file there @@ -24,10 +22,8 @@ class LTMSQLiteStorage: Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) self._initialize_db() - def _initialize_db(self): - """ - Initializes the SQLite database and creates LTM table - """ + def _initialize_db(self) -> None: + """Initializes the SQLite database and creates LTM table.""" try: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() @@ -40,7 +36,7 @@ class LTMSQLiteStorage: datetime TEXT, score REAL ) - """ + """, ) conn.commit() @@ -53,9 +49,9 @@ class LTMSQLiteStorage: def save( self, task_description: str, - metadata: Dict[str, Any], + metadata: dict[str, Any], datetime: str, - score: Union[int, float], + score: float, ) -> None: """Saves data to the LTM table with error handling.""" try: @@ -76,8 +72,8 @@ class LTMSQLiteStorage: ) def load( - self, task_description: str, latest_n: int - ) -> Optional[List[Dict[str, Any]]]: + self, task_description: str, latest_n: int, + ) -> list[dict[str, Any]] | None: """Queries the LTM table by task description with error handling.""" try: with sqlite3.connect(self.db_path) as conn: @@ -125,4 +121,3 @@ class LTMSQLiteStorage: content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}", color="red", ) - return None diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index 5d601ac1f..51d68dac6 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List +from typing import Any from mem0 import Memory, MemoryClient @@ -7,17 +7,15 @@ from crewai.memory.storage.interface import Storage class Mem0Storage(Storage): - """ - Extends Storage to handle embedding and searching across entities using Mem0. - """ + """Extends Storage to handle embedding and searching across entities using Mem0.""" - def __init__(self, type, crew=None, config=None): + def __init__(self, type, crew=None, config=None) -> None: super().__init__() supported_types = ["user", "short_term", "long_term", "entities", "external"] if type not in supported_types: raise ValueError( f"Invalid type '{type}' for Mem0Storage. Must be one of: " - + ", ".join(supported_types) + + ", ".join(supported_types), ) self.memory_type = type @@ -29,7 +27,8 @@ class Mem0Storage(Storage): # User ID is required for user memory type "user" since it's used as a unique identifier for the user. user_id = self._get_user_id() if type == "user" and not user_id: - raise ValueError("User ID is required for user memory type") + msg = "User ID is required for user memory type" + raise ValueError(msg) # API key in memory config overrides the environment variable config = self._get_config() @@ -42,23 +41,20 @@ class Mem0Storage(Storage): if mem0_api_key: if mem0_org_id and mem0_project_id: self.memory = MemoryClient( - api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id + api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id, ) else: self.memory = MemoryClient(api_key=mem0_api_key) + elif mem0_local_config and len(mem0_local_config): + self.memory = Memory.from_config(mem0_local_config) else: - if mem0_local_config and len(mem0_local_config): - self.memory = Memory.from_config(mem0_local_config) - else: - self.memory = Memory() + self.memory = Memory() def _sanitize_role(self, role: str) -> str: - """ - Sanitizes agent roles to ensure valid directory names. - """ + """Sanitizes agent roles to ensure valid directory names.""" return role.replace("\n", "").replace(" ", "_").replace("/", "_") - def save(self, value: Any, metadata: Dict[str, Any]) -> None: + def save(self, value: Any, metadata: dict[str, Any]) -> None: user_id = self._get_user_id() agent_name = self._get_agent_name() params = None @@ -97,7 +93,7 @@ class Mem0Storage(Storage): query: str, limit: int = 3, score_threshold: float = 0.35, - ) -> List[Any]: + ) -> list[Any]: params = {"query": query, "limit": limit, "output_format": "v1.1"} if user_id := self._get_user_id(): params["user_id"] = user_id @@ -120,7 +116,7 @@ class Mem0Storage(Storage): # automatically when the crew is created. if isinstance(self.memory, Memory): del params["metadata"], params["output_format"] - + results = self.memory.search(**params) return [r for r in results["results"] if r["score"] >= score_threshold] @@ -133,12 +129,11 @@ class Mem0Storage(Storage): agents = self.crew.agents agents = [self._sanitize_role(agent.role) for agent in agents] - agents = "_".join(agents) - return agents + return "_".join(agents) - def _get_config(self) -> Dict[str, Any]: + def _get_config(self) -> dict[str, Any]: return self.config or getattr(self, "memory_config", {}).get("config", {}) or {} - def reset(self): + def reset(self) -> None: if self.memory: self.memory.reset() diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index fd4c77838..94613c160 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -4,7 +4,7 @@ import logging import os import shutil import uuid -from typing import Any, Dict, List, Optional +from typing import Any from chromadb.api import ClientAPI @@ -32,16 +32,15 @@ def suppress_logging( class RAGStorage(BaseRAGStorage): - """ - Extends Storage to handle embeddings for memory entries, improving + """Extends Storage to handle embeddings for memory entries, improving search efficiency. """ app: ClientAPI | None = None def __init__( - self, type, allow_reset=True, embedder_config=None, crew=None, path=None - ): + self, type, allow_reset=True, embedder_config=None, crew=None, path=None, + ) -> None: super().__init__(type, allow_reset, embedder_config, crew) agents = crew.agents if crew else [] agents = [self._sanitize_role(agent.role) for agent in agents] @@ -55,11 +54,11 @@ class RAGStorage(BaseRAGStorage): self.path = path self._initialize_app() - def _set_embedder_config(self): + def _set_embedder_config(self) -> None: configurator = EmbeddingConfigurator() self.embedder_config = configurator.configure_embedder(self.embedder_config) - def _initialize_app(self): + def _initialize_app(self) -> None: import chromadb from chromadb.config import Settings @@ -73,48 +72,44 @@ class RAGStorage(BaseRAGStorage): try: self.collection = self.app.get_collection( - name=self.type, embedding_function=self.embedder_config + name=self.type, embedding_function=self.embedder_config, ) except Exception: self.collection = self.app.create_collection( - name=self.type, embedding_function=self.embedder_config + name=self.type, embedding_function=self.embedder_config, ) def _sanitize_role(self, role: str) -> str: - """ - Sanitizes agent roles to ensure valid directory names. - """ + """Sanitizes agent roles to ensure valid directory names.""" return role.replace("\n", "").replace(" ", "_").replace("/", "_") def _build_storage_file_name(self, type: str, file_name: str) -> str: - """ - Ensures file name does not exceed max allowed by OS - """ + """Ensures file name does not exceed max allowed by OS.""" base_path = f"{db_storage_path()}/{type}" if len(file_name) > MAX_FILE_NAME_LENGTH: logging.warning( - f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters." + f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters.", ) file_name = file_name[:MAX_FILE_NAME_LENGTH] return f"{base_path}/{file_name}" - def save(self, value: Any, metadata: Dict[str, Any]) -> None: + def save(self, value: Any, metadata: dict[str, Any]) -> None: if not hasattr(self, "app") or not hasattr(self, "collection"): self._initialize_app() try: self._generate_embedding(value, metadata) except Exception as e: - logging.error(f"Error during {self.type} save: {str(e)}") + logging.exception(f"Error during {self.type} save: {e!s}") def search( self, query: str, limit: int = 3, - filter: Optional[dict] = None, + filter: dict | None = None, score_threshold: float = 0.35, - ) -> List[Any]: + ) -> list[Any]: if not hasattr(self, "app"): self._initialize_app() @@ -135,10 +130,10 @@ class RAGStorage(BaseRAGStorage): return results except Exception as e: - logging.error(f"Error during {self.type} search: {str(e)}") + logging.exception(f"Error during {self.type} search: {e!s}") return [] - def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore + def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore if not hasattr(self, "app") or not hasattr(self, "collection"): self._initialize_app() @@ -160,8 +155,9 @@ class RAGStorage(BaseRAGStorage): # Ignore this specific error pass else: + msg = f"An error occurred while resetting the {self.type} memory: {e}" raise Exception( - f"An error occurred while resetting the {self.type} memory: {e}" + msg, ) def _create_default_embedding_function(self): @@ -170,5 +166,5 @@ class RAGStorage(BaseRAGStorage): ) return OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" + api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small", ) diff --git a/src/crewai/memory/user/user_memory.py b/src/crewai/memory/user/user_memory.py index 1baebee1d..81971990f 100644 --- a/src/crewai/memory/user/user_memory.py +++ b/src/crewai/memory/user/user_memory.py @@ -1,18 +1,17 @@ import warnings -from typing import Any, Dict, Optional +from typing import Any from crewai.memory.memory import Memory class UserMemory(Memory): - """ - UserMemory class for handling user memory storage and retrieval. + """UserMemory class for handling user memory storage and retrieval. Inherits from the Memory class and utilizes an instance of a class that adheres to the Storage for data storage, specifically working with MemoryItem instances. """ - def __init__(self, crew=None): + def __init__(self, crew=None) -> None: warnings.warn( "UserMemory is deprecated and will be removed in a future version. " "Please use ExternalMemory instead.", @@ -22,8 +21,9 @@ class UserMemory(Memory): try: from crewai.memory.storage.mem0_storage import Mem0Storage except ImportError: + msg = "Mem0 is not installed. Please install it with `pip install mem0ai`." raise ImportError( - "Mem0 is not installed. Please install it with `pip install mem0ai`." + msg, ) storage = Mem0Storage(type="user", crew=crew) super().__init__(storage) @@ -31,8 +31,8 @@ class UserMemory(Memory): def save( self, value, - metadata: Optional[Dict[str, Any]] = None, - agent: Optional[str] = None, + metadata: dict[str, Any] | None = None, + agent: str | None = None, ) -> None: # TODO: Change this function since we want to take care of the case where we save memories for the usr data = f"Remember the details about the user: {value}" @@ -44,15 +44,15 @@ class UserMemory(Memory): limit: int = 3, score_threshold: float = 0.35, ): - results = self.storage.search( + return self.storage.search( query=query, limit=limit, score_threshold=score_threshold, ) - return results def reset(self) -> None: try: self.storage.reset() except Exception as e: - raise Exception(f"An error occurred while resetting the user memory: {e}") + msg = f"An error occurred while resetting the user memory: {e}" + raise Exception(msg) diff --git a/src/crewai/memory/user/user_memory_item.py b/src/crewai/memory/user/user_memory_item.py index 288c1544a..56f797b63 100644 --- a/src/crewai/memory/user/user_memory_item.py +++ b/src/crewai/memory/user/user_memory_item.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, Optional +from typing import Any class UserMemoryItem: - def __init__(self, data: Any, user: str, metadata: Optional[Dict[str, Any]] = None): + def __init__(self, data: Any, user: str, metadata: dict[str, Any] | None = None) -> None: self.data = data self.user = user self.metadata = metadata if metadata is not None else {} diff --git a/src/crewai/process.py b/src/crewai/process.py index 2311c0e45..8acd97d22 100644 --- a/src/crewai/process.py +++ b/src/crewai/process.py @@ -2,9 +2,7 @@ from enum import Enum class Process(str, Enum): - """ - Class representing the different processes that can be used to tackle tasks - """ + """Class representing the different processes that can be used to tackle tasks.""" sequential = "sequential" hierarchical = "hierarchical" diff --git a/src/crewai/project/annotations.py b/src/crewai/project/annotations.py index d7c636ccf..d14f57b37 100644 --- a/src/crewai/project/annotations.py +++ b/src/crewai/project/annotations.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from functools import wraps -from typing import Callable from crewai import Crew from crewai.project.utils import memoize @@ -36,15 +36,13 @@ def task(func): def agent(func): """Marks a method as a crew agent.""" func.is_agent = True - func = memoize(func) - return func + return memoize(func) def llm(func): """Marks a method as an LLM provider.""" func.is_llm = True - func = memoize(func) - return func + return memoize(func) def output_json(cls): @@ -91,7 +89,7 @@ def crew(func) -> Callable[..., Crew]: agents = self._original_agents.items() # Instantiate tasks in order - for task_name, task_method in tasks: + for _task_name, task_method in tasks: task_instance = task_method(self) instantiated_tasks.append(task_instance) agent_instance = getattr(task_instance, "agent", None) @@ -100,7 +98,7 @@ def crew(func) -> Callable[..., Crew]: agent_roles.add(agent_instance.role) # Instantiate agents not included by tasks - for agent_name, agent_method in agents: + for _agent_name, agent_method in agents: agent_instance = agent_method(self) if agent_instance.role not in agent_roles: instantiated_agents.append(agent_instance) @@ -117,9 +115,9 @@ def crew(func) -> Callable[..., Crew]: return wrapper - for _, callback in self._before_kickoff.items(): + for callback in self._before_kickoff.values(): crew.before_kickoff_callbacks.append(callback_wrapper(callback, self)) - for _, callback in self._after_kickoff.items(): + for callback in self._after_kickoff.values(): crew.after_kickoff_callbacks.append(callback_wrapper(callback, self)) return crew diff --git a/src/crewai/project/crew_base.py b/src/crewai/project/crew_base.py index e90a0d30e..8f6da34ad 100644 --- a/src/crewai/project/crew_base.py +++ b/src/crewai/project/crew_base.py @@ -1,7 +1,8 @@ import inspect import logging +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Dict, TypeVar, cast +from typing import Any, TypeVar, cast import yaml from dotenv import load_dotenv @@ -23,11 +24,11 @@ def CrewBase(cls: T) -> T: base_directory = Path(inspect.getfile(cls)).parent original_agents_config_path = getattr( - cls, "agents_config", "config/agents.yaml" + cls, "agents_config", "config/agents.yaml", ) original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml") - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.load_configurations() self.map_all_agent_variables() @@ -49,22 +50,22 @@ def CrewBase(cls: T) -> T: } # Store specific function types self._original_tasks = self._filter_functions( - self._original_functions, "is_task" + self._original_functions, "is_task", ) self._original_agents = self._filter_functions( - self._original_functions, "is_agent" + self._original_functions, "is_agent", ) self._before_kickoff = self._filter_functions( - self._original_functions, "is_before_kickoff" + self._original_functions, "is_before_kickoff", ) self._after_kickoff = self._filter_functions( - self._original_functions, "is_after_kickoff" + self._original_functions, "is_after_kickoff", ) self._kickoff = self._filter_functions( - self._original_functions, "is_kickoff" + self._original_functions, "is_kickoff", ) - def load_configurations(self): + def load_configurations(self) -> None: """Load agent and task configurations from YAML files.""" if isinstance(self.original_agents_config_path, str): agents_config_path = ( @@ -75,12 +76,12 @@ def CrewBase(cls: T) -> T: except FileNotFoundError: logging.warning( f"Agent config file not found at {agents_config_path}. " - "Proceeding with empty agent configurations." + "Proceeding with empty agent configurations.", ) self.agents_config = {} else: logging.warning( - "No agent configuration path provided. Proceeding with empty agent configurations." + "No agent configuration path provided. Proceeding with empty agent configurations.", ) self.agents_config = {} @@ -93,22 +94,21 @@ def CrewBase(cls: T) -> T: except FileNotFoundError: logging.warning( f"Task config file not found at {tasks_config_path}. " - "Proceeding with empty task configurations." + "Proceeding with empty task configurations.", ) self.tasks_config = {} else: logging.warning( - "No task configuration path provided. Proceeding with empty task configurations." + "No task configuration path provided. Proceeding with empty task configurations.", ) self.tasks_config = {} @staticmethod def load_yaml(config_path: Path): try: - with open(config_path, "r", encoding="utf-8") as file: + with open(config_path, encoding="utf-8") as file: return yaml.safe_load(file) except FileNotFoundError: - print(f"File not found: {config_path}") raise def _get_all_functions(self): @@ -119,8 +119,8 @@ def CrewBase(cls: T) -> T: } def _filter_functions( - self, functions: Dict[str, Callable], attribute: str - ) -> Dict[str, Callable]: + self, functions: dict[str, Callable], attribute: str, + ) -> dict[str, Callable]: return { name: func for name, func in functions.items() @@ -132,7 +132,7 @@ def CrewBase(cls: T) -> T: llms = self._filter_functions(all_functions, "is_llm") tool_functions = self._filter_functions(all_functions, "is_tool") cache_handler_functions = self._filter_functions( - all_functions, "is_cache_handler" + all_functions, "is_cache_handler", ) callbacks = self._filter_functions(all_functions, "is_callback") @@ -149,11 +149,11 @@ def CrewBase(cls: T) -> T: def _map_agent_variables( self, agent_name: str, - agent_info: Dict[str, Any], - llms: Dict[str, Callable], - tool_functions: Dict[str, Callable], - cache_handler_functions: Dict[str, Callable], - callbacks: Dict[str, Callable], + agent_info: dict[str, Any], + llms: dict[str, Callable], + tool_functions: dict[str, Callable], + cache_handler_functions: dict[str, Callable], + callbacks: dict[str, Callable], ) -> None: if llm := agent_info.get("llm"): try: @@ -187,12 +187,12 @@ def CrewBase(cls: T) -> T: agents = self._filter_functions(all_functions, "is_agent") tasks = self._filter_functions(all_functions, "is_task") output_json_functions = self._filter_functions( - all_functions, "is_output_json" + all_functions, "is_output_json", ) tool_functions = self._filter_functions(all_functions, "is_tool") callback_functions = self._filter_functions(all_functions, "is_callback") output_pydantic_functions = self._filter_functions( - all_functions, "is_output_pydantic" + all_functions, "is_output_pydantic", ) for task_name, task_info in self.tasks_config.items(): @@ -210,13 +210,13 @@ def CrewBase(cls: T) -> T: def _map_task_variables( self, task_name: str, - task_info: Dict[str, Any], - agents: Dict[str, Callable], - tasks: Dict[str, Callable], - output_json_functions: Dict[str, Callable], - tool_functions: Dict[str, Callable], - callback_functions: Dict[str, Callable], - output_pydantic_functions: Dict[str, Callable], + task_info: dict[str, Any], + agents: dict[str, Callable], + tasks: dict[str, Callable], + output_json_functions: dict[str, Callable], + tool_functions: dict[str, Callable], + callback_functions: dict[str, Callable], + output_pydantic_functions: dict[str, Callable], ) -> None: if context_list := task_info.get("context"): self.tasks_config[task_name]["context"] = [ @@ -253,4 +253,4 @@ def CrewBase(cls: T) -> T: WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")" WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")" - return cast(T, WrappedClass) + return cast("T", WrappedClass) diff --git a/src/crewai/security/fingerprint.py b/src/crewai/security/fingerprint.py index 982c62492..b5dd69c96 100644 --- a/src/crewai/security/fingerprint.py +++ b/src/crewai/security/fingerprint.py @@ -1,5 +1,4 @@ -""" -Fingerprint Module +"""Fingerprint Module. This module provides functionality for generating and validating unique identifiers for CrewAI agents. These identifiers are used for tracking, auditing, and security. @@ -7,14 +6,13 @@ for CrewAI agents. These identifiers are used for tracking, auditing, and securi import uuid from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field, field_validator class Fingerprint(BaseModel): - """ - A class for generating and managing unique identifiers for agents. + """A class for generating and managing unique identifiers for agents. Each agent has dual identifiers: - Human-readable ID: For debugging and reference (derived from role if not specified) @@ -24,48 +22,54 @@ class Fingerprint(BaseModel): uuid_str (str): String representation of the UUID for this fingerprint, auto-generated created_at (datetime): When this fingerprint was created, auto-generated metadata (Dict[str, Any]): Additional metadata associated with this fingerprint + """ uuid_str: str = Field(default_factory=lambda: str(uuid.uuid4()), description="String representation of the UUID") created_at: datetime = Field(default_factory=datetime.now, description="When this fingerprint was created") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this fingerprint") + metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this fingerprint") model_config = ConfigDict(arbitrary_types_allowed=True) - - @field_validator('metadata') + + @field_validator("metadata") @classmethod def validate_metadata(cls, v): """Validate that metadata is a dictionary with string keys and valid values.""" if not isinstance(v, dict): - raise ValueError("Metadata must be a dictionary") - + msg = "Metadata must be a dictionary" + raise ValueError(msg) + # Validate that all keys are strings for key, value in v.items(): if not isinstance(key, str): - raise ValueError(f"Metadata keys must be strings, got {type(key)}") - + msg = f"Metadata keys must be strings, got {type(key)}" + raise ValueError(msg) + # Validate nested dictionaries (prevent deeply nested structures) if isinstance(value, dict): # Check for nested dictionaries (limit depth to 1) for nested_key, nested_value in value.items(): if not isinstance(nested_key, str): - raise ValueError(f"Nested metadata keys must be strings, got {type(nested_key)}") + msg = f"Nested metadata keys must be strings, got {type(nested_key)}" + raise ValueError(msg) if isinstance(nested_value, dict): - raise ValueError("Metadata can only be nested one level deep") - + msg = "Metadata can only be nested one level deep" + raise ValueError(msg) + # Check for maximum metadata size (prevent DoS) if len(str(v)) > 10000: # Limit metadata size to 10KB - raise ValueError("Metadata size exceeds maximum allowed (10KB)") - + msg = "Metadata size exceeds maximum allowed (10KB)" + raise ValueError(msg) + return v - def __init__(self, **data): + def __init__(self, **data) -> None: """Initialize a Fingerprint with auto-generated uuid_str and created_at.""" # Remove uuid_str and created_at from data to ensure they're auto-generated - if 'uuid_str' in data: - data.pop('uuid_str') - if 'created_at' in data: - data.pop('created_at') + if "uuid_str" in data: + data.pop("uuid_str") + if "created_at" in data: + data.pop("created_at") # Call the parent constructor with the modified data super().__init__(**data) @@ -77,32 +81,33 @@ class Fingerprint(BaseModel): @classmethod def _generate_uuid(cls, seed: str) -> str: - """ - Generate a deterministic UUID based on a seed string. + """Generate a deterministic UUID based on a seed string. Args: seed (str): The seed string to use for UUID generation Returns: str: A string representation of the UUID consistently generated from the seed + """ if not isinstance(seed, str): - raise ValueError("Seed must be a string") - + msg = "Seed must be a string" + raise ValueError(msg) + if not seed.strip(): - raise ValueError("Seed cannot be empty or whitespace") - + msg = "Seed cannot be empty or whitespace" + raise ValueError(msg) + # Create a deterministic UUID using v5 (SHA-1) # Custom namespace for CrewAI to enhance security # Using a unique namespace specific to CrewAI to reduce collision risks - CREW_AI_NAMESPACE = uuid.UUID('f47ac10b-58cc-4372-a567-0e02b2c3d479') + CREW_AI_NAMESPACE = uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479") return str(uuid.uuid5(CREW_AI_NAMESPACE, seed)) @classmethod - def generate(cls, seed: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> 'Fingerprint': - """ - Static factory method to create a new Fingerprint. + def generate(cls, seed: str | None = None, metadata: dict[str, Any] | None = None) -> "Fingerprint": + """Static factory method to create a new Fingerprint. Args: seed (Optional[str]): A string to use as seed for the UUID generation. @@ -111,11 +116,12 @@ class Fingerprint(BaseModel): Returns: Fingerprint: A new Fingerprint instance + """ fingerprint = cls(metadata=metadata or {}) if seed: # For seed-based generation, we need to manually set the uuid_str after creation - object.__setattr__(fingerprint, 'uuid_str', cls._generate_uuid(seed)) + object.__setattr__(fingerprint, "uuid_str", cls._generate_uuid(seed)) return fingerprint def __str__(self) -> str: @@ -132,29 +138,29 @@ class Fingerprint(BaseModel): """Hash of the fingerprint (based on UUID).""" return hash(self.uuid_str) - def to_dict(self) -> Dict[str, Any]: - """ - Convert the fingerprint to a dictionary representation. + def to_dict(self) -> dict[str, Any]: + """Convert the fingerprint to a dictionary representation. Returns: Dict[str, Any]: Dictionary representation of the fingerprint + """ return { "uuid_str": self.uuid_str, "created_at": self.created_at.isoformat(), - "metadata": self.metadata + "metadata": self.metadata, } @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'Fingerprint': - """ - Create a Fingerprint from a dictionary representation. + def from_dict(cls, data: dict[str, Any]) -> "Fingerprint": + """Create a Fingerprint from a dictionary representation. Args: data (Dict[str, Any]): Dictionary representation of a fingerprint Returns: Fingerprint: A new Fingerprint instance + """ if not data: return cls() @@ -163,8 +169,8 @@ class Fingerprint(BaseModel): # For consistency with existing stored fingerprints, we need to manually set these if "uuid_str" in data: - object.__setattr__(fingerprint, 'uuid_str', data["uuid_str"]) + object.__setattr__(fingerprint, "uuid_str", data["uuid_str"]) if "created_at" in data and isinstance(data["created_at"], str): - object.__setattr__(fingerprint, 'created_at', datetime.fromisoformat(data["created_at"])) + object.__setattr__(fingerprint, "created_at", datetime.fromisoformat(data["created_at"])) return fingerprint diff --git a/src/crewai/security/security_config.py b/src/crewai/security/security_config.py index 9f680de42..f2f115229 100644 --- a/src/crewai/security/security_config.py +++ b/src/crewai/security/security_config.py @@ -1,5 +1,4 @@ -""" -Security Configuration Module +"""Security Configuration Module. This module provides configuration for CrewAI security features, including: - Authentication settings @@ -10,7 +9,7 @@ The SecurityConfig class is the primary interface for managing security settings in CrewAI applications. """ -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -18,8 +17,7 @@ from crewai.security.fingerprint import Fingerprint class SecurityConfig(BaseModel): - """ - Configuration for CrewAI security features. + """Configuration for CrewAI security features. This class manages security settings for CrewAI agents, including: - Authentication credentials *TODO* @@ -30,82 +28,83 @@ class SecurityConfig(BaseModel): Attributes: version (str): Version of the security configuration fingerprint (Fingerprint): The unique fingerprint automatically generated for the component + """ model_config = ConfigDict( - arbitrary_types_allowed=True + arbitrary_types_allowed=True, # Note: Cannot use frozen=True as existing tests modify the fingerprint property ) version: str = Field( - default="1.0.0", - description="Version of the security configuration" + default="1.0.0", + description="Version of the security configuration", ) fingerprint: Fingerprint = Field( - default_factory=Fingerprint, - description="Unique identifier for the component" + default_factory=Fingerprint, + description="Unique identifier for the component", ) - + def is_compatible(self, min_version: str) -> bool: - """ - Check if this security configuration is compatible with the minimum required version. - + """Check if this security configuration is compatible with the minimum required version. + Args: min_version (str): Minimum required version in semver format (e.g., "1.0.0") - + Returns: bool: True if this configuration is compatible, False otherwise + """ # Simple version comparison (can be enhanced with packaging.version if needed) current = [int(x) for x in self.version.split(".")] minimum = [int(x) for x in min_version.split(".")] - + # Compare major, minor, patch versions - for c, m in zip(current, minimum): + for c, m in zip(current, minimum, strict=False): if c > m: return True if c < m: return False return True - @model_validator(mode='before') + @model_validator(mode="before") @classmethod def validate_fingerprint(cls, values): """Ensure fingerprint is properly initialized.""" if isinstance(values, dict): # Handle case where fingerprint is not provided or is None - if 'fingerprint' not in values or values['fingerprint'] is None: - values['fingerprint'] = Fingerprint() + if "fingerprint" not in values or values["fingerprint"] is None: + values["fingerprint"] = Fingerprint() # Handle case where fingerprint is a string (seed) - elif isinstance(values['fingerprint'], str): - if not values['fingerprint'].strip(): - raise ValueError("Fingerprint seed cannot be empty") - values['fingerprint'] = Fingerprint.generate(seed=values['fingerprint']) + elif isinstance(values["fingerprint"], str): + if not values["fingerprint"].strip(): + msg = "Fingerprint seed cannot be empty" + raise ValueError(msg) + values["fingerprint"] = Fingerprint.generate(seed=values["fingerprint"]) return values - def to_dict(self) -> Dict[str, Any]: - """ - Convert the security config to a dictionary. + def to_dict(self) -> dict[str, Any]: + """Convert the security config to a dictionary. Returns: Dict[str, Any]: Dictionary representation of the security config + """ - result = { - "fingerprint": self.fingerprint.to_dict() + return { + "fingerprint": self.fingerprint.to_dict(), } - return result @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'SecurityConfig': - """ - Create a SecurityConfig from a dictionary. + def from_dict(cls, data: dict[str, Any]) -> "SecurityConfig": + """Create a SecurityConfig from a dictionary. Args: data (Dict[str, Any]): Dictionary representation of a security config Returns: SecurityConfig: A new SecurityConfig instance + """ # Make a copy to avoid modifying the original data_copy = data.copy() diff --git a/src/crewai/task.py b/src/crewai/task.py index e4a25f438..550ad81fe 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -2,23 +2,16 @@ import datetime import inspect import json import logging -import re import threading import uuid +from collections.abc import Callable from concurrent.futures import Future from copy import copy from hashlib import md5 from pathlib import Path from typing import ( Any, - Callable, ClassVar, - Dict, - List, - Optional, - Set, - Tuple, - Type, Union, get_args, get_origin, @@ -71,6 +64,7 @@ class Task(BaseModel): output_pydantic: Pydantic model for task output. security_config: Security configuration including fingerprinting. tools: List of tools/resources limited for task execution. + """ __hash__ = object.__hash__ # type: ignore @@ -79,46 +73,46 @@ class Task(BaseModel): tools_errors: int = 0 delegations: int = 0 i18n: I18N = I18N() - name: Optional[str] = Field(default=None) - prompt_context: Optional[str] = None + name: str | None = Field(default=None) + prompt_context: str | None = None description: str = Field(description="Description of the actual task.") expected_output: str = Field( - description="Clear definition of expected output for the task." + description="Clear definition of expected output for the task.", ) - config: Optional[Dict[str, Any]] = Field( + config: dict[str, Any] | None = Field( description="Configuration for the agent", default=None, ) - callback: Optional[Any] = Field( - description="Callback to be executed after the task is completed.", default=None + callback: Any | None = Field( + description="Callback to be executed after the task is completed.", default=None, ) - agent: Optional[BaseAgent] = Field( - description="Agent responsible for execution the task.", default=None + agent: BaseAgent | None = Field( + description="Agent responsible for execution the task.", default=None, ) - context: Optional[List["Task"]] = Field( + context: list["Task"] | None = Field( description="Other tasks that will have their output used as context for this task.", default=None, ) - async_execution: Optional[bool] = Field( + async_execution: bool | None = Field( description="Whether the task should be executed asynchronously or not.", default=False, ) - output_json: Optional[Type[BaseModel]] = Field( + output_json: type[BaseModel] | None = Field( description="A Pydantic model to be used to create a JSON output.", default=None, ) - output_pydantic: Optional[Type[BaseModel]] = Field( + output_pydantic: type[BaseModel] | None = Field( description="A Pydantic model to be used to create a Pydantic output.", default=None, ) - output_file: Optional[str] = Field( + output_file: str | None = Field( description="A file path to be used to create a file output.", default=None, ) - output: Optional[TaskOutput] = Field( - description="Task output, it's final result after being executed", default=None + output: TaskOutput | None = Field( + description="Task output, it's final result after being executed", default=None, ) - tools: Optional[List[BaseTool]] = Field( + tools: list[BaseTool] | None = Field( default_factory=list, description="Tools the agent is limited to use for this task.", ) @@ -131,37 +125,36 @@ class Task(BaseModel): frozen=True, description="Unique identifier for the object, not set by user.", ) - human_input: Optional[bool] = Field( + human_input: bool | None = Field( description="Whether the task should have a human review the final answer of the agent", default=False, ) - converter_cls: Optional[Type[Converter]] = Field( + converter_cls: type[Converter] | None = Field( description="A converter class used to export structured output", default=None, ) - processed_by_agents: Set[str] = Field(default_factory=set) - guardrail: Optional[Union[Callable[[TaskOutput], Tuple[bool, Any]], str]] = Field( + processed_by_agents: set[str] = Field(default_factory=set) + guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str | None = Field( default=None, description="Function or string description of a guardrail to validate task output before proceeding to next task", ) max_retries: int = Field( - default=3, description="Maximum number of retries when guardrail fails" + default=3, description="Maximum number of retries when guardrail fails", ) retry_count: int = Field(default=0, description="Current number of retries") - start_time: Optional[datetime.datetime] = Field( - default=None, description="Start time of the task execution" + start_time: datetime.datetime | None = Field( + default=None, description="Start time of the task execution", ) - end_time: Optional[datetime.datetime] = Field( - default=None, description="End time of the task execution" + end_time: datetime.datetime | None = Field( + default=None, description="End time of the task execution", ) @field_validator("guardrail") @classmethod def validate_guardrail_function( - cls, v: Optional[str | Callable] - ) -> Optional[str | Callable]: - """ - If v is a callable, validate that the guardrail function has the correct signature and behavior. + cls, v: str | Callable | None, + ) -> str | Callable | None: + """If v is a callable, validate that the guardrail function has the correct signature and behavior. If v is a string, return it as is. While type hints provide static checking, this validator ensures runtime safety by: @@ -183,6 +176,7 @@ class Task(BaseModel): Raises: ValueError: If the function signature is invalid or return annotation doesn't match Tuple[bool, Any] + """ if v is not None and callable(v): sig = inspect.signature(v) @@ -192,7 +186,8 @@ class Task(BaseModel): if param.default is inspect.Parameter.empty ] if len(positional_args) != 1: - raise ValueError("Guardrail function must accept exactly one parameter") + msg = "Guardrail function must accept exactly one parameter" + raise ValueError(msg) # Check return annotation if present, but don't require it return_annotation = sig.return_annotation @@ -210,16 +205,17 @@ class Task(BaseModel): or return_annotation_args[1] == Union[str, TaskOutput] ) ): + msg = "If return type is annotated, it must be Tuple[bool, Any]" raise ValueError( - "If return type is annotated, it must be Tuple[bool, Any]" + msg, ) return v - _guardrail: Optional[Callable] = PrivateAttr(default=None) - _original_description: Optional[str] = PrivateAttr(default=None) - _original_expected_output: Optional[str] = PrivateAttr(default=None) - _original_output_file: Optional[str] = PrivateAttr(default=None) - _thread: Optional[threading.Thread] = PrivateAttr(default=None) + _guardrail: Callable | None = PrivateAttr(default=None) + _original_description: str | None = PrivateAttr(default=None) + _original_expected_output: str | None = PrivateAttr(default=None) + _original_output_file: str | None = PrivateAttr(default=None) + _thread: threading.Thread | None = PrivateAttr(default=None) @model_validator(mode="before") @classmethod @@ -231,8 +227,9 @@ class Task(BaseModel): required_fields = ["description", "expected_output"] for field in required_fields: if getattr(self, field) is None: + msg = f"{field} must be provided either directly or through config" raise ValueError( - f"{field} must be provided either directly or through config" + msg, ) return self @@ -245,22 +242,23 @@ class Task(BaseModel): assert self.agent is not None self._guardrail = LLMGuardrail( - description=self.guardrail, llm=self.agent.llm + description=self.guardrail, llm=self.agent.llm, ) return self @field_validator("id", mode="before") @classmethod - def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: + def _deny_user_set_id(cls, v: UUID4 | None) -> None: if v: + msg = "may_not_set_field" raise PydanticCustomError( - "may_not_set_field", "This field is not to be set by the user.", {} + msg, "This field is not to be set by the user.", {}, ) @field_validator("output_file") @classmethod - def output_file_validation(cls, value: Optional[str]) -> Optional[str]: + def output_file_validation(cls, value: str | None) -> str | None: """Validate the output file path. Args: @@ -274,26 +272,30 @@ class Task(BaseModel): Raises: ValueError: If the path contains invalid characters, path traversal attempts, or other security concerns. + """ if value is None: return None # Basic security checks if ".." in value: + msg = "Path traversal attempts are not allowed in output_file paths" raise ValueError( - "Path traversal attempts are not allowed in output_file paths" + msg, ) # Check for shell expansion first - if value.startswith("~") or value.startswith("$"): + if value.startswith(("~", "$")): + msg = "Shell expansion characters are not allowed in output_file paths" raise ValueError( - "Shell expansion characters are not allowed in output_file paths" + msg, ) # Then check other shell special characters if any(char in value for char in ["|", ">", "<", "&", ";"]): + msg = "Shell special characters are not allowed in output_file paths" raise ValueError( - "Shell special characters are not allowed in output_file paths" + msg, ) # Don't strip leading slash if it's a template path with variables @@ -302,7 +304,8 @@ class Task(BaseModel): template_vars = [part.split("}")[0] for part in value.split("{")[1:]] for var in template_vars: if not var.isidentifier(): - raise ValueError(f"Invalid template variable name: {var}") + msg = f"Invalid template variable name: {var}" + raise ValueError(msg) return value # Strip leading slash for regular paths @@ -330,8 +333,9 @@ class Task(BaseModel): """Check if an output type is set.""" output_types = [self.output_json, self.output_pydantic] if len([type for type in output_types if type]) > 1: + msg = "output_type" raise PydanticCustomError( - "output_type", + msg, "Only one output type can be set, either output_pydantic or output_json.", {}, ) @@ -339,9 +343,9 @@ class Task(BaseModel): def execute_sync( self, - agent: Optional[BaseAgent] = None, - context: Optional[str] = None, - tools: Optional[List[BaseTool]] = None, + agent: BaseAgent | None = None, + context: str | None = None, + tools: list[BaseTool] | None = None, ) -> TaskOutput: """Execute the task synchronously.""" return self._execute_core(agent, context, tools) @@ -363,8 +367,8 @@ class Task(BaseModel): def execute_async( self, agent: BaseAgent | None = None, - context: Optional[str] = None, - tools: Optional[List[BaseTool]] = None, + context: str | None = None, + tools: list[BaseTool] | None = None, ) -> Future[TaskOutput]: """Execute the task asynchronously.""" future: Future[TaskOutput] = Future() @@ -377,9 +381,9 @@ class Task(BaseModel): def _execute_task_async( self, - agent: Optional[BaseAgent], - context: Optional[str], - tools: Optional[List[Any]], + agent: BaseAgent | None, + context: str | None, + tools: list[Any] | None, future: Future[TaskOutput], ) -> None: """Execute the task asynchronously with context handling.""" @@ -388,17 +392,18 @@ class Task(BaseModel): def _execute_core( self, - agent: Optional[BaseAgent], - context: Optional[str], - tools: Optional[List[Any]], + agent: BaseAgent | None, + context: str | None, + tools: list[Any] | None, ) -> TaskOutput: """Run the core execution logic of the task.""" try: agent = agent or self.agent self.agent = agent if not agent: + msg = f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical." raise Exception( - f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical." + msg, ) self.start_time = datetime.datetime.now() @@ -430,10 +435,13 @@ class Task(BaseModel): guardrail_result = self._process_guardrail(task_output) if not guardrail_result.success: if self.retry_count >= self.max_retries: - raise Exception( + msg = ( f"Task failed guardrail validation after {self.max_retries} retries. " f"Last error: {guardrail_result.error}" ) + raise Exception( + msg, + ) self.retry_count += 1 context = self.i18n.errors("validation_error").format( @@ -448,14 +456,15 @@ class Task(BaseModel): return self._execute_core(agent, context, tools) if guardrail_result.result is None: + msg = "Task guardrail returned None as result. This is not allowed." raise Exception( - "Task guardrail returned None as result. This is not allowed." + msg, ) if isinstance(guardrail_result.result, str): task_output.raw = guardrail_result.result pydantic_output, json_output = self._export_output( - guardrail_result.result + guardrail_result.result, ) task_output.pydantic = pydantic_output task_output.json_dict = json_output @@ -482,13 +491,13 @@ class Task(BaseModel): ) self._save_file(content) crewai_event_bus.emit( - self, TaskCompletedEvent(output=task_output, task=self) + self, TaskCompletedEvent(output=task_output, task=self), ) return task_output except Exception as e: self.end_time = datetime.datetime.now() crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) - raise e # Re-raise the exception after emitting the event + raise # Re-raise the exception after emitting the event def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult: assert self._guardrail is not None @@ -504,7 +513,7 @@ class Task(BaseModel): crewai_event_bus.emit( self, LLMGuardrailStartedEvent( - guardrail=self._guardrail, retry_count=self.retry_count + guardrail=self._guardrail, retry_count=self.retry_count, ), ) @@ -526,17 +535,18 @@ class Task(BaseModel): Returns: Prompt of the task. + """ tasks_slices = [self.description] output = self.i18n.slice("expected_output").format( - expected_output=self.expected_output + expected_output=self.expected_output, ) tasks_slices = [self.description, output] return "\n".join(tasks_slices) def interpolate_inputs_and_add_conversation_history( - self, inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] + self, inputs: dict[str, str | int | float | dict[str, Any] | list[Any]], ) -> None: """Interpolate inputs into the task description, expected output, and output file path. Add conversation history if present. @@ -547,6 +557,7 @@ class Task(BaseModel): Raises: ValueError: If a required template variable is missing from inputs. + """ if self._original_description is None: self._original_description = self.description @@ -560,43 +571,46 @@ class Task(BaseModel): try: self.description = interpolate_only( - input_string=self._original_description, inputs=inputs + input_string=self._original_description, inputs=inputs, ) except KeyError as e: + msg = f"Missing required template variable '{e.args[0]}' in description" raise ValueError( - f"Missing required template variable '{e.args[0]}' in description" + msg, ) from e except ValueError as e: - raise ValueError(f"Error interpolating description: {str(e)}") from e + msg = f"Error interpolating description: {e!s}" + raise ValueError(msg) from e try: self.expected_output = interpolate_only( - input_string=self._original_expected_output, inputs=inputs + input_string=self._original_expected_output, inputs=inputs, ) except (KeyError, ValueError) as e: - raise ValueError(f"Error interpolating expected_output: {str(e)}") from e + msg = f"Error interpolating expected_output: {e!s}" + raise ValueError(msg) from e if self.output_file is not None: try: self.output_file = interpolate_only( - input_string=self._original_output_file, inputs=inputs + input_string=self._original_output_file, inputs=inputs, ) except (KeyError, ValueError) as e: + msg = f"Error interpolating output_file path: {e!s}" raise ValueError( - f"Error interpolating output_file path: {str(e)}" + msg, ) from e - if "crew_chat_messages" in inputs and inputs["crew_chat_messages"]: + if inputs.get("crew_chat_messages"): conversation_instruction = self.i18n.slice( - "conversation_history_instruction" + "conversation_history_instruction", ) crew_chat_messages_json = str(inputs["crew_chat_messages"]) try: crew_chat_messages = json.loads(crew_chat_messages_json) - except json.JSONDecodeError as e: - print("An error occurred while parsing crew chat messages:", e) + except json.JSONDecodeError: raise conversation_history = "\n".join( @@ -613,14 +627,14 @@ class Task(BaseModel): """Increment the tools errors counter.""" self.tools_errors += 1 - def increment_delegations(self, agent_name: Optional[str]) -> None: + def increment_delegations(self, agent_name: str | None) -> None: """Increment the delegations counter.""" if agent_name: self.processed_by_agents.add(agent_name) self.delegations += 1 def copy( - self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"] + self, agents: list["BaseAgent"], task_mapping: dict[str, "Task"], ) -> "Task": """Creates a deep copy of the Task while preserving its original class type. @@ -630,6 +644,7 @@ class Task(BaseModel): Returns: A copy of the task with the same class type as the original. + """ exclude = { "id", @@ -653,20 +668,19 @@ class Task(BaseModel): cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None cloned_tools = copy(self.tools) if self.tools else [] - copied_task = self.__class__( + return self.__class__( **copied_data, context=cloned_context, agent=cloned_agent, tools=cloned_tools, ) - return copied_task def _export_output( - self, result: str - ) -> Tuple[Optional[BaseModel], Optional[Dict[str, Any]]]: - pydantic_output: Optional[BaseModel] = None - json_output: Optional[Dict[str, Any]] = None + self, result: str, + ) -> tuple[BaseModel | None, dict[str, Any] | None]: + pydantic_output: BaseModel | None = None + json_output: dict[str, Any] | None = None if self.output_pydantic or self.output_json: model_output = convert_to_model( @@ -696,7 +710,7 @@ class Task(BaseModel): return OutputFormat.PYDANTIC return OutputFormat.RAW - def _save_file(self, result: Union[Dict, str, Any]) -> None: + def _save_file(self, result: dict | str | Any) -> None: """Save task output to a file. Note: @@ -713,9 +727,11 @@ class Task(BaseModel): RuntimeError: If there is an error writing to the file. For cross-platform compatibility, especially on Windows, use FileWriterTool from crewai_tools package. + """ if self.output_file is None: - raise ValueError("output_file is not set.") + msg = "output_file is not set." + raise ValueError(msg) FILEWRITER_RECOMMENDATION = ( "For cross-platform file writing, especially on Windows, " @@ -736,15 +752,14 @@ class Task(BaseModel): json.dump(result, file, ensure_ascii=False, indent=2) else: file.write(str(result)) - except (OSError, IOError) as e: + except OSError as e: raise RuntimeError( "\n".join( - [f"Failed to save output file: {e}", FILEWRITER_RECOMMENDATION] - ) + [f"Failed to save output file: {e}", FILEWRITER_RECOMMENDATION], + ), ) - return None - def __repr__(self): + def __repr__(self) -> str: return f"Task(description={self.description}, expected_output={self.expected_output})" @property @@ -753,5 +768,6 @@ class Task(BaseModel): Returns: Fingerprint: The fingerprint of the task + """ return self.security_config.fingerprint diff --git a/src/crewai/tasks/conditional_task.py b/src/crewai/tasks/conditional_task.py index d2edee12f..3458d355d 100644 --- a/src/crewai/tasks/conditional_task.py +++ b/src/crewai/tasks/conditional_task.py @@ -1,4 +1,5 @@ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from pydantic import Field @@ -8,8 +9,7 @@ from crewai.tasks.task_output import TaskOutput class ConditionalTask(Task): - """ - A task that can be conditionally executed based on the output of another task. + """A task that can be conditionally executed based on the output of another task. Note: This cannot be the only task you have in your crew and cannot be the first since its needs context from the previous task. """ @@ -22,19 +22,19 @@ class ConditionalTask(Task): self, condition: Callable[[Any], bool], **kwargs, - ): + ) -> None: super().__init__(**kwargs) self.condition = condition def should_execute(self, context: TaskOutput) -> bool: - """ - Determines whether the conditional task should be executed based on the provided context. + """Determines whether the conditional task should be executed based on the provided context. Args: context (Any): The context or output from the previous task that will be evaluated by the condition. Returns: bool: True if the task should be executed, False otherwise. + """ return self.condition(context) diff --git a/src/crewai/tasks/guardrail_result.py b/src/crewai/tasks/guardrail_result.py index ba8ebc552..73d44125f 100644 --- a/src/crewai/tasks/guardrail_result.py +++ b/src/crewai/tasks/guardrail_result.py @@ -1,11 +1,10 @@ -""" -Module for handling task guardrail validation results. +"""Module for handling task guardrail validation results. This module provides the GuardrailResult class which standardizes the way task guardrails return their validation results. """ -from typing import Any, Optional, Tuple, Union +from typing import Any from pydantic import BaseModel, field_validator @@ -21,10 +20,12 @@ class GuardrailResult(BaseModel): success (bool): Whether the guardrail validation passed result (Any, optional): The validated/transformed result if successful error (str, optional): Error message if validation failed + """ + success: bool - result: Optional[Any] = None - error: Optional[str] = None + result: Any | None = None + error: str | None = None @field_validator("result", "error") @classmethod @@ -32,13 +33,15 @@ class GuardrailResult(BaseModel): values = info.data if "success" in values: if values["success"] and v and "error" in values and values["error"]: - raise ValueError("Cannot have both result and error when success is True") + msg = "Cannot have both result and error when success is True" + raise ValueError(msg) if not values["success"] and v and "result" in values and values["result"]: - raise ValueError("Cannot have both result and error when success is False") + msg = "Cannot have both result and error when success is False" + raise ValueError(msg) return v @classmethod - def from_tuple(cls, result: Tuple[bool, Union[Any, str]]) -> "GuardrailResult": + def from_tuple(cls, result: tuple[bool, Any | str]) -> "GuardrailResult": """Create a GuardrailResult from a validation tuple. Args: @@ -47,10 +50,11 @@ class GuardrailResult(BaseModel): Returns: GuardrailResult: A new instance with the tuple data. + """ success, data = result return cls( success=success, result=data if success else None, - error=data if not success else None + error=data if not success else None, ) diff --git a/src/crewai/tasks/llm_guardrail.py b/src/crewai/tasks/llm_guardrail.py index 2bb948075..2d2bccab3 100644 --- a/src/crewai/tasks/llm_guardrail.py +++ b/src/crewai/tasks/llm_guardrail.py @@ -1,16 +1,15 @@ -from typing import Any, Optional, Tuple +from typing import Any from pydantic import BaseModel, Field from crewai.agent import Agent, LiteAgentOutput from crewai.llm import LLM -from crewai.task import Task from crewai.tasks.task_output import TaskOutput class LLMGuardrailResult(BaseModel): valid: bool = Field( - description="Whether the task output complies with the guardrail" + description="Whether the task output complies with the guardrail", ) feedback: str | None = Field( description="A feedback about the task output if it is not valid", @@ -27,13 +26,14 @@ class LLMGuardrail: Args: description (str): The description of the validation criteria. llm (LLM, optional): The language model to use for code generation. + """ def __init__( self, description: str, llm: LLM, - ): + ) -> None: self.description = description self.llm: LLM = llm @@ -54,7 +54,7 @@ class LLMGuardrail: Guardrail: {self.description} - + Your task: - Confirm if the Task result complies with the guardrail. - If not, provide clear feedback explaining what is wrong (e.g., by how much it violates the rule, or what specific part fails). @@ -62,11 +62,10 @@ class LLMGuardrail: - If the Task result complies with the guardrail, saying that is valid """ - result = agent.kickoff(query, response_format=LLMGuardrailResult) + return agent.kickoff(query, response_format=LLMGuardrailResult) - return result - def __call__(self, task_output: TaskOutput) -> Tuple[bool, Any]: + def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]: """Validates the output of a task based on specified criteria. Args: @@ -76,17 +75,16 @@ class LLMGuardrail: Tuple[bool, Any]: A tuple containing: - bool: True if validation passed, False otherwise - Any: The validation result or error message - """ + """ try: result = self._validate_output(task_output) assert isinstance( - result.pydantic, LLMGuardrailResult + result.pydantic, LLMGuardrailResult, ), "The guardrail result is not a valid pydantic model" if result.pydantic.valid: return True, task_output.raw - else: - return False, result.pydantic.feedback + return False, result.pydantic.feedback except Exception as e: - return False, f"Error while validating the task output: {str(e)}" + return False, f"Error while validating the task output: {e!s}" diff --git a/src/crewai/tasks/task_output.py b/src/crewai/tasks/task_output.py index b0e8aecd4..946c63991 100644 --- a/src/crewai/tasks/task_output.py +++ b/src/crewai/tasks/task_output.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field, model_validator @@ -10,21 +10,21 @@ class TaskOutput(BaseModel): """Class that represents the result of a task.""" description: str = Field(description="Description of the task") - name: Optional[str] = Field(description="Name of the task", default=None) - expected_output: Optional[str] = Field( - description="Expected output of the task", default=None + name: str | None = Field(description="Name of the task", default=None) + expected_output: str | None = Field( + description="Expected output of the task", default=None, ) - summary: Optional[str] = Field(description="Summary of the task", default=None) + summary: str | None = Field(description="Summary of the task", default=None) raw: str = Field(description="Raw output of the task", default="") - pydantic: Optional[BaseModel] = Field( - description="Pydantic output of task", default=None + pydantic: BaseModel | None = Field( + description="Pydantic output of task", default=None, ) - json_dict: Optional[Dict[str, Any]] = Field( - description="JSON dictionary of task", default=None + json_dict: dict[str, Any] | None = Field( + description="JSON dictionary of task", default=None, ) agent: str = Field(description="Agent that executed the task") output_format: OutputFormat = Field( - description="Output format of the task", default=OutputFormat.RAW + description="Output format of the task", default=OutputFormat.RAW, ) @model_validator(mode="after") @@ -35,19 +35,22 @@ class TaskOutput(BaseModel): return self @property - def json(self) -> Optional[str]: + def json(self) -> str | None: if self.output_format != OutputFormat.JSON: - raise ValueError( + msg = ( """ Invalid output format requested. If you would like to access the JSON output, please make sure to set the output_json property for the task """ ) + raise ValueError( + msg, + ) return json.dumps(self.json_dict) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert json_output and pydantic_output to a dictionary.""" output_dict = {} if self.json_dict: diff --git a/src/crewai/telemetry/telemetry.py b/src/crewai/telemetry/telemetry.py index 142cafb2a..59a61c2f3 100644 --- a/src/crewai/telemetry/telemetry.py +++ b/src/crewai/telemetry/telemetry.py @@ -6,9 +6,9 @@ import logging import os import platform import warnings -from contextlib import contextmanager +from contextlib import contextmanager, suppress from importlib.metadata import version -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from crewai.telemetry.constants import ( CREWAI_TELEMETRY_BASE_URL, @@ -27,7 +27,7 @@ def suppress_warnings(): from opentelemetry import trace # noqa: E402 from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter, # noqa: E402 + OTLPSpanExporter, ) from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402 from opentelemetry.sdk.trace import TracerProvider # noqa: E402 @@ -47,7 +47,7 @@ class SafeOTLPSpanExporter(OTLPSpanExporter): try: return super().export(spans) except Exception as e: - logger.error(e) + logger.exception(e) return SpanExportResult.FAILURE @@ -64,7 +64,7 @@ class Telemetry: attribute in the Crew class. """ - def __init__(self): + def __init__(self) -> None: self.ready: bool = False self.trace_set: bool = False @@ -82,7 +82,7 @@ class Telemetry: SafeOTLPSpanExporter( endpoint=f"{CREWAI_TELEMETRY_BASE_URL}/v1/traces", timeout=30, - ) + ), ) self.provider.add_span_processor(processor) @@ -102,7 +102,7 @@ class Telemetry: or os.getenv("CREWAI_DISABLE_TELEMETRY", "false").lower() == "true" ) - def set_tracer(self): + def set_tracer(self) -> None: if self.ready and not self.trace_set: try: with suppress_warnings(): @@ -112,18 +112,16 @@ class Telemetry: self.ready = False self.trace_set = False - def _safe_telemetry_operation(self, operation): + def _safe_telemetry_operation(self, operation) -> None: if not self.ready: return - try: + with suppress(Exception): operation() - except Exception: - pass - def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None): + def crew_creation(self, crew: Crew, inputs: dict[str, Any] | None) -> None: """Records the creation of a crew.""" - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Crew Created") self._add_attribute( @@ -183,7 +181,7 @@ class Telemetry: "llm": agent.llm.model, "delegation_enabled?": agent.allow_delegation, "allow_code_execution?": getattr( - agent, "allow_code_execution", False + agent, "allow_code_execution", False, ), "max_retry_limit": getattr(agent, "max_retry_limit", 3), "tools_names": [ @@ -211,7 +209,7 @@ class Telemetry: ), } for agent in crew.agents - ] + ], ), ) self._add_attribute( @@ -251,7 +249,7 @@ class Telemetry: ), } for task in crew.tasks - ] + ], ), ) self._add_attribute(span, "platform", platform.platform()) @@ -260,7 +258,7 @@ class Telemetry: self._add_attribute(span, "platform_version", platform.version()) self._add_attribute(span, "cpus", os.cpu_count()) self._add_attribute( - span, "crew_inputs", json.dumps(inputs) if inputs else None + span, "crew_inputs", json.dumps(inputs) if inputs else None, ) else: self._add_attribute( @@ -287,7 +285,7 @@ class Telemetry: "llm": agent.llm.model, "delegation_enabled?": agent.allow_delegation, "allow_code_execution?": getattr( - agent, "allow_code_execution", False + agent, "allow_code_execution", False, ), "max_retry_limit": getattr(agent, "max_retry_limit", 3), "tools_names": [ @@ -295,7 +293,7 @@ class Telemetry: ], } for agent in crew.agents - ] + ], ), ) self._add_attribute( @@ -317,7 +315,7 @@ class Telemetry: ], } for task in crew.tasks - ] + ], ), ) span.set_status(Status(StatusCode.OK)) @@ -341,12 +339,12 @@ class Telemetry: # Add fingerprint data if hasattr(crew, "fingerprint") and crew.fingerprint: self._add_attribute( - created_span, "crew_fingerprint", crew.fingerprint.uuid_str + created_span, "crew_fingerprint", crew.fingerprint.uuid_str, ) if hasattr(task, "fingerprint") and task.fingerprint: self._add_attribute( - created_span, "task_fingerprint", task.fingerprint.uuid_str + created_span, "task_fingerprint", task.fingerprint.uuid_str, ) self._add_attribute( created_span, @@ -364,19 +362,19 @@ class Telemetry: # Add agent fingerprint if task has an assigned agent if hasattr(task, "agent") and task.agent: agent_fingerprint = getattr( - getattr(task.agent, "fingerprint", None), "uuid_str", None + getattr(task.agent, "fingerprint", None), "uuid_str", None, ) if agent_fingerprint: self._add_attribute( - created_span, "agent_fingerprint", agent_fingerprint + created_span, "agent_fingerprint", agent_fingerprint, ) if crew.share_crew: self._add_attribute( - created_span, "formatted_description", task.description + created_span, "formatted_description", task.description, ) self._add_attribute( - created_span, "formatted_expected_output", task.expected_output + created_span, "formatted_expected_output", task.expected_output, ) created_span.set_status(Status(StatusCode.OK)) @@ -399,7 +397,7 @@ class Telemetry: # Add agent fingerprint if task has an assigned agent if hasattr(task, "agent") and task.agent: agent_fingerprint = getattr( - getattr(task.agent, "fingerprint", None), "uuid_str", None + getattr(task.agent, "fingerprint", None), "uuid_str", None, ) if agent_fingerprint: self._add_attribute(span, "agent_fingerprint", agent_fingerprint) @@ -407,14 +405,14 @@ class Telemetry: if crew.share_crew: self._add_attribute(span, "formatted_description", task.description) self._add_attribute( - span, "formatted_expected_output", task.expected_output + span, "formatted_expected_output", task.expected_output, ) return span return self._safe_telemetry_operation(operation) - def task_ended(self, span: Span, task: Task, crew: Crew): + def task_ended(self, span: Span, task: Task, crew: Crew) -> None: """Records the completion of a task execution in a crew. Args: @@ -424,9 +422,10 @@ class Telemetry: Note: If share_crew is enabled, this will also record the task output + """ - def operation(): + def operation() -> None: # Ensure fingerprint data is present on completion span if hasattr(task, "fingerprint") and task.fingerprint: self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str) @@ -443,16 +442,17 @@ class Telemetry: self._safe_telemetry_operation(operation) - def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int): + def tool_repeated_usage(self, llm: Any, tool_name: str, attempts: int) -> None: """Records when a tool is used repeatedly, which might indicate an issue. Args: llm (Any): The language model being used tool_name (str): Name of the tool being repeatedly used attempts (int): Number of attempts made with this tool + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Tool Repeated Usage") self._add_attribute( @@ -469,7 +469,7 @@ class Telemetry: self._safe_telemetry_operation(operation) - def tool_usage(self, llm: Any, tool_name: str, attempts: int, agent: Any = None): + def tool_usage(self, llm: Any, tool_name: str, attempts: int, agent: Any = None) -> None: """Records the usage of a tool by an agent. Args: @@ -477,9 +477,10 @@ class Telemetry: tool_name (str): Name of the tool being used attempts (int): Number of attempts made with this tool agent (Any, optional): The agent using the tool + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Tool Usage") self._add_attribute( @@ -495,7 +496,7 @@ class Telemetry: # Add agent fingerprint data if available if agent and hasattr(agent, "fingerprint") and agent.fingerprint: self._add_attribute( - span, "agent_fingerprint", agent.fingerprint.uuid_str + span, "agent_fingerprint", agent.fingerprint.uuid_str, ) if hasattr(agent, "role"): self._add_attribute(span, "agent_role", agent.role) @@ -506,17 +507,18 @@ class Telemetry: self._safe_telemetry_operation(operation) def tool_usage_error( - self, llm: Any, agent: Any = None, tool_name: Optional[str] = None - ): + self, llm: Any, agent: Any = None, tool_name: str | None = None, + ) -> None: """Records when a tool usage results in an error. Args: llm (Any): The language model being used when the error occurred agent (Any, optional): The agent using the tool tool_name (str, optional): Name of the tool that caused the error + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Tool Usage Error") self._add_attribute( @@ -533,7 +535,7 @@ class Telemetry: # Add agent fingerprint data if available if agent and hasattr(agent, "fingerprint") and agent.fingerprint: self._add_attribute( - span, "agent_fingerprint", agent.fingerprint.uuid_str + span, "agent_fingerprint", agent.fingerprint.uuid_str, ) if hasattr(agent, "role"): self._add_attribute(span, "agent_role", agent.role) @@ -544,8 +546,8 @@ class Telemetry: self._safe_telemetry_operation(operation) def individual_test_result_span( - self, crew: Crew, quality: float, exec_time: int, model_name: str - ): + self, crew: Crew, quality: float, exec_time: int, model_name: str, + ) -> None: """Records individual test results for a crew execution. Args: @@ -553,9 +555,10 @@ class Telemetry: quality (float): Quality score of the execution exec_time (int): Execution time in seconds model_name (str): Name of the model used + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Crew Individual Test Result") @@ -580,7 +583,7 @@ class Telemetry: iterations: int, inputs: dict[str, Any] | None, model_name: str, - ): + ) -> None: """Records the execution of a test suite for a crew. Args: @@ -588,9 +591,10 @@ class Telemetry: iterations (int): Number of test iterations inputs (dict[str, Any] | None): Input parameters for the test model_name (str): Name of the model used in testing + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Crew Test Execution") @@ -606,7 +610,7 @@ class Telemetry: if crew.share_crew: self._add_attribute( - span, "inputs", json.dumps(inputs) if inputs else None + span, "inputs", json.dumps(inputs) if inputs else None, ) span.set_status(Status(StatusCode.OK)) @@ -614,10 +618,10 @@ class Telemetry: self._safe_telemetry_operation(operation) - def deploy_signup_error_span(self): + def deploy_signup_error_span(self) -> None: """Records when an error occurs during the deployment signup process.""" - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Deploy Signup Error") span.set_status(Status(StatusCode.OK)) @@ -625,14 +629,15 @@ class Telemetry: self._safe_telemetry_operation(operation) - def start_deployment_span(self, uuid: Optional[str] = None): + def start_deployment_span(self, uuid: str | None = None) -> None: """Records the start of a deployment process. Args: uuid (Optional[str]): Unique identifier for the deployment + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Start Deployment") if uuid: @@ -642,10 +647,10 @@ class Telemetry: self._safe_telemetry_operation(operation) - def create_crew_deployment_span(self): + def create_crew_deployment_span(self) -> None: """Records the creation of a new crew deployment.""" - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Create Crew Deployment") span.set_status(Status(StatusCode.OK)) @@ -653,15 +658,16 @@ class Telemetry: self._safe_telemetry_operation(operation) - def get_crew_logs_span(self, uuid: Optional[str], log_type: str = "deployment"): + def get_crew_logs_span(self, uuid: str | None, log_type: str = "deployment") -> None: """Records the retrieval of crew logs. Args: uuid (Optional[str]): Unique identifier for the crew log_type (str, optional): Type of logs being retrieved. Defaults to "deployment". + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Get Crew Logs") self._add_attribute(span, "log_type", log_type) @@ -672,14 +678,15 @@ class Telemetry: self._safe_telemetry_operation(operation) - def remove_crew_span(self, uuid: Optional[str] = None): + def remove_crew_span(self, uuid: str | None = None) -> None: """Records the removal of a crew. Args: uuid (Optional[str]): Unique identifier for the crew being removed + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Remove Crew") if uuid: @@ -706,7 +713,7 @@ class Telemetry: self._add_attribute(span, "crew_key", crew.key) self._add_attribute(span, "crew_id", str(crew.id)) self._add_attribute( - span, "crew_inputs", json.dumps(inputs) if inputs else None + span, "crew_inputs", json.dumps(inputs) if inputs else None, ) self._add_attribute( span, @@ -730,7 +737,7 @@ class Telemetry: ], } for agent in crew.agents - ] + ], ), ) self._add_attribute( @@ -756,7 +763,7 @@ class Telemetry: ], } for task in crew.tasks - ] + ], ), ) return span @@ -765,15 +772,15 @@ class Telemetry: return self._safe_telemetry_operation(operation) return None - def end_crew(self, crew, final_string_output): - def operation(): + def end_crew(self, crew, final_string_output) -> None: + def operation() -> None: self._add_attribute( crew._execution_span, "crewai_version", version("crewai"), ) self._add_attribute( - crew._execution_span, "crew_output", final_string_output + crew._execution_span, "crew_output", final_string_output, ) self._add_attribute( crew._execution_span, @@ -786,7 +793,7 @@ class Telemetry: "output": task.output.raw_output, } for task in crew.tasks - ] + ], ), ) crew._execution_span.set_status(Status(StatusCode.OK)) @@ -795,7 +802,7 @@ class Telemetry: if crew.share_crew: self._safe_telemetry_operation(operation) - def _add_attribute(self, span, key, value): + def _add_attribute(self, span, key, value) -> None: """Add an attribute to a span.""" def operation(): @@ -803,14 +810,15 @@ class Telemetry: self._safe_telemetry_operation(operation) - def flow_creation_span(self, flow_name: str): + def flow_creation_span(self, flow_name: str) -> None: """Records the creation of a new flow. Args: flow_name (str): Name of the flow being created + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Flow Creation") self._add_attribute(span, "flow_name", flow_name) @@ -819,15 +827,16 @@ class Telemetry: self._safe_telemetry_operation(operation) - def flow_plotting_span(self, flow_name: str, node_names: list[str]): + def flow_plotting_span(self, flow_name: str, node_names: list[str]) -> None: """Records flow visualization/plotting activity. Args: flow_name (str): Name of the flow being plotted node_names (list[str]): List of node names in the flow + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Flow Plotting") self._add_attribute(span, "flow_name", flow_name) @@ -837,15 +846,16 @@ class Telemetry: self._safe_telemetry_operation(operation) - def flow_execution_span(self, flow_name: str, node_names: list[str]): + def flow_execution_span(self, flow_name: str, node_names: list[str]) -> None: """Records the execution of a flow. Args: flow_name (str): Name of the flow being executed node_names (list[str]): List of nodes being executed in the flow + """ - def operation(): + def operation() -> None: tracer = trace.get_tracer("crewai.telemetry") span = tracer.start_span("Flow Execution") self._add_attribute(span, "flow_name", flow_name) diff --git a/src/crewai/tools/agent_tools/add_image_tool.py b/src/crewai/tools/agent_tools/add_image_tool.py index 939dff2df..eba121771 100644 --- a/src/crewai/tools/agent_tools/add_image_tool.py +++ b/src/crewai/tools/agent_tools/add_image_tool.py @@ -1,4 +1,3 @@ -from typing import Dict, Optional, Union from pydantic import BaseModel, Field @@ -10,13 +9,13 @@ i18n = I18N() class AddImageToolSchema(BaseModel): image_url: str = Field(..., description="The URL or path of the image to add") - action: Optional[str] = Field( - default=None, description="Optional context or question about the image" + action: str | None = Field( + default=None, description="Optional context or question about the image", ) class AddImageTool(BaseTool): - """Tool for adding images to the content""" + """Tool for adding images to the content.""" name: str = Field(default_factory=lambda: i18n.tools("add_image")["name"]) # type: ignore description: str = Field(default_factory=lambda: i18n.tools("add_image")["description"]) # type: ignore @@ -25,7 +24,7 @@ class AddImageTool(BaseTool): def _run( self, image_url: str, - action: Optional[str] = None, + action: str | None = None, **kwargs, ) -> dict: action = action or i18n.tools("add_image")["default_action"] # type: ignore diff --git a/src/crewai/tools/agent_tools/agent_tools.py b/src/crewai/tools/agent_tools/agent_tools.py index 77d3c2d89..67d976370 100644 --- a/src/crewai/tools/agent_tools/agent_tools.py +++ b/src/crewai/tools/agent_tools/agent_tools.py @@ -7,14 +7,14 @@ from .delegate_work_tool import DelegateWorkTool class AgentTools: - """Manager class for agent-related tools""" + """Manager class for agent-related tools.""" - def __init__(self, agents: list[BaseAgent], i18n: I18N = I18N()): + def __init__(self, agents: list[BaseAgent], i18n: I18N = I18N()) -> None: self.agents = agents self.i18n = i18n def tools(self) -> list[BaseTool]: - """Get all available agent tools""" + """Get all available agent tools.""" coworkers = ", ".join([f"{agent.role}" for agent in self.agents]) delegate_tool = DelegateWorkTool( diff --git a/src/crewai/tools/agent_tools/ask_question_tool.py b/src/crewai/tools/agent_tools/ask_question_tool.py index 9294770e5..d7925640d 100644 --- a/src/crewai/tools/agent_tools/ask_question_tool.py +++ b/src/crewai/tools/agent_tools/ask_question_tool.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel, Field @@ -12,7 +11,7 @@ class AskQuestionToolSchema(BaseModel): class AskQuestionTool(BaseAgentTool): - """Tool for asking questions to coworkers""" + """Tool for asking questions to coworkers.""" name: str = "Ask question to coworker" args_schema: type[BaseModel] = AskQuestionToolSchema @@ -21,7 +20,7 @@ class AskQuestionTool(BaseAgentTool): self, question: str, context: str, - coworker: Optional[str] = None, + coworker: str | None = None, **kwargs, ) -> str: coworker = self._get_coworker(coworker, **kwargs) diff --git a/src/crewai/tools/agent_tools/base_agent_tools.py b/src/crewai/tools/agent_tools/base_agent_tools.py index b00fbb7b5..1b43e569b 100644 --- a/src/crewai/tools/agent_tools/base_agent_tools.py +++ b/src/crewai/tools/agent_tools/base_agent_tools.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from pydantic import Field @@ -12,16 +11,15 @@ logger = logging.getLogger(__name__) class BaseAgentTool(BaseTool): - """Base class for agent-related tools""" + """Base class for agent-related tools.""" agents: list[BaseAgent] = Field(description="List of available agents") i18n: I18N = Field( - default_factory=I18N, description="Internationalization settings" + default_factory=I18N, description="Internationalization settings", ) def sanitize_agent_name(self, name: str) -> str: - """ - Sanitize agent role name by normalizing whitespace and setting to lowercase. + """Sanitize agent role name by normalizing whitespace and setting to lowercase. Converts all whitespace (including newlines) to single spaces and removes quotes. Args: @@ -30,6 +28,7 @@ class BaseAgentTool(BaseTool): Returns: str: The sanitized agent role name, with whitespace normalized, converted to lowercase, and quotes removed + """ if not name: return "" @@ -38,7 +37,7 @@ class BaseAgentTool(BaseTool): # Remove quotes and convert to lowercase return normalized.replace('"', "").casefold() - def _get_coworker(self, coworker: Optional[str], **kwargs) -> Optional[str]: + def _get_coworker(self, coworker: str | None, **kwargs) -> str | None: coworker = coworker or kwargs.get("co_worker") or kwargs.get("coworker") if coworker: is_list = coworker.startswith("[") and coworker.endswith("]") @@ -48,12 +47,11 @@ class BaseAgentTool(BaseTool): def _execute( self, - agent_name: Optional[str], + agent_name: str | None, task: str, - context: Optional[str] = None + context: str | None = None, ) -> str: - """ - Execute delegation to an agent with case-insensitive and whitespace-tolerant matching. + """Execute delegation to an agent with case-insensitive and whitespace-tolerant matching. Args: agent_name: Name/role of the agent to delegate to (case-insensitive) @@ -63,6 +61,7 @@ class BaseAgentTool(BaseTool): Returns: str: The execution result from the delegated agent or an error message if the agent cannot be found + """ try: if agent_name is None: @@ -92,18 +91,18 @@ class BaseAgentTool(BaseTool): # Handle specific exceptions that might occur during role name processing return self.i18n.errors("agent_tool_unexisting_coworker").format( coworkers="\n".join( - [f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents] + [f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents], ), - error=str(e) + error=str(e), ) if not agent: # No matching agent found after sanitization return self.i18n.errors("agent_tool_unexisting_coworker").format( coworkers="\n".join( - [f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents] + [f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents], ), - error=f"No agent found with role '{sanitized_name}'" + error=f"No agent found with role '{sanitized_name}'", ) agent = agent[0] @@ -120,5 +119,5 @@ class BaseAgentTool(BaseTool): # Handle task creation or execution errors return self.i18n.errors("agent_tool_execution_error").format( agent_role=self.sanitize_agent_name(agent.role), - error=str(e) + error=str(e), ) diff --git a/src/crewai/tools/agent_tools/delegate_work_tool.py b/src/crewai/tools/agent_tools/delegate_work_tool.py index 9dbf6c920..431736073 100644 --- a/src/crewai/tools/agent_tools/delegate_work_tool.py +++ b/src/crewai/tools/agent_tools/delegate_work_tool.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel, Field @@ -9,12 +8,12 @@ class DelegateWorkToolSchema(BaseModel): task: str = Field(..., description="The task to delegate") context: str = Field(..., description="The context for the task") coworker: str = Field( - ..., description="The role/name of the coworker to delegate to" + ..., description="The role/name of the coworker to delegate to", ) class DelegateWorkTool(BaseAgentTool): - """Tool for delegating work to coworkers""" + """Tool for delegating work to coworkers.""" name: str = "Delegate work to coworker" args_schema: type[BaseModel] = DelegateWorkToolSchema @@ -23,7 +22,7 @@ class DelegateWorkTool(BaseAgentTool): self, task: str, context: str, - coworker: Optional[str] = None, + coworker: str | None = None, **kwargs, ) -> str: coworker = self._get_coworker(coworker, **kwargs) diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index 0e8a7a22b..2eb2384f1 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -1,8 +1,8 @@ import asyncio -import warnings from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import signature -from typing import Any, Callable, Type, get_args, get_origin +from typing import Any, get_args, get_origin from pydantic import ( BaseModel, @@ -26,8 +26,8 @@ class BaseTool(BaseModel, ABC): """The unique name of the tool that clearly communicates its purpose.""" description: str """Used to tell the model how/when/why to use the tool.""" - args_schema: Type[PydanticBaseModel] = Field( - default_factory=_ArgsSchemaPlaceholder, validate_default=True + args_schema: type[PydanticBaseModel] = Field( + default_factory=_ArgsSchemaPlaceholder, validate_default=True, ) """The schema for the arguments that the tool accepts.""" description_updated: bool = False @@ -40,8 +40,8 @@ class BaseTool(BaseModel, ABC): @field_validator("args_schema", mode="before") @classmethod def _default_args_schema( - cls, v: Type[PydanticBaseModel] - ) -> Type[PydanticBaseModel]: + cls, v: type[PydanticBaseModel], + ) -> type[PydanticBaseModel]: if not isinstance(v, cls._ArgsSchemaPlaceholder): return v @@ -65,7 +65,6 @@ class BaseTool(BaseModel, ABC): *args: Any, **kwargs: Any, ) -> Any: - print(f"Using Tool: {self.name}") result = self._run(*args, **kwargs) # If _run is async, we safely run it @@ -102,7 +101,8 @@ class BaseTool(BaseModel, ABC): attribute and infers the argument schema if not explicitly provided. """ if not hasattr(tool, "func") or not callable(tool.func): - raise ValueError("The provided tool must have a callable 'func' attribute.") + msg = "The provided tool must have a callable 'func' attribute." + raise ValueError(msg) args_schema = getattr(tool, "args_schema", None) @@ -126,7 +126,7 @@ class BaseTool(BaseModel, ABC): else: # Create a default schema with no fields if no parameters are found args_schema = create_model( - f"{tool.name}Input", __base__=PydanticBaseModel + f"{tool.name}Input", __base__=PydanticBaseModel, ) return cls( @@ -136,7 +136,7 @@ class BaseTool(BaseModel, ABC): args_schema=args_schema, ) - def _set_args_schema(self): + def _set_args_schema(self) -> None: if self.args_schema is None: class_name = f"{self.__class__.__name__}Schema" self.args_schema = type( @@ -151,7 +151,7 @@ class BaseTool(BaseModel, ABC): }, ) - def _generate_description(self): + def _generate_description(self) -> None: args_schema = { name: { "description": field.description, @@ -208,9 +208,11 @@ class Tool(BaseTool): Raises: ValueError: If the provided tool does not have a callable 'func' attribute. + """ if not hasattr(tool, "func") or not callable(tool.func): - raise ValueError("The provided tool must have a callable 'func' attribute.") + msg = "The provided tool must have a callable 'func' attribute." + raise ValueError(msg) args_schema = getattr(tool, "args_schema", None) @@ -234,7 +236,7 @@ class Tool(BaseTool): else: # Create a default schema with no fields if no parameters are found args_schema = create_model( - f"{tool.name}Input", __base__=PydanticBaseModel + f"{tool.name}Input", __base__=PydanticBaseModel, ) return cls( @@ -252,20 +254,22 @@ def to_langchain( def tool(*args, result_as_answer=False): - """ - Decorator to create a tool from a function. - + """Decorator to create a tool from a function. + Args: *args: Positional arguments, either the function to decorate or the tool name. result_as_answer: Flag to indicate if the tool result should be used as the final agent answer. + """ def _make_with_name(tool_name: str) -> Callable: def _make_tool(f: Callable) -> BaseTool: if f.__doc__ is None: - raise ValueError("Function must have a docstring") + msg = "Function must have a docstring" + raise ValueError(msg) if f.__annotations__ is None: - raise ValueError("Function must have type annotations") + msg = "Function must have type annotations" + raise ValueError(msg) class_name = "".join(tool_name.split()).title() args_schema = type( @@ -292,4 +296,5 @@ def tool(*args, result_as_answer=False): return _make_with_name(args[0].__name__)(args[0]) if len(args) == 1 and isinstance(args[0], str): return _make_with_name(args[0]) - raise ValueError("Invalid arguments") + msg = "Invalid arguments" + raise ValueError(msg) diff --git a/src/crewai/tools/structured_tool.py b/src/crewai/tools/structured_tool.py index dfd23a9cb..e75ae1289 100644 --- a/src/crewai/tools/structured_tool.py +++ b/src/crewai/tools/structured_tool.py @@ -2,12 +2,15 @@ from __future__ import annotations import inspect import textwrap -from typing import Any, Callable, Optional, Union, get_type_hints +from typing import TYPE_CHECKING, Any, get_type_hints from pydantic import BaseModel, Field, create_model from crewai.utilities.logger import Logger +if TYPE_CHECKING: + from collections.abc import Callable + class CrewStructuredTool: """A structured tool that can operate on any number of inputs. @@ -32,6 +35,7 @@ class CrewStructuredTool: args_schema: The pydantic model for the tool's arguments func: The function to run when the tool is called result_as_answer: Whether to return the output directly + """ self.name = name self.description = description @@ -47,10 +51,10 @@ class CrewStructuredTool: def from_function( cls, func: Callable, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, return_direct: bool = False, - args_schema: Optional[type[BaseModel]] = None, + args_schema: type[BaseModel] | None = None, infer_schema: bool = True, **kwargs: Any, ) -> CrewStructuredTool: @@ -73,13 +77,15 @@ class CrewStructuredTool: ... '''Add two numbers''' ... return a + b >>> tool = CrewStructuredTool.from_function(add) + """ name = name or func.__name__ description = description or inspect.getdoc(func) if description is None: + msg = f"Function {name} must have a docstring if description not provided." raise ValueError( - f"Function {name} must have a docstring if description not provided." + msg, ) # Clean up the description @@ -92,8 +98,9 @@ class CrewStructuredTool: # Infer schema from function signature schema = cls._create_schema_from_function(name, func) else: + msg = "Either args_schema must be provided or infer_schema must be True." raise ValueError( - "Either args_schema must be provided or infer_schema must be True." + msg, ) return cls( @@ -117,6 +124,7 @@ class CrewStructuredTool: Returns: A Pydantic model class + """ # Get function signature sig = inspect.signature(func) @@ -165,12 +173,15 @@ class CrewStructuredTool: # Only validate required parameters without defaults if param.default == inspect.Parameter.empty: if param_name not in schema_fields: - raise ValueError( + msg = ( f"Required function parameter '{param_name}' " f"not found in args_schema" ) + raise ValueError( + msg, + ) - def _parse_args(self, raw_args: Union[str, dict]) -> dict: + def _parse_args(self, raw_args: str | dict) -> dict: """Parse and validate the input arguments against the schema. Args: @@ -178,6 +189,7 @@ class CrewStructuredTool: Returns: The validated arguments as a dictionary + """ if isinstance(raw_args, str): try: @@ -185,18 +197,20 @@ class CrewStructuredTool: raw_args = json.loads(raw_args) except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse arguments as JSON: {e}") + msg = f"Failed to parse arguments as JSON: {e}" + raise ValueError(msg) try: validated_args = self.args_schema.model_validate(raw_args) return validated_args.model_dump() except Exception as e: - raise ValueError(f"Arguments validation failed: {e}") + msg = f"Arguments validation failed: {e}" + raise ValueError(msg) async def ainvoke( self, - input: Union[str, dict], - config: Optional[dict] = None, + input: str | dict, + config: dict | None = None, **kwargs: Any, ) -> Any: """Asynchronously invoke the tool. @@ -208,28 +222,28 @@ class CrewStructuredTool: Returns: The result of the tool execution + """ parsed_args = self._parse_args(input) if inspect.iscoroutinefunction(self.func): return await self.func(**parsed_args, **kwargs) - else: - # Run sync functions in a thread pool - import asyncio + # Run sync functions in a thread pool + import asyncio - return await asyncio.get_event_loop().run_in_executor( - None, lambda: self.func(**parsed_args, **kwargs) - ) + return await asyncio.get_event_loop().run_in_executor( + None, lambda: self.func(**parsed_args, **kwargs), + ) def _run(self, *args, **kwargs) -> Any: """Legacy method for compatibility.""" # Convert args/kwargs to our expected format - input_dict = dict(zip(self.args_schema.model_fields.keys(), args)) + input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False)) input_dict.update(kwargs) return self.invoke(input_dict) def invoke( - self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any + self, input: str | dict, config: dict | None = None, **kwargs: Any, ) -> Any: """Main method for tool execution.""" parsed_args = self._parse_args(input) diff --git a/src/crewai/tools/tool_calling.py b/src/crewai/tools/tool_calling.py index 16c5e0bfc..1206a1fd7 100644 --- a/src/crewai/tools/tool_calling.py +++ b/src/crewai/tools/tool_calling.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field from pydantic import BaseModel as PydanticBaseModel @@ -7,15 +7,15 @@ from pydantic import Field as PydanticField class ToolCalling(BaseModel): tool_name: str = Field(..., description="The name of the tool to be called.") - arguments: Optional[Dict[str, Any]] = Field( - ..., description="A dictionary of arguments to be passed to the tool." + arguments: dict[str, Any] | None = Field( + ..., description="A dictionary of arguments to be passed to the tool.", ) class InstructorToolCalling(PydanticBaseModel): tool_name: str = PydanticField( - ..., description="The name of the tool to be called." + ..., description="The name of the tool to be called.", ) - arguments: Optional[Dict[str, Any]] = PydanticField( - ..., description="A dictionary of arguments to be passed to the tool." + arguments: dict[str, Any] | None = PydanticField( + ..., description="A dictionary of arguments to be passed to the tool.", ) diff --git a/src/crewai/tools/tool_usage.py b/src/crewai/tools/tool_usage.py index dc5f8f29a..af8852038 100644 --- a/src/crewai/tools/tool_usage.py +++ b/src/crewai/tools/tool_usage.py @@ -1,11 +1,12 @@ import ast +import contextlib import datetime import json import time from difflib import SequenceMatcher from json import JSONDecodeError from textwrap import dedent -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Union import json5 from json_repair import repair_json @@ -53,8 +54,7 @@ class ToolUsageErrorException(Exception): class ToolUsage: - """ - Class that represents the usage of a tool by an agent. + """Class that represents the usage of a tool by an agent. Attributes: task: Task being executed. @@ -64,17 +64,18 @@ class ToolUsage: tools_description: Description of the tools available for the agent. tools_names: Names of the tools available for the agent. function_calling_llm: Language model to be used for the tool usage. + """ def __init__( self, - tools_handler: Optional[ToolsHandler], - tools: List[CrewStructuredTool], - task: Optional[Task], + tools_handler: ToolsHandler | None, + tools: list[CrewStructuredTool], + task: Task | None, function_calling_llm: Any, - agent: Optional[Union["BaseAgent", "LiteAgent"]] = None, + agent: Union["BaseAgent", "LiteAgent"] | None = None, action: Any = None, - fingerprint_context: Optional[Dict[str, str]] = None, + fingerprint_context: dict[str, str] | None = None, ) -> None: self._i18n: I18N = agent.i18n if agent else I18N() self._printer: Printer = Printer() @@ -105,7 +106,7 @@ class ToolUsage: return self._tool_calling(tool_string) def use( - self, calling: Union[ToolCalling, InstructorToolCalling], tool_string: str + self, calling: ToolCalling | InstructorToolCalling, tool_string: str, ) -> str: if isinstance(calling, ToolUsageErrorException): error = calling.message @@ -130,8 +131,7 @@ class ToolUsage: and tool.name == self._i18n.tools("add_image")["name"] # type: ignore ): try: - result = self._use(tool_string=tool_string, tool=tool, calling=calling) - return result + return self._use(tool_string=tool_string, tool=tool, calling=calling) except Exception as e: error = getattr(e, "message", str(e)) @@ -147,20 +147,19 @@ class ToolUsage: self, tool_string: str, tool: CrewStructuredTool, - calling: Union[ToolCalling, InstructorToolCalling], + calling: ToolCalling | InstructorToolCalling, ) -> str: if self._check_tool_repeated_usage(calling=calling): # type: ignore # _check_tool_repeated_usage of "ToolUsage" does not return a value (it only ever returns None) try: result = self._i18n.errors("task_repeated_usage").format( - tool_names=self.tools_names + tool_names=self.tools_names, ) self._telemetry.tool_repeated_usage( llm=self.function_calling_llm, tool_name=tool.name, attempts=self._run_attempts, ) - result = self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None) - return result # type: ignore # Fix the return type of this function + return self._format_result(result=result) # type: ignore # "_format_result" of "ToolUsage" does not return a value (it only ever returns None) except Exception: if self.task: @@ -180,14 +179,14 @@ class ToolUsage: event_data.update(self.agent.fingerprint) crewai_event_bus.emit(self,ToolUsageStartedEvent(**event_data)) - + started_at = time.time() from_cache = False result = None # type: ignore if self.tools_handler and self.tools_handler.cache: result = self.tools_handler.cache.read( - tool=calling.tool_name, input=calling.arguments + tool=calling.tool_name, input=calling.arguments, ) # type: ignore from_cache = result is not None @@ -240,16 +239,16 @@ class ToolUsage: if self._run_attempts > self._max_parsing_attempts: self._telemetry.tool_usage_error(llm=self.function_calling_llm) error_message = self._i18n.errors("tool_usage_exception").format( - error=e, tool=tool.name, tool_inputs=tool.description + error=e, tool=tool.name, tool_inputs=tool.description, ) error = ToolUsageErrorException( - f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}" + f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}", ).message if self.task: self.task.increment_tools_errors() if self.agent and self.agent.verbose: self._printer.print( - content=f"\n\n{error_message}\n", color="red" + content=f"\n\n{error_message}\n", color="red", ) return error # type: ignore # No return value expected @@ -264,11 +263,11 @@ class ToolUsage: and available_tool.cache_function # type: ignore # Item "None" of "Any | None" has no attribute "cache_function" ): should_cache = available_tool.cache_function( # type: ignore # Item "None" of "Any | None" has no attribute "cache_function" - calling.arguments, result + calling.arguments, result, ) self.tools_handler.on_tool_use( - calling=calling, output=result, should_cache=should_cache + calling=calling, output=result, should_cache=should_cache, ) self._telemetry.tool_usage( llm=self.function_calling_llm, @@ -317,12 +316,12 @@ class ToolUsage: def _remember_format(self, result: str) -> str: result = str(result) result += "\n\n" + self._i18n.slice("tools").format( - tools=self.tools_description, tool_names=self.tools_names + tools=self.tools_description, tool_names=self.tools_names, ) return result def _check_tool_repeated_usage( - self, calling: Union[ToolCalling, InstructorToolCalling] + self, calling: ToolCalling | InstructorToolCalling, ) -> bool: if not self.tools_handler: return False @@ -336,7 +335,7 @@ class ToolUsage: order_tools = sorted( self.tools, key=lambda tool: SequenceMatcher( - None, tool.name.lower().strip(), tool_name.lower().strip() + None, tool.name.lower().strip(), tool_name.lower().strip(), ).ratio(), reverse=True, ) @@ -344,14 +343,14 @@ class ToolUsage: if ( tool.name.lower().strip() == tool_name.lower().strip() or SequenceMatcher( - None, tool.name.lower().strip(), tool_name.lower().strip() + None, tool.name.lower().strip(), tool_name.lower().strip(), ).ratio() > 0.85 ): return tool if self.task: self.task.increment_tools_errors() - tool_selection_data: Dict[str, Any] = { + tool_selection_data: dict[str, Any] = { "agent_key": getattr(self.agent, "key", None) if self.agent else None, "agent_role": getattr(self.agent, "role", None) if self.agent else None, "tool_name": tool_name, @@ -368,16 +367,15 @@ class ToolUsage: ), ) raise Exception(error) - else: - error = f"I forgot the Action name, these are the only available Actions: {self.tools_description}" - crewai_event_bus.emit( - self, - ToolSelectionErrorEvent( - **tool_selection_data, - error=error, - ), - ) - raise Exception(error) + error = f"I forgot the Action name, these are the only available Actions: {self.tools_description}" + crewai_event_bus.emit( + self, + ToolSelectionErrorEvent( + **tool_selection_data, + error=error, + ), + ) + raise Exception(error) def _render(self) -> str: """Render the tool name and description in plain text.""" @@ -387,8 +385,8 @@ class ToolUsage: return "\n--\n".join(descriptions) def _function_calling( - self, tool_string: str - ) -> Union[ToolCalling, InstructorToolCalling]: + self, tool_string: str, + ) -> ToolCalling | InstructorToolCalling: model = ( InstructorToolCalling if self.function_calling_llm.supports_function_calling() @@ -411,13 +409,14 @@ class ToolUsage: ) tool_object = converter.to_pydantic() if not isinstance(tool_object, (ToolCalling, InstructorToolCalling)): - raise ToolUsageErrorException("Failed to parse tool calling") + msg = "Failed to parse tool calling" + raise ToolUsageErrorException(msg) return tool_object def _original_tool_calling( - self, tool_string: str, raise_error: bool = False - ) -> Union[ToolCalling, InstructorToolCalling, ToolUsageErrorException]: + self, tool_string: str, raise_error: bool = False, + ) -> ToolCalling | InstructorToolCalling | ToolUsageErrorException: tool_name = self.action.tool tool = self._select_tool(tool_name) try: @@ -426,18 +425,16 @@ class ToolUsage: except Exception: if raise_error: raise - else: - return ToolUsageErrorException( - f"{self._i18n.errors('tool_arguments_error')}" - ) + return ToolUsageErrorException( + f"{self._i18n.errors('tool_arguments_error')}", + ) if not isinstance(arguments, dict): if raise_error: raise - else: - return ToolUsageErrorException( - f"{self._i18n.errors('tool_arguments_error')}" - ) + return ToolUsageErrorException( + f"{self._i18n.errors('tool_arguments_error')}", + ) return ToolCalling( tool_name=tool.name, @@ -445,16 +442,15 @@ class ToolUsage: ) def _tool_calling( - self, tool_string: str - ) -> Union[ToolCalling, InstructorToolCalling, ToolUsageErrorException]: + self, tool_string: str, + ) -> ToolCalling | InstructorToolCalling | ToolUsageErrorException: try: try: return self._original_tool_calling(tool_string, raise_error=True) except Exception: if self.function_calling_llm: return self._function_calling(tool_string) - else: - return self._original_tool_calling(tool_string) + return self._original_tool_calling(tool_string) except Exception as e: self._run_attempts += 1 if self._run_attempts > self._max_parsing_attempts: @@ -464,17 +460,18 @@ class ToolUsage: if self.agent and self.agent.verbose: self._printer.print(content=f"\n\n{e}\n", color="red") return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling") - f"{self._i18n.errors('tool_usage_error').format(error=e)}\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}" + f"{self._i18n.errors('tool_usage_error').format(error=e)}\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}", ) return self._tool_calling(tool_string) - def _validate_tool_input(self, tool_input: Optional[str]) -> Dict[str, Any]: + def _validate_tool_input(self, tool_input: str | None) -> dict[str, Any]: if tool_input is None: return {} if not isinstance(tool_input, str) or not tool_input.strip(): + msg = "Tool input must be a valid dictionary in JSON or Python literal format" raise Exception( - "Tool input must be a valid dictionary in JSON or Python literal format" + msg, ) # Attempt 1: Parse as JSON @@ -492,7 +489,7 @@ class ToolUsage: return arguments except (ValueError, SyntaxError): repaired_input = repair_json(tool_input) - pass # Continue to the next parsing attempt + # Continue to the next parsing attempt # Attempt 3: Parse as JSON5 try: @@ -506,7 +503,7 @@ class ToolUsage: try: repaired_input = str(repair_json(tool_input, skip_json_loads=True)) self._printer.print( - content=f"Repaired JSON: {repaired_input}", color="blue" + content=f"Repaired JSON: {repaired_input}", color="blue", ) arguments = json.loads(repaired_input) if isinstance(arguments, dict): @@ -522,7 +519,7 @@ class ToolUsage: # If all parsing attempts fail, raise an error raise Exception(error_message) - def _emit_validate_input_error(self, final_error: str): + def _emit_validate_input_error(self, final_error: str) -> None: tool_selection_data = { "agent_key": getattr(self.agent, "key", None) if self.agent else None, "agent_role": getattr(self.agent, "role", None) if self.agent else None, @@ -544,7 +541,7 @@ class ToolUsage: def on_tool_error( self, tool: Any, - tool_calling: Union[ToolCalling, InstructorToolCalling], + tool_calling: ToolCalling | InstructorToolCalling, e: Exception, ) -> None: event_data = self._prepare_event_data(tool, tool_calling) @@ -553,7 +550,7 @@ class ToolUsage: def on_tool_use_finished( self, tool: Any, - tool_calling: Union[ToolCalling, InstructorToolCalling], + tool_calling: ToolCalling | InstructorToolCalling, from_cache: bool, started_at: float, result: Any, @@ -566,12 +563,12 @@ class ToolUsage: "finished_at": datetime.datetime.fromtimestamp(finished_at), "from_cache": from_cache, "output": result, - } + }, ) crewai_event_bus.emit(self, ToolUsageFinishedEvent(**event_data)) def _prepare_event_data( - self, tool: Any, tool_calling: Union[ToolCalling, InstructorToolCalling] + self, tool: Any, tool_calling: ToolCalling | InstructorToolCalling, ) -> dict: event_data = { "run_attempts": self._run_attempts, @@ -604,6 +601,7 @@ class ToolUsage: Returns: Updated arguments dictionary with fingerprint metadata + """ # Create a shallow copy to avoid modifying the original arguments = arguments.copy() @@ -618,22 +616,18 @@ class ToolUsage: if self.agent and hasattr(self.agent, "security_config"): security_config = getattr(self.agent, "security_config", None) if security_config and hasattr(security_config, "fingerprint"): - try: + with contextlib.suppress(AttributeError): security_context["agent_fingerprint"] = ( security_config.fingerprint.to_dict() ) - except AttributeError: - pass # Add task fingerprint if available if self.task and hasattr(self.task, "security_config"): security_config = getattr(self.task, "security_config", None) if security_config and hasattr(security_config, "fingerprint"): - try: + with contextlib.suppress(AttributeError): security_context["task_fingerprint"] = ( security_config.fingerprint.to_dict() ) - except AttributeError: - pass return arguments diff --git a/src/crewai/types/crew_chat.py b/src/crewai/types/crew_chat.py index 354642442..80d343005 100644 --- a/src/crewai/types/crew_chat.py +++ b/src/crewai/types/crew_chat.py @@ -1,16 +1,16 @@ -from typing import List from pydantic import BaseModel, Field class ChatInputField(BaseModel): - """ - Represents a single required input for the crew, with a name and short description. + """Represents a single required input for the crew, with a name and short description. + Example: { "name": "topic", "description": "The topic to focus on for the conversation" - } + }. + """ name: str = Field(..., description="The name of the input field") @@ -18,8 +18,8 @@ class ChatInputField(BaseModel): class ChatInputs(BaseModel): - """ - Holds a high-level crew_description plus a list of ChatInputFields. + """Holds a high-level crew_description plus a list of ChatInputFields. + Example: { "crew_name": "topic-based-qa", @@ -28,13 +28,14 @@ class ChatInputs(BaseModel): {"name": "topic", "description": "The topic to focus on"}, {"name": "username", "description": "Name of the user"}, ] - } + }. + """ crew_name: str = Field(..., description="The name of the crew") crew_description: str = Field( - ..., description="A description of the crew's purpose" + ..., description="A description of the crew's purpose", ) - inputs: List[ChatInputField] = Field( - default_factory=list, description="A list of input fields for the crew" + inputs: list[ChatInputField] = Field( + default_factory=list, description="A list of input fields for the crew", ) diff --git a/src/crewai/types/usage_metrics.py b/src/crewai/types/usage_metrics.py index e87a79e33..e96697456 100644 --- a/src/crewai/types/usage_metrics.py +++ b/src/crewai/types/usage_metrics.py @@ -2,8 +2,7 @@ from pydantic import BaseModel, Field class UsageMetrics(BaseModel): - """ - Model to track usage metrics for the crew's execution. + """Model to track usage metrics for the crew's execution. Attributes: total_tokens: Total number of tokens used. @@ -11,28 +10,29 @@ class UsageMetrics(BaseModel): cached_prompt_tokens: Number of cached prompt tokens used. completion_tokens: Number of tokens used in completions. successful_requests: Number of successful requests made. + """ total_tokens: int = Field(default=0, description="Total number of tokens used.") prompt_tokens: int = Field( - default=0, description="Number of tokens used in prompts." + default=0, description="Number of tokens used in prompts.", ) cached_prompt_tokens: int = Field( - default=0, description="Number of cached prompt tokens used." + default=0, description="Number of cached prompt tokens used.", ) completion_tokens: int = Field( - default=0, description="Number of tokens used in completions." + default=0, description="Number of tokens used in completions.", ) successful_requests: int = Field( - default=0, description="Number of successful requests made." + default=0, description="Number of successful requests made.", ) - def add_usage_metrics(self, usage_metrics: "UsageMetrics"): - """ - Add the usage metrics from another UsageMetrics object. + def add_usage_metrics(self, usage_metrics: "UsageMetrics") -> None: + """Add the usage metrics from another UsageMetrics object. Args: usage_metrics (UsageMetrics): The usage metrics to add. + """ self.total_tokens += usage_metrics.total_tokens self.prompt_tokens += usage_metrics.prompt_tokens diff --git a/src/crewai/utilities/agent_utils.py b/src/crewai/utilities/agent_utils.py index 8af665140..e96ecdc76 100644 --- a/src/crewai/utilities/agent_utils.py +++ b/src/crewai/utilities/agent_utils.py @@ -1,6 +1,7 @@ import json import re -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Any from crewai.agents.parser import ( FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE, @@ -21,7 +22,7 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( ) -def parse_tools(tools: List[BaseTool]) -> List[CrewStructuredTool]: +def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]: """Parse tools to be used for the task.""" tools_list = [] @@ -29,23 +30,24 @@ def parse_tools(tools: List[BaseTool]) -> List[CrewStructuredTool]: if isinstance(tool, CrewAITool): tools_list.append(tool.to_structured_tool()) else: - raise ValueError("Tool is not a CrewStructuredTool or BaseTool") + msg = "Tool is not a CrewStructuredTool or BaseTool" + raise ValueError(msg) return tools_list -def get_tool_names(tools: Sequence[Union[CrewStructuredTool, BaseTool]]) -> str: +def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str: """Get the names of the tools.""" return ", ".join([t.name for t in tools]) def render_text_description_and_args( - tools: Sequence[Union[CrewStructuredTool, BaseTool]], + tools: Sequence[CrewStructuredTool | BaseTool], ) -> str: """Render the tool name, description, and args in plain text. - - search: This tool is used for search, args: {"query": {"type": "string"}} - calculator: This tool is used for math, \ + + search: This tool is used for search, args: {"query": {"type": "string"}} + calculator: This tool is used for math, \ args: {"expression": {"type": "string"}} """ tool_strings = [] @@ -61,22 +63,24 @@ def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool: def handle_max_iterations_exceeded( - formatted_answer: Union[AgentAction, AgentFinish, None], + formatted_answer: AgentAction | AgentFinish | None, printer: Printer, i18n: I18N, - messages: List[Dict[str, str]], - llm: Union[LLM, BaseLLM], - callbacks: List[Any], -) -> Union[AgentAction, AgentFinish]: - """ - Handles the case when the maximum number of iterations is exceeded. + messages: list[dict[str, str]], + llm: LLM | BaseLLM, + callbacks: list[Any], +) -> AgentAction | AgentFinish: + """Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer. - Parameters: + Parameters + ---------- formatted_answer: The last formatted answer from the agent. - Returns: + Returns + ------- The final formatted answer after exceeding max iterations. + """ printer.print( content="Maximum iterations reached. Requesting final answer.", @@ -103,19 +107,19 @@ def handle_max_iterations_exceeded( content="Received None or empty response from LLM call.", color="red", ) - raise ValueError("Invalid response from LLM call - None or empty.") + msg = "Invalid response from LLM call - None or empty." + raise ValueError(msg) - formatted_answer = format_answer(answer) + return format_answer(answer) # Return the formatted answer, regardless of its type - return formatted_answer -def format_message_for_llm(prompt: str, role: str = "user") -> Dict[str, str]: +def format_message_for_llm(prompt: str, role: str = "user") -> dict[str, str]: prompt = prompt.rstrip() return {"role": role, "content": prompt} -def format_answer(answer: str) -> Union[AgentAction, AgentFinish]: +def format_answer(answer: str) -> AgentAction | AgentFinish: """Format a response from the LLM into an AgentAction or AgentFinish.""" try: return CrewAgentParser.parse_text(answer) @@ -129,7 +133,7 @@ def format_answer(answer: str) -> Union[AgentAction, AgentFinish]: def enforce_rpm_limit( - request_within_rpm_limit: Optional[Callable[[], bool]] = None, + request_within_rpm_limit: Callable[[], bool] | None = None, ) -> None: """Enforce the requests per minute (RPM) limit if applicable.""" if request_within_rpm_limit: @@ -137,9 +141,9 @@ def enforce_rpm_limit( def get_llm_response( - llm: Union[LLM, BaseLLM], - messages: List[Dict[str, str]], - callbacks: List[Any], + llm: LLM | BaseLLM, + messages: list[dict[str, str]], + callbacks: list[Any], printer: Printer, ) -> str: """Call the LLM and return the response, handling any invalid responses.""" @@ -153,20 +157,21 @@ def get_llm_response( content=f"Error during LLM call: {e}", color="red", ) - raise e + raise if not answer: printer.print( content="Received None or empty response from LLM call.", color="red", ) - raise ValueError("Invalid response from LLM call - None or empty.") + msg = "Invalid response from LLM call - None or empty." + raise ValueError(msg) return answer def process_llm_response( - answer: str, use_stop_words: bool -) -> Union[AgentAction, AgentFinish]: + answer: str, use_stop_words: bool, +) -> AgentAction | AgentFinish: """Process the LLM response and format it into an AgentAction or AgentFinish.""" if not use_stop_words: try: @@ -182,10 +187,10 @@ def process_llm_response( def handle_agent_action_core( formatted_answer: AgentAction, tool_result: ToolResult, - messages: Optional[List[Dict[str, str]]] = None, - step_callback: Optional[Callable] = None, - show_logs: Optional[Callable] = None, -) -> Union[AgentAction, AgentFinish]: + messages: list[dict[str, str]] | None = None, + step_callback: Callable | None = None, + show_logs: Callable | None = None, +) -> AgentAction | AgentFinish: """Core logic for handling agent actions and tool results. Args: @@ -197,6 +202,7 @@ def handle_agent_action_core( Returns: Either an AgentAction or AgentFinish + """ if step_callback: step_callback(tool_result) @@ -226,6 +232,7 @@ def handle_unknown_error(printer: Any, exception: Exception) -> None: Args: printer: Printer instance for output exception: The exception that occurred + """ printer.print( content="An unknown error occurred. Please check the details below.", @@ -239,10 +246,10 @@ def handle_unknown_error(printer: Any, exception: Exception) -> None: def handle_output_parser_exception( e: OutputParserException, - messages: List[Dict[str, str]], + messages: list[dict[str, str]], iterations: int, log_error_after: int = 3, - printer: Optional[Any] = None, + printer: Any | None = None, ) -> AgentAction: """Handle OutputParserException by updating messages and formatted_answer. @@ -255,6 +262,7 @@ def handle_output_parser_exception( Returns: AgentAction: A formatted answer with the error + """ messages.append({"role": "user", "content": e.error}) @@ -282,18 +290,19 @@ def is_context_length_exceeded(exception: Exception) -> bool: Returns: bool: True if the exception is due to context length exceeding + """ return LLMContextLengthExceededException(str(exception))._is_context_limit_error( - str(exception) + str(exception), ) def handle_context_length( respect_context_window: bool, printer: Any, - messages: List[Dict[str, str]], + messages: list[dict[str, str]], llm: Any, - callbacks: List[Any], + callbacks: list[Any], i18n: Any, ) -> None: """Handle context length exceeded by either summarizing or raising an error. @@ -305,6 +314,7 @@ def handle_context_length( llm: LLM instance for summarization callbacks: List of callbacks for LLM i18n: I18N instance for messages + """ if respect_context_window: printer.print( @@ -317,15 +327,16 @@ def handle_context_length( content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.", color="red", ) + msg = "Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools." raise SystemExit( - "Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools." + msg, ) def summarize_messages( - messages: List[Dict[str, str]], + messages: list[dict[str, str]], llm: Any, - callbacks: List[Any], + callbacks: list[Any], i18n: Any, ) -> None: """Summarize messages to fit within context window. @@ -335,6 +346,7 @@ def summarize_messages( llm: LLM instance for summarization callbacks: List of callbacks for LLM i18n: I18N instance for messages + """ messages_groups = [] for message in messages: @@ -348,7 +360,7 @@ def summarize_messages( summary = llm.call( [ format_message_for_llm( - i18n.slice("summarizer_system_message"), role="system" + i18n.slice("summarizer_system_message"), role="system", ), format_message_for_llm( i18n.slice("summarize_instruction").format(group=group["content"]), @@ -363,16 +375,16 @@ def summarize_messages( messages.clear() messages.append( format_message_for_llm( - i18n.slice("summary").format(merged_summary=merged_summary) - ) + i18n.slice("summary").format(merged_summary=merged_summary), + ), ) def show_agent_logs( printer: Printer, agent_role: str, - formatted_answer: Optional[Union[AgentAction, AgentFinish]] = None, - task_description: Optional[str] = None, + formatted_answer: AgentAction | AgentFinish | None = None, + task_description: str | None = None, verbose: bool = False, ) -> None: """Show agent logs for both start and execution states. @@ -383,6 +395,7 @@ def show_agent_logs( formatted_answer: Optional AgentAction or AgentFinish for execution logs task_description: Optional task description for start logs verbose: Whether to show verbose output + """ if not verbose: return @@ -392,16 +405,16 @@ def show_agent_logs( if formatted_answer is None: # Start logs printer.print( - content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m" + content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m", ) if task_description: printer.print( - content=f"\033[95m## Task:\033[00m \033[92m{task_description}\033[00m" + content=f"\033[95m## Task:\033[00m \033[92m{task_description}\033[00m", ) else: # Execution logs printer.print( - content=f"\n\n\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m" + content=f"\n\n\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m", ) if isinstance(formatted_answer, AgentAction): @@ -413,18 +426,18 @@ def show_agent_logs( ) if thought and thought != "": printer.print( - content=f"\033[95m## Thought:\033[00m \033[92m{thought}\033[00m" + content=f"\033[95m## Thought:\033[00m \033[92m{thought}\033[00m", ) printer.print( - content=f"\033[95m## Using tool:\033[00m \033[92m{formatted_answer.tool}\033[00m" + content=f"\033[95m## Using tool:\033[00m \033[92m{formatted_answer.tool}\033[00m", ) printer.print( - content=f"\033[95m## Tool Input:\033[00m \033[92m\n{formatted_json}\033[00m" + content=f"\033[95m## Tool Input:\033[00m \033[92m\n{formatted_json}\033[00m", ) printer.print( - content=f"\033[95m## Tool Output:\033[00m \033[92m\n{formatted_answer.result}\033[00m" + content=f"\033[95m## Tool Output:\033[00m \033[92m\n{formatted_answer.result}\033[00m", ) elif isinstance(formatted_answer, AgentFinish): printer.print( - content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n" + content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n", ) diff --git a/src/crewai/utilities/chromadb.py b/src/crewai/utilities/chromadb.py index d993a5896..4a12c0bbc 100644 --- a/src/crewai/utilities/chromadb.py +++ b/src/crewai/utilities/chromadb.py @@ -1,5 +1,4 @@ import re -from typing import Optional MIN_COLLECTION_LENGTH = 3 MAX_COLLECTION_LENGTH = 63 @@ -11,32 +10,32 @@ IPV4_PATTERN = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$") def is_ipv4_pattern(name: str) -> bool: - """ - Check if a string matches an IPv4 address pattern. + """Check if a string matches an IPv4 address pattern. Args: name: The string to check Returns: True if the string matches an IPv4 pattern, False otherwise + """ return bool(IPV4_PATTERN.match(name)) -def sanitize_collection_name(name: Optional[str]) -> str: - """ - Sanitize a collection name to meet ChromaDB requirements: +def sanitize_collection_name(name: str | None) -> str: + """Sanitize a collection name to meet ChromaDB requirements: 1. 3-63 characters long 2. Starts and ends with alphanumeric character 3. Contains only alphanumeric characters, underscores, or hyphens 4. No consecutive periods - 5. Not a valid IPv4 address + 5. Not a valid IPv4 address. Args: name: The original collection name to sanitize Returns: A sanitized collection name that meets ChromaDB requirements + """ if not name: return DEFAULT_COLLECTION diff --git a/src/crewai/utilities/config.py b/src/crewai/utilities/config.py index 156a3e66b..5e865e84d 100644 --- a/src/crewai/utilities/config.py +++ b/src/crewai/utilities/config.py @@ -1,13 +1,12 @@ -from typing import Any, Dict, Type +from typing import Any from pydantic import BaseModel def process_config( - values: Dict[str, Any], model_class: Type[BaseModel] -) -> Dict[str, Any]: - """ - Process the config dictionary and update the values accordingly. + values: dict[str, Any], model_class: type[BaseModel], +) -> dict[str, Any]: + """Process the config dictionary and update the values accordingly. Args: values (Dict[str, Any]): The dictionary of values to update. @@ -15,6 +14,7 @@ def process_config( Returns: Dict[str, Any]: The updated values dictionary. + """ config = values.get("config", {}) if not config: diff --git a/src/crewai/utilities/converter.py b/src/crewai/utilities/converter.py index a6144868e..73eb4c476 100644 --- a/src/crewai/utilities/converter.py +++ b/src/crewai/utilities/converter.py @@ -1,6 +1,6 @@ import json import re -from typing import Any, Optional, Type, Union, get_args, get_origin +from typing import Any, Union, get_args, get_origin from pydantic import BaseModel, ValidationError @@ -30,7 +30,7 @@ class Converter(OutputConverter): [ {"role": "system", "content": self.instructions}, {"role": "user", "content": self.text}, - ] + ], ) try: # Try to directly validate the response JSON @@ -47,25 +47,29 @@ class Converter(OutputConverter): parsed = json.loads(result) result = self.model.parse_obj(parsed) except Exception as parse_err: + msg = f"Failed to convert partial JSON result into Pydantic: {parse_err}" raise ConverterError( - f"Failed to convert partial JSON result into Pydantic: {parse_err}" + msg, ) else: + msg = "handle_partial_json returned an unexpected type." raise ConverterError( - "handle_partial_json returned an unexpected type." + msg, ) return result except ValidationError as e: if current_attempt < self.max_attempts: return self.to_pydantic(current_attempt + 1) + msg = f"Failed to convert text into a Pydantic model due to validation error: {e}" raise ConverterError( - f"Failed to convert text into a Pydantic model due to validation error: {e}" + msg, ) except Exception as e: if current_attempt < self.max_attempts: return self.to_pydantic(current_attempt + 1) + msg = f"Failed to convert text into a Pydantic model due to error: {e}" raise ConverterError( - f"Failed to convert text into a Pydantic model due to error: {e}" + msg, ) def to_json(self, current_attempt=1): @@ -73,15 +77,14 @@ class Converter(OutputConverter): try: if self.llm.supports_function_calling(): return self._create_instructor().to_json() - else: - return json.dumps( - self.llm.call( - [ - {"role": "system", "content": self.instructions}, - {"role": "user", "content": self.text}, - ] - ) - ) + return json.dumps( + self.llm.call( + [ + {"role": "system", "content": self.instructions}, + {"role": "user", "content": self.text}, + ], + ), + ) except Exception as e: if current_attempt < self.max_attempts: return self.to_json(current_attempt + 1) @@ -91,12 +94,11 @@ class Converter(OutputConverter): """Create an instructor.""" from crewai.utilities import InternalInstructor - inst = InternalInstructor( + return InternalInstructor( llm=self.llm, model=self.model, content=self.text, ) - return inst def _convert_with_instructions(self): """Create a chain.""" @@ -109,18 +111,18 @@ class Converter(OutputConverter): [ {"role": "system", "content": self.instructions}, {"role": "user", "content": self.text}, - ] + ], ) return parser.parse_result(result) def convert_to_model( result: str, - output_pydantic: Optional[Type[BaseModel]], - output_json: Optional[Type[BaseModel]], + output_pydantic: type[BaseModel] | None, + output_json: type[BaseModel] | None, agent: Any, - converter_cls: Optional[Type[Converter]] = None, -) -> Union[dict, BaseModel, str]: + converter_cls: type[Converter] | None = None, +) -> dict | BaseModel | str: model = output_pydantic or output_json if model is None: return result @@ -129,12 +131,12 @@ def convert_to_model( return validate_model(escaped_result, model, bool(output_json)) except json.JSONDecodeError: return handle_partial_json( - result, model, bool(output_json), agent, converter_cls + result, model, bool(output_json), agent, converter_cls, ) except ValidationError: return handle_partial_json( - result, model, bool(output_json), agent, converter_cls + result, model, bool(output_json), agent, converter_cls, ) except Exception as e: @@ -146,8 +148,8 @@ def convert_to_model( def validate_model( - result: str, model: Type[BaseModel], is_json_output: bool -) -> Union[dict, BaseModel]: + result: str, model: type[BaseModel], is_json_output: bool, +) -> dict | BaseModel: exported_result = model.model_validate_json(result) if is_json_output: return exported_result.model_dump() @@ -156,11 +158,11 @@ def validate_model( def handle_partial_json( result: str, - model: Type[BaseModel], + model: type[BaseModel], is_json_output: bool, agent: Any, - converter_cls: Optional[Type[Converter]] = None, -) -> Union[dict, BaseModel, str]: + converter_cls: type[Converter] | None = None, +) -> dict | BaseModel | str: match = re.search(r"({.*})", result, re.DOTALL) if match: try: @@ -179,17 +181,17 @@ def handle_partial_json( ) return convert_with_instructions( - result, model, is_json_output, agent, converter_cls + result, model, is_json_output, agent, converter_cls, ) def convert_with_instructions( result: str, - model: Type[BaseModel], + model: type[BaseModel], is_json_output: bool, agent: Any, - converter_cls: Optional[Type[Converter]] = None, -) -> Union[dict, BaseModel, str]: + converter_cls: type[Converter] | None = None, +) -> dict | BaseModel | str: llm = agent.function_calling_llm or agent.llm instructions = get_conversion_instructions(model, llm) converter = create_converter( @@ -214,7 +216,7 @@ def convert_with_instructions( return exported_result -def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str: +def get_conversion_instructions(model: type[BaseModel], llm: Any) -> str: instructions = "Please convert the following text into valid JSON." if llm and not isinstance(llm, str) and llm.supports_function_calling(): model_schema = PydanticSchemaParser(model=model).get_schema() @@ -232,8 +234,8 @@ def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str: def create_converter( - agent: Optional[Any] = None, - converter_cls: Optional[Type[Converter]] = None, + agent: Any | None = None, + converter_cls: type[Converter] | None = None, *args, **kwargs, ) -> Converter: @@ -241,21 +243,23 @@ def create_converter( if hasattr(agent, "get_output_converter"): converter = agent.get_output_converter(*args, **kwargs) else: - raise AttributeError("Agent does not have a 'get_output_converter' method") + msg = "Agent does not have a 'get_output_converter' method" + raise AttributeError(msg) elif converter_cls: converter = converter_cls(*args, **kwargs) else: - raise ValueError("Either agent or converter_cls must be provided") + msg = "Either agent or converter_cls must be provided" + raise ValueError(msg) if not converter: - raise Exception("No output converter found or set.") + msg = "No output converter found or set." + raise Exception(msg) return converter -def generate_model_description(model: Type[BaseModel]) -> str: - """ - Generate a string description of a Pydantic model's fields and their types. +def generate_model_description(model: type[BaseModel]) -> str: + """Generate a string description of a Pydantic model's fields and their types. This function takes a Pydantic model class and returns a string that describes the model's fields and their respective types. The description includes handling @@ -272,20 +276,18 @@ def generate_model_description(model: Type[BaseModel]) -> str: non_none_args = [arg for arg in args if arg is not type(None)] if len(non_none_args) == 1: return f"Optional[{describe_field(non_none_args[0])}]" - else: - return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]" - elif origin is list: + return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]" + if origin is list: return f"List[{describe_field(args[0])}]" - elif origin is dict: + if origin is dict: key_type = describe_field(args[0]) value_type = describe_field(args[1]) return f"Dict[{key_type}, {value_type}]" - elif isinstance(field_type, type) and issubclass(field_type, BaseModel): + if isinstance(field_type, type) and issubclass(field_type, BaseModel): return generate_model_description(field_type) - elif hasattr(field_type, "__name__"): + if hasattr(field_type, "__name__"): return field_type.__name__ - else: - return str(field_type) + return str(field_type) fields = model.model_fields field_descriptions = [ diff --git a/src/crewai/utilities/crew_json_encoder.py b/src/crewai/utilities/crew_json_encoder.py index 6e667431d..c1b910365 100644 --- a/src/crewai/utilities/crew_json_encoder.py +++ b/src/crewai/utilities/crew_json_encoder.py @@ -11,13 +11,14 @@ from pydantic import BaseModel class CrewJSONEncoder(json.JSONEncoder): """Custom JSON encoder for CrewAI objects and special types.""" + def default(self, obj): if isinstance(obj, BaseModel): return self._handle_pydantic_model(obj) - elif isinstance(obj, UUID) or isinstance(obj, Decimal) or isinstance(obj, Enum): + if isinstance(obj, (UUID, Decimal, Enum)): return str(obj) - elif isinstance(obj, datetime) or isinstance(obj, date): + if isinstance(obj, (datetime, date)): return obj.isoformat() return super().default(obj) @@ -29,10 +30,10 @@ class CrewJSONEncoder(json.JSONEncoder): for key, value in data.items(): if isinstance(value, BaseModel): data[key] = str( - value + value, ) # Convert nested models to string representation return data except RecursionError: return str( - obj + obj, ) # Fall back to string representation if circular reference is detected diff --git a/src/crewai/utilities/crew_pydantic_output_parser.py b/src/crewai/utilities/crew_pydantic_output_parser.py index d0dbfae06..0ee010894 100644 --- a/src/crewai/utilities/crew_pydantic_output_parser.py +++ b/src/crewai/utilities/crew_pydantic_output_parser.py @@ -1,5 +1,5 @@ import json -from typing import Any, Type +from typing import Any import regex from pydantic import BaseModel, ValidationError @@ -11,7 +11,7 @@ from crewai.agents.parser import OutputParserException class CrewPydanticOutputParser: """Parses text outputs into specified Pydantic models.""" - pydantic_object: Type[BaseModel] + pydantic_object: type[BaseModel] def parse_result(self, result: str) -> Any: result = self._transform_in_valid_json(result) diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index e523b60f0..c484d05d5 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,12 +1,12 @@ import os -from typing import Any, Dict, Optional, cast +from typing import Any, cast from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function class EmbeddingConfigurator: - def __init__(self): + def __init__(self) -> None: self.embedding_functions = { "openai": self._configure_openai, "azure": self._configure_azure, @@ -23,7 +23,7 @@ class EmbeddingConfigurator: def configure_embedder( self, - embedder_config: Optional[Dict[str, Any]] = None, + embedder_config: dict[str, Any] | None = None, ) -> EmbeddingFunction: """Configures and returns an embedding function based on the provided config.""" if embedder_config is None: @@ -34,8 +34,9 @@ class EmbeddingConfigurator: model_name = config.get("model") if provider != "custom" else None if provider not in self.embedding_functions: + msg = f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" raise Exception( - f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" + msg, ) embedding_function = self.embedding_functions[provider] @@ -52,7 +53,7 @@ class EmbeddingConfigurator: ) return OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" + api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small", ) @staticmethod @@ -178,8 +179,9 @@ class EmbeddingConfigurator: from ibm_watsonx_ai import Credentials from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams except ImportError as e: + msg = "IBM Watson dependencies are not installed. Please install them to use Watson embedding." raise ImportError( - "IBM Watson dependencies are not installed. Please install them to use Watson embedding." + msg, ) from e class WatsonEmbeddingFunction(EmbeddingFunction): @@ -196,17 +198,16 @@ class EmbeddingConfigurator: model_id=config.get("model"), params=embed_params, credentials=Credentials( - api_key=config.get("api_key"), url=config.get("api_url") + api_key=config.get("api_key"), url=config.get("api_url"), ), project_id=config.get("project_id"), ) try: embeddings = embedding.embed_documents(input) - return cast(Embeddings, embeddings) - except Exception as e: - print("Error during Watson embedding:", e) - raise e + return cast("Embeddings", embeddings) + except Exception: + raise return WatsonEmbeddingFunction() @@ -218,19 +219,23 @@ class EmbeddingConfigurator: validate_embedding_function(custom_embedder) return custom_embedder except Exception as e: - raise ValueError(f"Invalid custom embedding function: {str(e)}") + msg = f"Invalid custom embedding function: {e!s}" + raise ValueError(msg) elif callable(custom_embedder): try: instance = custom_embedder() if isinstance(instance, EmbeddingFunction): validate_embedding_function(instance) return instance + msg = "Custom embedder does not create an EmbeddingFunction instance" raise ValueError( - "Custom embedder does not create an EmbeddingFunction instance" + msg, ) except Exception as e: - raise ValueError(f"Error instantiating custom embedder: {str(e)}") + msg = f"Error instantiating custom embedder: {e!s}" + raise ValueError(msg) else: + msg = "Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one" raise ValueError( - "Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one" + msg, ) diff --git a/src/crewai/utilities/errors.py b/src/crewai/utilities/errors.py index f673c0600..1bd61ec5e 100644 --- a/src/crewai/utilities/errors.py +++ b/src/crewai/utilities/errors.py @@ -1,16 +1,16 @@ """Error message definitions for CrewAI database operations.""" -from typing import Optional class DatabaseOperationError(Exception): """Base exception class for database operation errors.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None) -> None: """Initialize the database operation error. Args: message: The error message to display original_error: The original exception that caused this error, if any + """ super().__init__(message) self.original_error = original_error @@ -35,5 +35,6 @@ class DatabaseError: Returns: The formatted error message + """ return template.format(str(error)) diff --git a/src/crewai/utilities/evaluators/crew_evaluator_handler.py b/src/crewai/utilities/evaluators/crew_evaluator_handler.py index 984dcf97f..a31adb2e8 100644 --- a/src/crewai/utilities/evaluators/crew_evaluator_handler.py +++ b/src/crewai/utilities/evaluators/crew_evaluator_handler.py @@ -14,26 +14,26 @@ from crewai.telemetry import Telemetry class TaskEvaluationPydanticOutput(BaseModel): quality: float = Field( - description="A score from 1 to 10 evaluating on completion, quality, and overall performance from the task_description and task_expected_output to the actual Task Output." + description="A score from 1 to 10 evaluating on completion, quality, and overall performance from the task_description and task_expected_output to the actual Task Output.", ) class CrewEvaluator: - """ - A class to evaluate the performance of the agents in the crew based on the tasks they have performed. + """A class to evaluate the performance of the agents in the crew based on the tasks they have performed. Attributes: crew (Crew): The crew of agents to evaluate. eval_llm (BaseLLM): Language model instance to use for evaluations tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task. iteration (int): The current iteration of the evaluation. + """ tasks_scores: defaultdict = defaultdict(list) run_execution_times: defaultdict = defaultdict(list) iteration: int = 0 - def __init__(self, crew, eval_llm: InstanceOf[BaseLLM]): + def __init__(self, crew, eval_llm: InstanceOf[BaseLLM]) -> None: self.crew = crew self.llm = eval_llm self._telemetry = Telemetry() @@ -56,7 +56,7 @@ class CrewEvaluator: ) def _evaluation_task( - self, evaluator_agent: Agent, task_to_evaluate: Task, task_output: str + self, evaluator_agent: Agent, task_to_evaluate: Task, task_output: str, ) -> Task: return Task( description=( @@ -76,8 +76,7 @@ class CrewEvaluator: self.iteration = iteration def print_crew_evaluation_result(self) -> None: - """ - Prints the evaluation result of the crew in a table. + """Prints the evaluation result of the crew in a table. A Crew with 2 tasks using the command crewai test -n 3 will output the following table: @@ -97,7 +96,7 @@ class CrewEvaluator: └────────────────────┴───────┴───────┴───────┴────────────┴──────────────────────────────┘ """ task_averages = [ - sum(scores) / len(scores) for scores in zip(*self.tasks_scores.values()) + sum(scores) / len(scores) for scores in zip(*self.tasks_scores.values(), strict=False) ] crew_average = sum(task_averages) / len(task_averages) @@ -151,13 +150,13 @@ class CrewEvaluator: ] execution_time_avg = int(sum(run_exec_times) / len(run_exec_times)) table.add_row( - "Execution Time (s)", *map(str, run_exec_times), f"{execution_time_avg}", "" + "Execution Time (s)", *map(str, run_exec_times), f"{execution_time_avg}", "", ) console = Console() console.print(table) - def evaluate(self, task_output: TaskOutput): + def evaluate(self, task_output: TaskOutput) -> None: """Evaluates the performance of the agents in the crew based on the tasks they have performed.""" current_task = None for task in self.crew.tasks: @@ -166,13 +165,14 @@ class CrewEvaluator: break if not current_task or not task_output: + msg = "Task to evaluate and task output are required for evaluation" raise ValueError( - "Task to evaluate and task output are required for evaluation" + msg, ) evaluator_agent = self._evaluator_agent() evaluation_task = self._evaluation_task( - evaluator_agent, current_task, task_output.raw + evaluator_agent, current_task, task_output.raw, ) evaluation_result = evaluation_task.execute_sync() @@ -186,7 +186,8 @@ class CrewEvaluator: ) self.tasks_scores[self.iteration].append(evaluation_result.pydantic.quality) self.run_execution_times[self.iteration].append( - current_task.execution_duration + current_task.execution_duration, ) else: - raise ValueError("Evaluation result is not in the expected format") + msg = "Evaluation result is not in the expected format" + raise ValueError(msg) diff --git a/src/crewai/utilities/evaluators/task_evaluator.py b/src/crewai/utilities/evaluators/task_evaluator.py index 6dde83c24..3cfc5a59c 100644 --- a/src/crewai/utilities/evaluators/task_evaluator.py +++ b/src/crewai/utilities/evaluators/task_evaluator.py @@ -1,4 +1,3 @@ -from typing import List from pydantic import BaseModel, Field @@ -11,41 +10,41 @@ class Entity(BaseModel): name: str = Field(description="The name of the entity.") type: str = Field(description="The type of the entity.") description: str = Field(description="Description of the entity.") - relationships: List[str] = Field(description="Relationships of the entity.") + relationships: list[str] = Field(description="Relationships of the entity.") class TaskEvaluation(BaseModel): - suggestions: List[str] = Field( - description="Suggestions to improve future similar tasks." + suggestions: list[str] = Field( + description="Suggestions to improve future similar tasks.", ) quality: float = Field( - description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task." + description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task.", ) - entities: List[Entity] = Field( - description="Entities extracted from the task output." + entities: list[Entity] = Field( + description="Entities extracted from the task output.", ) class TrainingTaskEvaluation(BaseModel): - suggestions: List[str] = Field( - description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions." + suggestions: list[str] = Field( + description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions.", ) quality: float = Field( - description="A score from 0 to 10 evaluating on completion, quality, and overall performance from the improved output to the initial output based on the human feedback." + description="A score from 0 to 10 evaluating on completion, quality, and overall performance from the improved output to the initial output based on the human feedback.", ) final_summary: str = Field( - description="A step by step action items to improve the next Agent based on the human-feedback and improved output." + description="A step by step action items to improve the next Agent based on the human-feedback and improved output.", ) class TaskEvaluator: - def __init__(self, original_agent): + def __init__(self, original_agent) -> None: self.llm = original_agent.llm self.original_agent = original_agent def evaluate(self, task, output) -> TaskEvaluation: crewai_event_bus.emit( - self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task) + self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task), ) evaluation_query = ( f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n" @@ -74,17 +73,18 @@ class TaskEvaluator: return converter.to_pydantic() def evaluate_training_data( - self, training_data: dict, agent_id: str + self, training_data: dict, agent_id: str, ) -> TrainingTaskEvaluation: - """ - Evaluate the training data based on the llm output, human feedback, and improved output. + """Evaluate the training data based on the llm output, human feedback, and improved output. - Parameters: + Parameters + ---------- - training_data (dict): The training data to be evaluated. - agent_id (str): The ID of the agent. + """ crewai_event_bus.emit( - self, TaskEvaluationEvent(evaluation_type="training_data_evaluation") + self, TaskEvaluationEvent(evaluation_type="training_data_evaluation"), ) output_training_data = training_data[agent_id] @@ -129,7 +129,7 @@ class TaskEvaluator: if not self.llm.supports_function_calling(): model_schema = PydanticSchemaParser( - model=TrainingTaskEvaluation + model=TrainingTaskEvaluation, ).get_schema() instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}" @@ -140,5 +140,4 @@ class TaskEvaluator: instructions=instructions, ) - pydantic_result = converter.to_pydantic() - return pydantic_result + return converter.to_pydantic() diff --git a/src/crewai/utilities/events/agent_events.py b/src/crewai/utilities/events/agent_events.py index 51b8d2122..6638670c4 100644 --- a/src/crewai/utilities/events/agent_events.py +++ b/src/crewai/utilities/events/agent_events.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.tools.base_tool import BaseTool @@ -6,22 +7,19 @@ from crewai.tools.structured_tool import CrewStructuredTool from .base_events import BaseEvent -if TYPE_CHECKING: - from crewai.agents.agent_builder.base_agent import BaseAgent - class AgentExecutionStartedEvent(BaseEvent): - """Event emitted when an agent starts executing a task""" + """Event emitted when an agent starts executing a task.""" agent: BaseAgent task: Any - tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]] + tools: Sequence[BaseTool | CrewStructuredTool] | None task_prompt: str type: str = "agent_execution_started" model_config = {"arbitrary_types_allowed": True} - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) # Set fingerprint data from the agent if hasattr(self.agent, "fingerprint") and self.agent.fingerprint: @@ -35,14 +33,14 @@ class AgentExecutionStartedEvent(BaseEvent): class AgentExecutionCompletedEvent(BaseEvent): - """Event emitted when an agent completes executing a task""" + """Event emitted when an agent completes executing a task.""" agent: BaseAgent task: Any output: str type: str = "agent_execution_completed" - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) # Set fingerprint data from the agent if hasattr(self.agent, "fingerprint") and self.agent.fingerprint: @@ -56,14 +54,14 @@ class AgentExecutionCompletedEvent(BaseEvent): class AgentExecutionErrorEvent(BaseEvent): - """Event emitted when an agent encounters an error during execution""" + """Event emitted when an agent encounters an error during execution.""" agent: BaseAgent task: Any error: str type: str = "agent_execution_error" - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) # Set fingerprint data from the agent if hasattr(self.agent, "fingerprint") and self.agent.fingerprint: @@ -78,27 +76,27 @@ class AgentExecutionErrorEvent(BaseEvent): # New event classes for LiteAgent class LiteAgentExecutionStartedEvent(BaseEvent): - """Event emitted when a LiteAgent starts executing""" + """Event emitted when a LiteAgent starts executing.""" - agent_info: Dict[str, Any] - tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]] - messages: Union[str, List[Dict[str, str]]] + agent_info: dict[str, Any] + tools: Sequence[BaseTool | CrewStructuredTool] | None + messages: str | list[dict[str, str]] type: str = "lite_agent_execution_started" model_config = {"arbitrary_types_allowed": True} class LiteAgentExecutionCompletedEvent(BaseEvent): - """Event emitted when a LiteAgent completes execution""" + """Event emitted when a LiteAgent completes execution.""" - agent_info: Dict[str, Any] + agent_info: dict[str, Any] output: str type: str = "lite_agent_execution_completed" class LiteAgentExecutionErrorEvent(BaseEvent): - """Event emitted when a LiteAgent encounters an error during execution""" + """Event emitted when a LiteAgent encounters an error during execution.""" - agent_info: Dict[str, Any] + agent_info: dict[str, Any] error: str type: str = "lite_agent_execution_error" diff --git a/src/crewai/utilities/events/base_event_listener.py b/src/crewai/utilities/events/base_event_listener.py index f08b70025..12b5d03a0 100644 --- a/src/crewai/utilities/events/base_event_listener.py +++ b/src/crewai/utilities/events/base_event_listener.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from logging import Logger from crewai.utilities.events.crewai_event_bus import CrewAIEventsBus, crewai_event_bus @@ -7,7 +6,7 @@ from crewai.utilities.events.crewai_event_bus import CrewAIEventsBus, crewai_eve class BaseEventListener(ABC): verbose: bool = False - def __init__(self): + def __init__(self) -> None: super().__init__() self.setup_listeners(crewai_event_bus) diff --git a/src/crewai/utilities/events/base_events.py b/src/crewai/utilities/events/base_events.py index 46648500b..867f82bc3 100644 --- a/src/crewai/utilities/events/base_events.py +++ b/src/crewai/utilities/events/base_events.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -7,22 +7,22 @@ from crewai.utilities.serialization import to_serializable class BaseEvent(BaseModel): - """Base class for all events""" + """Base class for all events.""" timestamp: datetime = Field(default_factory=datetime.now) type: str - source_fingerprint: Optional[str] = None # UUID string of the source entity - source_type: Optional[str] = None # "agent", "task", "crew" - fingerprint_metadata: Optional[Dict[str, Any]] = None # Any relevant metadata + source_fingerprint: str | None = None # UUID string of the source entity + source_type: str | None = None # "agent", "task", "crew" + fingerprint_metadata: dict[str, Any] | None = None # Any relevant metadata def to_json(self, exclude: set[str] | None = None): - """ - Converts the event to a JSON-serializable dictionary. + """Converts the event to a JSON-serializable dictionary. Args: exclude (set[str], optional): Set of keys to exclude from the result. Defaults to None. Returns: dict: A JSON-serializable dictionary. + """ return to_serializable(self, exclude=exclude) diff --git a/src/crewai/utilities/events/crew_events.py b/src/crewai/utilities/events/crew_events.py index d73cd95d3..4fc3cbacd 100644 --- a/src/crewai/utilities/events/crew_events.py +++ b/src/crewai/utilities/events/crew_events.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any from crewai.utilities.events.base_events import BaseEvent @@ -9,12 +9,12 @@ else: class CrewBaseEvent(BaseEvent): - """Base class for crew events with fingerprint handling""" + """Base class for crew events with fingerprint handling.""" - crew_name: Optional[str] - crew: Optional[Crew] = None + crew_name: str | None + crew: Crew | None = None - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) self.set_crew_fingerprint() @@ -36,37 +36,37 @@ class CrewBaseEvent(BaseEvent): class CrewKickoffStartedEvent(CrewBaseEvent): - """Event emitted when a crew starts execution""" + """Event emitted when a crew starts execution.""" - inputs: Optional[Dict[str, Any]] + inputs: dict[str, Any] | None type: str = "crew_kickoff_started" class CrewKickoffCompletedEvent(CrewBaseEvent): - """Event emitted when a crew completes execution""" + """Event emitted when a crew completes execution.""" output: Any type: str = "crew_kickoff_completed" class CrewKickoffFailedEvent(CrewBaseEvent): - """Event emitted when a crew fails to complete execution""" + """Event emitted when a crew fails to complete execution.""" error: str type: str = "crew_kickoff_failed" class CrewTrainStartedEvent(CrewBaseEvent): - """Event emitted when a crew starts training""" + """Event emitted when a crew starts training.""" n_iterations: int filename: str - inputs: Optional[Dict[str, Any]] + inputs: dict[str, Any] | None type: str = "crew_train_started" class CrewTrainCompletedEvent(CrewBaseEvent): - """Event emitted when a crew completes training""" + """Event emitted when a crew completes training.""" n_iterations: int filename: str @@ -74,29 +74,29 @@ class CrewTrainCompletedEvent(CrewBaseEvent): class CrewTrainFailedEvent(CrewBaseEvent): - """Event emitted when a crew fails to complete training""" + """Event emitted when a crew fails to complete training.""" error: str type: str = "crew_train_failed" class CrewTestStartedEvent(CrewBaseEvent): - """Event emitted when a crew starts testing""" + """Event emitted when a crew starts testing.""" n_iterations: int - eval_llm: Optional[Union[str, Any]] - inputs: Optional[Dict[str, Any]] + eval_llm: str | Any | None + inputs: dict[str, Any] | None type: str = "crew_test_started" class CrewTestCompletedEvent(CrewBaseEvent): - """Event emitted when a crew completes testing""" + """Event emitted when a crew completes testing.""" type: str = "crew_test_completed" class CrewTestFailedEvent(CrewBaseEvent): - """Event emitted when a crew fails to complete testing""" + """Event emitted when a crew fails to complete testing.""" error: str type: str = "crew_test_failed" diff --git a/src/crewai/utilities/events/crewai_event_bus.py b/src/crewai/utilities/events/crewai_event_bus.py index f255e5513..bcd745137 100644 --- a/src/crewai/utilities/events/crewai_event_bus.py +++ b/src/crewai/utilities/events/crewai_event_bus.py @@ -1,6 +1,7 @@ import threading -from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Type, TypeVar, cast +from collections.abc import Callable +from contextlib import contextmanager, suppress +from typing import Any, TypeVar, cast from blinker import Signal @@ -11,8 +12,7 @@ EventT = TypeVar("EventT", bound=BaseEvent) class CrewAIEventsBus: - """ - A singleton event bus that uses blinker signals for event handling. + """A singleton event bus that uses blinker signals for event handling. Allows both internal (Flow/Crew) and external event handling. """ @@ -23,20 +23,19 @@ class CrewAIEventsBus: if cls._instance is None: with cls._lock: if cls._instance is None: # prevent race condition - cls._instance = super(CrewAIEventsBus, cls).__new__(cls) + cls._instance = super().__new__(cls) cls._instance._initialize() return cls._instance def _initialize(self) -> None: - """Initialize the event bus internal state""" + """Initialize the event bus internal state.""" self._signal = Signal("crewai_event_bus") - self._handlers: Dict[Type[BaseEvent], List[Callable]] = {} + self._handlers: dict[type[BaseEvent], list[Callable]] = {} def on( - self, event_type: Type[EventT] + self, event_type: type[EventT], ) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]: - """ - Decorator to register an event handler for a specific event type. + """Decorator to register an event handler for a specific event type. Usage: @crewai_event_bus.on(AgentExecutionCompletedEvent) @@ -53,46 +52,41 @@ class CrewAIEventsBus: if event_type not in self._handlers: self._handlers[event_type] = [] self._handlers[event_type].append( - cast(Callable[[Any, EventT], None], handler) + cast("Callable[[Any, EventT], None]", handler), ) return handler return decorator def emit(self, source: Any, event: BaseEvent) -> None: - """ - Emit an event to all registered handlers + """Emit an event to all registered handlers. Args: source: The object emitting the event event: The event instance to emit + """ for event_type, handlers in self._handlers.items(): if isinstance(event, event_type): for handler in handlers: - try: + with suppress(Exception): handler(source, event) - except Exception as e: - print( - f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}" - ) self._signal.send(source, event=event) def register_handler( - self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None] + self, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None], ) -> None: - """Register an event handler for a specific event type""" + """Register an event handler for a specific event type.""" if event_type not in self._handlers: self._handlers[event_type] = [] self._handlers[event_type].append( - cast(Callable[[Any, EventTypes], None], handler) + cast("Callable[[Any, EventTypes], None]", handler), ) @contextmanager def scoped_handlers(self): - """ - Context manager for temporary event handling scope. + """Context manager for temporary event handling scope. Useful for testing or temporary event handling. Usage: diff --git a/src/crewai/utilities/events/event_listener.py b/src/crewai/utilities/events/event_listener.py index a76b87964..66809b54a 100644 --- a/src/crewai/utilities/events/event_listener.py +++ b/src/crewai/utilities/events/event_listener.py @@ -1,5 +1,5 @@ from io import StringIO -from typing import Any, Dict +from typing import Any from pydantic import Field, PrivateAttr @@ -62,7 +62,7 @@ class EventListener(BaseEventListener): _instance = None _telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry()) logger = Logger(verbose=True, default_color=EMITTER_COLOR) - execution_spans: Dict[Task, Any] = Field(default_factory=dict) + execution_spans: dict[Task, Any] = Field(default_factory=dict) next_chunk = 0 text_stream = StringIO() knowledge_retrieval_in_progress = False @@ -74,7 +74,7 @@ class EventListener(BaseEventListener): cls._instance._initialized = False return cls._instance - def __init__(self): + def __init__(self) -> None: if not hasattr(self, "_initialized") or not self._initialized: super().__init__() self._telemetry = Telemetry() @@ -85,14 +85,14 @@ class EventListener(BaseEventListener): # ----------- CREW EVENTS ----------- - def setup_listeners(self, crewai_event_bus): + def setup_listeners(self, crewai_event_bus) -> None: @crewai_event_bus.on(CrewKickoffStartedEvent) - def on_crew_started(source, event: CrewKickoffStartedEvent): + def on_crew_started(source, event: CrewKickoffStartedEvent) -> None: self.formatter.create_crew_tree(event.crew_name or "Crew", source.id) self._telemetry.crew_execution_span(source, event.inputs) @crewai_event_bus.on(CrewKickoffCompletedEvent) - def on_crew_completed(source, event: CrewKickoffCompletedEvent): + def on_crew_completed(source, event: CrewKickoffCompletedEvent) -> None: # Handle telemetry final_string_output = event.output.raw self._telemetry.end_crew(source, final_string_output) @@ -105,7 +105,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(CrewKickoffFailedEvent) - def on_crew_failed(source, event: CrewKickoffFailedEvent): + def on_crew_failed(source, event: CrewKickoffFailedEvent) -> None: self.formatter.update_crew_tree( self.formatter.current_crew_tree, event.crew_name or "Crew", @@ -114,33 +114,33 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(CrewTrainStartedEvent) - def on_crew_train_started(source, event: CrewTrainStartedEvent): + def on_crew_train_started(source, event: CrewTrainStartedEvent) -> None: self.formatter.handle_crew_train_started( - event.crew_name or "Crew", str(event.timestamp) + event.crew_name or "Crew", str(event.timestamp), ) @crewai_event_bus.on(CrewTrainCompletedEvent) - def on_crew_train_completed(source, event: CrewTrainCompletedEvent): + def on_crew_train_completed(source, event: CrewTrainCompletedEvent) -> None: self.formatter.handle_crew_train_completed( - event.crew_name or "Crew", str(event.timestamp) + event.crew_name or "Crew", str(event.timestamp), ) @crewai_event_bus.on(CrewTrainFailedEvent) - def on_crew_train_failed(source, event: CrewTrainFailedEvent): + def on_crew_train_failed(source, event: CrewTrainFailedEvent) -> None: self.formatter.handle_crew_train_failed(event.crew_name or "Crew") # ----------- TASK EVENTS ----------- @crewai_event_bus.on(TaskStartedEvent) - def on_task_started(source, event: TaskStartedEvent): + def on_task_started(source, event: TaskStartedEvent) -> None: span = self._telemetry.task_started(crew=source.agent.crew, task=source) self.execution_spans[source] = span self.formatter.create_task_branch( - self.formatter.current_crew_tree, source.id + self.formatter.current_crew_tree, source.id, ) @crewai_event_bus.on(TaskCompletedEvent) - def on_task_completed(source, event: TaskCompletedEvent): + def on_task_completed(source, event: TaskCompletedEvent) -> None: # Handle telemetry span = self.execution_spans.get(source) if span: @@ -155,7 +155,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(TaskFailedEvent) - def on_task_failed(source, event: TaskFailedEvent): + def on_task_failed(source, event: TaskFailedEvent) -> None: span = self.execution_spans.get(source) if span: if source.agent and source.agent.crew: @@ -172,7 +172,7 @@ class EventListener(BaseEventListener): # ----------- AGENT EVENTS ----------- @crewai_event_bus.on(AgentExecutionStartedEvent) - def on_agent_execution_started(source, event: AgentExecutionStartedEvent): + def on_agent_execution_started(source, event: AgentExecutionStartedEvent) -> None: self.formatter.create_agent_branch( self.formatter.current_task_branch, event.agent.role, @@ -180,7 +180,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(AgentExecutionCompletedEvent) - def on_agent_execution_completed(source, event: AgentExecutionCompletedEvent): + def on_agent_execution_completed(source, event: AgentExecutionCompletedEvent) -> None: self.formatter.update_agent_status( self.formatter.current_agent_branch, event.agent.role, @@ -191,24 +191,24 @@ class EventListener(BaseEventListener): @crewai_event_bus.on(LiteAgentExecutionStartedEvent) def on_lite_agent_execution_started( - source, event: LiteAgentExecutionStartedEvent - ): + source, event: LiteAgentExecutionStartedEvent, + ) -> None: """Handle LiteAgent execution started event.""" self.formatter.handle_lite_agent_execution( - event.agent_info["role"], status="started", **event.agent_info + event.agent_info["role"], status="started", **event.agent_info, ) @crewai_event_bus.on(LiteAgentExecutionCompletedEvent) def on_lite_agent_execution_completed( - source, event: LiteAgentExecutionCompletedEvent - ): + source, event: LiteAgentExecutionCompletedEvent, + ) -> None: """Handle LiteAgent execution completed event.""" self.formatter.handle_lite_agent_execution( - event.agent_info["role"], status="completed", **event.agent_info + event.agent_info["role"], status="completed", **event.agent_info, ) @crewai_event_bus.on(LiteAgentExecutionErrorEvent) - def on_lite_agent_execution_error(source, event: LiteAgentExecutionErrorEvent): + def on_lite_agent_execution_error(source, event: LiteAgentExecutionErrorEvent) -> None: """Handle LiteAgent execution error event.""" self.formatter.handle_lite_agent_execution( event.agent_info["role"], @@ -220,25 +220,25 @@ class EventListener(BaseEventListener): # ----------- FLOW EVENTS ----------- @crewai_event_bus.on(FlowCreatedEvent) - def on_flow_created(source, event: FlowCreatedEvent): + def on_flow_created(source, event: FlowCreatedEvent) -> None: self._telemetry.flow_creation_span(event.flow_name) self.formatter.create_flow_tree(event.flow_name, str(source.flow_id)) @crewai_event_bus.on(FlowStartedEvent) - def on_flow_started(source, event: FlowStartedEvent): + def on_flow_started(source, event: FlowStartedEvent) -> None: self._telemetry.flow_execution_span( - event.flow_name, list(source._methods.keys()) + event.flow_name, list(source._methods.keys()), ) self.formatter.start_flow(event.flow_name, str(source.flow_id)) @crewai_event_bus.on(FlowFinishedEvent) - def on_flow_finished(source, event: FlowFinishedEvent): + def on_flow_finished(source, event: FlowFinishedEvent) -> None: self.formatter.update_flow_status( - self.formatter.current_flow_tree, event.flow_name, source.flow_id + self.formatter.current_flow_tree, event.flow_name, source.flow_id, ) @crewai_event_bus.on(MethodExecutionStartedEvent) - def on_method_execution_started(source, event: MethodExecutionStartedEvent): + def on_method_execution_started(source, event: MethodExecutionStartedEvent) -> None: self.formatter.update_method_status( self.formatter.current_method_branch, self.formatter.current_flow_tree, @@ -247,7 +247,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(MethodExecutionFinishedEvent) - def on_method_execution_finished(source, event: MethodExecutionFinishedEvent): + def on_method_execution_finished(source, event: MethodExecutionFinishedEvent) -> None: self.formatter.update_method_status( self.formatter.current_method_branch, self.formatter.current_flow_tree, @@ -256,7 +256,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(MethodExecutionFailedEvent) - def on_method_execution_failed(source, event: MethodExecutionFailedEvent): + def on_method_execution_failed(source, event: MethodExecutionFailedEvent) -> None: self.formatter.update_method_status( self.formatter.current_method_branch, self.formatter.current_flow_tree, @@ -267,7 +267,7 @@ class EventListener(BaseEventListener): # ----------- TOOL USAGE EVENTS ----------- @crewai_event_bus.on(ToolUsageStartedEvent) - def on_tool_usage_started(source, event: ToolUsageStartedEvent): + def on_tool_usage_started(source, event: ToolUsageStartedEvent) -> None: self.formatter.handle_tool_usage_started( self.formatter.current_agent_branch, event.tool_name, @@ -275,7 +275,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(ToolUsageFinishedEvent) - def on_tool_usage_finished(source, event: ToolUsageFinishedEvent): + def on_tool_usage_finished(source, event: ToolUsageFinishedEvent) -> None: self.formatter.handle_tool_usage_finished( self.formatter.current_tool_branch, event.tool_name, @@ -283,7 +283,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(ToolUsageErrorEvent) - def on_tool_usage_error(source, event: ToolUsageErrorEvent): + def on_tool_usage_error(source, event: ToolUsageErrorEvent) -> None: self.formatter.handle_tool_usage_error( self.formatter.current_tool_branch, event.tool_name, @@ -294,14 +294,14 @@ class EventListener(BaseEventListener): # ----------- LLM EVENTS ----------- @crewai_event_bus.on(LLMCallStartedEvent) - def on_llm_call_started(source, event: LLMCallStartedEvent): + def on_llm_call_started(source, event: LLMCallStartedEvent) -> None: self.formatter.handle_llm_call_started( self.formatter.current_agent_branch, self.formatter.current_crew_tree, ) @crewai_event_bus.on(LLMCallCompletedEvent) - def on_llm_call_completed(source, event: LLMCallCompletedEvent): + def on_llm_call_completed(source, event: LLMCallCompletedEvent) -> None: self.formatter.handle_llm_call_completed( self.formatter.current_tool_branch, self.formatter.current_agent_branch, @@ -309,7 +309,7 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(LLMCallFailedEvent) - def on_llm_call_failed(source, event: LLMCallFailedEvent): + def on_llm_call_failed(source, event: LLMCallFailedEvent) -> None: self.formatter.handle_llm_call_failed( self.formatter.current_tool_branch, event.error, @@ -317,18 +317,17 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(LLMStreamChunkEvent) - def on_llm_stream_chunk(source, event: LLMStreamChunkEvent): + def on_llm_stream_chunk(source, event: LLMStreamChunkEvent) -> None: self.text_stream.write(event.chunk) self.text_stream.seek(self.next_chunk) # Read from the in-memory stream - content = self.text_stream.read() - print(content, end="", flush=True) + self.text_stream.read() self.next_chunk = self.text_stream.tell() @crewai_event_bus.on(CrewTestStartedEvent) - def on_crew_test_started(source, event: CrewTestStartedEvent): + def on_crew_test_started(source, event: CrewTestStartedEvent) -> None: cloned_crew = source.copy() self._telemetry.test_execution_span( cloned_crew, @@ -338,24 +337,24 @@ class EventListener(BaseEventListener): ) self.formatter.handle_crew_test_started( - event.crew_name or "Crew", source.id, event.n_iterations + event.crew_name or "Crew", source.id, event.n_iterations, ) @crewai_event_bus.on(CrewTestCompletedEvent) - def on_crew_test_completed(source, event: CrewTestCompletedEvent): + def on_crew_test_completed(source, event: CrewTestCompletedEvent) -> None: self.formatter.handle_crew_test_completed( self.formatter.current_flow_tree, event.crew_name or "Crew", ) @crewai_event_bus.on(CrewTestFailedEvent) - def on_crew_test_failed(source, event: CrewTestFailedEvent): + def on_crew_test_failed(source, event: CrewTestFailedEvent) -> None: self.formatter.handle_crew_test_failed(event.crew_name or "Crew") @crewai_event_bus.on(KnowledgeRetrievalStartedEvent) def on_knowledge_retrieval_started( - source, event: KnowledgeRetrievalStartedEvent - ): + source, event: KnowledgeRetrievalStartedEvent, + ) -> None: if self.knowledge_retrieval_in_progress: return @@ -368,8 +367,8 @@ class EventListener(BaseEventListener): @crewai_event_bus.on(KnowledgeRetrievalCompletedEvent) def on_knowledge_retrieval_completed( - source, event: KnowledgeRetrievalCompletedEvent - ): + source, event: KnowledgeRetrievalCompletedEvent, + ) -> None: if not self.knowledge_retrieval_in_progress: return @@ -381,11 +380,11 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(KnowledgeQueryStartedEvent) - def on_knowledge_query_started(source, event: KnowledgeQueryStartedEvent): + def on_knowledge_query_started(source, event: KnowledgeQueryStartedEvent) -> None: pass @crewai_event_bus.on(KnowledgeQueryFailedEvent) - def on_knowledge_query_failed(source, event: KnowledgeQueryFailedEvent): + def on_knowledge_query_failed(source, event: KnowledgeQueryFailedEvent) -> None: self.formatter.handle_knowledge_query_failed( self.formatter.current_agent_branch, event.error, @@ -393,13 +392,13 @@ class EventListener(BaseEventListener): ) @crewai_event_bus.on(KnowledgeQueryCompletedEvent) - def on_knowledge_query_completed(source, event: KnowledgeQueryCompletedEvent): + def on_knowledge_query_completed(source, event: KnowledgeQueryCompletedEvent) -> None: pass @crewai_event_bus.on(KnowledgeSearchQueryFailedEvent) def on_knowledge_search_query_failed( - source, event: KnowledgeSearchQueryFailedEvent - ): + source, event: KnowledgeSearchQueryFailedEvent, + ) -> None: self.formatter.handle_knowledge_search_query_failed( self.formatter.current_agent_branch, event.error, diff --git a/src/crewai/utilities/events/flow_events.py b/src/crewai/utilities/events/flow_events.py index 7f48215e9..ee93c84e9 100644 --- a/src/crewai/utilities/events/flow_events.py +++ b/src/crewai/utilities/events/flow_events.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any from pydantic import BaseModel, ConfigDict @@ -6,49 +6,49 @@ from .base_events import BaseEvent class FlowEvent(BaseEvent): - """Base class for all flow events""" + """Base class for all flow events.""" type: str flow_name: str class FlowStartedEvent(FlowEvent): - """Event emitted when a flow starts execution""" + """Event emitted when a flow starts execution.""" flow_name: str - inputs: Optional[Dict[str, Any]] = None + inputs: dict[str, Any] | None = None type: str = "flow_started" class FlowCreatedEvent(FlowEvent): - """Event emitted when a flow is created""" + """Event emitted when a flow is created.""" flow_name: str type: str = "flow_created" class MethodExecutionStartedEvent(FlowEvent): - """Event emitted when a flow method starts execution""" + """Event emitted when a flow method starts execution.""" flow_name: str method_name: str - state: Union[Dict[str, Any], BaseModel] - params: Optional[Dict[str, Any]] = None + state: dict[str, Any] | BaseModel + params: dict[str, Any] | None = None type: str = "method_execution_started" class MethodExecutionFinishedEvent(FlowEvent): - """Event emitted when a flow method completes execution""" + """Event emitted when a flow method completes execution.""" flow_name: str method_name: str result: Any = None - state: Union[Dict[str, Any], BaseModel] + state: dict[str, Any] | BaseModel type: str = "method_execution_finished" class MethodExecutionFailedEvent(FlowEvent): - """Event emitted when a flow method fails execution""" + """Event emitted when a flow method fails execution.""" flow_name: str method_name: str @@ -59,15 +59,15 @@ class MethodExecutionFailedEvent(FlowEvent): class FlowFinishedEvent(FlowEvent): - """Event emitted when a flow completes execution""" + """Event emitted when a flow completes execution.""" flow_name: str - result: Optional[Any] = None + result: Any | None = None type: str = "flow_finished" class FlowPlotEvent(FlowEvent): - """Event emitted when a flow plot is created""" + """Event emitted when a flow plot is created.""" flow_name: str type: str = "flow_plot" diff --git a/src/crewai/utilities/events/knowledge_events.py b/src/crewai/utilities/events/knowledge_events.py index e512ca575..9e8c7b520 100644 --- a/src/crewai/utilities/events/knowledge_events.py +++ b/src/crewai/utilities/events/knowledge_events.py @@ -1,11 +1,8 @@ -from typing import TYPE_CHECKING, Any +from typing import Any from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.utilities.events.base_events import BaseEvent -if TYPE_CHECKING: - from crewai.agents.agent_builder.base_agent import BaseAgent - class KnowledgeRetrievalStartedEvent(BaseEvent): """Event emitted when a knowledge retrieval is started.""" diff --git a/src/crewai/utilities/events/llm_events.py b/src/crewai/utilities/events/llm_events.py index ca8d0367a..ba2936855 100644 --- a/src/crewai/utilities/events/llm_events.py +++ b/src/crewai/utilities/events/llm_events.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any from pydantic import BaseModel @@ -7,29 +7,30 @@ from crewai.utilities.events.base_events import BaseEvent class LLMCallType(Enum): - """Type of LLM call being made""" + """Type of LLM call being made.""" TOOL_CALL = "tool_call" LLM_CALL = "llm_call" class LLMCallStartedEvent(BaseEvent): - """Event emitted when a LLM call starts + """Event emitted when a LLM call starts. Attributes: messages: Content can be either a string or a list of dictionaries that support multimodal content (text, images, etc.) + """ type: str = "llm_call_started" - messages: Union[str, List[Dict[str, Any]]] - tools: Optional[List[dict]] = None - callbacks: Optional[List[Any]] = None - available_functions: Optional[Dict[str, Any]] = None + messages: str | list[dict[str, Any]] + tools: list[dict] | None = None + callbacks: list[Any] | None = None + available_functions: dict[str, Any] | None = None class LLMCallCompletedEvent(BaseEvent): - """Event emitted when a LLM call completes""" + """Event emitted when a LLM call completes.""" type: str = "llm_call_completed" response: Any @@ -37,7 +38,7 @@ class LLMCallCompletedEvent(BaseEvent): class LLMCallFailedEvent(BaseEvent): - """Event emitted when a LLM call fails""" + """Event emitted when a LLM call fails.""" error: str type: str = "llm_call_failed" @@ -45,19 +46,19 @@ class LLMCallFailedEvent(BaseEvent): class FunctionCall(BaseModel): arguments: str - name: Optional[str] = None + name: str | None = None class ToolCall(BaseModel): - id: Optional[str] = None + id: str | None = None function: FunctionCall - type: Optional[str] = None + type: str | None = None index: int class LLMStreamChunkEvent(BaseEvent): - """Event emitted when a streaming chunk is received""" + """Event emitted when a streaming chunk is received.""" type: str = "llm_stream_chunk" chunk: str - tool_call: Optional[ToolCall] = None + tool_call: ToolCall | None = None diff --git a/src/crewai/utilities/events/llm_guardrail_events.py b/src/crewai/utilities/events/llm_guardrail_events.py index a484c187a..164b0e8f6 100644 --- a/src/crewai/utilities/events/llm_guardrail_events.py +++ b/src/crewai/utilities/events/llm_guardrail_events.py @@ -1,21 +1,23 @@ -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any from crewai.utilities.events.base_events import BaseEvent class LLMGuardrailStartedEvent(BaseEvent): - """Event emitted when a guardrail task starts + """Event emitted when a guardrail task starts. Attributes: guardrail: The guardrail callable or LLMGuardrail instance retry_count: The number of times the guardrail has been retried + """ type: str = "llm_guardrail_started" - guardrail: Union[str, Callable] + guardrail: str | Callable retry_count: int - def __init__(self, **data): + def __init__(self, **data) -> None: from inspect import getsource from crewai.tasks.llm_guardrail import LLMGuardrail @@ -29,10 +31,10 @@ class LLMGuardrailStartedEvent(BaseEvent): class LLMGuardrailCompletedEvent(BaseEvent): - """Event emitted when a guardrail task completes""" + """Event emitted when a guardrail task completes.""" type: str = "llm_guardrail_completed" success: bool result: Any - error: Optional[str] = None + error: str | None = None retry_count: int diff --git a/src/crewai/utilities/events/task_events.py b/src/crewai/utilities/events/task_events.py index 1bf5baf8c..1c5d377f3 100644 --- a/src/crewai/utilities/events/task_events.py +++ b/src/crewai/utilities/events/task_events.py @@ -1,17 +1,17 @@ -from typing import Any, Optional +from typing import Any from crewai.tasks.task_output import TaskOutput from crewai.utilities.events.base_events import BaseEvent class TaskStartedEvent(BaseEvent): - """Event emitted when a task starts""" + """Event emitted when a task starts.""" type: str = "task_started" - context: Optional[str] - task: Optional[Any] = None + context: str | None + task: Any | None = None - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) # Set fingerprint data from the task if hasattr(self.task, "fingerprint") and self.task.fingerprint: @@ -25,13 +25,13 @@ class TaskStartedEvent(BaseEvent): class TaskCompletedEvent(BaseEvent): - """Event emitted when a task completes""" + """Event emitted when a task completes.""" output: TaskOutput type: str = "task_completed" - task: Optional[Any] = None + task: Any | None = None - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) # Set fingerprint data from the task if hasattr(self.task, "fingerprint") and self.task.fingerprint: @@ -45,13 +45,13 @@ class TaskCompletedEvent(BaseEvent): class TaskFailedEvent(BaseEvent): - """Event emitted when a task fails""" + """Event emitted when a task fails.""" error: str type: str = "task_failed" - task: Optional[Any] = None + task: Any | None = None - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) # Set fingerprint data from the task if hasattr(self.task, "fingerprint") and self.task.fingerprint: @@ -65,13 +65,13 @@ class TaskFailedEvent(BaseEvent): class TaskEvaluationEvent(BaseEvent): - """Event emitted when a task evaluation is completed""" + """Event emitted when a task evaluation is completed.""" type: str = "task_evaluation" evaluation_type: str - task: Optional[Any] = None + task: Any | None = None - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) # Set fingerprint data from the task if hasattr(self.task, "fingerprint") and self.task.fingerprint: diff --git a/src/crewai/utilities/events/third_party/agentops_listener.py b/src/crewai/utilities/events/third_party/agentops_listener.py index 294a820ee..b0ce3e3a9 100644 --- a/src/crewai/utilities/events/third_party/agentops_listener.py +++ b/src/crewai/utilities/events/third_party/agentops_listener.py @@ -21,15 +21,15 @@ class AgentOpsListener(BaseEventListener): tool_event: Optional["agentops.ToolEvent"] = None session: Optional["agentops.Session"] = None - def __init__(self): + def __init__(self) -> None: super().__init__() - def setup_listeners(self, crewai_event_bus): + def setup_listeners(self, crewai_event_bus) -> None: if not AGENTOPS_INSTALLED: return @crewai_event_bus.on(CrewKickoffStartedEvent) - def on_crew_kickoff_started(source, event: CrewKickoffStartedEvent): + def on_crew_kickoff_started(source, event: CrewKickoffStartedEvent) -> None: self.session = agentops.init() for agent in source.agents: if self.session: @@ -39,7 +39,7 @@ class AgentOpsListener(BaseEventListener): ) @crewai_event_bus.on(CrewKickoffCompletedEvent) - def on_crew_kickoff_completed(source, event: CrewKickoffCompletedEvent): + def on_crew_kickoff_completed(source, event: CrewKickoffCompletedEvent) -> None: if self.session: self.session.end_session( end_state="Success", @@ -47,20 +47,20 @@ class AgentOpsListener(BaseEventListener): ) @crewai_event_bus.on(ToolUsageStartedEvent) - def on_tool_usage_started(source, event: ToolUsageStartedEvent): + def on_tool_usage_started(source, event: ToolUsageStartedEvent) -> None: self.tool_event = agentops.ToolEvent(name=event.tool_name) if self.session: self.session.record(self.tool_event) @crewai_event_bus.on(ToolUsageErrorEvent) - def on_tool_usage_error(source, event: ToolUsageErrorEvent): + def on_tool_usage_error(source, event: ToolUsageErrorEvent) -> None: agentops.ErrorEvent(exception=event.error, trigger_event=self.tool_event) @crewai_event_bus.on(TaskEvaluationEvent) - def on_task_evaluation(source, event: TaskEvaluationEvent): + def on_task_evaluation(source, event: TaskEvaluationEvent) -> None: if self.session: self.session.create_agent( - name="Task Evaluator", agent_id=str(source.original_agent.id) + name="Task Evaluator", agent_id=str(source.original_agent.id), ) diff --git a/src/crewai/utilities/events/tool_usage_events.py b/src/crewai/utilities/events/tool_usage_events.py index 8ab22f667..2cb41ad26 100644 --- a/src/crewai/utilities/events/tool_usage_events.py +++ b/src/crewai/utilities/events/tool_usage_events.py @@ -1,24 +1,25 @@ +from collections.abc import Callable from datetime import datetime -from typing import Any, Callable, Dict, Optional +from typing import Any from .base_events import BaseEvent class ToolUsageEvent(BaseEvent): - """Base event for tool usage tracking""" + """Base event for tool usage tracking.""" agent_key: str agent_role: str tool_name: str - tool_args: Dict[str, Any] | str + tool_args: dict[str, Any] | str tool_class: str run_attempts: int | None = None delegations: int | None = None - agent: Optional[Any] = None + agent: Any | None = None model_config = {"arbitrary_types_allowed": True} - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) # Set fingerprint data from the agent if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint: @@ -32,13 +33,13 @@ class ToolUsageEvent(BaseEvent): class ToolUsageStartedEvent(ToolUsageEvent): - """Event emitted when a tool execution is started""" + """Event emitted when a tool execution is started.""" type: str = "tool_usage_started" class ToolUsageFinishedEvent(ToolUsageEvent): - """Event emitted when a tool execution is completed""" + """Event emitted when a tool execution is completed.""" started_at: datetime finished_at: datetime @@ -48,37 +49,37 @@ class ToolUsageFinishedEvent(ToolUsageEvent): class ToolUsageErrorEvent(ToolUsageEvent): - """Event emitted when a tool execution encounters an error""" + """Event emitted when a tool execution encounters an error.""" error: Any type: str = "tool_usage_error" class ToolValidateInputErrorEvent(ToolUsageEvent): - """Event emitted when a tool input validation encounters an error""" + """Event emitted when a tool input validation encounters an error.""" error: Any type: str = "tool_validate_input_error" class ToolSelectionErrorEvent(ToolUsageEvent): - """Event emitted when a tool selection encounters an error""" + """Event emitted when a tool selection encounters an error.""" error: Any type: str = "tool_selection_error" class ToolExecutionErrorEvent(BaseEvent): - """Event emitted when a tool execution encounters an error""" + """Event emitted when a tool execution encounters an error.""" error: Any type: str = "tool_execution_error" tool_name: str - tool_args: Dict[str, Any] + tool_args: dict[str, Any] tool_class: Callable - agent: Optional[Any] = None + agent: Any | None = None - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) # Set fingerprint data from the agent if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint: diff --git a/src/crewai/utilities/events/utils/console_formatter.py b/src/crewai/utilities/events/utils/console_formatter.py index b9adc9fda..7f9abfcb2 100644 --- a/src/crewai/utilities/events/utils/console_formatter.py +++ b/src/crewai/utilities/events/utils/console_formatter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any from rich.console import Console from rich.panel import Panel @@ -7,16 +7,16 @@ from rich.tree import Tree class ConsoleFormatter: - current_crew_tree: Optional[Tree] = None - current_task_branch: Optional[Tree] = None - current_agent_branch: Optional[Tree] = None - current_tool_branch: Optional[Tree] = None - current_flow_tree: Optional[Tree] = None - current_method_branch: Optional[Tree] = None - current_lite_agent_branch: Optional[Tree] = None - tool_usage_counts: Dict[str, int] = {} + current_crew_tree: Tree | None = None + current_task_branch: Tree | None = None + current_agent_branch: Tree | None = None + current_tool_branch: Tree | None = None + current_flow_tree: Tree | None = None + current_method_branch: Tree | None = None + current_lite_agent_branch: Tree | None = None + tool_usage_counts: dict[str, int] = {} - def __init__(self, verbose: bool = False): + def __init__(self, verbose: bool = False) -> None: self.console = Console(width=None) self.verbose = verbose @@ -30,7 +30,7 @@ class ConsoleFormatter: ) def create_status_content( - self, title: str, name: str, status_style: str = "blue", **fields + self, title: str, name: str, status_style: str = "blue", **fields, ) -> Text: """Create standardized status content with consistent formatting.""" content = Text() @@ -41,7 +41,7 @@ class ConsoleFormatter: for label, value in fields.items(): content.append(f"{label}: ", style="white") content.append( - f"{value}\n", style=fields.get(f"{label}_style", status_style) + f"{value}\n", style=fields.get(f"{label}_style", status_style), ) return content @@ -52,7 +52,7 @@ class ConsoleFormatter: prefix: str, name: str, style: str = "blue", - status: Optional[str] = None, + status: str | None = None, ) -> None: """Update tree label with consistent formatting.""" label = Text() @@ -72,21 +72,17 @@ class ConsoleFormatter: self.console.print(*args, **kwargs) def print_panel( - self, content: Text, title: str, style: str = "blue", is_flow: bool = False + self, content: Text, title: str, style: str = "blue", is_flow: bool = False, ) -> None: """Print a panel with consistent formatting if verbose is enabled.""" panel = self.create_panel(content, title, style) - if is_flow: + if is_flow or self.verbose: self.print(panel) self.print() - else: - if self.verbose: - self.print(panel) - self.print() def update_crew_tree( self, - tree: Optional[Tree], + tree: Tree | None, crew_name: str, source_id: str, status: str = "completed", @@ -124,13 +120,13 @@ class ConsoleFormatter: self.print_panel(content, title, style) - def create_crew_tree(self, crew_name: str, source_id: str) -> Optional[Tree]: + def create_crew_tree(self, crew_name: str, source_id: str) -> Tree | None: """Create and initialize a new crew tree with initial status.""" if not self.verbose: return None tree = Tree( - Text("🚀 Crew: ", style="cyan bold") + Text(crew_name, style="cyan") + Text("🚀 Crew: ", style="cyan bold") + Text(crew_name, style="cyan"), ) content = self.create_status_content( @@ -148,8 +144,8 @@ class ConsoleFormatter: return tree def create_task_branch( - self, crew_tree: Optional[Tree], task_id: str - ) -> Optional[Tree]: + self, crew_tree: Tree | None, task_id: str, + ) -> Tree | None: """Create and initialize a task branch.""" if not self.verbose: return None @@ -175,7 +171,7 @@ class ConsoleFormatter: def update_task_status( self, - crew_tree: Optional[Tree], + crew_tree: Tree | None, task_id: str, agent_role: str, status: str = "completed", @@ -208,20 +204,20 @@ class ConsoleFormatter: # Show status panel content = self.create_status_content( - f"Task {status.title()}", str(task_id), style, Agent=agent_role + f"Task {status.title()}", str(task_id), style, Agent=agent_role, ) self.print_panel(content, panel_title, style) def create_agent_branch( - self, task_branch: Optional[Tree], agent_role: str, crew_tree: Optional[Tree] - ) -> Optional[Tree]: + self, task_branch: Tree | None, agent_role: str, crew_tree: Tree | None, + ) -> Tree | None: """Create and initialize an agent branch.""" if not self.verbose or not task_branch or not crew_tree: return None agent_branch = task_branch.add("") self.update_tree_label( - agent_branch, "🤖 Agent:", agent_role, "green", "In Progress" + agent_branch, "🤖 Agent:", agent_role, "green", "In Progress", ) self.print(crew_tree) @@ -234,9 +230,9 @@ class ConsoleFormatter: def update_agent_status( self, - agent_branch: Optional[Tree], + agent_branch: Tree | None, agent_role: str, - crew_tree: Optional[Tree], + crew_tree: Tree | None, status: str = "completed", ) -> None: """Update agent status in the tree.""" @@ -254,10 +250,10 @@ class ConsoleFormatter: self.print(crew_tree) self.print() - def create_flow_tree(self, flow_name: str, flow_id: str) -> Optional[Tree]: + def create_flow_tree(self, flow_name: str, flow_id: str) -> Tree | None: """Create and initialize a flow tree.""" content = self.create_status_content( - "Starting Flow Execution", flow_name, "blue", ID=flow_id + "Starting Flow Execution", flow_name, "blue", ID=flow_id, ) self.print_panel(content, "Flow Execution", "blue", is_flow=True) @@ -274,7 +270,7 @@ class ConsoleFormatter: return flow_tree - def start_flow(self, flow_name: str, flow_id: str) -> Optional[Tree]: + def start_flow(self, flow_name: str, flow_id: str) -> Tree | None: """Initialize a flow execution tree.""" flow_tree = Tree("") flow_label = Text() @@ -294,7 +290,7 @@ class ConsoleFormatter: def update_flow_status( self, - flow_tree: Optional[Tree], + flow_tree: Tree | None, flow_name: str, flow_id: str, status: str = "completed", @@ -336,16 +332,16 @@ class ConsoleFormatter: ) self.print(flow_tree) self.print_panel( - content, "Flow Completion", "green" if status == "completed" else "red" + content, "Flow Completion", "green" if status == "completed" else "red", ) def update_method_status( self, - method_branch: Optional[Tree], - flow_tree: Optional[Tree], + method_branch: Tree | None, + flow_tree: Tree | None, method_name: str, status: str = "running", - ) -> Optional[Tree]: + ) -> Tree | None: """Update method status in the flow tree.""" if not flow_tree: return None @@ -377,7 +373,7 @@ class ConsoleFormatter: method_branch = flow_tree.add("") method_branch.label = Text(prefix, style=f"{style} bold") + Text( - f" {method_name}", style=style + f" {method_name}", style=style, ) self.print(flow_tree) @@ -386,10 +382,10 @@ class ConsoleFormatter: def handle_tool_usage_started( self, - agent_branch: Optional[Tree], + agent_branch: Tree | None, tool_name: str, - crew_tree: Optional[Tree], - ) -> Optional[Tree]: + crew_tree: Tree | None, + ) -> Tree | None: """Handle tool usage started event.""" if not self.verbose: return None @@ -427,9 +423,9 @@ class ConsoleFormatter: def handle_tool_usage_finished( self, - tool_branch: Optional[Tree], + tool_branch: Tree | None, tool_name: str, - crew_tree: Optional[Tree], + crew_tree: Tree | None, ) -> None: """Handle tool usage finished event.""" if not self.verbose or tool_branch is None: @@ -458,10 +454,10 @@ class ConsoleFormatter: def handle_tool_usage_error( self, - tool_branch: Optional[Tree], + tool_branch: Tree | None, tool_name: str, error: str, - crew_tree: Optional[Tree], + crew_tree: Tree | None, ) -> None: """Handle tool usage error event.""" if not self.verbose: @@ -483,15 +479,15 @@ class ConsoleFormatter: # Show error panel error_content = self.create_status_content( - "Tool Usage Failed", tool_name, "red", Error=error + "Tool Usage Failed", tool_name, "red", Error=error, ) self.print_panel(error_content, "Tool Error", "red") def handle_llm_call_started( self, - agent_branch: Optional[Tree], - crew_tree: Optional[Tree], - ) -> Optional[Tree]: + agent_branch: Tree | None, + crew_tree: Tree | None, + ) -> Tree | None: """Handle LLM call started event.""" if not self.verbose: return None @@ -515,9 +511,9 @@ class ConsoleFormatter: def handle_llm_call_completed( self, - tool_branch: Optional[Tree], - agent_branch: Optional[Tree], - crew_tree: Optional[Tree], + tool_branch: Tree | None, + agent_branch: Tree | None, + crew_tree: Tree | None, ) -> None: """Handle LLM call completed event.""" if not self.verbose or tool_branch is None: @@ -543,7 +539,7 @@ class ConsoleFormatter: pass def handle_llm_call_failed( - self, tool_branch: Optional[Tree], error: str, crew_tree: Optional[Tree] + self, tool_branch: Tree | None, error: str, crew_tree: Tree | None, ) -> None: """Handle LLM call failed event.""" if not self.verbose: @@ -568,8 +564,8 @@ class ConsoleFormatter: self.print_panel(error_content, "LLM Error", "red") def handle_crew_test_started( - self, crew_name: str, source_id: str, n_iterations: int - ) -> Optional[Tree]: + self, crew_name: str, source_id: str, n_iterations: int, + ) -> Tree | None: """Handle crew test started event.""" if not self.verbose: return None @@ -603,7 +599,7 @@ class ConsoleFormatter: return test_tree def handle_crew_test_completed( - self, flow_tree: Optional[Tree], crew_name: str + self, flow_tree: Tree | None, crew_name: str, ) -> None: """Handle crew test completed event.""" if not self.verbose: @@ -693,7 +689,7 @@ class ConsoleFormatter: self.print_panel(failure_content, "Test Failure", "red") self.print() - def create_lite_agent_branch(self, lite_agent_role: str) -> Optional[Tree]: + def create_lite_agent_branch(self, lite_agent_role: str) -> Tree | None: """Create and initialize a lite agent branch.""" if not self.verbose: return None @@ -715,10 +711,10 @@ class ConsoleFormatter: def update_lite_agent_status( self, - lite_agent_branch: Optional[Tree], + lite_agent_branch: Tree | None, lite_agent_role: str, status: str = "completed", - **fields: Dict[str, Any], + **fields: dict[str, Any], ) -> None: """Update lite agent status in the tree.""" if not self.verbose or lite_agent_branch is None: @@ -752,7 +748,7 @@ class ConsoleFormatter: # Show status panel if additional fields are provided if fields: content = self.create_status_content( - f"LiteAgent {status.title()}", lite_agent_role, style, **fields + f"LiteAgent {status.title()}", lite_agent_role, style, **fields, ) self.print_panel(content, title, style) @@ -761,7 +757,7 @@ class ConsoleFormatter: lite_agent_role: str, status: str = "started", error: Any = None, - **fields: Dict[str, Any], + **fields: dict[str, Any], ) -> None: """Handle lite agent execution events with consistent formatting.""" if not self.verbose: @@ -773,7 +769,7 @@ class ConsoleFormatter: if lite_agent_branch and fields: # Show initial status panel content = self.create_status_content( - "LiteAgent Session Started", lite_agent_role, "cyan", **fields + "LiteAgent Session Started", lite_agent_role, "cyan", **fields, ) self.print_panel(content, "LiteAgent Started", "cyan") else: @@ -781,14 +777,14 @@ class ConsoleFormatter: if error: fields["Error"] = error self.update_lite_agent_status( - self.current_lite_agent_branch, lite_agent_role, status, **fields + self.current_lite_agent_branch, lite_agent_role, status, **fields, ) def handle_knowledge_retrieval_started( self, - agent_branch: Optional[Tree], - crew_tree: Optional[Tree], - ) -> Optional[Tree]: + agent_branch: Tree | None, + crew_tree: Tree | None, + ) -> Tree | None: """Handle knowledge retrieval started event.""" if not self.verbose: return None @@ -805,7 +801,7 @@ class ConsoleFormatter: knowledge_branch = branch_to_use.add("") self.update_tree_label( - knowledge_branch, "🔍", "Knowledge Retrieval Started", "blue" + knowledge_branch, "🔍", "Knowledge Retrieval Started", "blue", ) self.print(tree_to_use) @@ -814,13 +810,13 @@ class ConsoleFormatter: def handle_knowledge_retrieval_completed( self, - agent_branch: Optional[Tree], - crew_tree: Optional[Tree], + agent_branch: Tree | None, + crew_tree: Tree | None, retrieved_knowledge: Any, ) -> None: """Handle knowledge retrieval completed event.""" if not self.verbose: - return None + return branch_to_use = self.current_lite_agent_branch or agent_branch tree_to_use = branch_to_use or crew_tree @@ -842,13 +838,13 @@ class ConsoleFormatter: ) self.print(knowledge_panel) self.print() - return None + return knowledge_branch_found = False for child in branch_to_use.children: if "Knowledge Retrieval Started" in str(child.label): self.update_tree_label( - child, "✅", "Knowledge Retrieval Completed", "green" + child, "✅", "Knowledge Retrieval Completed", "green", ) knowledge_branch_found = True break @@ -861,7 +857,7 @@ class ConsoleFormatter: and "Completed" not in str(child.label) ): self.update_tree_label( - child, "✅", "Knowledge Retrieval Completed", "green" + child, "✅", "Knowledge Retrieval Completed", "green", ) knowledge_branch_found = True break @@ -869,7 +865,7 @@ class ConsoleFormatter: if not knowledge_branch_found: knowledge_branch = branch_to_use.add("") self.update_tree_label( - knowledge_branch, "✅", "Knowledge Retrieval Completed", "green" + knowledge_branch, "✅", "Knowledge Retrieval Completed", "green", ) self.print(tree_to_use) @@ -891,22 +887,22 @@ class ConsoleFormatter: def handle_knowledge_query_started( self, - agent_branch: Optional[Tree], + agent_branch: Tree | None, task_prompt: str, - crew_tree: Optional[Tree], + crew_tree: Tree | None, ) -> None: """Handle knowledge query generated event.""" if not self.verbose: - return None + return branch_to_use = self.current_lite_agent_branch or agent_branch tree_to_use = branch_to_use or crew_tree if branch_to_use is None or tree_to_use is None: - return None + return query_branch = branch_to_use.add("") self.update_tree_label( - query_branch, "🔎", f"Query: {task_prompt[:50]}...", "yellow" + query_branch, "🔎", f"Query: {task_prompt[:50]}...", "yellow", ) self.print(tree_to_use) @@ -914,9 +910,9 @@ class ConsoleFormatter: def handle_knowledge_query_failed( self, - agent_branch: Optional[Tree], + agent_branch: Tree | None, error: str, - crew_tree: Optional[Tree], + crew_tree: Tree | None, ) -> None: """Handle knowledge query failed event.""" if not self.verbose: @@ -933,24 +929,24 @@ class ConsoleFormatter: # Show error panel error_content = self.create_status_content( - "Knowledge Query Failed", "Query Error", "red", Error=error + "Knowledge Query Failed", "Query Error", "red", Error=error, ) self.print_panel(error_content, "Knowledge Error", "red") def handle_knowledge_query_completed( self, - agent_branch: Optional[Tree], - crew_tree: Optional[Tree], + agent_branch: Tree | None, + crew_tree: Tree | None, ) -> None: """Handle knowledge query completed event.""" if not self.verbose: - return None + return branch_to_use = self.current_lite_agent_branch or agent_branch tree_to_use = branch_to_use or crew_tree if branch_to_use is None or tree_to_use is None: - return None + return query_branch = branch_to_use.add("") self.update_tree_label(query_branch, "✅", "Knowledge Query Completed", "green") @@ -960,9 +956,9 @@ class ConsoleFormatter: def handle_knowledge_search_query_failed( self, - agent_branch: Optional[Tree], + agent_branch: Tree | None, error: str, - crew_tree: Optional[Tree], + crew_tree: Tree | None, ) -> None: """Handle knowledge search query failed event.""" if not self.verbose: @@ -979,6 +975,6 @@ class ConsoleFormatter: # Show error panel error_content = self.create_status_content( - "Knowledge Search Failed", "Search Error", "red", Error=error + "Knowledge Search Failed", "Search Error", "red", Error=error, ) self.print_panel(error_content, "Search Error", "red") diff --git a/src/crewai/utilities/exceptions/context_window_exceeding_exception.py b/src/crewai/utilities/exceptions/context_window_exceeding_exception.py index 399cf5a00..c5c00bb01 100644 --- a/src/crewai/utilities/exceptions/context_window_exceeding_exception.py +++ b/src/crewai/utilities/exceptions/context_window_exceeding_exception.py @@ -10,7 +10,7 @@ class LLMContextLengthExceededException(Exception): "exceeds token limit", ] - def __init__(self, error_message: str): + def __init__(self, error_message: str) -> None: self.original_error_message = error_message super().__init__(self._get_error_message(error_message)) @@ -20,7 +20,7 @@ class LLMContextLengthExceededException(Exception): for phrase in self.CONTEXT_LIMIT_ERRORS ) - def _get_error_message(self, error_message: str): + def _get_error_message(self, error_message: str) -> str: return ( f"LLM context length exceeded. Original error: {error_message}\n" "Consider using a smaller input or implementing a text splitting strategy." diff --git a/src/crewai/utilities/file_handler.py b/src/crewai/utilities/file_handler.py index 85d9766c5..77d27b1ee 100644 --- a/src/crewai/utilities/file_handler.py +++ b/src/crewai/utilities/file_handler.py @@ -2,33 +2,34 @@ import json import os import pickle from datetime import datetime -from typing import Union class FileHandler: """Handler for file operations supporting both JSON and text-based logging. - + Args: file_path (Union[bool, str]): Path to the log file or boolean flag + """ - def __init__(self, file_path: Union[bool, str]): + def __init__(self, file_path: bool | str) -> None: self._initialize_path(file_path) - - def _initialize_path(self, file_path: Union[bool, str]): + + def _initialize_path(self, file_path: bool | str) -> None: if file_path is True: # File path is boolean True self._path = os.path.join(os.curdir, "logs.txt") - + elif isinstance(file_path, str): # File path is a string if file_path.endswith((".json", ".txt")): self._path = file_path # No modification if the file ends with .json or .txt else: self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt - + else: - raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid - - def log(self, **kwargs): + msg = "file_path must be a string or boolean." + raise ValueError(msg) # Handle the case where file_path isn't valid + + def log(self, **kwargs) -> None: try: now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") log_entry = {"timestamp": now, **kwargs} @@ -39,34 +40,36 @@ class FileHandler: # If the file is empty, start with a list; else, append to it try: # Try reading existing content to avoid overwriting - with open(self._path, "r", encoding="utf-8") as read_file: + with open(self._path, encoding="utf-8") as read_file: existing_data = json.load(read_file) existing_data.append(log_entry) except (json.JSONDecodeError, FileNotFoundError): # If no valid JSON or file doesn't exist, start with an empty list existing_data = [log_entry] - + with open(self._path, "w", encoding="utf-8") as write_file: json.dump(existing_data, write_file, indent=4) write_file.write("\n") - + else: # Append log in plain text format - message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n" + message = f"{now}: " + ", ".join([f'{key}="{value}"' for key, value in kwargs.items()]) + "\n" with open(self._path, "a", encoding="utf-8") as file: file.write(message) except Exception as e: - raise ValueError(f"Failed to log message: {str(e)}") - + msg = f"Failed to log message: {e!s}" + raise ValueError(msg) + class PickleHandler: def __init__(self, file_name: str) -> None: - """ - Initialize the PickleHandler with the name of the file where data will be stored. + """Initialize the PickleHandler with the name of the file where data will be stored. The file will be saved in the current directory. - Parameters: + Parameters + ---------- - file_name (str): The name of the file for saving and loading data. + """ if not file_name.endswith(".pkl"): file_name += ".pkl" @@ -74,27 +77,26 @@ class PickleHandler: self.file_path = os.path.join(os.getcwd(), file_name) def initialize_file(self) -> None: - """ - Initialize the file with an empty dictionary and overwrite any existing data. - """ + """Initialize the file with an empty dictionary and overwrite any existing data.""" self.save({}) def save(self, data) -> None: - """ - Save the data to the specified file using pickle. + """Save the data to the specified file using pickle. - Parameters: + Parameters + ---------- - data (object): The data to be saved. + """ with open(self.file_path, "wb") as file: pickle.dump(data, file) def load(self) -> dict: - """ - Load the data from the specified file using pickle. + """Load the data from the specified file using pickle. Returns: - dict: The data loaded from the file. + """ if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: return {} # Return an empty dictionary if the file does not exist or is empty diff --git a/src/crewai/utilities/formatter.py b/src/crewai/utilities/formatter.py index 19b2a74f9..b6c3c9910 100644 --- a/src/crewai/utilities/formatter.py +++ b/src/crewai/utilities/formatter.py @@ -1,21 +1,19 @@ -import re -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING if TYPE_CHECKING: from crewai.task import Task from crewai.tasks.task_output import TaskOutput -def aggregate_raw_outputs_from_task_outputs(task_outputs: List["TaskOutput"]) -> str: +def aggregate_raw_outputs_from_task_outputs(task_outputs: list["TaskOutput"]) -> str: """Generate string context from the task outputs.""" dividers = "\n\n----------\n\n" # Join task outputs with dividers - context = dividers.join(output.raw for output in task_outputs) - return context + return dividers.join(output.raw for output in task_outputs) -def aggregate_raw_outputs_from_tasks(tasks: List["Task"]) -> str: +def aggregate_raw_outputs_from_tasks(tasks: list["Task"]) -> str: """Generate string context from the tasks.""" task_outputs = [task.output for task in tasks if task.output is not None] diff --git a/src/crewai/utilities/i18n.py b/src/crewai/utilities/i18n.py index f2540e455..1f91cd17f 100644 --- a/src/crewai/utilities/i18n.py +++ b/src/crewai/utilities/i18n.py @@ -1,6 +1,5 @@ import json import os -from typing import Dict, Optional, Union from pydantic import BaseModel, Field, PrivateAttr, model_validator @@ -8,8 +7,9 @@ from pydantic import BaseModel, Field, PrivateAttr, model_validator class I18N(BaseModel): """Handles loading and retrieving internationalized prompts.""" - _prompts: Dict[str, Dict[str, str]] = PrivateAttr() - prompt_file: Optional[str] = Field( + + _prompts: dict[str, dict[str, str]] = PrivateAttr() + prompt_file: str | None = Field( default=None, description="Path to the prompt_file file to load", ) @@ -19,18 +19,20 @@ class I18N(BaseModel): """Load prompts from a JSON file.""" try: if self.prompt_file: - with open(self.prompt_file, "r", encoding="utf-8") as f: + with open(self.prompt_file, encoding="utf-8") as f: self._prompts = json.load(f) else: dir_path = os.path.dirname(os.path.realpath(__file__)) prompts_path = os.path.join(dir_path, "../translations/en.json") - with open(prompts_path, "r", encoding="utf-8") as f: + with open(prompts_path, encoding="utf-8") as f: self._prompts = json.load(f) except FileNotFoundError: - raise Exception(f"Prompt file '{self.prompt_file}' not found.") + msg = f"Prompt file '{self.prompt_file}' not found." + raise Exception(msg) except json.JSONDecodeError: - raise Exception("Error decoding JSON from the prompts file.") + msg = "Error decoding JSON from the prompts file." + raise Exception(msg) if not self._prompts: self._prompts = {} @@ -43,11 +45,12 @@ class I18N(BaseModel): def errors(self, error: str) -> str: return self.retrieve("errors", error) - def tools(self, tool: str) -> Union[str, Dict[str, str]]: + def tools(self, tool: str) -> str | dict[str, str]: return self.retrieve("tools", tool) def retrieve(self, kind, key) -> str: try: return self._prompts[kind][key] except Exception as _: - raise Exception(f"Prompt for '{kind}':'{key}' not found.") + msg = f"Prompt for '{kind}':'{key}' not found." + raise Exception(msg) diff --git a/src/crewai/utilities/internal_instructor.py b/src/crewai/utilities/internal_instructor.py index e9401c778..7588e166b 100644 --- a/src/crewai/utilities/internal_instructor.py +++ b/src/crewai/utilities/internal_instructor.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Optional, Type +from typing import Any class InternalInstructor: @@ -8,10 +8,10 @@ class InternalInstructor: def __init__( self, content: str, - model: Type, - agent: Optional[Any] = None, - llm: Optional[str] = None, - ): + model: type, + agent: Any | None = None, + llm: str | None = None, + ) -> None: self.content = content self.agent = agent self.llm = llm @@ -19,7 +19,7 @@ class InternalInstructor: self._client = None self.set_instructor() - def set_instructor(self): + def set_instructor(self) -> None: """Set instructor.""" if self.agent and not self.llm: self.llm = self.agent.function_calling_llm or self.agent.llm @@ -37,7 +37,6 @@ class InternalInstructor: def to_pydantic(self): messages = [{"role": "user", "content": self.content}] - model = self._client.chat.completions.create( - model=self.llm.model, response_model=self.model, messages=messages + return self._client.chat.completions.create( + model=self.llm.model, response_model=self.model, messages=messages, ) - return model diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 1eb0a4693..6bd4c7412 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -1,15 +1,14 @@ import os -from typing import Any, Dict, List, Optional, Union +from typing import Any from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS from crewai.llm import LLM, BaseLLM def create_llm( - llm_value: Union[str, LLM, Any, None] = None, -) -> Optional[LLM | BaseLLM]: - """ - Creates or returns an LLM instance based on the given llm_value. + llm_value: str | LLM | Any | None = None, +) -> LLM | BaseLLM | None: + """Creates or returns an LLM instance based on the given llm_value. Args: llm_value (str | BaseLLM | Any | None): @@ -20,19 +19,17 @@ def create_llm( Returns: A BaseLLM instance if successful, or None if something fails. - """ + """ # 1) If llm_value is already a BaseLLM or LLM object, return it directly - if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM): + if isinstance(llm_value, (LLM, BaseLLM)): return llm_value # 2) If llm_value is a string (model name) if isinstance(llm_value, str): try: - created_llm = LLM(model=llm_value) - return created_llm - except Exception as e: - print(f"Failed to instantiate LLM with model='{llm_value}': {e}") + return LLM(model=llm_value) + except Exception: return None # 3) If llm_value is None, parse environment variables or use default @@ -48,15 +45,15 @@ def create_llm( or getattr(llm_value, "deployment_name", None) or str(llm_value) ) - temperature: Optional[float] = getattr(llm_value, "temperature", None) - max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None) - logprobs: Optional[int] = getattr(llm_value, "logprobs", None) - timeout: Optional[float] = getattr(llm_value, "timeout", None) - api_key: Optional[str] = getattr(llm_value, "api_key", None) - base_url: Optional[str] = getattr(llm_value, "base_url", None) - api_base: Optional[str] = getattr(llm_value, "api_base", None) + temperature: float | None = getattr(llm_value, "temperature", None) + max_tokens: int | None = getattr(llm_value, "max_tokens", None) + logprobs: int | None = getattr(llm_value, "logprobs", None) + timeout: float | None = getattr(llm_value, "timeout", None) + api_key: str | None = getattr(llm_value, "api_key", None) + base_url: str | None = getattr(llm_value, "base_url", None) + api_base: str | None = getattr(llm_value, "api_base", None) - created_llm = LLM( + return LLM( model=model, temperature=temperature, max_tokens=max_tokens, @@ -66,16 +63,12 @@ def create_llm( base_url=base_url, api_base=api_base, ) - return created_llm - except Exception as e: - print(f"Error instantiating LLM from unknown object type: {e}") + except Exception: return None -def _llm_via_environment_or_fallback() -> Optional[LLM]: - """ - Helper function: if llm_value is None, we load environment variables or fallback default model. - """ +def _llm_via_environment_or_fallback() -> LLM | None: + """Helper function: if llm_value is None, we load environment variables or fallback default model.""" model_name = ( os.environ.get("MODEL") or os.environ.get("MODEL_NAME") @@ -85,24 +78,24 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: # Initialize parameters with correct types model: str = model_name - temperature: Optional[float] = None - max_tokens: Optional[int] = None - max_completion_tokens: Optional[int] = None - logprobs: Optional[int] = None - timeout: Optional[float] = None - api_key: Optional[str] = None - base_url: Optional[str] = None - api_version: Optional[str] = None - presence_penalty: Optional[float] = None - frequency_penalty: Optional[float] = None - top_p: Optional[float] = None - n: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None - logit_bias: Optional[Dict[int, float]] = None - response_format: Optional[Dict[str, Any]] = None - seed: Optional[int] = None - top_logprobs: Optional[int] = None - callbacks: List[Any] = [] + temperature: float | None = None + max_tokens: int | None = None + max_completion_tokens: int | None = None + logprobs: int | None = None + timeout: float | None = None + api_key: str | None = None + base_url: str | None = None + api_version: str | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + top_p: float | None = None + n: int | None = None + stop: str | list[str] | None = None + logit_bias: dict[int, float] | None = None + response_format: dict[str, Any] | None = None + seed: int | None = None + top_logprobs: int | None = None + callbacks: list[Any] = [] # Optional base URL from env base_url = ( @@ -120,7 +113,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: base_url = api_base # Initialize llm_params dictionary - llm_params: Dict[str, Any] = { + llm_params: dict[str, Any] = { "model": model, "temperature": temperature, "max_tokens": max_tokens, @@ -167,27 +160,20 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: if key not in ["prompt", "key_name", "default"]: llm_params[key.lower()] = value else: - print( - f"Expected env_var to be a dictionary, but got {type(env_var)}" - ) + pass # Remove None values llm_params = {k: v for k, v in llm_params.items() if v is not None} # Try creating the LLM try: - new_llm = LLM(**llm_params) - return new_llm - except Exception as e: - print( - f"Error instantiating LLM from environment/fallback: {type(e).__name__}: {e}" - ) + return LLM(**llm_params) + except Exception: return None def _normalize_key_name(key_name: str) -> str: - """ - Maps environment variable names to recognized litellm parameter keys, + """Maps environment variable names to recognized litellm parameter keys, using patterns from LITELLM_PARAMS. """ for pattern in LITELLM_PARAMS: diff --git a/src/crewai/utilities/logger.py b/src/crewai/utilities/logger.py index 2f69e7abc..6e5c71e74 100644 --- a/src/crewai/utilities/logger.py +++ b/src/crewai/utilities/logger.py @@ -10,11 +10,11 @@ class Logger(BaseModel): _printer: Printer = PrivateAttr(default_factory=Printer) default_color: str = Field(default="bold_yellow") - def log(self, level, message, color=None): + def log(self, level, message, color=None) -> None: if color is None: color = self.default_color if self.verbose: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self._printer.print( - f"\n[{timestamp}][{level.upper()}]: {message}", color=color + f"\n[{timestamp}][{level.upper()}]: {message}", color=color, ) diff --git a/src/crewai/utilities/parser.py b/src/crewai/utilities/parser.py index c19cc1133..7debfc6d1 100644 --- a/src/crewai/utilities/parser.py +++ b/src/crewai/utilities/parser.py @@ -4,28 +4,34 @@ import re class YamlParser: @staticmethod def parse(file): - """ - Parses a YAML file, modifies specific patterns, and checks for unsupported 'context' usage. + """Parses a YAML file, modifies specific patterns, and checks for unsupported 'context' usage. + Args: file (file object): The YAML file to parse. + Returns: str: The modified content of the YAML file. + Raises: ValueError: If 'context:' is used incorrectly. + """ content = file.read() # Replace single { and } with doubled ones, while leaving already doubled ones intact and the other special characters {# and {% modified_content = re.sub(r"(? str: Returns: str: Full path to the SQLite database file + """ app_name = get_project_directory_name() app_author = "CrewAI" @@ -25,7 +26,5 @@ def get_project_directory_name(): if project_directory_name: return project_directory_name - else: - cwd = Path.cwd() - project_directory_name = cwd.name - return project_directory_name \ No newline at end of file + cwd = Path.cwd() + return cwd.name diff --git a/src/crewai/utilities/planning_handler.py b/src/crewai/utilities/planning_handler.py index 1bd14a0c8..1a9958cb2 100644 --- a/src/crewai/utilities/planning_handler.py +++ b/src/crewai/utilities/planning_handler.py @@ -1,5 +1,5 @@ import logging -from typing import Any, List, Optional +from typing import Any from pydantic import BaseModel, Field @@ -11,6 +11,7 @@ logger = logging.getLogger(__name__) class PlanPerTask(BaseModel): """Represents a plan for a specific task.""" + task: str = Field(..., description="The task for which the plan is created") plan: str = Field( ..., @@ -20,7 +21,8 @@ class PlanPerTask(BaseModel): class PlannerTaskPydanticOutput(BaseModel): """Output format for task planning results.""" - list_of_plans_per_task: List[PlanPerTask] = Field( + + list_of_plans_per_task: list[PlanPerTask] = Field( ..., description="Step by step plan on how the agents can execute their tasks using the available tools with mastery", ) @@ -28,7 +30,8 @@ class PlannerTaskPydanticOutput(BaseModel): class CrewPlanner: """Plans and coordinates the execution of crew tasks.""" - def __init__(self, tasks: List[Task], planning_agent_llm: Optional[Any] = None): + + def __init__(self, tasks: list[Task], planning_agent_llm: Any | None = None) -> None: self.tasks = tasks if planning_agent_llm is None: @@ -48,7 +51,8 @@ class CrewPlanner: if isinstance(result.pydantic, PlannerTaskPydanticOutput): return result.pydantic - raise ValueError("Failed to get the Planning output") + msg = "Failed to get the Planning output" + raise ValueError(msg) def _create_planning_agent(self) -> Agent: """Creates the planning agent for the crew planning.""" @@ -74,15 +78,15 @@ class CrewPlanner: output_pydantic=PlannerTaskPydanticOutput, ) - def _get_agent_knowledge(self, task: Task) -> List[str]: - """ - Safely retrieve knowledge source content from the task's agent. + def _get_agent_knowledge(self, task: Task) -> list[str]: + """Safely retrieve knowledge source content from the task's agent. Args: task: The task containing an agent with potential knowledge sources Returns: List[str]: A list of knowledge source strings + """ try: if task.agent and task.agent.knowledge_sources: @@ -98,7 +102,7 @@ class CrewPlanner: knowledge_list = self._get_agent_knowledge(task) agent_tools = ( f"[{', '.join(str(tool) for tool in task.agent.tools)}]" if task.agent and task.agent.tools else '"agent has no tools"', - f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' if knowledge_list and str(knowledge_list) != "None" else "" + f',\n "agent_knowledge": "[\\"{knowledge_list[0]}\\"]"' if knowledge_list and str(knowledge_list) != "None" else "", ) task_summary = f""" Task Number {idx + 1} - {task.description} diff --git a/src/crewai/utilities/printer.py b/src/crewai/utilities/printer.py index 74ad9a30b..658dcaf3f 100644 --- a/src/crewai/utilities/printer.py +++ b/src/crewai/utilities/printer.py @@ -1,12 +1,11 @@ """Utility for colored console output.""" -from typing import Optional class Printer: """Handles colored console output formatting.""" - def print(self, content: str, color: Optional[str] = None): + def print(self, content: str, color: str | None = None) -> None: if color == "purple": self._print_purple(content) elif color == "red": @@ -32,40 +31,40 @@ class Printer: elif color == "green": self._print_green(content) else: - print(content) + pass - def _print_bold_purple(self, content): - print("\033[1m\033[95m {}\033[00m".format(content)) + def _print_bold_purple(self, content) -> None: + pass - def _print_bold_green(self, content): - print("\033[1m\033[92m {}\033[00m".format(content)) + def _print_bold_green(self, content) -> None: + pass - def _print_purple(self, content): - print("\033[95m {}\033[00m".format(content)) + def _print_purple(self, content) -> None: + pass - def _print_red(self, content): - print("\033[91m {}\033[00m".format(content)) + def _print_red(self, content) -> None: + pass - def _print_bold_blue(self, content): - print("\033[1m\033[94m {}\033[00m".format(content)) + def _print_bold_blue(self, content) -> None: + pass - def _print_yellow(self, content): - print("\033[93m {}\033[00m".format(content)) + def _print_yellow(self, content) -> None: + pass - def _print_bold_yellow(self, content): - print("\033[1m\033[93m {}\033[00m".format(content)) + def _print_bold_yellow(self, content) -> None: + pass - def _print_cyan(self, content): - print("\033[96m {}\033[00m".format(content)) + def _print_cyan(self, content) -> None: + pass - def _print_bold_cyan(self, content): - print("\033[1m\033[96m {}\033[00m".format(content)) + def _print_bold_cyan(self, content) -> None: + pass - def _print_magenta(self, content): - print("\033[35m {}\033[00m".format(content)) + def _print_magenta(self, content) -> None: + pass - def _print_bold_magenta(self, content): - print("\033[1m\033[35m {}\033[00m".format(content)) + def _print_bold_magenta(self, content) -> None: + pass - def _print_green(self, content): - print("\033[32m {}\033[00m".format(content)) + def _print_green(self, content) -> None: + pass diff --git a/src/crewai/utilities/prompts.py b/src/crewai/utilities/prompts.py index cd3577874..aca27f162 100644 --- a/src/crewai/utilities/prompts.py +++ b/src/crewai/utilities/prompts.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -10,10 +10,10 @@ class Prompts(BaseModel): i18n: I18N = Field(default=I18N()) has_tools: bool = False - system_template: Optional[str] = None - prompt_template: Optional[str] = None - response_template: Optional[str] = None - use_system_prompt: Optional[bool] = False + system_template: str | None = None + prompt_template: str | None = None + response_template: str | None = None + use_system_prompt: bool | None = False agent: Any def task_execution(self) -> dict[str, str]: @@ -36,15 +36,14 @@ class Prompts(BaseModel): "user": self._build_prompt(["task"]), "prompt": self._build_prompt(slices), } - else: - return { - "prompt": self._build_prompt( - slices, - self.system_template, - self.prompt_template, - self.response_template, - ) - } + return { + "prompt": self._build_prompt( + slices, + self.system_template, + self.prompt_template, + self.response_template, + ), + } def _build_prompt( self, @@ -67,7 +66,7 @@ class Prompts(BaseModel): ] system = system_template.replace("{{ .System }}", "".join(prompt_parts)) prompt = prompt_template.replace( - "{{ .Prompt }}", "".join(self.i18n.slice("task")) + "{{ .Prompt }}", "".join(self.i18n.slice("task")), ) # Handle missing response_template if response_template: @@ -76,9 +75,8 @@ class Prompts(BaseModel): else: prompt = f"{system}\n{prompt}" - prompt = ( + return ( prompt.replace("{goal}", self.agent.goal) .replace("{role}", self.agent.role) .replace("{backstory}", self.agent.backstory) ) - return prompt diff --git a/src/crewai/utilities/pydantic_schema_parser.py b/src/crewai/utilities/pydantic_schema_parser.py index 2827d70aa..9e2dbd528 100644 --- a/src/crewai/utilities/pydantic_schema_parser.py +++ b/src/crewai/utilities/pydantic_schema_parser.py @@ -1,20 +1,19 @@ -from typing import Dict, List, Type, Union, get_args, get_origin +from typing import Union, get_args, get_origin from pydantic import BaseModel class PydanticSchemaParser(BaseModel): - model: Type[BaseModel] + model: type[BaseModel] def get_schema(self) -> str: - """ - Public method to get the schema of a Pydantic model. + """Public method to get the schema of a Pydantic model. :return: String representation of the model schema. """ return "{\n" + self._get_model_schema(self.model) + "\n}" - def _get_model_schema(self, model: Type[BaseModel], depth: int = 0) -> str: + def _get_model_schema(self, model: type[BaseModel], depth: int = 0) -> str: indent = " " * 4 * depth lines = [ f"{indent} {field_name}: {self._get_field_type(field, depth + 1)}" @@ -26,11 +25,11 @@ class PydanticSchemaParser(BaseModel): field_type = field.annotation origin = get_origin(field_type) - if origin in {list, List}: + if origin in {list, list}: list_item_type = get_args(field_type)[0] return self._format_list_type(list_item_type, depth) - if origin in {dict, Dict}: + if origin in {dict, dict}: key_type, value_type = get_args(field_type) return f"Dict[{key_type.__name__}, {value_type.__name__}]" @@ -58,29 +57,27 @@ class PydanticSchemaParser(BaseModel): non_none_args = [arg for arg in args if arg is not type(None)] if len(non_none_args) == 1: inner_type = self._get_field_type_for_annotation( - non_none_args[0], depth + non_none_args[0], depth, ) return f"Optional[{inner_type}]" - else: - # Union with None and multiple other types - inner_types = ", ".join( - self._get_field_type_for_annotation(arg, depth) - for arg in non_none_args - ) - return f"Optional[Union[{inner_types}]]" - else: - # General Union type + # Union with None and multiple other types inner_types = ", ".join( - self._get_field_type_for_annotation(arg, depth) for arg in args + self._get_field_type_for_annotation(arg, depth) + for arg in non_none_args ) - return f"Union[{inner_types}]" + return f"Optional[Union[{inner_types}]]" + # General Union type + inner_types = ", ".join( + self._get_field_type_for_annotation(arg, depth) for arg in args + ) + return f"Union[{inner_types}]" def _get_field_type_for_annotation(self, annotation, depth: int) -> str: origin = get_origin(annotation) - if origin in {list, List}: + if origin in {list, list}: list_item_type = get_args(annotation)[0] return self._format_list_type(list_item_type, depth) - if origin in {dict, Dict}: + if origin in {dict, dict}: key_type, value_type = get_args(annotation) return f"Dict[{key_type.__name__}, {value_type.__name__}]" if origin is Union: diff --git a/src/crewai/utilities/rpm_controller.py b/src/crewai/utilities/rpm_controller.py index ec59b8304..cdc5107eb 100644 --- a/src/crewai/utilities/rpm_controller.py +++ b/src/crewai/utilities/rpm_controller.py @@ -1,6 +1,5 @@ import threading import time -from typing import Optional from pydantic import BaseModel, Field, PrivateAttr, model_validator @@ -12,32 +11,31 @@ from crewai.utilities.logger import Logger class RPMController(BaseModel): """Manages requests per minute limiting.""" - max_rpm: Optional[int] = Field(default=None) + max_rpm: int | None = Field(default=None) logger: Logger = Field(default_factory=lambda: Logger(verbose=False)) _current_rpm: int = PrivateAttr(default=0) - _timer: Optional[threading.Timer] = PrivateAttr(default=None) - _lock: Optional[threading.Lock] = PrivateAttr(default=None) + _timer: threading.Timer | None = PrivateAttr(default=None) + _lock: threading.Lock | None = PrivateAttr(default=None) _shutdown_flag: bool = PrivateAttr(default=False) @model_validator(mode="after") def reset_counter(self): - if self.max_rpm is not None: - if not self._shutdown_flag: - self._lock = threading.Lock() - self._reset_request_count() + if self.max_rpm is not None and not self._shutdown_flag: + self._lock = threading.Lock() + self._reset_request_count() return self def check_or_wait(self): if self.max_rpm is None: return True - def _check_and_increment(): + def _check_and_increment() -> bool: if self.max_rpm is not None and self._current_rpm < self.max_rpm: self._current_rpm += 1 return True - elif self.max_rpm is not None: + if self.max_rpm is not None: self.logger.log( - "info", "Max RPM reached, waiting for next minute to start." + "info", "Max RPM reached, waiting for next minute to start.", ) self._wait_for_next_minute() self._current_rpm = 1 @@ -50,17 +48,17 @@ class RPMController(BaseModel): else: return _check_and_increment() - def stop_rpm_counter(self): + def stop_rpm_counter(self) -> None: if self._timer: self._timer.cancel() self._timer = None - def _wait_for_next_minute(self): + def _wait_for_next_minute(self) -> None: time.sleep(60) self._current_rpm = 0 - def _reset_request_count(self): - def _reset(): + def _reset_request_count(self) -> None: + def _reset() -> None: self._current_rpm = 0 if not self._shutdown_flag: self._timer = threading.Timer(60.0, self._reset_request_count) diff --git a/src/crewai/utilities/serialization.py b/src/crewai/utilities/serialization.py index c3c0c3d47..390f69476 100644 --- a/src/crewai/utilities/serialization.py +++ b/src/crewai/utilities/serialization.py @@ -1,13 +1,13 @@ import json import uuid from datetime import date, datetime -from typing import Any, Dict, List, Union +from typing import Any, Union from pydantic import BaseModel SerializablePrimitive = Union[str, int, float, bool, None] Serializable = Union[ - SerializablePrimitive, List["Serializable"], Dict[str, "Serializable"] + SerializablePrimitive, list["Serializable"], dict[str, "Serializable"], ] @@ -30,6 +30,7 @@ def to_serializable( Returns: Serializable: A JSON-compatible structure. + """ if _current_depth >= max_depth: return repr(obj) @@ -39,18 +40,18 @@ def to_serializable( if isinstance(obj, (str, int, float, bool, type(None))): return obj - elif isinstance(obj, uuid.UUID): + if isinstance(obj, uuid.UUID): return str(obj) - elif isinstance(obj, (date, datetime)): + if isinstance(obj, (date, datetime)): return obj.isoformat() - elif isinstance(obj, (list, tuple, set)): + if isinstance(obj, (list, tuple, set)): return [ to_serializable( - item, max_depth=max_depth, _current_depth=_current_depth + 1 + item, max_depth=max_depth, _current_depth=_current_depth + 1, ) for item in obj ] - elif isinstance(obj, dict): + if isinstance(obj, dict): return { _to_serializable_key(key): to_serializable( obj=value, @@ -61,20 +62,19 @@ def to_serializable( for key, value in obj.items() if key not in exclude } - elif isinstance(obj, BaseModel): + if isinstance(obj, BaseModel): return to_serializable( obj=obj.model_dump(exclude=exclude), max_depth=max_depth, _current_depth=_current_depth + 1, ) - else: - return repr(obj) + return repr(obj) def _to_serializable_key(key: Any) -> str: if isinstance(key, (str, int)): return str(key) - return f"key_{id(key)}_{repr(key)}" + return f"key_{id(key)}_{key!r}" def to_string(obj: Any) -> str | None: @@ -85,9 +85,9 @@ def to_string(obj: Any) -> str | None: Returns: str | None: A JSON-formatted string or `None` if empty. + """ serializable = to_serializable(obj) if serializable is None: return None - else: - return json.dumps(serializable) + return json.dumps(serializable) diff --git a/src/crewai/utilities/string_utils.py b/src/crewai/utilities/string_utils.py index 9a1857781..ee0b24900 100644 --- a/src/crewai/utilities/string_utils.py +++ b/src/crewai/utilities/string_utils.py @@ -1,10 +1,10 @@ import re -from typing import Any, Dict, List, Optional, Union +from typing import Any def interpolate_only( - input_string: Optional[str], - inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]], + input_string: str | None, + inputs: dict[str, str | int | float | dict[str, Any] | list[Any]], ) -> str: """Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched. Only interpolates placeholders that follow the pattern {variable_name} where @@ -23,6 +23,7 @@ def interpolate_only( Raises: ValueError: If a value contains unsupported types or a template variable is missing + """ # Validation function for recursive type checking @@ -35,25 +36,30 @@ def interpolate_only( for item in value.values() if isinstance(value, dict) else value: validate_type(item) return - raise ValueError( + msg = ( f"Unsupported type {type(value).__name__} in inputs. " "Only str, int, float, bool, dict, and list are allowed." ) + raise ValueError( + msg, + ) # Validate all input values for key, value in inputs.items(): try: validate_type(value) except ValueError as e: - raise ValueError(f"Invalid value for key '{key}': {str(e)}") from e + msg = f"Invalid value for key '{key}': {e!s}" + raise ValueError(msg) from e if input_string is None or not input_string: return "" if "{" not in input_string and "}" not in input_string: return input_string if not inputs: + msg = "Inputs dictionary cannot be empty when interpolating variables" raise ValueError( - "Inputs dictionary cannot be empty when interpolating variables" + msg, ) # The regex pattern to find valid variable placeholders @@ -68,8 +74,9 @@ def interpolate_only( # Check if all variables exist in inputs missing_vars = [var for var in variables if var not in inputs] if missing_vars: + msg = f"Template variable '{missing_vars[0]}' not found in inputs dictionary" raise KeyError( - f"Template variable '{missing_vars[0]}' not found in inputs dictionary" + msg, ) # Replace each variable with its value diff --git a/src/crewai/utilities/task_output_storage_handler.py b/src/crewai/utilities/task_output_storage_handler.py index 80e749bee..6d470ed33 100644 --- a/src/crewai/utilities/task_output_storage_handler.py +++ b/src/crewai/utilities/task_output_storage_handler.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import BaseModel, Field @@ -12,12 +12,13 @@ from crewai.task import Task class ExecutionLog(BaseModel): """Represents a log entry for task execution.""" + task_id: str - expected_output: Optional[str] = None - output: Dict[str, Any] + expected_output: str | None = None + output: dict[str, Any] timestamp: datetime = Field(default_factory=datetime.now) task_index: int - inputs: Dict[str, Any] = Field(default_factory=dict) + inputs: dict[str, Any] = Field(default_factory=dict) was_replayed: bool = False def __getitem__(self, key: str) -> Any: @@ -30,10 +31,11 @@ class TaskOutputStorageHandler: def __init__(self) -> None: self.storage = KickoffTaskOutputsSQLiteStorage() - def update(self, task_index: int, log: Dict[str, Any]): + def update(self, task_index: int, log: dict[str, Any]) -> None: saved_outputs = self.load() if saved_outputs is None: - raise ValueError("Logs cannot be None") + msg = "Logs cannot be None" + raise ValueError(msg) if log.get("was_replayed", False): replayed = { @@ -53,15 +55,17 @@ class TaskOutputStorageHandler: def add( self, task: Task, - output: Dict[str, Any], + output: dict[str, Any], task_index: int, - inputs: Dict[str, Any] = {}, + inputs: dict[str, Any] | None = None, was_replayed: bool = False, - ): + ) -> None: + if inputs is None: + inputs = {} self.storage.add(task, output, task_index, was_replayed, inputs) - def reset(self): + def reset(self) -> None: self.storage.delete_all() - def load(self) -> Optional[List[Dict[str, Any]]]: + def load(self) -> list[dict[str, Any]] | None: return self.storage.load() diff --git a/src/crewai/utilities/token_counter_callback.py b/src/crewai/utilities/token_counter_callback.py index 7037ad5c4..fdaf2bc1a 100644 --- a/src/crewai/utilities/token_counter_callback.py +++ b/src/crewai/utilities/token_counter_callback.py @@ -1,20 +1,22 @@ import warnings -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any from litellm.integrations.custom_logger import CustomLogger -from litellm.types.utils import Usage from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess +if TYPE_CHECKING: + from litellm.types.utils import Usage + class TokenCalcHandler(CustomLogger): - def __init__(self, token_cost_process: Optional[TokenProcess]): + def __init__(self, token_cost_process: TokenProcess | None) -> None: self.token_cost_process = token_cost_process def log_success_event( self, - kwargs: Dict[str, Any], - response_obj: Dict[str, Any], + kwargs: dict[str, Any], + response_obj: dict[str, Any], start_time: float, end_time: float, ) -> None: @@ -31,7 +33,7 @@ class TokenCalcHandler(CustomLogger): self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens) if hasattr(usage, "completion_tokens"): self.token_cost_process.sum_completion_tokens( - usage.completion_tokens + usage.completion_tokens, ) if ( hasattr(usage, "prompt_tokens_details") @@ -39,5 +41,5 @@ class TokenCalcHandler(CustomLogger): and usage.prompt_tokens_details.cached_tokens ): self.token_cost_process.sum_cached_prompt_tokens( - usage.prompt_tokens_details.cached_tokens + usage.prompt_tokens_details.cached_tokens, ) diff --git a/src/crewai/utilities/tool_utils.py b/src/crewai/utilities/tool_utils.py index eaf065477..231532c1a 100644 --- a/src/crewai/utilities/tool_utils.py +++ b/src/crewai/utilities/tool_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any from crewai.agents.parser import AgentAction from crewai.security import Fingerprint @@ -10,15 +10,15 @@ from crewai.utilities.i18n import I18N def execute_tool_and_check_finality( agent_action: AgentAction, - tools: List[CrewStructuredTool], + tools: list[CrewStructuredTool], i18n: I18N, - agent_key: Optional[str] = None, - agent_role: Optional[str] = None, - tools_handler: Optional[Any] = None, - task: Optional[Any] = None, - agent: Optional[Any] = None, - function_calling_llm: Optional[Any] = None, - fingerprint_context: Optional[Dict[str, str]] = None, + agent_key: str | None = None, + agent_role: str | None = None, + tools_handler: Any | None = None, + task: Any | None = None, + agent: Any | None = None, + function_calling_llm: Any | None = None, + fingerprint_context: dict[str, str] | None = None, ) -> ToolResult: """Execute a tool and check if the result should be treated as a final answer. @@ -35,22 +35,22 @@ def execute_tool_and_check_finality( Returns: ToolResult containing the execution result and whether it should be treated as a final answer + """ try: tool_name_to_tool_map = {tool.name: tool for tool in tools} if agent_key and agent_role and agent: fingerprint_context = fingerprint_context or {} - if agent: - if hasattr(agent, "set_fingerprint") and callable( - agent.set_fingerprint - ): - if isinstance(fingerprint_context, dict): - try: - fingerprint_obj = Fingerprint.from_dict(fingerprint_context) - agent.set_fingerprint(fingerprint_obj) - except Exception as e: - raise ValueError(f"Failed to set fingerprint: {e}") + if agent and hasattr(agent, "set_fingerprint") and callable( + agent.set_fingerprint, + ) and isinstance(fingerprint_context, dict): + try: + fingerprint_obj = Fingerprint.from_dict(fingerprint_context) + agent.set_fingerprint(fingerprint_obj) + except Exception as e: + msg = f"Failed to set fingerprint: {e}" + raise ValueError(msg) # Create tool usage instance tool_usage = ToolUsage( @@ -86,5 +86,5 @@ def execute_tool_and_check_finality( ) return ToolResult(tool_result, False) - except Exception as e: - raise e + except Exception: + raise diff --git a/src/crewai/utilities/training_handler.py b/src/crewai/utilities/training_handler.py index 2d34f3261..c85999493 100644 --- a/src/crewai/utilities/training_handler.py +++ b/src/crewai/utilities/training_handler.py @@ -5,23 +5,25 @@ from crewai.utilities.file_handler import PickleHandler class CrewTrainingHandler(PickleHandler): def save_trained_data(self, agent_id: str, trained_data: dict) -> None: - """ - Save the trained data for a specific agent. + """Save the trained data for a specific agent. - Parameters: + Parameters + ---------- - agent_id (str): The ID of the agent. - trained_data (dict): The trained data to be saved. + """ data = self.load() data[agent_id] = trained_data self.save(data) def append(self, train_iteration: int, agent_id: str, new_data) -> None: - """ - Append new data to the existing pickle file. + """Append new data to the existing pickle file. - Parameters: + Parameters + ---------- - new_data (object): The new data to be appended. + """ data = self.load()