Apply automatic linting fixes to src directory

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-12 13:30:50 +00:00
parent 807dfe0558
commit ad1ea46bbb
160 changed files with 3218 additions and 3197 deletions

View File

@@ -1,6 +1,7 @@
import shutil import shutil
import subprocess import subprocess
from typing import Any, Dict, List, Literal, Optional, Sequence, Type, Union from collections.abc import Sequence
from typing import Any, Literal
from pydantic import Field, InstanceOf, PrivateAttr, model_validator from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -67,40 +68,41 @@ class Agent(BaseAgent):
step_callback: Callback to be executed after each step of the agent execution. step_callback: Callback to be executed after each step of the agent execution.
knowledge_sources: Knowledge sources for the agent. knowledge_sources: Knowledge sources for the agent.
embedder: Embedder configuration for the agent. embedder: Embedder configuration for the agent.
""" """
_times_executed: int = PrivateAttr(default=0) _times_executed: int = PrivateAttr(default=0)
max_execution_time: Optional[int] = Field( max_execution_time: int | None = Field(
default=None, default=None,
description="Maximum execution time for an agent to execute a task", description="Maximum execution time for an agent to execute a task",
) )
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str") agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str") agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
step_callback: Optional[Any] = Field( step_callback: Any | None = Field(
default=None, default=None,
description="Callback to be executed after each step of the agent execution.", description="Callback to be executed after each step of the agent execution.",
) )
use_system_prompt: Optional[bool] = Field( use_system_prompt: bool | None = Field(
default=True, default=True,
description="Use system prompt for the agent.", description="Use system prompt for the agent.",
) )
llm: Union[str, InstanceOf[BaseLLM], Any] = Field( llm: str | InstanceOf[BaseLLM] | Any = Field(
description="Language model that will run the agent.", default=None description="Language model that will run the agent.", default=None,
) )
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
description="Language model that will run the agent.", default=None description="Language model that will run the agent.", default=None,
) )
system_template: Optional[str] = Field( system_template: str | None = Field(
default=None, description="System format for the agent." default=None, description="System format for the agent.",
) )
prompt_template: Optional[str] = Field( prompt_template: str | None = Field(
default=None, description="Prompt format for the agent." default=None, description="Prompt format for the agent.",
) )
response_template: Optional[str] = Field( response_template: str | None = Field(
default=None, description="Response format for the agent." default=None, description="Response format for the agent.",
) )
allow_code_execution: Optional[bool] = Field( allow_code_execution: bool | None = Field(
default=False, description="Enable code execution for the agent." default=False, description="Enable code execution for the agent.",
) )
respect_context_window: bool = Field( respect_context_window: bool = Field(
default=True, default=True,
@@ -118,19 +120,19 @@ class Agent(BaseAgent):
default="safe", default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).", description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
) )
embedder: Optional[Dict[str, Any]] = Field( embedder: dict[str, Any] | None = Field(
default=None, default=None,
description="Embedder configuration for the agent.", description="Embedder configuration for the agent.",
) )
agent_knowledge_context: Optional[str] = Field( agent_knowledge_context: str | None = Field(
default=None, default=None,
description="Knowledge context for the agent.", description="Knowledge context for the agent.",
) )
crew_knowledge_context: Optional[str] = Field( crew_knowledge_context: str | None = Field(
default=None, default=None,
description="Knowledge context for the crew.", description="Knowledge context for the crew.",
) )
knowledge_search_query: Optional[str] = Field( knowledge_search_query: str | None = Field(
default=None, default=None,
description="Knowledge search query for the agent dynamically generated by the agent.", description="Knowledge search query for the agent dynamically generated by the agent.",
) )
@@ -141,7 +143,7 @@ class Agent(BaseAgent):
self.llm = create_llm(self.llm) self.llm = create_llm(self.llm)
if self.function_calling_llm and not isinstance( if self.function_calling_llm and not isinstance(
self.function_calling_llm, BaseLLM self.function_calling_llm, BaseLLM,
): ):
self.function_calling_llm = create_llm(self.function_calling_llm) self.function_calling_llm = create_llm(self.function_calling_llm)
@@ -153,12 +155,12 @@ class Agent(BaseAgent):
return self return self
def _setup_agent_executor(self): def _setup_agent_executor(self) -> None:
if not self.cache_handler: if not self.cache_handler:
self.cache_handler = CacheHandler() self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler) self.set_cache_handler(self.cache_handler)
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
try: try:
if self.embedder is None and crew_embedder: if self.embedder is None and crew_embedder:
self.embedder = crew_embedder self.embedder = crew_embedder
@@ -174,7 +176,8 @@ class Agent(BaseAgent):
storage=self.knowledge_storage or None, storage=self.knowledge_storage or None,
) )
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}") msg = f"Invalid Knowledge Configuration: {e!s}"
raise ValueError(msg)
def _is_any_available_memory(self) -> bool: def _is_any_available_memory(self) -> bool:
"""Check if any memory is available.""" """Check if any memory is available."""
@@ -196,8 +199,8 @@ class Agent(BaseAgent):
def execute_task( def execute_task(
self, self,
task: Task, task: Task,
context: Optional[str] = None, context: str | None = None,
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
) -> str: ) -> str:
"""Execute a task with the agent. """Execute a task with the agent.
@@ -213,6 +216,7 @@ class Agent(BaseAgent):
TimeoutError: If execution exceeds the maximum execution time. TimeoutError: If execution exceeds the maximum execution time.
ValueError: If the max execution time is not a positive integer. ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons. RuntimeError: If the agent execution fails for other reasons.
""" """
if self.tools_handler: if self.tools_handler:
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling") self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
@@ -228,18 +232,18 @@ class Agent(BaseAgent):
# schema = json.dumps(task.output_json, indent=2) # schema = json.dumps(task.output_json, indent=2)
schema = generate_model_description(task.output_json) schema = generate_model_description(task.output_json)
task_prompt += "\n" + self.i18n.slice( task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions" "formatted_task_instructions",
).format(output_format=schema) ).format(output_format=schema)
elif task.output_pydantic: elif task.output_pydantic:
schema = generate_model_description(task.output_pydantic) schema = generate_model_description(task.output_pydantic)
task_prompt += "\n" + self.i18n.slice( task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions" "formatted_task_instructions",
).format(output_format=schema) ).format(output_format=schema)
if context: if context:
task_prompt = self.i18n.slice("task_with_context").format( task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context task=task_prompt, context=context,
) )
if self._is_any_available_memory(): if self._is_any_available_memory():
@@ -267,25 +271,25 @@ class Agent(BaseAgent):
) )
try: try:
self.knowledge_search_query = self._get_knowledge_search_query( self.knowledge_search_query = self._get_knowledge_search_query(
task_prompt task_prompt,
) )
if self.knowledge_search_query: if self.knowledge_search_query:
agent_knowledge_snippets = self.knowledge.query( agent_knowledge_snippets = self.knowledge.query(
[self.knowledge_search_query], **knowledge_config [self.knowledge_search_query], **knowledge_config,
) )
if agent_knowledge_snippets: if agent_knowledge_snippets:
self.agent_knowledge_context = extract_knowledge_context( self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets agent_knowledge_snippets,
) )
if self.agent_knowledge_context: if self.agent_knowledge_context:
task_prompt += self.agent_knowledge_context task_prompt += self.agent_knowledge_context
if self.crew: if self.crew:
knowledge_snippets = self.crew.query_knowledge( knowledge_snippets = self.crew.query_knowledge(
[self.knowledge_search_query], **knowledge_config [self.knowledge_search_query], **knowledge_config,
) )
if knowledge_snippets: if knowledge_snippets:
self.crew_knowledge_context = extract_knowledge_context( self.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets knowledge_snippets,
) )
if self.crew_knowledge_context: if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context task_prompt += self.crew_knowledge_context
@@ -342,11 +346,12 @@ class Agent(BaseAgent):
not isinstance(self.max_execution_time, int) not isinstance(self.max_execution_time, int)
or self.max_execution_time <= 0 or self.max_execution_time <= 0
): ):
msg = "Max Execution time must be a positive integer greater than zero"
raise ValueError( raise ValueError(
"Max Execution time must be a positive integer greater than zero" msg,
) )
result = self._execute_with_timeout( result = self._execute_with_timeout(
task_prompt, task, self.max_execution_time task_prompt, task, self.max_execution_time,
) )
else: else:
result = self._execute_without_timeout(task_prompt, task) result = self._execute_without_timeout(task_prompt, task)
@@ -361,7 +366,7 @@ class Agent(BaseAgent):
error=str(e), error=str(e),
), ),
) )
raise e raise
except Exception as e: except Exception as e:
if e.__class__.__module__.startswith("litellm"): if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors # Do not retry on litellm errors
@@ -373,7 +378,7 @@ class Agent(BaseAgent):
error=str(e), error=str(e),
), ),
) )
raise e raise
self._times_executed += 1 self._times_executed += 1
if self._times_executed > self.max_retry_limit: if self._times_executed > self.max_retry_limit:
crewai_event_bus.emit( crewai_event_bus.emit(
@@ -384,7 +389,7 @@ class Agent(BaseAgent):
error=str(e), error=str(e),
), ),
) )
raise e raise
result = self.execute_task(task, context, tools) result = self.execute_task(task, context, tools)
if self.max_rpm and self._rpm_controller: if self.max_rpm and self._rpm_controller:
@@ -416,24 +421,27 @@ class Agent(BaseAgent):
Raises: Raises:
TimeoutError: If execution exceeds the timeout. TimeoutError: If execution exceeds the timeout.
RuntimeError: If execution fails for other reasons. RuntimeError: If execution fails for other reasons.
""" """
import concurrent.futures import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit( future = executor.submit(
self._execute_without_timeout, task_prompt=task_prompt, task=task self._execute_without_timeout, task_prompt=task_prompt, task=task,
) )
try: try:
return future.result(timeout=timeout) return future.result(timeout=timeout)
except concurrent.futures.TimeoutError: except concurrent.futures.TimeoutError:
future.cancel() future.cancel()
msg = f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
raise TimeoutError( raise TimeoutError(
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task." msg,
) )
except Exception as e: except Exception as e:
future.cancel() future.cancel()
raise RuntimeError(f"Task execution failed: {str(e)}") msg = f"Task execution failed: {e!s}"
raise RuntimeError(msg)
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str: def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
"""Execute a task without a timeout. """Execute a task without a timeout.
@@ -444,6 +452,7 @@ class Agent(BaseAgent):
Returns: Returns:
The output of the agent. The output of the agent.
""" """
return self.agent_executor.invoke( return self.agent_executor.invoke(
{ {
@@ -451,18 +460,19 @@ class Agent(BaseAgent):
"tool_names": self.agent_executor.tools_names, "tool_names": self.agent_executor.tools_names,
"tools": self.agent_executor.tools_description, "tools": self.agent_executor.tools_description,
"ask_for_human_input": task.human_input, "ask_for_human_input": task.human_input,
} },
)["output"] )["output"]
def create_agent_executor( def create_agent_executor(
self, tools: Optional[List[BaseTool]] = None, task=None self, tools: list[BaseTool] | None = None, task=None,
) -> None: ) -> None:
"""Create an agent executor for the agent. """Create an agent executor for the agent.
Returns: Returns:
An instance of the CrewAgentExecutor class. An instance of the CrewAgentExecutor class.
""" """
raw_tools: List[BaseTool] = tools or self.tools or [] raw_tools: list[BaseTool] = tools or self.tools or []
parsed_tools = parse_tools(raw_tools) parsed_tools = parse_tools(raw_tools)
prompt = Prompts( prompt = Prompts(
@@ -479,7 +489,7 @@ class Agent(BaseAgent):
if self.response_template: if self.response_template:
stop_words.append( stop_words.append(
self.response_template.split("{{ .Response }}")[1].strip() self.response_template.split("{{ .Response }}")[1].strip(),
) )
self.agent_executor = CrewAgentExecutor( self.agent_executor = CrewAgentExecutor(
@@ -504,10 +514,9 @@ class Agent(BaseAgent):
callbacks=[TokenCalcHandler(self._token_process)], callbacks=[TokenCalcHandler(self._token_process)],
) )
def get_delegation_tools(self, agents: List[BaseAgent]): def get_delegation_tools(self, agents: list[BaseAgent]):
agent_tools = AgentTools(agents=agents) agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools() return agent_tools.tools()
return tools
def get_multimodal_tools(self) -> Sequence[BaseTool]: def get_multimodal_tools(self) -> Sequence[BaseTool]:
from crewai.tools.agent_tools.add_image_tool import AddImageTool from crewai.tools.agent_tools.add_image_tool import AddImageTool
@@ -523,7 +532,7 @@ class Agent(BaseAgent):
return [CodeInterpreterTool(unsafe_mode=unsafe_mode)] return [CodeInterpreterTool(unsafe_mode=unsafe_mode)]
except ModuleNotFoundError: except ModuleNotFoundError:
self._logger.log( self._logger.log(
"info", "Coding tools not available. Install crewai_tools. " "info", "Coding tools not available. Install crewai_tools. ",
) )
def get_output_converter(self, llm, text, model, instructions): def get_output_converter(self, llm, text, model, instructions):
@@ -555,7 +564,7 @@ class Agent(BaseAgent):
) )
return task_prompt return task_prompt
def _render_text_description(self, tools: List[Any]) -> str: def _render_text_description(self, tools: list[Any]) -> str:
"""Render the tool name and description in plain text. """Render the tool name and description in plain text.
Output will be in the format of: Output will be in the format of:
@@ -565,48 +574,48 @@ class Agent(BaseAgent):
search: This tool is used for search search: This tool is used for search
calculator: This tool is used for math calculator: This tool is used for math
""" """
description = "\n".join( return "\n".join(
[ [
f"Tool name: {tool.name}\nTool description:\n{tool.description}" f"Tool name: {tool.name}\nTool description:\n{tool.description}"
for tool in tools for tool in tools
] ],
) )
return description
def _validate_docker_installation(self) -> None: def _validate_docker_installation(self) -> None:
"""Check if Docker is installed and running.""" """Check if Docker is installed and running."""
if not shutil.which("docker"): if not shutil.which("docker"):
msg = f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}"
raise RuntimeError( raise RuntimeError(
f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}" msg,
) )
try: try:
subprocess.run( subprocess.run(
["docker", "info"], ["docker", "info"],
check=True, check=True,
stdout=subprocess.PIPE, capture_output=True,
stderr=subprocess.PIPE,
) )
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
msg = f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
raise RuntimeError( raise RuntimeError(
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}" msg,
) )
def __repr__(self): def __repr__(self) -> str:
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})" return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
@property @property
def fingerprint(self) -> Fingerprint: def fingerprint(self) -> Fingerprint:
""" """Get the agent's fingerprint.
Get the agent's fingerprint.
Returns: Returns:
Fingerprint: The agent's fingerprint Fingerprint: The agent's fingerprint
""" """
return self.security_config.fingerprint return self.security_config.fingerprint
def set_fingerprint(self, fingerprint: Fingerprint): def set_fingerprint(self, fingerprint: Fingerprint) -> None:
self.security_config.fingerprint = fingerprint self.security_config.fingerprint = fingerprint
def _get_knowledge_search_query(self, task_prompt: str) -> str | None: def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
@@ -619,7 +628,7 @@ class Agent(BaseAgent):
), ),
) )
query = self.i18n.slice("knowledge_search_query").format( query = self.i18n.slice("knowledge_search_query").format(
task_prompt=task_prompt task_prompt=task_prompt,
) )
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt") rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
if not isinstance(self.llm, BaseLLM): if not isinstance(self.llm, BaseLLM):
@@ -644,7 +653,7 @@ class Agent(BaseAgent):
"content": rewriter_prompt, "content": rewriter_prompt,
}, },
{"role": "user", "content": query}, {"role": "user", "content": query},
] ],
) )
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
@@ -666,11 +675,10 @@ class Agent(BaseAgent):
def kickoff( def kickoff(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: str | list[dict[str, str]],
response_format: Optional[Type[Any]] = None, response_format: type[Any] | None = None,
) -> LiteAgentOutput: ) -> LiteAgentOutput:
""" """Execute the agent with the given messages using a LiteAgent instance.
Execute the agent with the given messages using a LiteAgent instance.
This method is useful when you want to use the Agent configuration but This method is useful when you want to use the Agent configuration but
with the simpler and more direct execution flow of LiteAgent. with the simpler and more direct execution flow of LiteAgent.
@@ -683,6 +691,7 @@ class Agent(BaseAgent):
Returns: Returns:
LiteAgentOutput: The result of the agent execution. LiteAgentOutput: The result of the agent execution.
""" """
lite_agent = LiteAgent( lite_agent = LiteAgent(
role=self.role, role=self.role,
@@ -703,11 +712,10 @@ class Agent(BaseAgent):
async def kickoff_async( async def kickoff_async(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: str | list[dict[str, str]],
response_format: Optional[Type[Any]] = None, response_format: type[Any] | None = None,
) -> LiteAgentOutput: ) -> LiteAgentOutput:
""" """Execute the agent asynchronously with the given messages using a LiteAgent instance.
Execute the agent asynchronously with the given messages using a LiteAgent instance.
This is the async version of the kickoff method. This is the async version of the kickoff method.
@@ -719,6 +727,7 @@ class Agent(BaseAgent):
Returns: Returns:
LiteAgentOutput: The result of the agent execution. LiteAgentOutput: The result of the agent execution.
""" """
lite_agent = LiteAgent( lite_agent = LiteAgent(
role=self.role, role=self.role,

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any
from pydantic import PrivateAttr from pydantic import PrivateAttr
@@ -16,27 +16,27 @@ class BaseAgentAdapter(BaseAgent, ABC):
""" """
adapted_structured_output: bool = False adapted_structured_output: bool = False
_agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None) _agent_config: dict[str, Any] | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any): def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any) -> None:
super().__init__(adapted_agent=True, **kwargs) super().__init__(adapted_agent=True, **kwargs)
self._agent_config = agent_config self._agent_config = agent_config
@abstractmethod @abstractmethod
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None: def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure and adapt tools for the specific agent implementation. """Configure and adapt tools for the specific agent implementation.
Args: Args:
tools: Optional list of BaseTool instances to be configured tools: Optional list of BaseTool instances to be configured
""" """
pass
def configure_structured_output(self, structured_output: Any) -> None: def configure_structured_output(self, structured_output: Any) -> None:
"""Configure the structured output for the specific agent implementation. """Configure the structured output for the specific agent implementation.
Args: Args:
structured_output: The structured output to be configured structured_output: The structured output to be configured
""" """
pass

View File

@@ -8,7 +8,7 @@ class BaseConverterAdapter(ABC):
converter adapters must implement for converting structured output. converter adapters must implement for converting structured output.
""" """
def __init__(self, agent_adapter): def __init__(self, agent_adapter) -> None:
self.agent_adapter = agent_adapter self.agent_adapter = agent_adapter
@abstractmethod @abstractmethod
@@ -16,14 +16,11 @@ class BaseConverterAdapter(ABC):
"""Configure agents to return structured output. """Configure agents to return structured output.
Must support json and pydantic output. Must support json and pydantic output.
""" """
pass
@abstractmethod @abstractmethod
def enhance_system_prompt(self, base_prompt: str) -> str: def enhance_system_prompt(self, base_prompt: str) -> str:
"""Enhance the system prompt with structured output instructions.""" """Enhance the system prompt with structured output instructions."""
pass
@abstractmethod @abstractmethod
def post_process_result(self, result: str) -> str: def post_process_result(self, result: str) -> str:
"""Post-process the result to ensure it matches the expected format: string.""" """Post-process the result to ensure it matches the expected format: string."""
pass

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional from typing import Any
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
@@ -12,23 +12,23 @@ class BaseToolAdapter(ABC):
different frameworks and platforms. different frameworks and platforms.
""" """
original_tools: List[BaseTool] original_tools: list[BaseTool]
converted_tools: List[Any] converted_tools: list[Any]
def __init__(self, tools: Optional[List[BaseTool]] = None): def __init__(self, tools: list[BaseTool] | None = None) -> None:
self.original_tools = tools or [] self.original_tools = tools or []
self.converted_tools = [] self.converted_tools = []
@abstractmethod @abstractmethod
def configure_tools(self, tools: List[BaseTool]) -> None: def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure and convert tools for the specific implementation. """Configure and convert tools for the specific implementation.
Args: Args:
tools: List of BaseTool instances to be configured and converted tools: List of BaseTool instances to be configured and converted
"""
pass
def tools(self) -> List[Any]: """
def tools(self) -> list[Any]:
"""Return all converted tools.""" """Return all converted tools."""
return self.converted_tools return self.converted_tools

View File

@@ -1,4 +1,4 @@
from typing import Any, AsyncIterable, Dict, List, Optional from typing import Any
from pydantic import Field, PrivateAttr from pydantic import Field, PrivateAttr
@@ -52,16 +52,17 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
role: str, role: str,
goal: str, goal: str,
backstory: str, backstory: str,
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
llm: Any = None, llm: Any = None,
max_iterations: int = 10, max_iterations: int = 10,
agent_config: Optional[Dict[str, Any]] = None, agent_config: dict[str, Any] | None = None,
**kwargs, **kwargs,
): ) -> None:
"""Initialize the LangGraph agent adapter.""" """Initialize the LangGraph agent adapter."""
if not LANGGRAPH_AVAILABLE: if not LANGGRAPH_AVAILABLE:
msg = "LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
raise ImportError( raise ImportError(
"LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`" msg,
) )
super().__init__( super().__init__(
role=role, role=role,
@@ -82,7 +83,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
try: try:
self._memory = MemorySaver() self._memory = MemorySaver()
converted_tools: List[Any] = self._tool_adapter.tools() converted_tools: list[Any] = self._tool_adapter.tools()
if self._agent_config: if self._agent_config:
self._graph = create_react_agent( self._graph = create_react_agent(
model=self.llm, model=self.llm,
@@ -101,11 +102,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
except ImportError as e: except ImportError as e:
self._logger.log( self._logger.log(
"error", f"Failed to import LangGraph dependencies: {str(e)}" "error", f"Failed to import LangGraph dependencies: {e!s}",
) )
raise raise
except Exception as e: except Exception as e:
self._logger.log("error", f"Error setting up LangGraph agent: {str(e)}") self._logger.log("error", f"Error setting up LangGraph agent: {e!s}")
raise raise
def _build_system_prompt(self) -> str: def _build_system_prompt(self) -> str:
@@ -124,8 +125,8 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
def execute_task( def execute_task(
self, self,
task: Any, task: Any,
context: Optional[str] = None, context: str | None = None,
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
) -> str: ) -> str:
"""Execute a task using the LangGraph workflow.""" """Execute a task using the LangGraph workflow."""
self.create_agent_executor(tools) self.create_agent_executor(tools)
@@ -137,7 +138,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
if context: if context:
task_prompt = self.i18n.slice("task_with_context").format( task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context task=task_prompt, context=context,
) )
crewai_event_bus.emit( crewai_event_bus.emit(
@@ -159,7 +160,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
"messages": [ "messages": [
("system", self._build_system_prompt()), ("system", self._build_system_prompt()),
("user", task_prompt), ("user", task_prompt),
] ],
}, },
config, config,
) )
@@ -180,14 +181,14 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=AgentExecutionCompletedEvent( event=AgentExecutionCompletedEvent(
agent=self, task=task, output=final_answer agent=self, task=task, output=final_answer,
), ),
) )
return final_answer return final_answer
except Exception as e: except Exception as e:
self._logger.log("error", f"Error executing LangGraph task: {str(e)}") self._logger.log("error", f"Error executing LangGraph task: {e!s}")
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=AgentExecutionErrorEvent( event=AgentExecutionErrorEvent(
@@ -198,11 +199,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
) )
raise raise
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None: def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
"""Configure the LangGraph agent for execution.""" """Configure the LangGraph agent for execution."""
self.configure_tools(tools) self.configure_tools(tools)
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None: def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure tools for the LangGraph agent.""" """Configure tools for the LangGraph agent."""
if tools: if tools:
all_tools = list(self.tools or []) + list(tools or []) all_tools = list(self.tools or []) + list(tools or [])
@@ -210,13 +211,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
available_tools = self._tool_adapter.tools() available_tools = self._tool_adapter.tools()
self._graph.tools = available_tools self._graph.tools = available_tools
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]: def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
"""Implement delegation tools support for LangGraph.""" """Implement delegation tools support for LangGraph."""
agent_tools = AgentTools(agents=agents) agent_tools = AgentTools(agents=agents)
return agent_tools.tools() return agent_tools.tools()
def get_output_converter( def get_output_converter(
self, llm: Any, text: str, model: Any, instructions: str self, llm: Any, text: str, model: Any, instructions: str,
) -> Any: ) -> Any:
"""Convert output format if needed.""" """Convert output format if needed."""
return Converter(llm=llm, text=text, model=model, instructions=instructions) return Converter(llm=llm, text=text, model=model, instructions=instructions)

View File

@@ -1,29 +1,25 @@
import inspect import inspect
from typing import Any, List, Optional from typing import Any
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
class LangGraphToolAdapter(BaseToolAdapter): class LangGraphToolAdapter(BaseToolAdapter):
"""Adapts CrewAI tools to LangGraph agent tool compatible format""" """Adapts CrewAI tools to LangGraph agent tool compatible format."""
def __init__(self, tools: Optional[List[BaseTool]] = None): def __init__(self, tools: list[BaseTool] | None = None) -> None:
self.original_tools = tools or [] self.original_tools = tools or []
self.converted_tools = [] self.converted_tools = []
def configure_tools(self, tools: List[BaseTool]) -> None: def configure_tools(self, tools: list[BaseTool]) -> None:
""" """Configure and convert CrewAI tools to LangGraph-compatible format.
Configure and convert CrewAI tools to LangGraph-compatible format.
LangGraph expects tools in langchain_core.tools format. LangGraph expects tools in langchain_core.tools format.
""" """
from langchain_core.tools import BaseTool, StructuredTool from langchain_core.tools import BaseTool, StructuredTool
converted_tools = [] converted_tools = []
if self.original_tools: all_tools = tools + self.original_tools if self.original_tools else tools
all_tools = tools + self.original_tools
else:
all_tools = tools
for tool in all_tools: for tool in all_tools:
if isinstance(tool, BaseTool): if isinstance(tool, BaseTool):
converted_tools.append(tool) converted_tools.append(tool)
@@ -57,5 +53,5 @@ class LangGraphToolAdapter(BaseToolAdapter):
self.converted_tools = converted_tools self.converted_tools = converted_tools
def tools(self) -> List[Any]: def tools(self) -> list[Any]:
return self.converted_tools or [] return self.converted_tools or []

View File

@@ -5,10 +5,10 @@ from crewai.utilities.converter import generate_model_description
class LangGraphConverterAdapter(BaseConverterAdapter): class LangGraphConverterAdapter(BaseConverterAdapter):
"""Adapter for handling structured output conversion in LangGraph agents""" """Adapter for handling structured output conversion in LangGraph agents."""
def __init__(self, agent_adapter): def __init__(self, agent_adapter) -> None:
"""Initialize the converter adapter with a reference to the agent adapter""" """Initialize the converter adapter with a reference to the agent adapter."""
self.agent_adapter = agent_adapter self.agent_adapter = agent_adapter
self._output_format = None self._output_format = None
self._schema = None self._schema = None
@@ -32,7 +32,7 @@ class LangGraphConverterAdapter(BaseConverterAdapter):
self._system_prompt_appendix = self._generate_system_prompt_appendix() self._system_prompt_appendix = self._generate_system_prompt_appendix()
def _generate_system_prompt_appendix(self) -> str: def _generate_system_prompt_appendix(self) -> str:
"""Generate an appendix for the system prompt to enforce structured output""" """Generate an appendix for the system prompt to enforce structured output."""
if not self._output_format or not self._schema: if not self._output_format or not self._schema:
return "" return ""
@@ -46,14 +46,14 @@ The output should be raw JSON that exactly matches the specified schema.
""" """
def enhance_system_prompt(self, original_prompt: str) -> str: def enhance_system_prompt(self, original_prompt: str) -> str:
"""Add structured output instructions to the system prompt if needed""" """Add structured output instructions to the system prompt if needed."""
if not self._system_prompt_appendix: if not self._system_prompt_appendix:
return original_prompt return original_prompt
return f"{original_prompt}\n{self._system_prompt_appendix}" return f"{original_prompt}\n{self._system_prompt_appendix}"
def post_process_result(self, result: str) -> str: def post_process_result(self, result: str) -> str:
"""Post-process the result to ensure it matches the expected format""" """Post-process the result to ensure it matches the expected format."""
if not self._output_format: if not self._output_format:
return result return result

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Optional from typing import Any
from pydantic import Field, PrivateAttr from pydantic import Field, PrivateAttr
@@ -29,13 +29,13 @@ except ImportError:
class OpenAIAgentAdapter(BaseAgentAdapter): class OpenAIAgentAdapter(BaseAgentAdapter):
"""Adapter for OpenAI Assistants""" """Adapter for OpenAI Assistants."""
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
_openai_agent: "OpenAIAgent" = PrivateAttr() _openai_agent: "OpenAIAgent" = PrivateAttr()
_logger: Logger = PrivateAttr(default_factory=lambda: Logger()) _logger: Logger = PrivateAttr(default_factory=lambda: Logger())
_active_thread: Optional[str] = PrivateAttr(default=None) _active_thread: str | None = PrivateAttr(default=None)
function_calling_llm: Any = Field(default=None) function_calling_llm: Any = Field(default=None)
step_callback: Any = Field(default=None) step_callback: Any = Field(default=None)
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr() _tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
@@ -44,29 +44,29 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
def __init__( def __init__(
self, self,
model: str = "gpt-4o-mini", model: str = "gpt-4o-mini",
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
agent_config: Optional[dict] = None, agent_config: dict | None = None,
**kwargs, **kwargs,
): ) -> None:
if not OPENAI_AVAILABLE: if not OPENAI_AVAILABLE:
msg = "OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
raise ImportError( raise ImportError(
"OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`" msg,
) )
else: role = kwargs.pop("role", None)
role = kwargs.pop("role", None) goal = kwargs.pop("goal", None)
goal = kwargs.pop("goal", None) backstory = kwargs.pop("backstory", None)
backstory = kwargs.pop("backstory", None) super().__init__(
super().__init__( role=role,
role=role, goal=goal,
goal=goal, backstory=backstory,
backstory=backstory, tools=tools,
tools=tools, agent_config=agent_config,
agent_config=agent_config, **kwargs,
**kwargs, )
) self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools) self.llm = model
self.llm = model self._converter_adapter = OpenAIConverterAdapter(self)
self._converter_adapter = OpenAIConverterAdapter(self)
def _build_system_prompt(self) -> str: def _build_system_prompt(self) -> str:
"""Build a system prompt for the OpenAI agent.""" """Build a system prompt for the OpenAI agent."""
@@ -84,10 +84,10 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
def execute_task( def execute_task(
self, self,
task: Any, task: Any,
context: Optional[str] = None, context: str | None = None,
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
) -> str: ) -> str:
"""Execute a task using the OpenAI Assistant""" """Execute a task using the OpenAI Assistant."""
self._converter_adapter.configure_structured_output(task) self._converter_adapter.configure_structured_output(task)
self.create_agent_executor(tools) self.create_agent_executor(tools)
@@ -98,7 +98,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
task_prompt = task.prompt() task_prompt = task.prompt()
if context: if context:
task_prompt = self.i18n.slice("task_with_context").format( task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context task=task_prompt, context=context,
) )
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
@@ -114,13 +114,13 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=AgentExecutionCompletedEvent( event=AgentExecutionCompletedEvent(
agent=self, task=task, output=final_answer agent=self, task=task, output=final_answer,
), ),
) )
return final_answer return final_answer
except Exception as e: except Exception as e:
self._logger.log("error", f"Error executing OpenAI task: {str(e)}") self._logger.log("error", f"Error executing OpenAI task: {e!s}")
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=AgentExecutionErrorEvent( event=AgentExecutionErrorEvent(
@@ -131,9 +131,8 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
) )
raise raise
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None: def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
""" """Configure the OpenAI agent for execution.
Configure the OpenAI agent for execution.
While OpenAI handles execution differently through Runner, While OpenAI handles execution differently through Runner,
we can use this method to set up tools and configurations. we can use this method to set up tools and configurations.
""" """
@@ -152,27 +151,27 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
self.agent_executor = Runner self.agent_executor = Runner
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None: def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure tools for the OpenAI Assistant""" """Configure tools for the OpenAI Assistant."""
if tools: if tools:
self._tool_adapter.configure_tools(tools) self._tool_adapter.configure_tools(tools)
if self._tool_adapter.converted_tools: if self._tool_adapter.converted_tools:
self._openai_agent.tools = self._tool_adapter.converted_tools self._openai_agent.tools = self._tool_adapter.converted_tools
def handle_execution_result(self, result: Any) -> str: def handle_execution_result(self, result: Any) -> str:
"""Process OpenAI Assistant execution result converting any structured output to a string""" """Process OpenAI Assistant execution result converting any structured output to a string."""
return self._converter_adapter.post_process_result(result.final_output) return self._converter_adapter.post_process_result(result.final_output)
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]: def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
"""Implement delegation tools support""" """Implement delegation tools support."""
agent_tools = AgentTools(agents=agents) agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools() return agent_tools.tools()
return tools
def configure_structured_output(self, task) -> None: def configure_structured_output(self, task) -> None:
"""Configure the structured output for the specific agent implementation. """Configure the structured output for the specific agent implementation.
Args: Args:
structured_output: The structured output to be configured structured_output: The structured output to be configured
""" """
self._converter_adapter.configure_structured_output(task) self._converter_adapter.configure_structured_output(task)

View File

@@ -1,5 +1,5 @@
import inspect import inspect
from typing import Any, List, Optional from typing import Any
from agents import FunctionTool, Tool from agents import FunctionTool, Tool
@@ -8,42 +8,36 @@ from crewai.tools import BaseTool
class OpenAIAgentToolAdapter(BaseToolAdapter): class OpenAIAgentToolAdapter(BaseToolAdapter):
"""Adapter for OpenAI Assistant tools""" """Adapter for OpenAI Assistant tools."""
def __init__(self, tools: Optional[List[BaseTool]] = None): def __init__(self, tools: list[BaseTool] | None = None) -> None:
self.original_tools = tools or [] self.original_tools = tools or []
def configure_tools(self, tools: List[BaseTool]) -> None: def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure tools for the OpenAI Assistant""" """Configure tools for the OpenAI Assistant."""
if self.original_tools: all_tools = tools + self.original_tools if self.original_tools else tools
all_tools = tools + self.original_tools
else:
all_tools = tools
if all_tools: if all_tools:
self.converted_tools = self._convert_tools_to_openai_format(all_tools) self.converted_tools = self._convert_tools_to_openai_format(all_tools)
def _convert_tools_to_openai_format( def _convert_tools_to_openai_format(
self, tools: Optional[List[BaseTool]] self, tools: list[BaseTool] | None,
) -> List[Tool]: ) -> list[Tool]:
"""Convert CrewAI tools to OpenAI Assistant tool format""" """Convert CrewAI tools to OpenAI Assistant tool format."""
if not tools: if not tools:
return [] return []
def sanitize_tool_name(name: str) -> str: def sanitize_tool_name(name: str) -> str:
"""Convert tool name to match OpenAI's required pattern""" """Convert tool name to match OpenAI's required pattern."""
import re import re
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower() return re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
return sanitized
def create_tool_wrapper(tool: BaseTool): def create_tool_wrapper(tool: BaseTool):
"""Create a wrapper function that handles the OpenAI function tool interface""" """Create a wrapper function that handles the OpenAI function tool interface."""
async def wrapper(context_wrapper: Any, arguments: Any) -> Any: async def wrapper(context_wrapper: Any, arguments: Any) -> Any:
# Get the parameter name from the schema # Get the parameter name from the schema
param_name = list( param_name = next(iter(tool.args_schema.model_json_schema()["properties"].keys()))
tool.args_schema.model_json_schema()["properties"].keys()
)[0]
# Handle different argument types # Handle different argument types
if isinstance(arguments, dict): if isinstance(arguments, dict):

View File

@@ -7,8 +7,7 @@ from crewai.utilities.i18n import I18N
class OpenAIConverterAdapter(BaseConverterAdapter): class OpenAIConverterAdapter(BaseConverterAdapter):
""" """Adapter for handling structured output conversion in OpenAI agents.
Adapter for handling structured output conversion in OpenAI agents.
This adapter enhances the OpenAI agent to handle structured output formats This adapter enhances the OpenAI agent to handle structured output formats
and post-processes the results when needed. and post-processes the results when needed.
@@ -17,21 +16,22 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
_output_format: The expected output format (json, pydantic, or None) _output_format: The expected output format (json, pydantic, or None)
_schema: The schema description for the expected output _schema: The schema description for the expected output
_output_model: The Pydantic model for the output _output_model: The Pydantic model for the output
""" """
def __init__(self, agent_adapter): def __init__(self, agent_adapter) -> None:
"""Initialize the converter adapter with a reference to the agent adapter""" """Initialize the converter adapter with a reference to the agent adapter."""
self.agent_adapter = agent_adapter self.agent_adapter = agent_adapter
self._output_format = None self._output_format = None
self._schema = None self._schema = None
self._output_model = None self._output_model = None
def configure_structured_output(self, task) -> None: def configure_structured_output(self, task) -> None:
""" """Configure the structured output for OpenAI agent based on task requirements.
Configure the structured output for OpenAI agent based on task requirements.
Args: Args:
task: The task containing output format requirements task: The task containing output format requirements
""" """
# Reset configuration # Reset configuration
self._output_format = None self._output_format = None
@@ -55,14 +55,14 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
self._output_model = task.output_pydantic self._output_model = task.output_pydantic
def enhance_system_prompt(self, base_prompt: str) -> str: def enhance_system_prompt(self, base_prompt: str) -> str:
""" """Enhance the base system prompt with structured output requirements if needed.
Enhance the base system prompt with structured output requirements if needed.
Args: Args:
base_prompt: The original system prompt base_prompt: The original system prompt
Returns: Returns:
Enhanced system prompt with output format instructions if needed Enhanced system prompt with output format instructions if needed
""" """
if not self._output_format: if not self._output_format:
return base_prompt return base_prompt
@@ -76,8 +76,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
return f"{base_prompt}\n\n{output_schema}" return f"{base_prompt}\n\n{output_schema}"
def post_process_result(self, result: str) -> str: def post_process_result(self, result: str) -> str:
""" """Post-process the result to ensure it matches the expected format.
Post-process the result to ensure it matches the expected format.
This method attempts to extract valid JSON from the result if necessary. This method attempts to extract valid JSON from the result if necessary.
@@ -86,6 +85,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
Returns: Returns:
Processed result conforming to the expected output format Processed result conforming to the expected output format
""" """
if not self._output_format: if not self._output_format:
return result return result

View File

@@ -1,8 +1,9 @@
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy as shallow_copy from copy import copy as shallow_copy
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, TypeVar from typing import Any, TypeVar
from pydantic import ( from pydantic import (
UUID4, UUID4,
@@ -14,6 +15,7 @@ from pydantic import (
model_validator, model_validator,
) )
from pydantic_core import PydanticCustomError from pydantic_core import PydanticCustomError
from typing_extensions import Self
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler from crewai.agents.cache.cache_handler import CacheHandler
@@ -25,7 +27,6 @@ from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool from crewai.tools.base_tool import BaseTool, Tool
from crewai.utilities import I18N, Logger, RPMController from crewai.utilities import I18N, Logger, RPMController
from crewai.utilities.config import process_config from crewai.utilities.config import process_config
from crewai.utilities.converter import Converter
from crewai.utilities.string_utils import interpolate_only from crewai.utilities.string_utils import interpolate_only
T = TypeVar("T", bound="BaseAgent") T = TypeVar("T", bound="BaseAgent")
@@ -77,30 +78,31 @@ class BaseAgent(ABC, BaseModel):
Set the rpm controller for the agent. Set the rpm controller for the agent.
set_private_attrs() -> "BaseAgent": set_private_attrs() -> "BaseAgent":
Set private attributes. Set private attributes.
""" """
__hash__ = object.__hash__ # type: ignore __hash__ = object.__hash__ # type: ignore
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False)) _logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None) _rpm_controller: RPMController | None = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None) _request_within_rpm_limit: Any = PrivateAttr(default=None)
_original_role: Optional[str] = PrivateAttr(default=None) _original_role: str | None = PrivateAttr(default=None)
_original_goal: Optional[str] = PrivateAttr(default=None) _original_goal: str | None = PrivateAttr(default=None)
_original_backstory: Optional[str] = PrivateAttr(default=None) _original_backstory: str | None = PrivateAttr(default=None)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess) _token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
role: str = Field(description="Role of the agent") role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent") goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent") backstory: str = Field(description="Backstory of the agent")
config: Optional[Dict[str, Any]] = Field( config: dict[str, Any] | None = Field(
description="Configuration for the agent", default=None, exclude=True description="Configuration for the agent", default=None, exclude=True,
) )
cache: bool = Field( cache: bool = Field(
default=True, description="Whether the agent should use a cache for tool usage." default=True, description="Whether the agent should use a cache for tool usage.",
) )
verbose: bool = Field( verbose: bool = Field(
default=False, description="Verbose mode for the Agent Execution" default=False, description="Verbose mode for the Agent Execution",
) )
max_rpm: Optional[int] = Field( max_rpm: int | None = Field(
default=None, default=None,
description="Maximum number of requests per minute for the agent execution to be respected.", description="Maximum number of requests per minute for the agent execution to be respected.",
) )
@@ -108,41 +110,41 @@ class BaseAgent(ABC, BaseModel):
default=False, default=False,
description="Enable agent to delegate and ask questions among each other.", description="Enable agent to delegate and ask questions among each other.",
) )
tools: Optional[List[BaseTool]] = Field( tools: list[BaseTool] | None = Field(
default_factory=list, description="Tools at agents' disposal" default_factory=list, description="Tools at agents' disposal",
) )
max_iter: int = Field( max_iter: int = Field(
default=25, description="Maximum iterations for an agent to execute a task" default=25, description="Maximum iterations for an agent to execute a task",
) )
agent_executor: InstanceOf = Field( agent_executor: InstanceOf = Field(
default=None, description="An instance of the CrewAgentExecutor class." default=None, description="An instance of the CrewAgentExecutor class.",
) )
llm: Any = Field( llm: Any = Field(
default=None, description="Language model that will run the agent." default=None, description="Language model that will run the agent.",
) )
crew: Any = Field(default=None, description="Crew to which the agent belongs.") crew: Any = Field(default=None, description="Crew to which the agent belongs.")
i18n: I18N = Field(default=I18N(), description="Internationalization settings.") i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
cache_handler: Optional[InstanceOf[CacheHandler]] = Field( cache_handler: InstanceOf[CacheHandler] | None = Field(
default=None, description="An instance of the CacheHandler class." default=None, description="An instance of the CacheHandler class.",
) )
tools_handler: InstanceOf[ToolsHandler] = Field( tools_handler: InstanceOf[ToolsHandler] = Field(
default_factory=ToolsHandler, default_factory=ToolsHandler,
description="An instance of the ToolsHandler class.", description="An instance of the ToolsHandler class.",
) )
tools_results: List[Dict[str, Any]] = Field( tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent." default=[], description="Results of the tools used by the agent.",
) )
max_tokens: Optional[int] = Field( max_tokens: int | None = Field(
default=None, description="Maximum number of tokens for the agent's execution." default=None, description="Maximum number of tokens for the agent's execution.",
) )
knowledge: Optional[Knowledge] = Field( knowledge: Knowledge | None = Field(
default=None, description="Knowledge for the agent." default=None, description="Knowledge for the agent.",
) )
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field( knowledge_sources: list[BaseKnowledgeSource] | None = Field(
default=None, default=None,
description="Knowledge sources for the agent.", description="Knowledge sources for the agent.",
) )
knowledge_storage: Optional[Any] = Field( knowledge_storage: Any | None = Field(
default=None, default=None,
description="Custom knowledge storage for the agent.", description="Custom knowledge storage for the agent.",
) )
@@ -150,13 +152,13 @@ class BaseAgent(ABC, BaseModel):
default_factory=SecurityConfig, default_factory=SecurityConfig,
description="Security configuration for the agent, including fingerprinting.", description="Security configuration for the agent, including fingerprinting.",
) )
callbacks: List[Callable] = Field( callbacks: list[Callable] = Field(
default=[], description="Callbacks to be used for the agent" default=[], description="Callbacks to be used for the agent",
) )
adapted_agent: bool = Field( adapted_agent: bool = Field(
default=False, description="Whether the agent is adapted" default=False, description="Whether the agent is adapted",
) )
knowledge_config: Optional[KnowledgeConfig] = Field( knowledge_config: KnowledgeConfig | None = Field(
default=None, default=None,
description="Knowledge configuration for the agent such as limits and threshold", description="Knowledge configuration for the agent such as limits and threshold",
) )
@@ -168,7 +170,7 @@ class BaseAgent(ABC, BaseModel):
@field_validator("tools") @field_validator("tools")
@classmethod @classmethod
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]: def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
"""Validate and process the tools provided to the agent. """Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool This method ensures that each tool is either an instance of BaseTool
@@ -188,11 +190,14 @@ class BaseAgent(ABC, BaseModel):
# Tool has the required attributes, create a Tool instance # Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool)) processed_tools.append(Tool.from_langchain(tool))
else: else:
raise ValueError( msg = (
f"Invalid tool type: {type(tool)}. " f"Invalid tool type: {type(tool)}. "
"Tool must be an instance of BaseTool or " "Tool must be an instance of BaseTool or "
"an object with 'name', 'func', and 'description' attributes." "an object with 'name', 'func', and 'description' attributes."
) )
raise ValueError(
msg,
)
return processed_tools return processed_tools
@model_validator(mode="after") @model_validator(mode="after")
@@ -200,15 +205,16 @@ class BaseAgent(ABC, BaseModel):
# Validate required fields # Validate required fields
for field in ["role", "goal", "backstory"]: for field in ["role", "goal", "backstory"]:
if getattr(self, field) is None: if getattr(self, field) is None:
msg = f"{field} must be provided either directly or through config"
raise ValueError( raise ValueError(
f"{field} must be provided either directly or through config" msg,
) )
# Set private attributes # Set private attributes
self._logger = Logger(verbose=self.verbose) self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller: if self.max_rpm and not self._rpm_controller:
self._rpm_controller = RPMController( self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger max_rpm=self.max_rpm, logger=self._logger,
) )
if not self._token_process: if not self._token_process:
self._token_process = TokenProcess() self._token_process = TokenProcess()
@@ -221,10 +227,11 @@ class BaseAgent(ABC, BaseModel):
@field_validator("id", mode="before") @field_validator("id", mode="before")
@classmethod @classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: def _deny_user_set_id(cls, v: UUID4 | None) -> None:
if v: if v:
msg = "may_not_set_field"
raise PydanticCustomError( raise PydanticCustomError(
"may_not_set_field", "This field is not to be set by the user.", {} msg, "This field is not to be set by the user.", {},
) )
@model_validator(mode="after") @model_validator(mode="after")
@@ -233,7 +240,7 @@ class BaseAgent(ABC, BaseModel):
self._logger = Logger(verbose=self.verbose) self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller: if self.max_rpm and not self._rpm_controller:
self._rpm_controller = RPMController( self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger max_rpm=self.max_rpm, logger=self._logger,
) )
if not self._token_process: if not self._token_process:
self._token_process = TokenProcess() self._token_process = TokenProcess()
@@ -252,8 +259,8 @@ class BaseAgent(ABC, BaseModel):
def execute_task( def execute_task(
self, self,
task: Any, task: Any,
context: Optional[str] = None, context: str | None = None,
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
) -> str: ) -> str:
pass pass
@@ -262,11 +269,10 @@ class BaseAgent(ABC, BaseModel):
pass pass
@abstractmethod @abstractmethod
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]: def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
"""Set the task tools that init BaseAgenTools class.""" """Set the task tools that init BaseAgenTools class."""
pass
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel" def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
"""Create a deep copy of the Agent.""" """Create a deep copy of the Agent."""
exclude = { exclude = {
"id", "id",
@@ -309,7 +315,7 @@ class BaseAgent(ABC, BaseModel):
copied_data = self.model_dump(exclude=exclude) copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None} copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = type(self)( return type(self)(
**copied_data, **copied_data,
llm=existing_llm, llm=existing_llm,
tools=self.tools, tools=self.tools,
@@ -318,9 +324,8 @@ class BaseAgent(ABC, BaseModel):
knowledge_storage=copied_knowledge_storage, knowledge_storage=copied_knowledge_storage,
) )
return copied_agent
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None: def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolate inputs into the agent description and backstory.""" """Interpolate inputs into the agent description and backstory."""
if self._original_role is None: if self._original_role is None:
self._original_role = self.role self._original_role = self.role
@@ -331,13 +336,13 @@ class BaseAgent(ABC, BaseModel):
if inputs: if inputs:
self.role = interpolate_only( self.role = interpolate_only(
input_string=self._original_role, inputs=inputs input_string=self._original_role, inputs=inputs,
) )
self.goal = interpolate_only( self.goal = interpolate_only(
input_string=self._original_goal, inputs=inputs input_string=self._original_goal, inputs=inputs,
) )
self.backstory = interpolate_only( self.backstory = interpolate_only(
input_string=self._original_backstory, inputs=inputs input_string=self._original_backstory, inputs=inputs,
) )
def set_cache_handler(self, cache_handler: CacheHandler) -> None: def set_cache_handler(self, cache_handler: CacheHandler) -> None:
@@ -345,6 +350,7 @@ class BaseAgent(ABC, BaseModel):
Args: Args:
cache_handler: An instance of the CacheHandler class. cache_handler: An instance of the CacheHandler class.
""" """
self.tools_handler = ToolsHandler() self.tools_handler = ToolsHandler()
if self.cache: if self.cache:
@@ -357,10 +363,11 @@ class BaseAgent(ABC, BaseModel):
Args: Args:
rpm_controller: An instance of the RPMController class. rpm_controller: An instance of the RPMController class.
""" """
if not self._rpm_controller: if not self._rpm_controller:
self._rpm_controller = rpm_controller self._rpm_controller = rpm_controller
self.create_agent_executor() self.create_agent_executor()
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
pass pass

View File

@@ -1,3 +1,4 @@
import contextlib
import time import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -43,8 +44,7 @@ class CrewAgentExecutorMixin:
}, },
agent=self.agent.role, agent=self.agent.role,
) )
except Exception as e: except Exception:
print(f"Failed to add to short term memory: {e}")
pass pass
def _create_external_memory(self, output) -> None: def _create_external_memory(self, output) -> None:
@@ -56,7 +56,7 @@ class CrewAgentExecutorMixin:
and hasattr(self.crew, "_external_memory") and hasattr(self.crew, "_external_memory")
and self.crew._external_memory and self.crew._external_memory
): ):
try: with contextlib.suppress(Exception):
self.crew._external_memory.save( self.crew._external_memory.save(
value=output.text, value=output.text,
metadata={ metadata={
@@ -64,9 +64,6 @@ class CrewAgentExecutorMixin:
}, },
agent=self.agent.role, agent=self.agent.role,
) )
except Exception as e:
print(f"Failed to add to external memory: {e}")
pass
def _create_long_term_memory(self, output) -> None: def _create_long_term_memory(self, output) -> None:
"""Create and save long-term and entity memory items based on evaluation.""" """Create and save long-term and entity memory items based on evaluation."""
@@ -103,15 +100,13 @@ class CrewAgentExecutorMixin:
type=entity.type, type=entity.type,
description=entity.description, description=entity.description,
relationships="\n".join( relationships="\n".join(
[f"- {r}" for r in entity.relationships] [f"- {r}" for r in entity.relationships],
), ),
) )
self.crew._entity_memory.save(entity_memory) self.crew._entity_memory.save(entity_memory)
except AttributeError as e: except AttributeError:
print(f"Missing attributes for long term memory: {e}")
pass pass
except Exception as e: except Exception:
print(f"Failed to add to long term memory: {e}")
pass pass
elif ( elif (
self.crew self.crew
@@ -126,7 +121,7 @@ class CrewAgentExecutorMixin:
def _ask_human_input(self, final_answer: str) -> str: def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging.""" """Prompt human input with mode-appropriate messaging."""
self._printer.print( self._printer.print(
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m" content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m",
) )
# Training mode prompt (single iteration) # Training mode prompt (single iteration)

View File

@@ -1,12 +1,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class OutputConverter(BaseModel, ABC): class OutputConverter(BaseModel, ABC):
""" """Abstract base class for converting task results into structured formats.
Abstract base class for converting task results into structured formats.
This class provides a framework for converting unstructured text into This class provides a framework for converting unstructured text into
either Pydantic models or JSON, tailored for specific agent requirements. either Pydantic models or JSON, tailored for specific agent requirements.
@@ -19,6 +18,7 @@ class OutputConverter(BaseModel, ABC):
model (Any): The target model for structuring the output. model (Any): The target model for structuring the output.
instructions (str): Specific instructions for the conversion process. instructions (str): Specific instructions for the conversion process.
max_attempts (int): Maximum number of conversion attempts (default: 3). max_attempts (int): Maximum number of conversion attempts (default: 3).
""" """
text: str = Field(description="Text to be converted.") text: str = Field(description="Text to be converted.")
@@ -33,9 +33,7 @@ class OutputConverter(BaseModel, ABC):
@abstractmethod @abstractmethod
def to_pydantic(self, current_attempt=1) -> BaseModel: def to_pydantic(self, current_attempt=1) -> BaseModel:
"""Convert text to pydantic.""" """Convert text to pydantic."""
pass
@abstractmethod @abstractmethod
def to_json(self, current_attempt=1) -> dict: def to_json(self, current_attempt=1) -> dict:
"""Convert text to json.""" """Convert text to json."""
pass

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, PrivateAttr from pydantic import BaseModel, PrivateAttr
@@ -6,10 +6,10 @@ from pydantic import BaseModel, PrivateAttr
class CacheHandler(BaseModel): class CacheHandler(BaseModel):
"""Callback handler for tool usage.""" """Callback handler for tool usage."""
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict) _cache: dict[str, Any] = PrivateAttr(default_factory=dict)
def add(self, tool, input, output): def add(self, tool, input, output) -> None:
self._cache[f"{tool}-{input}"] = output self._cache[f"{tool}-{input}"] = output
def read(self, tool, input) -> Optional[str]: def read(self, tool, input) -> str | None:
return self._cache.get(f"{tool}-{input}") return self._cache.get(f"{tool}-{input}")

View File

@@ -1,6 +1,5 @@
import json from collections.abc import Callable
import re from typing import TYPE_CHECKING, Any
from typing import Any, Callable, Dict, List, Optional, Union
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
@@ -10,8 +9,6 @@ from crewai.agents.parser import (
OutputParserException, OutputParserException,
) )
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
from crewai.llm import BaseLLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer from crewai.utilities import I18N, Printer
@@ -34,6 +31,10 @@ from crewai.utilities.logger import Logger
from crewai.utilities.tool_utils import execute_tool_and_check_finality from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
if TYPE_CHECKING:
from crewai.llm import BaseLLM
from crewai.tools.base_tool import BaseTool
class CrewAgentExecutor(CrewAgentExecutorMixin): class CrewAgentExecutor(CrewAgentExecutorMixin):
_logger: Logger = Logger() _logger: Logger = Logger()
@@ -46,18 +47,22 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
agent: BaseAgent, agent: BaseAgent,
prompt: dict[str, str], prompt: dict[str, str],
max_iter: int, max_iter: int,
tools: List[CrewStructuredTool], tools: list[CrewStructuredTool],
tools_names: str, tools_names: str,
stop_words: List[str], stop_words: list[str],
tools_description: str, tools_description: str,
tools_handler: ToolsHandler, tools_handler: ToolsHandler,
step_callback: Any = None, step_callback: Any = None,
original_tools: List[Any] = [], original_tools: list[Any] | None = None,
function_calling_llm: Any = None, function_calling_llm: Any = None,
respect_context_window: bool = False, respect_context_window: bool = False,
request_within_rpm_limit: Optional[Callable[[], bool]] = None, request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: List[Any] = [], callbacks: list[Any] | None = None,
): ) -> None:
if callbacks is None:
callbacks = []
if original_tools is None:
original_tools = []
self._i18n: I18N = I18N() self._i18n: I18N = I18N()
self.llm: BaseLLM = llm self.llm: BaseLLM = llm
self.task = task self.task = task
@@ -79,10 +84,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.respect_context_window = respect_context_window self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit self.request_within_rpm_limit = request_within_rpm_limit
self.ask_for_human_input = False self.ask_for_human_input = False
self.messages: List[Dict[str, str]] = [] self.messages: list[dict[str, str]] = []
self.iterations = 0 self.iterations = 0
self.log_error_after = 3 self.log_error_after = 3
self.tool_name_to_tool_map: Dict[str, Union[CrewStructuredTool, BaseTool]] = { self.tool_name_to_tool_map: dict[str, CrewStructuredTool | BaseTool] = {
tool.name: tool for tool in self.tools tool.name: tool for tool in self.tools
} }
existing_stop = self.llm.stop or [] existing_stop = self.llm.stop or []
@@ -90,11 +95,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
set( set(
existing_stop + self.stop existing_stop + self.stop
if isinstance(existing_stop, list) if isinstance(existing_stop, list)
else self.stop else self.stop,
) ),
) )
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]: def invoke(self, inputs: dict[str, str]) -> dict[str, Any]:
if "system" in self.prompt: if "system" in self.prompt:
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs) system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
user_prompt = self._format_prompt(self.prompt.get("user", ""), inputs) user_prompt = self._format_prompt(self.prompt.get("user", ""), inputs)
@@ -120,9 +125,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
handle_unknown_error(self._printer, e) handle_unknown_error(self._printer, e)
if e.__class__.__module__.startswith("litellm"): if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors # Do not retry on litellm errors
raise e raise
else: raise
raise e
if self.ask_for_human_input: if self.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer) formatted_answer = self._handle_human_feedback(formatted_answer)
@@ -133,8 +137,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return {"output": formatted_answer.output} return {"output": formatted_answer.output}
def _invoke_loop(self) -> AgentFinish: def _invoke_loop(self) -> AgentFinish:
""" """Main loop to invoke the agent's thought process until it reaches a conclusion
Main loop to invoke the agent's thought process until it reaches a conclusion
or the maximum number of iterations is reached. or the maximum number of iterations is reached.
""" """
formatted_answer = None formatted_answer = None
@@ -170,8 +173,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
): ):
fingerprint_context = { fingerprint_context = {
"agent_fingerprint": str( "agent_fingerprint": str(
self.agent.security_config.fingerprint self.agent.security_config.fingerprint,
) ),
} }
tool_result = execute_tool_and_check_finality( tool_result = execute_tool_and_check_finality(
@@ -187,7 +190,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
function_calling_llm=self.function_calling_llm, function_calling_llm=self.function_calling_llm,
) )
formatted_answer = self._handle_agent_action( formatted_answer = self._handle_agent_action(
formatted_answer, tool_result formatted_answer, tool_result,
) )
self._invoke_step_callback(formatted_answer) self._invoke_step_callback(formatted_answer)
@@ -205,7 +208,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
except Exception as e: except Exception as e:
if e.__class__.__module__.startswith("litellm"): if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors # Do not retry on litellm errors
raise e raise
if is_context_length_exceeded(e): if is_context_length_exceeded(e):
handle_context_length( handle_context_length(
respect_context_window=self.respect_context_window, respect_context_window=self.respect_context_window,
@@ -216,9 +219,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
i18n=self._i18n, i18n=self._i18n,
) )
continue continue
else: handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e) raise
raise e
finally: finally:
self.iterations += 1 self.iterations += 1
@@ -231,8 +233,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return formatted_answer return formatted_answer
def _handle_agent_action( def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult self, formatted_answer: AgentAction, tool_result: ToolResult,
) -> Union[AgentAction, AgentFinish]: ) -> AgentAction | AgentFinish:
"""Handle the AgentAction, execute tools, and process the results.""" """Handle the AgentAction, execute tools, and process the results."""
# Special case for add_image_tool # Special case for add_image_tool
add_image_tool = self._i18n.tools("add_image") add_image_tool = self._i18n.tools("add_image")
@@ -261,24 +263,26 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"""Append a message to the message list with the given role.""" """Append a message to the message list with the given role."""
self.messages.append(format_message_for_llm(text, role=role)) self.messages.append(format_message_for_llm(text, role=role))
def _show_start_logs(self): def _show_start_logs(self) -> None:
"""Show logs for the start of agent execution.""" """Show logs for the start of agent execution."""
if self.agent is None: if self.agent is None:
raise ValueError("Agent cannot be None") msg = "Agent cannot be None"
raise ValueError(msg)
show_agent_logs( show_agent_logs(
printer=self._printer, printer=self._printer,
agent_role=self.agent.role, agent_role=self.agent.role,
task_description=( task_description=(
getattr(self.task, "description") if self.task else "Not Found" self.task.description if self.task else "Not Found"
), ),
verbose=self.agent.verbose verbose=self.agent.verbose
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)), or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
) )
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]): def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
"""Show logs for the agent's execution.""" """Show logs for the agent's execution."""
if self.agent is None: if self.agent is None:
raise ValueError("Agent cannot be None") msg = "Agent cannot be None"
raise ValueError(msg)
show_agent_logs( show_agent_logs(
printer=self._printer, printer=self._printer,
agent_role=self.agent.role, agent_role=self.agent.role,
@@ -300,11 +304,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
summary = self.llm.call( summary = self.llm.call(
[ [
format_message_for_llm( format_message_for_llm(
self._i18n.slice("summarizer_system_message"), role="system" self._i18n.slice("summarizer_system_message"), role="system",
), ),
format_message_for_llm( format_message_for_llm(
self._i18n.slice("summarize_instruction").format( self._i18n.slice("summarize_instruction").format(
group=group["content"] group=group["content"],
), ),
), ),
], ],
@@ -316,12 +320,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages = [ self.messages = [
format_message_for_llm( format_message_for_llm(
self._i18n.slice("summary").format(merged_summary=merged_summary) self._i18n.slice("summary").format(merged_summary=merged_summary),
) ),
] ]
def _handle_crew_training_output( def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: Optional[str] = None self, result: AgentFinish, human_feedback: str | None = None,
) -> None: ) -> None:
"""Handle the process of saving training data.""" """Handle the process of saving training data."""
agent_id = str(self.agent.id) # type: ignore agent_id = str(self.agent.id) # type: ignore
@@ -348,29 +352,27 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"initial_output": result.output, "initial_output": result.output,
"human_feedback": human_feedback, "human_feedback": human_feedback,
} }
# Save improved output
elif train_iteration in agent_training_data:
agent_training_data[train_iteration]["improved_output"] = result.output
else: else:
# Save improved output self._printer.print(
if train_iteration in agent_training_data: content=(
agent_training_data[train_iteration]["improved_output"] = result.output f"No existing training data for agent {agent_id} and iteration "
else: f"{train_iteration}. Cannot save improved output."
self._printer.print( ),
content=( color="red",
f"No existing training data for agent {agent_id} and iteration " )
f"{train_iteration}. Cannot save improved output." return
),
color="red",
)
return
# Update the training data and save # Update the training data and save
training_data[agent_id] = agent_training_data training_data[agent_id] = agent_training_data
training_handler.save(training_data) training_handler.save(training_data)
def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str: def _format_prompt(self, prompt: str, inputs: dict[str, str]) -> str:
prompt = prompt.replace("{input}", inputs["input"]) prompt = prompt.replace("{input}", inputs["input"])
prompt = prompt.replace("{tool_names}", inputs["tool_names"]) prompt = prompt.replace("{tool_names}", inputs["tool_names"])
prompt = prompt.replace("{tools}", inputs["tools"]) return prompt.replace("{tools}", inputs["tools"])
return prompt
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish: def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
"""Handle human feedback with different flows for training vs regular use. """Handle human feedback with different flows for training vs regular use.
@@ -380,6 +382,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns: Returns:
AgentFinish: The final answer after processing feedback AgentFinish: The final answer after processing feedback
""" """
human_feedback = self._ask_human_input(formatted_answer.output) human_feedback = self._ask_human_input(formatted_answer.output)
@@ -393,14 +396,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return bool(self.crew and self.crew._train) return bool(self.crew and self.crew._train)
def _handle_training_feedback( def _handle_training_feedback(
self, initial_answer: AgentFinish, feedback: str self, initial_answer: AgentFinish, feedback: str,
) -> AgentFinish: ) -> AgentFinish:
"""Process feedback for training scenarios with single iteration.""" """Process feedback for training scenarios with single iteration."""
self._handle_crew_training_output(initial_answer, feedback) self._handle_crew_training_output(initial_answer, feedback)
self.messages.append( self.messages.append(
format_message_for_llm( format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback) self._i18n.slice("feedback_instructions").format(feedback=feedback),
) ),
) )
improved_answer = self._invoke_loop() improved_answer = self._invoke_loop()
self._handle_crew_training_output(improved_answer) self._handle_crew_training_output(improved_answer)
@@ -408,7 +411,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return improved_answer return improved_answer
def _handle_regular_feedback( def _handle_regular_feedback(
self, current_answer: AgentFinish, initial_feedback: str self, current_answer: AgentFinish, initial_feedback: str,
) -> AgentFinish: ) -> AgentFinish:
"""Process feedback for regular use with potential multiple iterations.""" """Process feedback for regular use with potential multiple iterations."""
feedback = initial_feedback feedback = initial_feedback
@@ -428,8 +431,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"""Process a single feedback iteration.""" """Process a single feedback iteration."""
self.messages.append( self.messages.append(
format_message_for_llm( format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback) self._i18n.slice("feedback_instructions").format(feedback=feedback),
) ),
) )
return self._invoke_loop() return self._invoke_loop()

View File

@@ -1,5 +1,5 @@
import re import re
from typing import Any, Optional, Union from typing import Any
from json_repair import repair_json from json_repair import repair_json
@@ -18,7 +18,7 @@ class AgentAction:
text: str text: str
result: str result: str
def __init__(self, thought: str, tool: str, tool_input: str, text: str): def __init__(self, thought: str, tool: str, tool_input: str, text: str) -> None:
self.thought = thought self.thought = thought
self.tool = tool self.tool = tool
self.tool_input = tool_input self.tool_input = tool_input
@@ -30,7 +30,7 @@ class AgentFinish:
output: str output: str
text: str text: str
def __init__(self, thought: str, output: str, text: str): def __init__(self, thought: str, output: str, text: str) -> None:
self.thought = thought self.thought = thought
self.output = output self.output = output
self.text = text self.text = text
@@ -39,7 +39,7 @@ class AgentFinish:
class OutputParserException(Exception): class OutputParserException(Exception):
error: str error: str
def __init__(self, error: str): def __init__(self, error: str) -> None:
self.error = error self.error = error
@@ -67,24 +67,24 @@ class CrewAgentParser:
_i18n: I18N = I18N() _i18n: I18N = I18N()
agent: Any = None agent: Any = None
def __init__(self, agent: Optional[Any] = None): def __init__(self, agent: Any | None = None) -> None:
self.agent = agent self.agent = agent
@staticmethod @staticmethod
def parse_text(text: str) -> Union[AgentAction, AgentFinish]: def parse_text(text: str) -> AgentAction | AgentFinish:
""" """Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
Args: Args:
text: The text to parse. text: The text to parse.
Returns: Returns:
Either an AgentAction or AgentFinish based on the parsed content. Either an AgentAction or AgentFinish based on the parsed content.
""" """
parser = CrewAgentParser() parser = CrewAgentParser()
return parser.parse(text) return parser.parse(text)
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> AgentAction | AgentFinish:
thought = self._extract_thought(text) thought = self._extract_thought(text)
includes_answer = FINAL_ANSWER_ACTION in text includes_answer = FINAL_ANSWER_ACTION in text
regex = ( regex = (
@@ -102,7 +102,7 @@ class CrewAgentParser:
final_answer = final_answer[:-3].rstrip() final_answer = final_answer[:-3].rstrip()
return AgentFinish(thought, final_answer, text) return AgentFinish(thought, final_answer, text)
elif action_match: if action_match:
action = action_match.group(1) action = action_match.group(1)
clean_action = self._clean_action(action) clean_action = self._clean_action(action)
@@ -114,21 +114,21 @@ class CrewAgentParser:
return AgentAction(thought, clean_action, safe_tool_input, text) return AgentAction(thought, clean_action, safe_tool_input, text)
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL): if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
msg = f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}"
raise OutputParserException( raise OutputParserException(
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}", msg,
) )
elif not re.search( if not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL,
): ):
raise OutputParserException( raise OutputParserException(
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE, MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
) )
else: format = self._i18n.slice("format_without_tools")
format = self._i18n.slice("format_without_tools") error = f"{format}"
error = f"{format}" raise OutputParserException(
raise OutputParserException( error,
error, )
)
def _extract_thought(self, text: str) -> str: def _extract_thought(self, text: str) -> str:
thought_index = text.find("\nAction") thought_index = text.find("\nAction")
@@ -138,8 +138,7 @@ class CrewAgentParser:
return "" return ""
thought = text[:thought_index].strip() thought = text[:thought_index].strip()
# Remove any triple backticks from the thought string # Remove any triple backticks from the thought string
thought = thought.replace("```", "").strip() return thought.replace("```", "").strip()
return thought
def _clean_action(self, text: str) -> str: def _clean_action(self, text: str) -> str:
"""Clean action string by removing non-essential formatting characters.""" """Clean action string by removing non-essential formatting characters."""

View File

@@ -1,7 +1,8 @@
from typing import Any, Optional, Union from typing import Any
from crewai.tools.cache_tools.cache_tools import CacheTools
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
from ..tools.cache_tools.cache_tools import CacheTools
from ..tools.tool_calling import InstructorToolCalling, ToolCalling
from .cache.cache_handler import CacheHandler from .cache.cache_handler import CacheHandler
@@ -9,16 +10,16 @@ class ToolsHandler:
"""Callback handler for tool usage.""" """Callback handler for tool usage."""
last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling") last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling")
cache: Optional[CacheHandler] cache: CacheHandler | None
def __init__(self, cache: Optional[CacheHandler] = None): def __init__(self, cache: CacheHandler | None = None) -> None:
"""Initialize the callback handler.""" """Initialize the callback handler."""
self.cache = cache self.cache = cache
self.last_used_tool = {} # type: ignore # BUG?: same as above self.last_used_tool = {} # type: ignore # BUG?: same as above
def on_tool_use( def on_tool_use(
self, self,
calling: Union[ToolCalling, InstructorToolCalling], calling: ToolCalling | InstructorToolCalling,
output: str, output: str,
should_cache: bool = True, should_cache: bool = True,
) -> Any: ) -> Any:

View File

@@ -9,9 +9,9 @@ def add_crew_to_flow(crew_name: str) -> None:
"""Add a new crew to the current flow.""" """Add a new crew to the current flow."""
# Check if pyproject.toml exists in the current directory # Check if pyproject.toml exists in the current directory
if not Path("pyproject.toml").exists(): if not Path("pyproject.toml").exists():
print("This command must be run from the root of a flow project.") msg = "This command must be run from the root of a flow project."
raise click.ClickException( raise click.ClickException(
"This command must be run from the root of a flow project." msg,
) )
# Determine the flow folder based on the current directory # Determine the flow folder based on the current directory
@@ -19,8 +19,8 @@ def add_crew_to_flow(crew_name: str) -> None:
crews_folder = flow_folder / "src" / flow_folder.name / "crews" crews_folder = flow_folder / "src" / flow_folder.name / "crews"
if not crews_folder.exists(): if not crews_folder.exists():
print("Crews folder does not exist in the current flow.") msg = "Crews folder does not exist in the current flow."
raise click.ClickException("Crews folder does not exist in the current flow.") raise click.ClickException(msg)
# Create the crew within the flow's crews directory # Create the crew within the flow's crews directory
create_embedded_crew(crew_name, parent_folder=crews_folder) create_embedded_crew(crew_name, parent_folder=crews_folder)
@@ -39,7 +39,7 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
if crew_folder.exists(): if crew_folder.exists():
if not click.confirm( if not click.confirm(
f"Crew {folder_name} already exists. Do you want to override it?" f"Crew {folder_name} already exists. Do you want to override it?",
): ):
click.secho("Operation cancelled.", fg="yellow") click.secho("Operation cancelled.", fg="yellow")
return return
@@ -66,5 +66,5 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
copy_template(src_file, dst_file, crew_name, class_name, folder_name) copy_template(src_file, dst_file, crew_name, class_name, folder_name)
click.secho( click.secho(
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True,
) )

View File

@@ -1,6 +1,6 @@
import time import time
import webbrowser import webbrowser
from typing import Any, Dict from typing import Any
import requests import requests
from rich.console import Console from rich.console import Console
@@ -17,38 +17,37 @@ class AuthenticationCommand:
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code" DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token" TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
def __init__(self): def __init__(self) -> None:
self.token_manager = TokenManager() self.token_manager = TokenManager()
def signup(self) -> None: def signup(self) -> None:
"""Sign up to CrewAI+""" """Sign up to CrewAI+."""
console.print("Signing Up to CrewAI+ \n", style="bold blue") console.print("Signing Up to CrewAI+ \n", style="bold blue")
device_code_data = self._get_device_code() device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data) self._display_auth_instructions(device_code_data)
return self._poll_for_token(device_code_data) return self._poll_for_token(device_code_data)
def _get_device_code(self) -> Dict[str, Any]: def _get_device_code(self) -> dict[str, Any]:
"""Get the device code to authenticate the user.""" """Get the device code to authenticate the user."""
device_code_payload = { device_code_payload = {
"client_id": AUTH0_CLIENT_ID, "client_id": AUTH0_CLIENT_ID,
"scope": "openid", "scope": "openid",
"audience": AUTH0_AUDIENCE, "audience": AUTH0_AUDIENCE,
} }
response = requests.post( response = requests.post(
url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20 url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20,
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def _display_auth_instructions(self, device_code_data: Dict[str, str]) -> None: def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
"""Display the authentication instructions to the user.""" """Display the authentication instructions to the user."""
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"]) console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
console.print("2. Enter the following code: ", device_code_data["user_code"]) console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"]) webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None: def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
"""Poll the server for the token.""" """Poll the server for the token."""
token_payload = { token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code", "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
@@ -81,7 +80,7 @@ class AuthenticationCommand:
) )
console.print( console.print(
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n" "\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n",
) )
return return
@@ -92,5 +91,5 @@ class AuthenticationCommand:
attempts += 1 attempts += 1
console.print( console.print(
"Timeout: Failed to get the token. Please try again.", style="bold red" "Timeout: Failed to get the token. Please try again.", style="bold red",
) )

View File

@@ -5,5 +5,5 @@ def get_auth_token() -> str:
"""Get the authentication token.""" """Get the authentication token."""
access_token = TokenManager().get_token() access_token = TokenManager().get_token()
if not access_token: if not access_token:
raise Exception() raise Exception
return access_token return access_token

View File

@@ -3,7 +3,6 @@ import os
import sys import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import Optional
from auth0.authentication.token_verifier import ( from auth0.authentication.token_verifier import (
AsymmetricSignatureVerifier, AsymmetricSignatureVerifier,
@@ -15,8 +14,7 @@ from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
def validate_token(id_token: str) -> None: def validate_token(id_token: str) -> None:
""" """Verify the token and its precedence.
Verify the token and its precedence
:param id_token: :param id_token:
""" """
@@ -24,15 +22,14 @@ def validate_token(id_token: str) -> None:
issuer = f"https://{AUTH0_DOMAIN}/" issuer = f"https://{AUTH0_DOMAIN}/"
signature_verifier = AsymmetricSignatureVerifier(jwks_url) signature_verifier = AsymmetricSignatureVerifier(jwks_url)
token_verifier = TokenVerifier( token_verifier = TokenVerifier(
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID,
) )
token_verifier.verify(id_token) token_verifier.verify(id_token)
class TokenManager: class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None: def __init__(self, file_path: str = "tokens.enc") -> None:
""" """Initialize the TokenManager class.
Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc". :param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
""" """
@@ -41,8 +38,7 @@ class TokenManager:
self.fernet = Fernet(self.key) self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes: def _get_or_create_key(self) -> bytes:
""" """Get or create the encryption key.
Get or create the encryption key.
:return: The encryption key. :return: The encryption key.
""" """
@@ -57,8 +53,7 @@ class TokenManager:
return new_key return new_key
def save_tokens(self, access_token: str, expires_in: int) -> None: def save_tokens(self, access_token: str, expires_in: int) -> None:
""" """Save the access token and its expiration time.
Save the access token and its expiration time.
:param access_token: The access token to save. :param access_token: The access token to save.
:param expires_in: The expiration time of the access token in seconds. :param expires_in: The expiration time of the access token in seconds.
@@ -71,9 +66,8 @@ class TokenManager:
encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data) self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]: def get_token(self) -> str | None:
""" """Get the access token if it is valid and not expired.
Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None. :return: The access token if valid and not expired, otherwise None.
""" """
@@ -89,8 +83,7 @@ class TokenManager:
return data["access_token"] return data["access_token"]
def get_secure_storage_path(self) -> Path: def get_secure_storage_path(self) -> Path:
""" """Get the secure storage path based on the operating system.
Get the secure storage path based on the operating system.
:return: The secure storage path. :return: The secure storage path.
""" """
@@ -112,8 +105,7 @@ class TokenManager:
return storage_path return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None: def save_secure_file(self, filename: str, content: bytes) -> None:
""" """Save the content to a secure file.
Save the content to a secure file.
:param filename: The name of the file. :param filename: The name of the file.
:param content: The content to save. :param content: The content to save.
@@ -127,9 +119,8 @@ class TokenManager:
# Set appropriate permissions (read/write for owner only) # Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600) os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]: def read_secure_file(self, filename: str) -> bytes | None:
""" """Read the content of a secure file.
Read the content of a secure file.
:param filename: The name of the file. :param filename: The name of the file.
:return: The content of the file if it exists, otherwise None. :return: The content of the file if it exists, otherwise None.

View File

@@ -1,6 +1,4 @@
import os
from importlib.metadata import version as get_version from importlib.metadata import version as get_version
from typing import Optional, Tuple
import click import click
@@ -28,7 +26,7 @@ from .update_crew import update_crew
@click.group() @click.group()
@click.version_option(get_version("crewai")) @click.version_option(get_version("crewai"))
def crewai(): def crewai() -> None:
"""Top-level command group for crewai.""" """Top-level command group for crewai."""
@@ -37,7 +35,7 @@ def crewai():
@click.argument("name") @click.argument("name")
@click.option("--provider", type=str, help="The provider to use for the crew") @click.option("--provider", type=str, help="The provider to use for the crew")
@click.option("--skip_provider", is_flag=True, help="Skip provider validation") @click.option("--skip_provider", is_flag=True, help="Skip provider validation")
def create(type, name, provider, skip_provider=False): def create(type, name, provider, skip_provider=False) -> None:
"""Create a new crew, or flow.""" """Create a new crew, or flow."""
if type == "crew": if type == "crew":
create_crew(name, provider, skip_provider) create_crew(name, provider, skip_provider)
@@ -49,9 +47,9 @@ def create(type, name, provider, skip_provider=False):
@crewai.command() @crewai.command()
@click.option( @click.option(
"--tools", is_flag=True, help="Show the installed version of crewai tools" "--tools", is_flag=True, help="Show the installed version of crewai tools",
) )
def version(tools): def version(tools) -> None:
"""Show the installed version of crewai.""" """Show the installed version of crewai."""
try: try:
crewai_version = get_version("crewai") crewai_version = get_version("crewai")
@@ -82,7 +80,7 @@ def version(tools):
default="trained_agents_data.pkl", default="trained_agents_data.pkl",
help="Path to a custom file for training", help="Path to a custom file for training",
) )
def train(n_iterations: int, filename: str): def train(n_iterations: int, filename: str) -> None:
"""Train the crew.""" """Train the crew."""
click.echo(f"Training the Crew for {n_iterations} iterations") click.echo(f"Training the Crew for {n_iterations} iterations")
train_crew(n_iterations, filename) train_crew(n_iterations, filename)
@@ -96,11 +94,11 @@ def train(n_iterations: int, filename: str):
help="Replay the crew from this task ID, including all subsequent tasks.", help="Replay the crew from this task ID, including all subsequent tasks.",
) )
def replay(task_id: str) -> None: def replay(task_id: str) -> None:
""" """Replay the crew execution from a specific task.
Replay the crew execution from a specific task.
Args: Args:
task_id (str): The ID of the task to replay from. task_id (str): The ID of the task to replay from.
""" """
try: try:
click.echo(f"Replaying the crew from task {task_id}") click.echo(f"Replaying the crew from task {task_id}")
@@ -111,16 +109,14 @@ def replay(task_id: str) -> None:
@crewai.command() @crewai.command()
def log_tasks_outputs() -> None: def log_tasks_outputs() -> None:
""" """Retrieve your latest crew.kickoff() task outputs."""
Retrieve your latest crew.kickoff() task outputs.
"""
try: try:
storage = KickoffTaskOutputsSQLiteStorage() storage = KickoffTaskOutputsSQLiteStorage()
tasks = storage.load() tasks = storage.load()
if not tasks: if not tasks:
click.echo( click.echo(
"No task outputs found. Only crew kickoff task outputs are logged." "No task outputs found. Only crew kickoff task outputs are logged.",
) )
return return
@@ -153,13 +149,11 @@ def reset_memories(
kickoff_outputs: bool, kickoff_outputs: bool,
all: bool, all: bool,
) -> None: ) -> None:
""" """Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved."""
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.
"""
try: try:
if not all and not (long or short or entities or knowledge or kickoff_outputs): if not all and not (long or short or entities or knowledge or kickoff_outputs):
click.echo( click.echo(
"Please specify at least one memory type to reset using the appropriate flags." "Please specify at least one memory type to reset using the appropriate flags.",
) )
return return
reset_memories_command(long, short, entities, knowledge, kickoff_outputs, all) reset_memories_command(long, short, entities, knowledge, kickoff_outputs, all)
@@ -182,71 +176,69 @@ def reset_memories(
default="gpt-4o-mini", default="gpt-4o-mini",
help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.", help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.",
) )
def test(n_iterations: int, model: str): def test(n_iterations: int, model: str) -> None:
"""Test the crew and evaluate the results.""" """Test the crew and evaluate the results."""
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}") click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
evaluate_crew(n_iterations, model) evaluate_crew(n_iterations, model)
@crewai.command( @crewai.command(
context_settings=dict( context_settings={
ignore_unknown_options=True, "ignore_unknown_options": True,
allow_extra_args=True, "allow_extra_args": True,
) },
) )
@click.pass_context @click.pass_context
def install(context): def install(context) -> None:
"""Install the Crew.""" """Install the Crew."""
install_crew(context.args) install_crew(context.args)
@crewai.command() @crewai.command()
def run(): def run() -> None:
"""Run the Crew.""" """Run the Crew."""
run_crew() run_crew()
@crewai.command() @crewai.command()
def update(): def update() -> None:
"""Update the pyproject.toml of the Crew project to use uv.""" """Update the pyproject.toml of the Crew project to use uv."""
update_crew() update_crew()
@crewai.command() @crewai.command()
def signup(): def signup() -> None:
"""Sign Up/Login to CrewAI+.""" """Sign Up/Login to CrewAI+."""
AuthenticationCommand().signup() AuthenticationCommand().signup()
@crewai.command() @crewai.command()
def login(): def login() -> None:
"""Sign Up/Login to CrewAI+.""" """Sign Up/Login to CrewAI+."""
AuthenticationCommand().signup() AuthenticationCommand().signup()
# DEPLOY CREWAI+ COMMANDS # DEPLOY CREWAI+ COMMANDS
@crewai.group() @crewai.group()
def deploy(): def deploy() -> None:
"""Deploy the Crew CLI group.""" """Deploy the Crew CLI group."""
pass
@crewai.group() @crewai.group()
def tool(): def tool() -> None:
"""Tool Repository related commands.""" """Tool Repository related commands."""
pass
@deploy.command(name="create") @deploy.command(name="create")
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt") @click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
def deploy_create(yes: bool): def deploy_create(yes: bool) -> None:
"""Create a Crew deployment.""" """Create a Crew deployment."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.create_crew(yes) deploy_cmd.create_crew(yes)
@deploy.command(name="list") @deploy.command(name="list")
def deploy_list(): def deploy_list() -> None:
"""List all deployments.""" """List all deployments."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.list_crews() deploy_cmd.list_crews()
@@ -254,7 +246,7 @@ def deploy_list():
@deploy.command(name="push") @deploy.command(name="push")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_push(uuid: Optional[str]): def deploy_push(uuid: str | None) -> None:
"""Deploy the Crew.""" """Deploy the Crew."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.deploy(uuid=uuid) deploy_cmd.deploy(uuid=uuid)
@@ -262,7 +254,7 @@ def deploy_push(uuid: Optional[str]):
@deploy.command(name="status") @deploy.command(name="status")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deply_status(uuid: Optional[str]): def deply_status(uuid: str | None) -> None:
"""Get the status of a deployment.""" """Get the status of a deployment."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.get_crew_status(uuid=uuid) deploy_cmd.get_crew_status(uuid=uuid)
@@ -270,7 +262,7 @@ def deply_status(uuid: Optional[str]):
@deploy.command(name="logs") @deploy.command(name="logs")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_logs(uuid: Optional[str]): def deploy_logs(uuid: str | None) -> None:
"""Get the logs of a deployment.""" """Get the logs of a deployment."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.get_crew_logs(uuid=uuid) deploy_cmd.get_crew_logs(uuid=uuid)
@@ -278,7 +270,7 @@ def deploy_logs(uuid: Optional[str]):
@deploy.command(name="remove") @deploy.command(name="remove")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_remove(uuid: Optional[str]): def deploy_remove(uuid: str | None) -> None:
"""Remove a deployment.""" """Remove a deployment."""
deploy_cmd = DeployCommand() deploy_cmd = DeployCommand()
deploy_cmd.remove_crew(uuid=uuid) deploy_cmd.remove_crew(uuid=uuid)
@@ -286,14 +278,14 @@ def deploy_remove(uuid: Optional[str]):
@tool.command(name="create") @tool.command(name="create")
@click.argument("handle") @click.argument("handle")
def tool_create(handle: str): def tool_create(handle: str) -> None:
tool_cmd = ToolCommand() tool_cmd = ToolCommand()
tool_cmd.create(handle) tool_cmd.create(handle)
@tool.command(name="install") @tool.command(name="install")
@click.argument("handle") @click.argument("handle")
def tool_install(handle: str): def tool_install(handle: str) -> None:
tool_cmd = ToolCommand() tool_cmd = ToolCommand()
tool_cmd.login() tool_cmd.login()
tool_cmd.install(handle) tool_cmd.install(handle)
@@ -309,27 +301,26 @@ def tool_install(handle: str):
) )
@click.option("--public", "is_public", flag_value=True, default=False) @click.option("--public", "is_public", flag_value=True, default=False)
@click.option("--private", "is_public", flag_value=False) @click.option("--private", "is_public", flag_value=False)
def tool_publish(is_public: bool, force: bool): def tool_publish(is_public: bool, force: bool) -> None:
tool_cmd = ToolCommand() tool_cmd = ToolCommand()
tool_cmd.login() tool_cmd.login()
tool_cmd.publish(is_public, force) tool_cmd.publish(is_public, force)
@crewai.group() @crewai.group()
def flow(): def flow() -> None:
"""Flow related commands.""" """Flow related commands."""
pass
@flow.command(name="kickoff") @flow.command(name="kickoff")
def flow_run(): def flow_run() -> None:
"""Kickoff the Flow.""" """Kickoff the Flow."""
click.echo("Running the Flow") click.echo("Running the Flow")
kickoff_flow() kickoff_flow()
@flow.command(name="plot") @flow.command(name="plot")
def flow_plot(): def flow_plot() -> None:
"""Plot the Flow.""" """Plot the Flow."""
click.echo("Plotting the Flow") click.echo("Plotting the Flow")
plot_flow() plot_flow()
@@ -337,20 +328,19 @@ def flow_plot():
@flow.command(name="add-crew") @flow.command(name="add-crew")
@click.argument("crew_name") @click.argument("crew_name")
def flow_add_crew(crew_name): def flow_add_crew(crew_name) -> None:
"""Add a crew to an existing flow.""" """Add a crew to an existing flow."""
click.echo(f"Adding crew {crew_name} to the flow") click.echo(f"Adding crew {crew_name} to the flow")
add_crew_to_flow(crew_name) add_crew_to_flow(crew_name)
@crewai.command() @crewai.command()
def chat(): def chat() -> None:
""" """Start a conversation with the Crew, collecting user-supplied inputs,
Start a conversation with the Crew, collecting user-supplied inputs,
and using the Chat LLM to generate responses. and using the Chat LLM to generate responses.
""" """
click.secho( click.secho(
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n", "\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
) )
run_chat() run_chat()

View File

@@ -10,13 +10,13 @@ console = Console()
class BaseCommand: class BaseCommand:
def __init__(self): def __init__(self) -> None:
self._telemetry = Telemetry() self._telemetry = Telemetry()
self._telemetry.set_tracer() self._telemetry.set_tracer()
class PlusAPIMixin: class PlusAPIMixin:
def __init__(self, telemetry): def __init__(self, telemetry) -> None:
try: try:
telemetry.set_tracer() telemetry.set_tracer()
self.plus_api_client = PlusAPI(api_key=get_auth_token()) self.plus_api_client = PlusAPI(api_key=get_auth_token())
@@ -30,11 +30,11 @@ class PlusAPIMixin:
raise SystemExit raise SystemExit
def _validate_response(self, response: requests.Response) -> None: def _validate_response(self, response: requests.Response) -> None:
""" """Handle and display error messages from API responses.
Handle and display error messages from API responses.
Args: Args:
response (requests.Response): The response from the Plus API response (requests.Response): The response from the Plus API
""" """
try: try:
json_response = response.json() json_response = response.json()
@@ -55,13 +55,13 @@ class PlusAPIMixin:
for field, messages in json_response.items(): for field, messages in json_response.items():
for message in messages: for message in messages:
console.print( console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}" f"* [bold red]{field.capitalize()}[/bold red] {message}",
) )
raise SystemExit raise SystemExit
if not response.ok: if not response.ok:
console.print( console.print(
"Request to Enterprise API failed. Details:", style="bold red" "Request to Enterprise API failed. Details:", style="bold red",
) )
details = ( details = (
json_response.get("error") json_response.get("error")

View File

@@ -1,6 +1,5 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -8,16 +7,16 @@ DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
class Settings(BaseModel): class Settings(BaseModel):
tool_repository_username: Optional[str] = Field( tool_repository_username: str | None = Field(
None, description="Username for interacting with the Tool Repository" None, description="Username for interacting with the Tool Repository",
) )
tool_repository_password: Optional[str] = Field( tool_repository_password: str | None = Field(
None, description="Password for interacting with the Tool Repository" None, description="Password for interacting with the Tool Repository",
) )
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True) config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True)
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data): def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data) -> None:
"""Load Settings from config path""" """Load Settings from config path."""
config_path.parent.mkdir(parents=True, exist_ok=True) config_path.parent.mkdir(parents=True, exist_ok=True)
file_data = {} file_data = {}
@@ -32,7 +31,7 @@ class Settings(BaseModel):
super().__init__(config_path=config_path, **merged_data) super().__init__(config_path=config_path, **merged_data)
def dump(self) -> None: def dump(self) -> None:
"""Save current settings to settings.json""" """Save current settings to settings.json."""
if self.config_path.is_file(): if self.config_path.is_file():
with self.config_path.open("r") as f: with self.config_path.open("r") as f:
existing_data = json.load(f) existing_data = json.load(f)

View File

@@ -3,31 +3,31 @@ ENV_VARS = {
{ {
"prompt": "Enter your OPENAI API key (press Enter to skip)", "prompt": "Enter your OPENAI API key (press Enter to skip)",
"key_name": "OPENAI_API_KEY", "key_name": "OPENAI_API_KEY",
} },
], ],
"anthropic": [ "anthropic": [
{ {
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)", "prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
"key_name": "ANTHROPIC_API_KEY", "key_name": "ANTHROPIC_API_KEY",
} },
], ],
"gemini": [ "gemini": [
{ {
"prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)", "prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)",
"key_name": "GEMINI_API_KEY", "key_name": "GEMINI_API_KEY",
} },
], ],
"nvidia_nim": [ "nvidia_nim": [
{ {
"prompt": "Enter your NVIDIA API key (press Enter to skip)", "prompt": "Enter your NVIDIA API key (press Enter to skip)",
"key_name": "NVIDIA_NIM_API_KEY", "key_name": "NVIDIA_NIM_API_KEY",
} },
], ],
"groq": [ "groq": [
{ {
"prompt": "Enter your GROQ API key (press Enter to skip)", "prompt": "Enter your GROQ API key (press Enter to skip)",
"key_name": "GROQ_API_KEY", "key_name": "GROQ_API_KEY",
} },
], ],
"watson": [ "watson": [
{ {
@@ -47,7 +47,7 @@ ENV_VARS = {
{ {
"default": True, "default": True,
"API_BASE": "http://localhost:11434", "API_BASE": "http://localhost:11434",
} },
], ],
"bedrock": [ "bedrock": [
{ {
@@ -101,7 +101,7 @@ ENV_VARS = {
{ {
"prompt": "Enter your SambaNovaCloud API key (press Enter to skip)", "prompt": "Enter your SambaNovaCloud API key (press Enter to skip)",
"key_name": "SAMBANOVA_API_KEY", "key_name": "SAMBANOVA_API_KEY",
} },
], ],
} }

View File

@@ -24,7 +24,7 @@ def create_folder_structure(name, parent_folder=None):
if folder_path.exists(): if folder_path.exists():
if not click.confirm( if not click.confirm(
f"Folder {folder_name} already exists. Do you want to override it?" f"Folder {folder_name} already exists. Do you want to override it?",
): ):
click.secho("Operation cancelled.", fg="yellow") click.secho("Operation cancelled.", fg="yellow")
sys.exit(0) sys.exit(0)
@@ -48,7 +48,7 @@ def create_folder_structure(name, parent_folder=None):
return folder_path, folder_name, class_name return folder_path, folder_name, class_name
def copy_template_files(folder_path, name, class_name, parent_folder): def copy_template_files(folder_path, name, class_name, parent_folder) -> None:
package_dir = Path(__file__).parent package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew" templates_dir = package_dir / "templates" / "crew"
@@ -89,7 +89,7 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
copy_template(src_file, dst_file, name, class_name, folder_path.name) copy_template(src_file, dst_file, name, class_name, folder_path.name)
def create_crew(name, provider=None, skip_provider=False, parent_folder=None): def create_crew(name, provider=None, skip_provider=False, parent_folder=None) -> None:
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path) env_vars = load_env_vars(folder_path)
if not skip_provider: if not skip_provider:
@@ -109,7 +109,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
if existing_provider: if existing_provider:
if not click.confirm( if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?" f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?",
): ):
click.secho("Keeping existing provider configuration.", fg="yellow") click.secho("Keeping existing provider configuration.", fg="yellow")
return return
@@ -126,11 +126,11 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
if selected_provider: # Valid selection if selected_provider: # Valid selection
break break
click.secho( click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red" "No provider selected. Please try again or press 'q' to exit.", fg="red",
) )
# Check if the selected provider has predefined models # Check if the selected provider has predefined models
if selected_provider in MODELS and MODELS[selected_provider]: if MODELS.get(selected_provider):
while True: while True:
selected_model = select_model(selected_provider, provider_models) selected_model = select_model(selected_provider, provider_models)
if selected_model is None: # User typed 'q' if selected_model is None: # User typed 'q'
@@ -167,7 +167,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
click.secho("API keys and model saved to .env file", fg="green") click.secho("API keys and model saved to .env file", fg="green")
else: else:
click.secho( click.secho(
"No API keys provided. Skipping .env file creation.", fg="yellow" "No API keys provided. Skipping .env file creation.", fg="yellow",
) )
click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green") click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")

View File

@@ -5,7 +5,7 @@ import click
from crewai.telemetry import Telemetry from crewai.telemetry import Telemetry
def create_flow(name): def create_flow(name) -> None:
"""Create a new flow.""" """Create a new flow."""
folder_name = name.replace(" ", "_").replace("-", "_").lower() folder_name = name.replace(" ", "_").replace("-", "_").lower()
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "") class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
@@ -43,12 +43,12 @@ def create_flow(name):
"poem_crew", "poem_crew",
] ]
def process_file(src_file, dst_file): def process_file(src_file, dst_file) -> None:
if src_file.suffix in [".pyc", ".pyo", ".pyd"]: if src_file.suffix in [".pyc", ".pyo", ".pyd"]:
return return
try: try:
with open(src_file, "r", encoding="utf-8") as file: with open(src_file, encoding="utf-8") as file:
content = file.read() content = file.read()
except Exception as e: except Exception as e:
click.secho(f"Error processing file {src_file}: {e}", fg="red") click.secho(f"Error processing file {src_file}: {e}", fg="red")

View File

@@ -5,7 +5,7 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any
import click import click
import tomli import tomli
@@ -22,10 +22,9 @@ MIN_REQUIRED_VERSION = "0.98.0"
def check_conversational_crews_version( def check_conversational_crews_version(
crewai_version: str, pyproject_data: dict crewai_version: str, pyproject_data: dict,
) -> bool: ) -> bool:
""" """Check if the installed crewAI version supports conversational crews.
Check if the installed crewAI version supports conversational crews.
Args: Args:
crewai_version: The current version of crewAI. crewai_version: The current version of crewAI.
@@ -33,6 +32,7 @@ def check_conversational_crews_version(
Returns: Returns:
bool: True if version check passes, False otherwise. bool: True if version check passes, False otherwise.
""" """
try: try:
if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION): if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION):
@@ -48,9 +48,8 @@ def check_conversational_crews_version(
return True return True
def run_chat(): def run_chat() -> None:
""" """Runs an interactive chat loop using the Crew's chat LLM with function calling.
Runs an interactive chat loop using the Crew's chat LLM with function calling.
Incorporates crew_name, crew_description, and input fields to build a tool schema. Incorporates crew_name, crew_description, and input fields to build a tool schema.
Exits if crew_name or crew_description are missing. Exits if crew_name or crew_description are missing.
""" """
@@ -84,7 +83,7 @@ def run_chat():
# Call the LLM to generate the introductory message # Call the LLM to generate the introductory message
introductory_message = chat_llm.call( introductory_message = chat_llm.call(
messages=[{"role": "system", "content": system_message}] messages=[{"role": "system", "content": system_message}],
) )
finally: finally:
# Stop loading indicator # Stop loading indicator
@@ -108,15 +107,13 @@ def run_chat():
chat_loop(chat_llm, messages, crew_tool_schema, available_functions) chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
def show_loading(event: threading.Event): def show_loading(event: threading.Event) -> None:
"""Display animated loading dots while processing.""" """Display animated loading dots while processing."""
while not event.is_set(): while not event.is_set():
print(".", end="", flush=True)
time.sleep(1) time.sleep(1)
print()
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]: def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
"""Initializes the chat LLM and handles exceptions.""" """Initializes the chat LLM and handles exceptions."""
try: try:
return create_llm(crew.chat_llm) return create_llm(crew.chat_llm)
@@ -157,7 +154,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
) )
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any: def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
"""Creates a wrapper function for running the crew tool with messages.""" """Creates a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs): def run_crew_tool_with_messages(**kwargs):
@@ -166,7 +163,7 @@ def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
return run_crew_tool_with_messages return run_crew_tool_with_messages
def flush_input(): def flush_input() -> None:
"""Flush any pending input from the user.""" """Flush any pending input from the user."""
if platform.system() == "Windows": if platform.system() == "Windows":
# Windows platform # Windows platform
@@ -181,7 +178,7 @@ def flush_input():
termios.tcflush(sys.stdin, termios.TCIFLUSH) termios.tcflush(sys.stdin, termios.TCIFLUSH)
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions): def chat_loop(chat_llm, messages, crew_tool_schema, available_functions) -> None:
"""Main chat loop for interacting with the user.""" """Main chat loop for interacting with the user."""
while True: while True:
try: try:
@@ -190,7 +187,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
user_input = get_user_input() user_input = get_user_input()
handle_user_input( handle_user_input(
user_input, chat_llm, messages, crew_tool_schema, available_functions user_input, chat_llm, messages, crew_tool_schema, available_functions,
) )
except KeyboardInterrupt: except KeyboardInterrupt:
@@ -221,9 +218,9 @@ def get_user_input() -> str:
def handle_user_input( def handle_user_input(
user_input: str, user_input: str,
chat_llm: LLM, chat_llm: LLM,
messages: List[Dict[str, str]], messages: list[dict[str, str]],
crew_tool_schema: Dict[str, Any], crew_tool_schema: dict[str, Any],
available_functions: Dict[str, Any], available_functions: dict[str, Any],
) -> None: ) -> None:
if user_input.strip().lower() == "exit": if user_input.strip().lower() == "exit":
click.echo("Exiting chat. Goodbye!") click.echo("Exiting chat. Goodbye!")
@@ -251,8 +248,7 @@ def handle_user_input(
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict: def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
""" """Dynamically build a Littellm 'function' schema for the given crew.
Dynamically build a Littellm 'function' schema for the given crew.
crew_name: The name of the crew (used for the function 'name'). crew_name: The name of the crew (used for the function 'name').
crew_inputs: A ChatInputs object containing crew_description crew_inputs: A ChatInputs object containing crew_description
@@ -281,9 +277,8 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
} }
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
""" """Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
Args: Args:
crew (Crew): The crew instance to run. crew (Crew): The crew instance to run.
@@ -295,6 +290,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
Raises: Raises:
SystemExit: Exits the chat if an error occurs during crew execution. SystemExit: Exits the chat if an error occurs during crew execution.
""" """
try: try:
# Serialize 'messages' to JSON string before adding to kwargs # Serialize 'messages' to JSON string before adding to kwargs
@@ -304,9 +300,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
crew_output = crew.kickoff(inputs=kwargs) crew_output = crew.kickoff(inputs=kwargs)
# Convert CrewOutput to a string to send back to the user # Convert CrewOutput to a string to send back to the user
result = str(crew_output) return str(crew_output)
return result
except Exception as e: except Exception as e:
# Exit the chat and show the error message # Exit the chat and show the error message
click.secho("An error occurred while running the crew:", fg="red") click.secho("An error occurred while running the crew:", fg="red")
@@ -314,12 +309,12 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
sys.exit(1) sys.exit(1)
def load_crew_and_name() -> Tuple[Crew, str]: def load_crew_and_name() -> tuple[Crew, str]:
""" """Loads the crew by importing the crew class from the user's project.
Loads the crew by importing the crew class from the user's project.
Returns: Returns:
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew. Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
""" """
# Get the current working directory # Get the current working directory
cwd = Path.cwd() cwd = Path.cwd()
@@ -327,7 +322,8 @@ def load_crew_and_name() -> Tuple[Crew, str]:
# Path to the pyproject.toml file # Path to the pyproject.toml file
pyproject_path = cwd / "pyproject.toml" pyproject_path = cwd / "pyproject.toml"
if not pyproject_path.exists(): if not pyproject_path.exists():
raise FileNotFoundError("pyproject.toml not found in the current directory.") msg = "pyproject.toml not found in the current directory."
raise FileNotFoundError(msg)
# Load the pyproject.toml file using 'tomli' # Load the pyproject.toml file using 'tomli'
with pyproject_path.open("rb") as f: with pyproject_path.open("rb") as f:
@@ -351,14 +347,16 @@ def load_crew_and_name() -> Tuple[Crew, str]:
try: try:
crew_module = __import__(crew_module_name, fromlist=[crew_class_name]) crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
except ImportError as e: except ImportError as e:
raise ImportError(f"Failed to import crew module {crew_module_name}: {e}") msg = f"Failed to import crew module {crew_module_name}: {e}"
raise ImportError(msg)
# Get the crew class from the module # Get the crew class from the module
try: try:
crew_class = getattr(crew_module, crew_class_name) crew_class = getattr(crew_module, crew_class_name)
except AttributeError: except AttributeError:
msg = f"Crew class {crew_class_name} not found in module {crew_module_name}"
raise AttributeError( raise AttributeError(
f"Crew class {crew_class_name} not found in module {crew_module_name}" msg,
) )
# Instantiate the crew # Instantiate the crew
@@ -367,8 +365,7 @@ def load_crew_and_name() -> Tuple[Crew, str]:
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs: def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
""" """Generates the ChatInputs required for the crew by analyzing the tasks and agents.
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
Args: Args:
crew (Crew): The crew object containing tasks and agents. crew (Crew): The crew object containing tasks and agents.
@@ -377,6 +374,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
Returns: Returns:
ChatInputs: An object containing the crew's name, description, and input fields. ChatInputs: An object containing the crew's name, description, and input fields.
""" """
# Extract placeholders from tasks and agents # Extract placeholders from tasks and agents
required_inputs = fetch_required_inputs(crew) required_inputs = fetch_required_inputs(crew)
@@ -391,22 +389,22 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
crew_description = generate_crew_description_with_ai(crew, chat_llm) crew_description = generate_crew_description_with_ai(crew, chat_llm)
return ChatInputs( return ChatInputs(
crew_name=crew_name, crew_description=crew_description, inputs=input_fields crew_name=crew_name, crew_description=crew_description, inputs=input_fields,
) )
def fetch_required_inputs(crew: Crew) -> Set[str]: def fetch_required_inputs(crew: Crew) -> set[str]:
""" """Extracts placeholders from the crew's tasks and agents.
Extracts placeholders from the crew's tasks and agents.
Args: Args:
crew (Crew): The crew object. crew (Crew): The crew object.
Returns: Returns:
Set[str]: A set of placeholder names. Set[str]: A set of placeholder names.
""" """
placeholder_pattern = re.compile(r"\{(.+?)\}") placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: Set[str] = set() required_inputs: set[str] = set()
# Scan tasks # Scan tasks
for task in crew.tasks: for task in crew.tasks:
@@ -422,8 +420,7 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str: def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
""" """Generates an input description using AI based on the context of the crew.
Generates an input description using AI based on the context of the crew.
Args: Args:
input_name (str): The name of the input placeholder. input_name (str): The name of the input placeholder.
@@ -432,6 +429,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
Returns: Returns:
str: A concise description of the input. str: A concise description of the input.
""" """
# Gather context from tasks and agents where the input is used # Gather context from tasks and agents where the input is used
context_texts = [] context_texts = []
@@ -444,10 +442,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
): ):
# Replace placeholders with input names # Replace placeholders with input names
task_description = placeholder_pattern.sub( task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or "" lambda m: m.group(1), task.description or "",
) )
expected_output = placeholder_pattern.sub( expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output or "" lambda m: m.group(1), task.expected_output or "",
) )
context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}") context_texts.append(f"Expected Output: {expected_output}")
@@ -461,7 +459,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub( agent_backstory = placeholder_pattern.sub(
lambda m: m.group(1), agent.backstory or "" lambda m: m.group(1), agent.backstory or "",
) )
context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Role: {agent_role}")
context_texts.append(f"Agent Goal: {agent_goal}") context_texts.append(f"Agent Goal: {agent_goal}")
@@ -470,7 +468,8 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
context = "\n".join(context_texts) context = "\n".join(context_texts)
if not context: if not context:
# If no context is found for the input, raise an exception as per instruction # If no context is found for the input, raise an exception as per instruction
raise ValueError(f"No context found for input '{input_name}'.") msg = f"No context found for input '{input_name}'."
raise ValueError(msg)
prompt = ( prompt = (
f"Based on the following context, write a concise description (15 words or less) of the input '{input_name}'.\n" f"Based on the following context, write a concise description (15 words or less) of the input '{input_name}'.\n"
@@ -479,14 +478,12 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
f"{context}" f"{context}"
) )
response = chat_llm.call(messages=[{"role": "user", "content": prompt}]) response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
description = response.strip() return response.strip()
return description
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
""" """Generates a brief description of the crew using AI.
Generates a brief description of the crew using AI.
Args: Args:
crew (Crew): The crew object. crew (Crew): The crew object.
@@ -494,6 +491,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
Returns: Returns:
str: A concise description of the crew's purpose (15 words or less). str: A concise description of the crew's purpose (15 words or less).
""" """
# Gather context from tasks and agents # Gather context from tasks and agents
context_texts = [] context_texts = []
@@ -502,10 +500,10 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
for task in crew.tasks: for task in crew.tasks:
# Replace placeholders with input names # Replace placeholders with input names
task_description = placeholder_pattern.sub( task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or "" lambda m: m.group(1), task.description or "",
) )
expected_output = placeholder_pattern.sub( expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output or "" lambda m: m.group(1), task.expected_output or "",
) )
context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}") context_texts.append(f"Expected Output: {expected_output}")
@@ -514,7 +512,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub( agent_backstory = placeholder_pattern.sub(
lambda m: m.group(1), agent.backstory or "" lambda m: m.group(1), agent.backstory or "",
) )
context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Role: {agent_role}")
context_texts.append(f"Agent Goal: {agent_goal}") context_texts.append(f"Agent Goal: {agent_goal}")
@@ -522,7 +520,8 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
context = "\n".join(context_texts) context = "\n".join(context_texts)
if not context: if not context:
raise ValueError("No context found for generating crew description.") msg = "No context found for generating crew description."
raise ValueError(msg)
prompt = ( prompt = (
"Based on the following context, write a concise, action-oriented description (15 words or less) of the crew's purpose.\n" "Based on the following context, write a concise, action-oriented description (15 words or less) of the crew's purpose.\n"
@@ -531,6 +530,5 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
f"{context}" f"{context}"
) )
response = chat_llm.call(messages=[{"role": "user", "content": prompt}]) response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
crew_description = response.strip() return response.strip()
return crew_description

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any
from rich.console import Console from rich.console import Console
@@ -10,34 +10,27 @@ console = Console()
class DeployCommand(BaseCommand, PlusAPIMixin): class DeployCommand(BaseCommand, PlusAPIMixin):
""" """A class to handle deployment-related operations for CrewAI projects."""
A class to handle deployment-related operations for CrewAI projects.
"""
def __init__(self):
"""
Initialize the DeployCommand with project name and API client.
"""
def __init__(self) -> None:
"""Initialize the DeployCommand with project name and API client."""
BaseCommand.__init__(self) BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry) PlusAPIMixin.__init__(self, telemetry=self._telemetry)
self.project_name = get_project_name(require=True) self.project_name = get_project_name(require=True)
def _standard_no_param_error_message(self) -> None: def _standard_no_param_error_message(self) -> None:
""" """Display a standard error message when no UUID or project name is available."""
Display a standard error message when no UUID or project name is available.
"""
console.print( console.print(
"No UUID provided, project pyproject.toml not found or with error.", "No UUID provided, project pyproject.toml not found or with error.",
style="bold red", style="bold red",
) )
def _display_deployment_info(self, json_response: Dict[str, Any]) -> None: def _display_deployment_info(self, json_response: dict[str, Any]) -> None:
""" """Display deployment information.
Display deployment information.
Args: Args:
json_response (Dict[str, Any]): The deployment information to display. json_response (Dict[str, Any]): The deployment information to display.
""" """
console.print("Deploying the crew...\n", style="bold blue") console.print("Deploying the crew...\n", style="bold blue")
for key, value in json_response.items(): for key, value in json_response.items():
@@ -47,24 +40,24 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
console.print(" or") console.print(" or")
console.print(f"crewai deploy status --uuid \"{json_response['uuid']}\"") console.print(f"crewai deploy status --uuid \"{json_response['uuid']}\"")
def _display_logs(self, log_messages: List[Dict[str, Any]]) -> None: def _display_logs(self, log_messages: list[dict[str, Any]]) -> None:
""" """Display log messages.
Display log messages.
Args: Args:
log_messages (List[Dict[str, Any]]): The log messages to display. log_messages (List[Dict[str, Any]]): The log messages to display.
""" """
for log_message in log_messages: for log_message in log_messages:
console.print( console.print(
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}" f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}",
) )
def deploy(self, uuid: Optional[str] = None) -> None: def deploy(self, uuid: str | None = None) -> None:
""" """Deploy a crew using either UUID or project name.
Deploy a crew using either UUID or project name.
Args: Args:
uuid (Optional[str]): The UUID of the crew to deploy. uuid (Optional[str]): The UUID of the crew to deploy.
""" """
self._start_deployment_span = self._telemetry.start_deployment_span(uuid) self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
console.print("Starting deployment...", style="bold blue") console.print("Starting deployment...", style="bold blue")
@@ -80,9 +73,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._display_deployment_info(response.json()) self._display_deployment_info(response.json())
def create_crew(self, confirm: bool = False) -> None: def create_crew(self, confirm: bool = False) -> None:
""" """Create a new crew deployment."""
Create a new crew deployment.
"""
self._create_crew_deployment_span = ( self._create_crew_deployment_span = (
self._telemetry.create_crew_deployment_span() self._telemetry.create_crew_deployment_span()
) )
@@ -110,29 +101,28 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._display_creation_success(response.json()) self._display_creation_success(response.json())
def _confirm_input( def _confirm_input(
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool self, env_vars: dict[str, str], remote_repo_url: str, confirm: bool,
) -> None: ) -> None:
""" """Confirm input parameters with the user.
Confirm input parameters with the user.
Args: Args:
env_vars (Dict[str, str]): Environment variables. env_vars (Dict[str, str]): Environment variables.
remote_repo_url (str): Remote repository URL. remote_repo_url (str): Remote repository URL.
confirm (bool): Whether to confirm input. confirm (bool): Whether to confirm input.
""" """
if not confirm: if not confirm:
input(f"Press Enter to continue with the following Env vars: {env_vars}") input(f"Press Enter to continue with the following Env vars: {env_vars}")
input( input(
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n" f"Press Enter to continue with the following remote repository: {remote_repo_url}\n",
) )
def _create_payload( def _create_payload(
self, self,
env_vars: Dict[str, str], env_vars: dict[str, str],
remote_repo_url: str, remote_repo_url: str,
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """Create the payload for crew creation.
Create the payload for crew creation.
Args: Args:
remote_repo_url (str): Remote repository URL. remote_repo_url (str): Remote repository URL.
@@ -140,25 +130,26 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
Returns: Returns:
Dict[str, Any]: The payload for crew creation. Dict[str, Any]: The payload for crew creation.
""" """
return { return {
"deploy": { "deploy": {
"name": self.project_name, "name": self.project_name,
"repo_clone_url": remote_repo_url, "repo_clone_url": remote_repo_url,
"env": env_vars, "env": env_vars,
} },
} }
def _display_creation_success(self, json_response: Dict[str, Any]) -> None: def _display_creation_success(self, json_response: dict[str, Any]) -> None:
""" """Display success message after crew creation.
Display success message after crew creation.
Args: Args:
json_response (Dict[str, Any]): The response containing crew information. json_response (Dict[str, Any]): The response containing crew information.
""" """
console.print("Deployment created successfully!\n", style="bold green") console.print("Deployment created successfully!\n", style="bold green")
console.print( console.print(
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green" f"Name: {self.project_name} ({json_response['uuid']})", style="bold green",
) )
console.print(f"Status: {json_response['status']}", style="bold green") console.print(f"Status: {json_response['status']}", style="bold green")
console.print("\nTo (re)deploy the crew, run:") console.print("\nTo (re)deploy the crew, run:")
@@ -167,9 +158,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
console.print(f"crewai deploy push --uuid {json_response['uuid']}") console.print(f"crewai deploy push --uuid {json_response['uuid']}")
def list_crews(self) -> None: def list_crews(self) -> None:
""" """List all available crews."""
List all available crews.
"""
console.print("Listing all Crews\n", style="bold blue") console.print("Listing all Crews\n", style="bold blue")
response = self.plus_api_client.list_crews() response = self.plus_api_client.list_crews()
@@ -179,31 +168,29 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
else: else:
self._display_no_crews_message() self._display_no_crews_message()
def _display_crews(self, crews_data: List[Dict[str, Any]]) -> None: def _display_crews(self, crews_data: list[dict[str, Any]]) -> None:
""" """Display the list of crews.
Display the list of crews.
Args: Args:
crews_data (List[Dict[str, Any]]): List of crew data to display. crews_data (List[Dict[str, Any]]): List of crew data to display.
""" """
for crew_data in crews_data: for crew_data in crews_data:
console.print( console.print(
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]" f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]",
) )
def _display_no_crews_message(self) -> None: def _display_no_crews_message(self) -> None:
""" """Display a message when no crews are available."""
Display a message when no crews are available.
"""
console.print("You don't have any Crews yet. Let's create one!", style="yellow") console.print("You don't have any Crews yet. Let's create one!", style="yellow")
console.print(" crewai create crew <crew_name>", style="green") console.print(" crewai create crew <crew_name>", style="green")
def get_crew_status(self, uuid: Optional[str] = None) -> None: def get_crew_status(self, uuid: str | None = None) -> None:
""" """Get the status of a crew.
Get the status of a crew.
Args: Args:
uuid (Optional[str]): The UUID of the crew to check. uuid (Optional[str]): The UUID of the crew to check.
""" """
console.print("Fetching deployment status...", style="bold blue") console.print("Fetching deployment status...", style="bold blue")
if uuid: if uuid:
@@ -217,23 +204,23 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._validate_response(response) self._validate_response(response)
self._display_crew_status(response.json()) self._display_crew_status(response.json())
def _display_crew_status(self, status_data: Dict[str, str]) -> None: def _display_crew_status(self, status_data: dict[str, str]) -> None:
""" """Display the status of a crew.
Display the status of a crew.
Args: Args:
status_data (Dict[str, str]): The status data to display. status_data (Dict[str, str]): The status data to display.
""" """
console.print(f"Name:\t {status_data['name']}") console.print(f"Name:\t {status_data['name']}")
console.print(f"Status:\t {status_data['status']}") console.print(f"Status:\t {status_data['status']}")
def get_crew_logs(self, uuid: Optional[str], log_type: str = "deployment") -> None: def get_crew_logs(self, uuid: str | None, log_type: str = "deployment") -> None:
""" """Get logs for a crew.
Get logs for a crew.
Args: Args:
uuid (Optional[str]): The UUID of the crew to get logs for. uuid (Optional[str]): The UUID of the crew to get logs for.
log_type (str): The type of logs to retrieve (default: "deployment"). log_type (str): The type of logs to retrieve (default: "deployment").
""" """
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type) self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
console.print(f"Fetching {log_type} logs...", style="bold blue") console.print(f"Fetching {log_type} logs...", style="bold blue")
@@ -249,12 +236,12 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._validate_response(response) self._validate_response(response)
self._display_logs(response.json()) self._display_logs(response.json())
def remove_crew(self, uuid: Optional[str]) -> None: def remove_crew(self, uuid: str | None) -> None:
""" """Remove a crew deployment.
Remove a crew deployment.
Args: Args:
uuid (Optional[str]): The UUID of the crew to remove. uuid (Optional[str]): The UUID of the crew to remove.
""" """
self._remove_crew_span = self._telemetry.remove_crew_span(uuid) self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
console.print("Removing deployment...", style="bold blue") console.print("Removing deployment...", style="bold blue")
@@ -269,9 +256,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
if response.status_code == 204: if response.status_code == 204:
console.print( console.print(
f"Crew '{self.project_name}' removed successfully.", style="green" f"Crew '{self.project_name}' removed successfully.", style="green",
) )
else: else:
console.print( console.print(
f"Failed to remove crew '{self.project_name}'", style="bold red" f"Failed to remove crew '{self.project_name}'", style="bold red",
) )

View File

@@ -4,18 +4,19 @@ import click
def evaluate_crew(n_iterations: int, model: str) -> None: def evaluate_crew(n_iterations: int, model: str) -> None:
""" """Test and Evaluate the crew by running a command in the UV environment.
Test and Evaluate the crew by running a command in the UV environment.
Args: Args:
n_iterations (int): The number of iterations to test the crew. n_iterations (int): The number of iterations to test the crew.
model (str): The model to test the crew with. model (str): The model to test the crew with.
""" """
command = ["uv", "run", "test", str(n_iterations), model] command = ["uv", "run", "test", str(n_iterations), model]
try: try:
if n_iterations <= 0: if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.") msg = "The number of iterations must be a positive integer."
raise ValueError(msg)
result = subprocess.run(command, capture_output=False, text=True, check=True) result = subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -1,16 +1,18 @@
import subprocess import subprocess
from functools import lru_cache from functools import cache
class Repository: class Repository:
def __init__(self, path="."): def __init__(self, path=".") -> None:
self.path = path self.path = path
if not self.is_git_installed(): if not self.is_git_installed():
raise ValueError("Git is not installed or not found in your PATH.") msg = "Git is not installed or not found in your PATH."
raise ValueError(msg)
if not self.is_git_repo(): if not self.is_git_repo():
raise ValueError(f"{self.path} is not a Git repository.") msg = f"{self.path} is not a Git repository."
raise ValueError(msg)
self.fetch() self.fetch()
@@ -18,7 +20,7 @@ class Repository:
"""Check if Git is installed and available in the system.""" """Check if Git is installed and available in the system."""
try: try:
subprocess.run( subprocess.run(
["git", "--version"], capture_output=True, check=True, text=True ["git", "--version"], capture_output=True, check=True, text=True,
) )
return True return True
except (subprocess.CalledProcessError, FileNotFoundError): except (subprocess.CalledProcessError, FileNotFoundError):
@@ -36,7 +38,7 @@ class Repository:
encoding="utf-8", encoding="utf-8",
).strip() ).strip()
@lru_cache(maxsize=None) @cache
def is_git_repo(self) -> bool: def is_git_repo(self) -> bool:
"""Check if the current directory is a git repository.""" """Check if the current directory is a git repository."""
try: try:
@@ -62,10 +64,7 @@ class Repository:
def is_synced(self) -> bool: def is_synced(self) -> bool:
"""Return True if the Git repository is fully synced with the remote, False otherwise.""" """Return True if the Git repository is fully synced with the remote, False otherwise."""
if self.has_uncommitted_changes() or self.is_ahead_or_behind(): return not (self.has_uncommitted_changes() or self.is_ahead_or_behind())
return False
else:
return True
def origin_url(self) -> str | None: def origin_url(self) -> str | None:
"""Get the Git repository's remote URL.""" """Get the Git repository's remote URL."""

View File

@@ -8,11 +8,9 @@ import click
# so if you expect this to support more things you will need to replicate it there # so if you expect this to support more things you will need to replicate it there
# ask @joaomdmoura if you are unsure # ask @joaomdmoura if you are unsure
def install_crew(proxy_options: list[str]) -> None: def install_crew(proxy_options: list[str]) -> None:
""" """Install the crew by running the UV command to lock and install."""
Install the crew by running the UV command to lock and install.
"""
try: try:
command = ["uv", "sync"] + proxy_options command = ["uv", "sync", *proxy_options]
subprocess.run(command, check=True, capture_output=False, text=True) subprocess.run(command, check=True, capture_output=False, text=True)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:

View File

@@ -4,9 +4,7 @@ import click
def kickoff_flow() -> None: def kickoff_flow() -> None:
""" """Kickoff the flow by running a command in the UV environment."""
Kickoff the flow by running a command in the UV environment.
"""
command = ["uv", "run", "kickoff"] command = ["uv", "run", "kickoff"]
try: try:

View File

@@ -4,9 +4,7 @@ import click
def plot_flow() -> None: def plot_flow() -> None:
""" """Plot the flow by running a command in the UV environment."""
Plot the flow by running a command in the UV environment.
"""
command = ["uv", "run", "plot"] command = ["uv", "run", "plot"]
try: try:

View File

@@ -1,5 +1,4 @@
from os import getenv from os import getenv
from typing import Optional
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
@@ -8,9 +7,7 @@ from crewai.cli.version import get_crewai_version
class PlusAPI: class PlusAPI:
""" """This class exposes methods for working with the CrewAI+ API."""
This class exposes methods for working with the CrewAI+ API.
"""
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools" TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
CREWS_RESOURCE = "/crewai_plus/api/v1/crews" CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
@@ -42,7 +39,7 @@ class PlusAPI:
handle: str, handle: str,
is_public: bool, is_public: bool,
version: str, version: str,
description: Optional[str], description: str | None,
encoded_file: str, encoded_file: str,
): ):
params = { params = {
@@ -56,7 +53,7 @@ class PlusAPI:
def deploy_by_name(self, project_name: str) -> requests.Response: def deploy_by_name(self, project_name: str) -> requests.Response:
return self._make_request( return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy" "POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy",
) )
def deploy_by_uuid(self, uuid: str) -> requests.Response: def deploy_by_uuid(self, uuid: str) -> requests.Response:
@@ -64,29 +61,29 @@ class PlusAPI:
def crew_status_by_name(self, project_name: str) -> requests.Response: def crew_status_by_name(self, project_name: str) -> requests.Response:
return self._make_request( return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status" "GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status",
) )
def crew_status_by_uuid(self, uuid: str) -> requests.Response: def crew_status_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status") return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name( def crew_by_name(
self, project_name: str, log_type: str = "deployment" self, project_name: str, log_type: str = "deployment",
) -> requests.Response: ) -> requests.Response:
return self._make_request( return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}" "GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}",
) )
def crew_by_uuid( def crew_by_uuid(
self, uuid: str, log_type: str = "deployment" self, uuid: str, log_type: str = "deployment",
) -> requests.Response: ) -> requests.Response:
return self._make_request( return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}" "GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}",
) )
def delete_crew_by_name(self, project_name: str) -> requests.Response: def delete_crew_by_name(self, project_name: str) -> requests.Response:
return self._make_request( return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}" "DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}",
) )
def delete_crew_by_uuid(self, uuid: str) -> requests.Response: def delete_crew_by_uuid(self, uuid: str) -> requests.Response:

View File

@@ -10,8 +10,7 @@ from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message, choices): def select_choice(prompt_message, choices):
""" """Presents a list of choices to the user and prompts them to select one.
Presents a list of choices to the user and prompts them to select one.
Args: Args:
- prompt_message (str): The message to display to the user before presenting the choices. - prompt_message (str): The message to display to the user before presenting the choices.
@@ -19,11 +18,11 @@ def select_choice(prompt_message, choices):
Returns: Returns:
- str: The selected choice from the list, or None if the user chooses to quit. - str: The selected choice from the list, or None if the user chooses to quit.
"""
"""
provider_models = get_provider_data() provider_models = get_provider_data()
if not provider_models: if not provider_models:
return return None
click.secho(prompt_message, fg="cyan") click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1): for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan") click.secho(f"{idx}. {choice}", fg="cyan")
@@ -31,7 +30,7 @@ def select_choice(prompt_message, choices):
while True: while True:
choice = click.prompt( choice = click.prompt(
"Enter the number of your choice or 'q' to quit", type=str "Enter the number of your choice or 'q' to quit", type=str,
) )
if choice.lower() == "q": if choice.lower() == "q":
@@ -51,8 +50,7 @@ def select_choice(prompt_message, choices):
def select_provider(provider_models): def select_provider(provider_models):
""" """Presents a list of providers to the user and prompts them to select one.
Presents a list of providers to the user and prompts them to select one.
Args: Args:
- provider_models (dict): A dictionary of provider models. - provider_models (dict): A dictionary of provider models.
@@ -60,12 +58,13 @@ def select_provider(provider_models):
Returns: Returns:
- str: The selected provider - str: The selected provider
- None: If user explicitly quits - None: If user explicitly quits
""" """
predefined_providers = [p.lower() for p in PROVIDERS] predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys()))) all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice( provider = select_choice(
"Select a provider to set up:", predefined_providers + ["other"] "Select a provider to set up:", [*predefined_providers, "other"],
) )
if provider is None: # User typed 'q' if provider is None: # User typed 'q'
return None return None
@@ -79,8 +78,7 @@ def select_provider(provider_models):
def select_model(provider, provider_models): def select_model(provider, provider_models):
""" """Presents a list of models for a given provider to the user and prompts them to select one.
Presents a list of models for a given provider to the user and prompts them to select one.
Args: Args:
- provider (str): The provider for which to select a model. - provider (str): The provider for which to select a model.
@@ -88,6 +86,7 @@ def select_model(provider, provider_models):
Returns: Returns:
- str: The selected model, or None if the operation is aborted or an invalid selection is made. - str: The selected model, or None if the operation is aborted or an invalid selection is made.
""" """
predefined_providers = [p.lower() for p in PROVIDERS] predefined_providers = [p.lower() for p in PROVIDERS]
@@ -100,15 +99,13 @@ def select_model(provider, provider_models):
click.secho(f"No models available for provider '{provider}'.", fg="red") click.secho(f"No models available for provider '{provider}'.", fg="red")
return None return None
selected_model = select_choice( return select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models f"Select a model to use for {provider.capitalize()}:", available_models,
) )
return selected_model
def load_provider_data(cache_file, cache_expiry): def load_provider_data(cache_file, cache_expiry):
""" """Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
Args: Args:
- cache_file (Path): The path to the cache file. - cache_file (Path): The path to the cache file.
@@ -116,6 +113,7 @@ def load_provider_data(cache_file, cache_expiry):
Returns: Returns:
- dict or None: The loaded provider data or None if the operation fails. - dict or None: The loaded provider data or None if the operation fails.
""" """
current_time = time.time() current_time = time.time()
if ( if (
@@ -126,7 +124,7 @@ def load_provider_data(cache_file, cache_expiry):
if data: if data:
return data return data
click.secho( click.secho(
"Cache is corrupted. Fetching provider data from the web...", fg="yellow" "Cache is corrupted. Fetching provider data from the web...", fg="yellow",
) )
else: else:
click.secho( click.secho(
@@ -137,31 +135,31 @@ def load_provider_data(cache_file, cache_expiry):
def read_cache_file(cache_file): def read_cache_file(cache_file):
""" """Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
Args: Args:
- cache_file (Path): The path to the cache file. - cache_file (Path): The path to the cache file.
Returns: Returns:
- dict or None: The JSON content of the cache file or None if the JSON is invalid. - dict or None: The JSON content of the cache file or None if the JSON is invalid.
""" """
try: try:
with open(cache_file, "r") as f: with open(cache_file) as f:
return json.load(f) return json.load(f)
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
def fetch_provider_data(cache_file): def fetch_provider_data(cache_file):
""" """Fetches provider data from a specified URL and caches it to a file.
Fetches provider data from a specified URL and caches it to a file.
Args: Args:
- cache_file (Path): The path to the cache file. - cache_file (Path): The path to the cache file.
Returns: Returns:
- dict or None: The fetched provider data or None if the operation fails. - dict or None: The fetched provider data or None if the operation fails.
""" """
try: try:
response = requests.get(JSON_URL, stream=True, timeout=60) response = requests.get(JSON_URL, stream=True, timeout=60)
@@ -178,20 +176,20 @@ def fetch_provider_data(cache_file):
def download_data(response): def download_data(response):
""" """Downloads data from a given HTTP response and returns the JSON content.
Downloads data from a given HTTP response and returns the JSON content.
Args: Args:
- response (requests.Response): The HTTP response object. - response (requests.Response): The HTTP response object.
Returns: Returns:
- dict: The JSON content of the response. - dict: The JSON content of the response.
""" """
total_size = int(response.headers.get("content-length", 0)) total_size = int(response.headers.get("content-length", 0))
block_size = 8192 block_size = 8192
data_chunks = [] data_chunks = []
with click.progressbar( with click.progressbar(
length=total_size, label="Downloading", show_pos=True length=total_size, label="Downloading", show_pos=True,
) as progress_bar: ) as progress_bar:
for chunk in response.iter_content(block_size): for chunk in response.iter_content(block_size):
if chunk: if chunk:
@@ -202,11 +200,11 @@ def download_data(response):
def get_provider_data(): def get_provider_data():
""" """Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
Returns: Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails. - dict or None: A dictionary of providers mapped to their models or None if the operation fails.
""" """
cache_dir = Path.home() / ".crewai" cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True) cache_dir.mkdir(exist_ok=True)

View File

@@ -4,11 +4,11 @@ import click
def replay_task_command(task_id: str) -> None: def replay_task_command(task_id: str) -> None:
""" """Replay the crew execution from a specific task.
Replay the crew execution from a specific task.
Args: Args:
task_id (str): The ID of the task to replay from. task_id (str): The ID of the task to replay from.
""" """
command = ["uv", "run", "replay", task_id] command = ["uv", "run", "replay", task_id]

View File

@@ -13,8 +13,7 @@ def reset_memories_command(
kickoff_outputs, kickoff_outputs,
all, all,
) -> None: ) -> None:
""" """Reset the crew memories.
Reset the crew memories.
Args: Args:
long (bool): Whether to reset the long-term memory. long (bool): Whether to reset the long-term memory.
@@ -23,49 +22,50 @@ def reset_memories_command(
kickoff_outputs (bool): Whether to reset the latest kickoff task outputs. kickoff_outputs (bool): Whether to reset the latest kickoff task outputs.
all (bool): Whether to reset all memories. all (bool): Whether to reset all memories.
knowledge (bool): Whether to reset the knowledge. knowledge (bool): Whether to reset the knowledge.
"""
"""
try: try:
if not any([long, short, entity, kickoff_outputs, knowledge, all]): if not any([long, short, entity, kickoff_outputs, knowledge, all]):
click.echo( click.echo(
"No memory type specified. Please specify at least one type to reset." "No memory type specified. Please specify at least one type to reset.",
) )
return return
crews = get_crews() crews = get_crews()
if not crews: if not crews:
raise ValueError("No crew found.") msg = "No crew found."
raise ValueError(msg)
for crew in crews: for crew in crews:
if all: if all:
crew.reset_memories(command_type="all") crew.reset_memories(command_type="all")
click.echo( click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed." f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed.",
) )
continue continue
if long: if long:
crew.reset_memories(command_type="long") crew.reset_memories(command_type="long")
click.echo( click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset." f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset.",
) )
if short: if short:
crew.reset_memories(command_type="short") crew.reset_memories(command_type="short")
click.echo( click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset." f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset.",
) )
if entity: if entity:
crew.reset_memories(command_type="entity") crew.reset_memories(command_type="entity")
click.echo( click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset." f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset.",
) )
if kickoff_outputs: if kickoff_outputs:
crew.reset_memories(command_type="kickoff_outputs") crew.reset_memories(command_type="kickoff_outputs")
click.echo( click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset." f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset.",
) )
if knowledge: if knowledge:
crew.reset_memories(command_type="knowledge") crew.reset_memories(command_type="knowledge")
click.echo( click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset." f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset.",
) )
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:

View File

@@ -1,6 +1,5 @@
import subprocess import subprocess
from enum import Enum from enum import Enum
from typing import List, Optional
import click import click
from packaging import version from packaging import version
@@ -15,8 +14,7 @@ class CrewType(Enum):
def run_crew() -> None: def run_crew() -> None:
""" """Run the crew or flow by running a command in the UV environment.
Run the crew or flow by running a command in the UV environment.
Starting from version 0.103.0, this command can be used to run both Starting from version 0.103.0, this command can be used to run both
standard crews and flows. For flows, it detects the type from pyproject.toml standard crews and flows. For flows, it detects the type from pyproject.toml
@@ -48,11 +46,11 @@ def run_crew() -> None:
def execute_command(crew_type: CrewType) -> None: def execute_command(crew_type: CrewType) -> None:
""" """Execute the appropriate command based on crew type.
Execute the appropriate command based on crew type.
Args: Args:
crew_type: The type of crew to run crew_type: The type of crew to run
""" """
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"] command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
@@ -67,12 +65,12 @@ def execute_command(crew_type: CrewType) -> None:
def handle_error(error: subprocess.CalledProcessError, crew_type: CrewType) -> None: def handle_error(error: subprocess.CalledProcessError, crew_type: CrewType) -> None:
""" """Handle subprocess errors with appropriate messaging.
Handle subprocess errors with appropriate messaging.
Args: Args:
error: The subprocess error that occurred error: The subprocess error that occurred
crew_type: The type of crew that was being run crew_type: The type of crew that was being run
""" """
entity_type = "flow" if crew_type == CrewType.FLOW else "crew" entity_type = "flow" if crew_type == CrewType.FLOW else "crew"
click.echo(f"An error occurred while running the {entity_type}: {error}", err=True) click.echo(f"An error occurred while running the {entity_type}: {error}", err=True)

View File

@@ -22,15 +22,13 @@ console = Console()
class ToolCommand(BaseCommand, PlusAPIMixin): class ToolCommand(BaseCommand, PlusAPIMixin):
""" """A class to handle tool repository related operations for CrewAI projects."""
A class to handle tool repository related operations for CrewAI projects.
"""
def __init__(self): def __init__(self) -> None:
BaseCommand.__init__(self) BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry) PlusAPIMixin.__init__(self, telemetry=self._telemetry)
def create(self, handle: str): def create(self, handle: str) -> None:
self._ensure_not_in_project() self._ensure_not_in_project()
folder_name = handle.replace(" ", "_").replace("-", "_").lower() folder_name = handle.replace(" ", "_").replace("-", "_").lower()
@@ -40,8 +38,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
if project_root.exists(): if project_root.exists():
click.secho(f"Folder {folder_name} already exists.", fg="red") click.secho(f"Folder {folder_name} already exists.", fg="red")
raise SystemExit raise SystemExit
else: os.makedirs(project_root)
os.makedirs(project_root)
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True) click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
@@ -56,12 +53,12 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
self.login() self.login()
subprocess.run(["git", "init"], check=True) subprocess.run(["git", "init"], check=True)
console.print( console.print(
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]" f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]",
) )
finally: finally:
os.chdir(old_directory) os.chdir(old_directory)
def publish(self, is_public: bool, force: bool = False): def publish(self, is_public: bool, force: bool = False) -> None:
if not git.Repository().is_synced() and not force: if not git.Repository().is_synced() and not force:
console.print( console.print(
"[bold red]Failed to publish tool.[/bold red]\n" "[bold red]Failed to publish tool.[/bold red]\n"
@@ -69,9 +66,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
"* [bold]Commit[/bold] your changes.\n" "* [bold]Commit[/bold] your changes.\n"
"* [bold]Push[/bold] to sync with the remote.\n" "* [bold]Push[/bold] to sync with the remote.\n"
"* [bold]Pull[/bold] the latest changes from the remote.\n" "* [bold]Pull[/bold] the latest changes from the remote.\n"
"\nOnce your repository is up-to-date, retry publishing the tool." "\nOnce your repository is up-to-date, retry publishing the tool.",
) )
raise SystemExit() raise SystemExit
project_name = get_project_name(require=True) project_name = get_project_name(require=True)
assert isinstance(project_name, str) assert isinstance(project_name, str)
@@ -90,7 +87,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
) )
tarball_filename = next( tarball_filename = next(
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None (f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None,
) )
if not tarball_filename: if not tarball_filename:
console.print( console.print(
@@ -123,7 +120,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
style="bold green", style="bold green",
) )
def install(self, handle: str): def install(self, handle: str) -> None:
get_response = self.plus_api_client.get_tool(handle) get_response = self.plus_api_client.get_tool(handle)
if get_response.status_code == 404: if get_response.status_code == 404:
@@ -132,9 +129,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
style="bold red", style="bold red",
) )
raise SystemExit raise SystemExit
elif get_response.status_code != 200: if get_response.status_code != 200:
console.print( console.print(
"Failed to get tool details. Please try again later.", style="bold red" "Failed to get tool details. Please try again later.", style="bold red",
) )
raise SystemExit raise SystemExit
@@ -142,7 +139,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
console.print(f"Successfully installed {handle}", style="bold green") console.print(f"Successfully installed {handle}", style="bold green")
def login(self): def login(self) -> None:
login_response = self.plus_api_client.login_to_tool_repository() login_response = self.plus_api_client.login_to_tool_repository()
if login_response.status_code != 200: if login_response.status_code != 200:
@@ -164,10 +161,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
settings.dump() settings.dump()
console.print( console.print(
"Successfully authenticated to the tool repository.", style="bold green" "Successfully authenticated to the tool repository.", style="bold green",
) )
def _add_package(self, tool_details): def _add_package(self, tool_details) -> None:
tool_handle = tool_details["handle"] tool_handle = tool_details["handle"]
repository_handle = tool_details["repository"]["handle"] repository_handle = tool_details["repository"]["handle"]
repository_url = tool_details["repository"]["url"] repository_url = tool_details["repository"]["url"]
@@ -192,16 +189,16 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
click.echo(add_package_result.stderr, err=True) click.echo(add_package_result.stderr, err=True)
raise SystemExit raise SystemExit
def _ensure_not_in_project(self): def _ensure_not_in_project(self) -> None:
if os.path.isfile("./pyproject.toml"): if os.path.isfile("./pyproject.toml"):
console.print( console.print(
"[bold red]Oops! It looks like you're inside a project.[/bold red]" "[bold red]Oops! It looks like you're inside a project.[/bold red]",
) )
console.print( console.print(
"You can't create a new tool while inside an existing project." "You can't create a new tool while inside an existing project.",
) )
console.print( console.print(
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again." "[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again.",
) )
raise SystemExit raise SystemExit
@@ -211,10 +208,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
env = os.environ.copy() env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str( env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or "" settings.tool_repository_username or "",
) )
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str( env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or "" settings.tool_repository_password or "",
) )
return env return env

View File

@@ -4,20 +4,22 @@ import click
def train_crew(n_iterations: int, filename: str) -> None: def train_crew(n_iterations: int, filename: str) -> None:
""" """Train the crew by running a command in the UV environment.
Train the crew by running a command in the UV environment.
Args: Args:
n_iterations (int): The number of iterations to train the crew. n_iterations (int): The number of iterations to train the crew.
""" """
command = ["uv", "run", "train", str(n_iterations), filename] command = ["uv", "run", "train", str(n_iterations), filename]
try: try:
if n_iterations <= 0: if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.") msg = "The number of iterations must be a positive integer."
raise ValueError(msg)
if not filename.endswith(".pkl"): if not filename.endswith(".pkl"):
raise ValueError("The filename must not end with .pkl") msg = "The filename must not end with .pkl"
raise ValueError(msg)
result = subprocess.run(command, capture_output=False, text=True, check=True) result = subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -11,9 +11,8 @@ def update_crew() -> None:
migrate_pyproject("pyproject.toml", "pyproject.toml") migrate_pyproject("pyproject.toml", "pyproject.toml")
def migrate_pyproject(input_file, output_file): def migrate_pyproject(input_file, output_file) -> None:
""" """Migrate the pyproject.toml to the new format.
Migrate the pyproject.toml to the new format.
This function is used to migrate the pyproject.toml to the new format. This function is used to migrate the pyproject.toml to the new format.
And it will be used to migrate the pyproject.toml to the new format when uv is used. And it will be used to migrate the pyproject.toml to the new format when uv is used.
@@ -81,7 +80,7 @@ def migrate_pyproject(input_file, output_file):
# Extract the module name from any existing script # Extract the module name from any existing script
existing_scripts = new_pyproject["project"]["scripts"] existing_scripts = new_pyproject["project"]["scripts"]
module_name = next( module_name = next(
(value.split(".")[0] for value in existing_scripts.values() if "." in value) value.split(".")[0] for value in existing_scripts.values() if "." in value
) )
new_pyproject["project"]["scripts"]["run_crew"] = f"{module_name}.main:run" new_pyproject["project"]["scripts"]["run_crew"] = f"{module_name}.main:run"
@@ -93,22 +92,19 @@ def migrate_pyproject(input_file, output_file):
# Backup the old pyproject.toml # Backup the old pyproject.toml
backup_file = "pyproject-old.toml" backup_file = "pyproject-old.toml"
shutil.copy2(input_file, backup_file) shutil.copy2(input_file, backup_file)
print(f"Original pyproject.toml backed up as {backup_file}")
# Rename the poetry.lock file # Rename the poetry.lock file
lock_file = "poetry.lock" lock_file = "poetry.lock"
lock_backup = "poetry-old.lock" lock_backup = "poetry-old.lock"
if os.path.exists(lock_file): if os.path.exists(lock_file):
os.rename(lock_file, lock_backup) os.rename(lock_file, lock_backup)
print(f"Original poetry.lock renamed to {lock_backup}")
else: else:
print("No poetry.lock file found to rename.") pass
# Write the new pyproject.toml # Write the new pyproject.toml
with open(output_file, "wb") as f: with open(output_file, "wb") as f:
tomli_w.dump(new_pyproject, f) tomli_w.dump(new_pyproject, f)
print(f"Migration complete. New pyproject.toml written to {output_file}")
def parse_version(version: str) -> str: def parse_version(version: str) -> str:

View File

@@ -3,7 +3,7 @@ import shutil
import sys import sys
from functools import reduce from functools import reduce
from inspect import isfunction, ismethod from inspect import isfunction, ismethod
from typing import Any, Dict, List, get_type_hints from typing import Any, get_type_hints
import click import click
import tomli import tomli
@@ -19,9 +19,9 @@ if sys.version_info >= (3, 11):
console = Console() console = Console()
def copy_template(src, dst, name, class_name, folder_name): def copy_template(src, dst, name, class_name, folder_name) -> None:
"""Copy a file from src to dst.""" """Copy a file from src to dst."""
with open(src, "r") as file: with open(src) as file:
content = file.read() content = file.read()
# Interpolate the content # Interpolate the content
@@ -39,8 +39,7 @@ def copy_template(src, dst, name, class_name, folder_name):
def read_toml(file_path: str = "pyproject.toml"): def read_toml(file_path: str = "pyproject.toml"):
"""Read the content of a TOML file and return it as a dictionary.""" """Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
toml_dict = tomli.load(f) return tomli.load(f)
return toml_dict
def parse_toml(content): def parse_toml(content):
@@ -50,59 +49,56 @@ def parse_toml(content):
def get_project_name( def get_project_name(
pyproject_path: str = "pyproject.toml", require: bool = False pyproject_path: str = "pyproject.toml", require: bool = False,
) -> str | None: ) -> str | None:
"""Get the project name from the pyproject.toml file.""" """Get the project name from the pyproject.toml file."""
return _get_project_attribute(pyproject_path, ["project", "name"], require=require) return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
def get_project_version( def get_project_version(
pyproject_path: str = "pyproject.toml", require: bool = False pyproject_path: str = "pyproject.toml", require: bool = False,
) -> str | None: ) -> str | None:
"""Get the project version from the pyproject.toml file.""" """Get the project version from the pyproject.toml file."""
return _get_project_attribute( return _get_project_attribute(
pyproject_path, ["project", "version"], require=require pyproject_path, ["project", "version"], require=require,
) )
def get_project_description( def get_project_description(
pyproject_path: str = "pyproject.toml", require: bool = False pyproject_path: str = "pyproject.toml", require: bool = False,
) -> str | None: ) -> str | None:
"""Get the project description from the pyproject.toml file.""" """Get the project description from the pyproject.toml file."""
return _get_project_attribute( return _get_project_attribute(
pyproject_path, ["project", "description"], require=require pyproject_path, ["project", "description"], require=require,
) )
def _get_project_attribute( def _get_project_attribute(
pyproject_path: str, keys: List[str], require: bool pyproject_path: str, keys: list[str], require: bool,
) -> Any | None: ) -> Any | None:
"""Get an attribute from the pyproject.toml file.""" """Get an attribute from the pyproject.toml file."""
attribute = None attribute = None
try: try:
with open(pyproject_path, "r") as f: with open(pyproject_path) as f:
pyproject_content = parse_toml(f.read()) pyproject_content = parse_toml(f.read())
dependencies = ( dependencies = (
_get_nested_value(pyproject_content, ["project", "dependencies"]) or [] _get_nested_value(pyproject_content, ["project", "dependencies"]) or []
) )
if not any(True for dep in dependencies if "crewai" in dep): if not any(True for dep in dependencies if "crewai" in dep):
raise Exception("crewai is not in the dependencies.") msg = "crewai is not in the dependencies."
raise Exception(msg)
attribute = _get_nested_value(pyproject_content, keys) attribute = _get_nested_value(pyproject_content, keys)
except FileNotFoundError: except FileNotFoundError:
print(f"Error: {pyproject_path} not found.") pass
except KeyError: except KeyError:
print(f"Error: {pyproject_path} is not a valid pyproject.toml file.") pass
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception: # type: ignore
print( pass
f"Error: {pyproject_path} is not a valid TOML file." except Exception:
if sys.version_info >= (3, 11) pass
else f"Error reading the pyproject.toml file: {e}"
)
except Exception as e:
print(f"Error reading the pyproject.toml file: {e}")
if require and not attribute: if require and not attribute:
console.print( console.print(
@@ -114,7 +110,7 @@ def _get_project_attribute(
return attribute return attribute
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any: def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data) return reduce(dict.__getitem__, keys, data)
@@ -122,7 +118,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
"""Fetch the environment variables from a .env file and return them as a dictionary.""" """Fetch the environment variables from a .env file and return them as a dictionary."""
try: try:
# Read the .env file # Read the .env file
with open(env_file_path, "r") as f: with open(env_file_path) as f:
env_content = f.read() env_content = f.read()
# Parse the .env file content to a dictionary # Parse the .env file content to a dictionary
@@ -135,14 +131,14 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
return env_dict return env_dict
except FileNotFoundError: except FileNotFoundError:
print(f"Error: {env_file_path} not found.") pass
except Exception as e: except Exception:
print(f"Error reading the .env file: {e}") pass
return {} return {}
def tree_copy(source, destination): def tree_copy(source, destination) -> None:
"""Copies the entire directory structure from the source to the destination.""" """Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source): for item in os.listdir(source):
source_item = os.path.join(source, item) source_item = os.path.join(source, item)
@@ -153,7 +149,7 @@ def tree_copy(source, destination):
shutil.copy2(source_item, destination_item) shutil.copy2(source_item, destination_item)
def tree_find_and_replace(directory, find, replace): def tree_find_and_replace(directory, find, replace) -> None:
"""Recursively searches through a directory, replacing a target string in """Recursively searches through a directory, replacing a target string in
both file contents and filenames with a specified replacement string. both file contents and filenames with a specified replacement string.
""" """
@@ -161,7 +157,7 @@ def tree_find_and_replace(directory, find, replace):
for filename in files: for filename in files:
filepath = os.path.join(path, filename) filepath = os.path.join(path, filename)
with open(filepath, "r") as file: with open(filepath) as file:
contents = file.read() contents = file.read()
with open(filepath, "w") as file: with open(filepath, "w") as file:
file.write(contents.replace(find, replace)) file.write(contents.replace(find, replace))
@@ -180,19 +176,19 @@ def tree_find_and_replace(directory, find, replace):
def load_env_vars(folder_path): def load_env_vars(folder_path):
""" """Loads environment variables from a .env file in the specified folder path.
Loads environment variables from a .env file in the specified folder path.
Args: Args:
- folder_path (Path): The path to the folder containing the .env file. - folder_path (Path): The path to the folder containing the .env file.
Returns: Returns:
- dict: A dictionary of environment variables. - dict: A dictionary of environment variables.
""" """
env_file_path = folder_path / ".env" env_file_path = folder_path / ".env"
env_vars = {} env_vars = {}
if env_file_path.exists(): if env_file_path.exists():
with open(env_file_path, "r") as file: with open(env_file_path) as file:
for line in file: for line in file:
key, _, value = line.strip().partition("=") key, _, value = line.strip().partition("=")
if key and value: if key and value:
@@ -201,8 +197,7 @@ def load_env_vars(folder_path):
def update_env_vars(env_vars, provider, model): def update_env_vars(env_vars, provider, model):
""" """Updates environment variables with the API key for the selected provider and model.
Updates environment variables with the API key for the selected provider and model.
Args: Args:
- env_vars (dict): Environment variables dictionary. - env_vars (dict): Environment variables dictionary.
@@ -211,6 +206,7 @@ def update_env_vars(env_vars, provider, model):
Returns: Returns:
- None - None
""" """
api_key_var = ENV_VARS.get( api_key_var = ENV_VARS.get(
provider, provider,
@@ -218,14 +214,14 @@ def update_env_vars(env_vars, provider, model):
click.prompt( click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key", f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str, type=str,
) ),
], ],
)[0] )[0]
if api_key_var not in env_vars: if api_key_var not in env_vars:
try: try:
env_vars[api_key_var] = click.prompt( env_vars[api_key_var] = click.prompt(
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True f"Enter your {provider.capitalize()} API key", type=str, hide_input=True,
) )
except click.exceptions.Abort: except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red") click.secho("Operation aborted by the user.", fg="red")
@@ -238,13 +234,13 @@ def update_env_vars(env_vars, provider, model):
return env_vars return env_vars
def write_env_file(folder_path, env_vars): def write_env_file(folder_path, env_vars) -> None:
""" """Writes environment variables to a .env file in the specified folder.
Writes environment variables to a .env file in the specified folder.
Args: Args:
- folder_path (Path): The path to the folder where the .env file will be written. - folder_path (Path): The path to the folder where the .env file will be written.
- env_vars (dict): A dictionary of environment variables to write. - env_vars (dict): A dictionary of environment variables to write.
""" """
env_file_path = folder_path / ".env" env_file_path = folder_path / ".env"
with open(env_file_path, "w") as file: with open(env_file_path, "w") as file:
@@ -263,7 +259,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
crew_os_path = os.path.join(root, crew_path) crew_os_path = os.path.join(root, crew_path)
try: try:
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
"crew_module", crew_os_path "crew_module", crew_os_path,
) )
if not spec or not spec.loader: if not spec or not spec.loader:
continue continue
@@ -277,19 +273,16 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
try: try:
crew_instances.extend(fetch_crews(module_attr)) crew_instances.extend(fetch_crews(module_attr))
except Exception as e: except Exception:
print(f"Error processing attribute {attr_name}: {e}")
continue continue
except Exception as exec_error: except Exception:
print(f"Error executing module: {exec_error}")
import traceback import traceback
print(f"Traceback: {traceback.format_exc()}")
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
if require: if require:
console.print( console.print(
f"Error importing crew from {crew_path}: {str(e)}", f"Error importing crew from {crew_path}: {e!s}",
style="bold red", style="bold red",
) )
continue continue
@@ -303,7 +296,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
except Exception as e: except Exception as e:
if require: if require:
console.print( console.print(
f"Unexpected error while loading crew: {str(e)}", style="bold red" f"Unexpected error while loading crew: {e!s}", style="bold red",
) )
raise SystemExit raise SystemExit
return crew_instances return crew_instances
@@ -317,13 +310,12 @@ def get_crew_instance(module_attr) -> Crew | None:
): ):
return module_attr().crew() return module_attr().crew()
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints( if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
module_attr module_attr,
).get("return") is Crew: ).get("return") is Crew:
return module_attr() return module_attr()
elif isinstance(module_attr, Crew): if isinstance(module_attr, Crew):
return module_attr return module_attr
else: return None
return None
def fetch_crews(module_attr) -> list[Crew]: def fetch_crews(module_attr) -> list[Crew]:

View File

@@ -2,5 +2,5 @@ import importlib.metadata
def get_crewai_version() -> str: def get_crewai_version() -> str:
"""Get the version number of CrewAI running the CLI""" """Get the version number of CrewAI running the CLI."""
return importlib.metadata.version("crewai") return importlib.metadata.version("crewai")

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import json import json
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -12,27 +12,28 @@ class CrewOutput(BaseModel):
"""Class that represents the result of a crew.""" """Class that represents the result of a crew."""
raw: str = Field(description="Raw output of crew", default="") raw: str = Field(description="Raw output of crew", default="")
pydantic: Optional[BaseModel] = Field( pydantic: BaseModel | None = Field(
description="Pydantic output of Crew", default=None description="Pydantic output of Crew", default=None,
) )
json_dict: Optional[Dict[str, Any]] = Field( json_dict: dict[str, Any] | None = Field(
description="JSON dict output of Crew", default=None description="JSON dict output of Crew", default=None,
) )
tasks_output: list[TaskOutput] = Field( tasks_output: list[TaskOutput] = Field(
description="Output of each task", default=[] description="Output of each task", default=[],
) )
token_usage: UsageMetrics = Field(description="Processed token summary", default={}) token_usage: UsageMetrics = Field(description="Processed token summary", default={})
@property @property
def json(self) -> Optional[str]: def json(self) -> str | None:
if self.tasks_output[-1].output_format != OutputFormat.JSON: if self.tasks_output[-1].output_format != OutputFormat.JSON:
msg = "No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
raise ValueError( raise ValueError(
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew." msg,
) )
return json.dumps(self.json_dict) return json.dumps(self.json_dict)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary.""" """Convert json_output and pydantic_output to a dictionary."""
output_dict = {} output_dict = {}
if self.json_dict: if self.json_dict:
@@ -44,12 +45,12 @@ class CrewOutput(BaseModel):
def __getitem__(self, key): def __getitem__(self, key):
if self.pydantic and hasattr(self.pydantic, key): if self.pydantic and hasattr(self.pydantic, key):
return getattr(self.pydantic, key) return getattr(self.pydantic, key)
elif self.json_dict and key in self.json_dict: if self.json_dict and key in self.json_dict:
return self.json_dict[key] return self.json_dict[key]
else: msg = f"Key '{key}' not found in CrewOutput."
raise KeyError(f"Key '{key}' not found in CrewOutput.") raise KeyError(msg)
def __str__(self): def __str__(self) -> str:
if self.pydantic: if self.pydantic:
return str(self.pydantic) return str(self.pydantic)
if self.json_dict: if self.json_dict:

View File

@@ -2,17 +2,11 @@ import asyncio
import copy import copy
import inspect import inspect
import logging import logging
from collections.abc import Callable
from typing import ( from typing import (
Any, Any,
Callable,
Dict,
Generic, Generic,
List,
Optional,
Set,
Type,
TypeVar, TypeVar,
Union,
cast, cast,
) )
from uuid import uuid4 from uuid import uuid4
@@ -48,14 +42,14 @@ class FlowState(BaseModel):
# Type variables with explicit bounds # Type variables with explicit bounds
T = TypeVar( T = TypeVar(
"T", bound=Union[Dict[str, Any], BaseModel] "T", bound=dict[str, Any] | BaseModel,
) # Generic flow state type parameter ) # Generic flow state type parameter
StateT = TypeVar( StateT = TypeVar(
"StateT", bound=Union[Dict[str, Any], BaseModel] "StateT", bound=dict[str, Any] | BaseModel,
) # State validation type parameter ) # State validation type parameter
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT: def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
"""Ensure state matches expected type with proper validation. """Ensure state matches expected type with proper validation.
Args: Args:
@@ -68,6 +62,7 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
Raises: Raises:
TypeError: If state doesn't match expected type TypeError: If state doesn't match expected type
ValueError: If state validation fails ValueError: If state validation fails
""" """
"""Ensure state matches expected type with proper validation. """Ensure state matches expected type with proper validation.
@@ -84,20 +79,22 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
""" """
if expected_type is dict: if expected_type is dict:
if not isinstance(state, dict): if not isinstance(state, dict):
raise TypeError(f"Expected dict, got {type(state).__name__}") msg = f"Expected dict, got {type(state).__name__}"
return cast(StateT, state) raise TypeError(msg)
return cast("StateT", state)
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel): if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
if not isinstance(state, expected_type): if not isinstance(state, expected_type):
msg = f"Expected {expected_type.__name__}, got {type(state).__name__}"
raise TypeError( raise TypeError(
f"Expected {expected_type.__name__}, got {type(state).__name__}" msg,
) )
return cast(StateT, state) return cast("StateT", state)
raise TypeError(f"Invalid expected_type: {expected_type}") msg = f"Invalid expected_type: {expected_type}"
raise TypeError(msg)
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: def start(condition: str | dict | Callable | None = None) -> Callable:
""" """Marks a method as a flow's starting point.
Marks a method as a flow's starting point.
This decorator designates a method as an entry point for the flow execution. This decorator designates a method as an entry point for the flow execution.
It can optionally specify conditions that trigger the start based on other It can optionally specify conditions that trigger the start based on other
@@ -135,6 +132,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
>>> @start(and_("method1", "method2")) # Start after multiple methods >>> @start(and_("method1", "method2")) # Start after multiple methods
>>> def complex_start(self): >>> def complex_start(self):
... pass ... pass
""" """
def decorator(func): def decorator(func):
@@ -154,17 +152,17 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
func.__trigger_methods__ = [condition.__name__] func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR" func.__condition_type__ = "OR"
else: else:
msg = "Condition must be a method, string, or a result of or_() or and_()"
raise ValueError( raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()" msg,
) )
return func return func
return decorator return decorator
def listen(condition: Union[str, dict, Callable]) -> Callable: def listen(condition: str | dict | Callable) -> Callable:
""" """Creates a listener that executes when specified conditions are met.
Creates a listener that executes when specified conditions are met.
This decorator sets up a method to execute in response to other method This decorator sets up a method to execute in response to other method
executions in the flow. It supports both simple and complex triggering executions in the flow. It supports both simple and complex triggering
@@ -197,6 +195,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
>>> @listen(or_("success", "failure")) # Listen to multiple methods >>> @listen(or_("success", "failure")) # Listen to multiple methods
>>> def handle_completion(self): >>> def handle_completion(self):
... pass ... pass
""" """
def decorator(func): def decorator(func):
@@ -214,17 +213,17 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
func.__trigger_methods__ = [condition.__name__] func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR" func.__condition_type__ = "OR"
else: else:
msg = "Condition must be a method, string, or a result of or_() or and_()"
raise ValueError( raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()" msg,
) )
return func return func
return decorator return decorator
def router(condition: Union[str, dict, Callable]) -> Callable: def router(condition: str | dict | Callable) -> Callable:
""" """Creates a routing method that directs flow execution based on conditions.
Creates a routing method that directs flow execution based on conditions.
This decorator marks a method as a router, which can dynamically determine This decorator marks a method as a router, which can dynamically determine
the next steps in the flow based on its return value. Routers are triggered the next steps in the flow based on its return value. Routers are triggered
@@ -262,6 +261,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
... if all([self.state.valid, self.state.processed]): ... if all([self.state.valid, self.state.processed]):
... return CONTINUE ... return CONTINUE
... return STOP ... return STOP
""" """
def decorator(func): def decorator(func):
@@ -280,17 +280,17 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
func.__trigger_methods__ = [condition.__name__] func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR" func.__condition_type__ = "OR"
else: else:
msg = "Condition must be a method, string, or a result of or_() or and_()"
raise ValueError( raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()" msg,
) )
return func return func
return decorator return decorator
def or_(*conditions: Union[str, dict, Callable]) -> dict: def or_(*conditions: str | dict | Callable) -> dict:
""" """Combines multiple conditions with OR logic for flow control.
Combines multiple conditions with OR logic for flow control.
Creates a condition that is satisfied when any of the specified conditions Creates a condition that is satisfied when any of the specified conditions
are met. This is used with @start, @listen, or @router decorators to create are met. This is used with @start, @listen, or @router decorators to create
@@ -320,6 +320,7 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
>>> @listen(or_("success", "timeout")) >>> @listen(or_("success", "timeout"))
>>> def handle_completion(self): >>> def handle_completion(self):
... pass ... pass
""" """
methods = [] methods = []
for condition in conditions: for condition in conditions:
@@ -330,13 +331,13 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
elif callable(condition): elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition))) methods.append(getattr(condition, "__name__", repr(condition)))
else: else:
raise ValueError("Invalid condition in or_()") msg = "Invalid condition in or_()"
raise ValueError(msg)
return {"type": "OR", "methods": methods} return {"type": "OR", "methods": methods}
def and_(*conditions: Union[str, dict, Callable]) -> dict: def and_(*conditions: str | dict | Callable) -> dict:
""" """Combines multiple conditions with AND logic for flow control.
Combines multiple conditions with AND logic for flow control.
Creates a condition that is satisfied only when all specified conditions Creates a condition that is satisfied only when all specified conditions
are met. This is used with @start, @listen, or @router decorators to create are met. This is used with @start, @listen, or @router decorators to create
@@ -366,6 +367,7 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict:
>>> @listen(and_("validated", "processed")) >>> @listen(and_("validated", "processed"))
>>> def handle_complete_data(self): >>> def handle_complete_data(self):
... pass ... pass
""" """
methods = [] methods = []
for condition in conditions: for condition in conditions:
@@ -376,7 +378,8 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict:
elif callable(condition): elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition))) methods.append(getattr(condition, "__name__", repr(condition)))
else: else:
raise ValueError("Invalid condition in and_()") msg = "Invalid condition in and_()"
raise ValueError(msg)
return {"type": "AND", "methods": methods} return {"type": "AND", "methods": methods}
@@ -416,10 +419,10 @@ class FlowMeta(type):
if possible_returns: if possible_returns:
router_paths[attr_name] = possible_returns router_paths[attr_name] = possible_returns
setattr(cls, "_start_methods", start_methods) cls._start_methods = start_methods
setattr(cls, "_listeners", listeners) cls._listeners = listeners
setattr(cls, "_routers", routers) cls._routers = routers
setattr(cls, "_router_paths", router_paths) cls._router_paths = router_paths
return cls return cls
@@ -427,17 +430,18 @@ class FlowMeta(type):
class Flow(Generic[T], metaclass=FlowMeta): class Flow(Generic[T], metaclass=FlowMeta):
"""Base class for all flows. """Base class for all flows.
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel.""" Type parameter T must be either Dict[str, Any] or a subclass of BaseModel.
"""
_printer = Printer() _printer = Printer()
_start_methods: List[str] = [] _start_methods: list[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {} _listeners: dict[str, tuple[str, list[str]]] = {}
_routers: Set[str] = set() _routers: set[str] = set()
_router_paths: Dict[str, List[str]] = {} _router_paths: dict[str, list[str]] = {}
initial_state: Union[Type[T], T, None] = None initial_state: type[T] | T | None = None
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]: def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]:
class _FlowGeneric(cls): # type: ignore class _FlowGeneric(cls): # type: ignore
_initial_state_T = item # type: ignore _initial_state_T = item # type: ignore
@@ -446,7 +450,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
def __init__( def __init__(
self, self,
persistence: Optional[FlowPersistence] = None, persistence: FlowPersistence | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize a new Flow instance. """Initialize a new Flow instance.
@@ -454,13 +458,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
Args: Args:
persistence: Optional persistence backend for storing flow states persistence: Optional persistence backend for storing flow states
**kwargs: Additional state values to initialize or override **kwargs: Additional state values to initialize or override
""" """
# Initialize basic instance attributes # Initialize basic instance attributes
self._methods: Dict[str, Callable] = {} self._methods: dict[str, Callable] = {}
self._method_execution_counts: Dict[str, int] = {} self._method_execution_counts: dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {} self._pending_and_listeners: dict[str, set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs self._method_outputs: list[Any] = [] # List to store all method outputs
self._persistence: Optional[FlowPersistence] = persistence self._persistence: FlowPersistence | None = persistence
# Initialize state with initial values # Initialize state with initial values
self._state = self._create_initial_state() self._state = self._create_initial_state()
@@ -502,58 +507,61 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises: Raises:
ValueError: If structured state model lacks 'id' field ValueError: If structured state model lacks 'id' field
TypeError: If state is neither BaseModel nor dictionary TypeError: If state is neither BaseModel nor dictionary
""" """
# Handle case where initial_state is None but we have a type parameter # Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_T"): if self.initial_state is None and hasattr(self, "_initial_state_T"):
state_type = getattr(self, "_initial_state_T") state_type = self._initial_state_T
if isinstance(state_type, type): if isinstance(state_type, type):
if issubclass(state_type, FlowState): if issubclass(state_type, FlowState):
# Create instance without id, then set it # Create instance without id, then set it
instance = state_type() instance = state_type()
if not hasattr(instance, "id"): if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4())) instance.id = str(uuid4())
return cast(T, instance) return cast("T", instance)
elif issubclass(state_type, BaseModel): if issubclass(state_type, BaseModel):
# Create a new type that includes the ID field # Create a new type that includes the ID field
class StateWithId(state_type, FlowState): # type: ignore class StateWithId(state_type, FlowState): # type: ignore
pass pass
instance = StateWithId() instance = StateWithId()
if not hasattr(instance, "id"): if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4())) instance.id = str(uuid4())
return cast(T, instance) return cast("T", instance)
elif state_type is dict: if state_type is dict:
return cast(T, {"id": str(uuid4())}) return cast("T", {"id": str(uuid4())})
# Handle case where no initial state is provided # Handle case where no initial state is provided
if self.initial_state is None: if self.initial_state is None:
return cast(T, {"id": str(uuid4())}) return cast("T", {"id": str(uuid4())})
# Handle case where initial_state is a type (class) # Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type): if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState): if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state()) # Uses model defaults return cast("T", self.initial_state()) # Uses model defaults
elif issubclass(self.initial_state, BaseModel): if issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field # Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None) model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields: if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field") msg = "Flow state model must have an 'id' field"
return cast(T, self.initial_state()) # Uses model defaults raise ValueError(msg)
elif self.initial_state is dict: return cast("T", self.initial_state()) # Uses model defaults
return cast(T, {"id": str(uuid4())}) if self.initial_state is dict:
return cast("T", {"id": str(uuid4())})
# Handle dictionary instance case # Handle dictionary instance case
if isinstance(self.initial_state, dict): if isinstance(self.initial_state, dict):
new_state = dict(self.initial_state) # Copy to avoid mutations new_state = dict(self.initial_state) # Copy to avoid mutations
if "id" not in new_state: if "id" not in new_state:
new_state["id"] = str(uuid4()) new_state["id"] = str(uuid4())
return cast(T, new_state) return cast("T", new_state)
# Handle BaseModel instance case # Handle BaseModel instance case
if isinstance(self.initial_state, BaseModel): if isinstance(self.initial_state, BaseModel):
model = cast(BaseModel, self.initial_state) model = cast("BaseModel", self.initial_state)
if not hasattr(model, "id"): if not hasattr(model, "id"):
raise ValueError("Flow state model must have an 'id' field") msg = "Flow state model must have an 'id' field"
raise ValueError(msg)
# Create new instance with same values to avoid mutations # Create new instance with same values to avoid mutations
if hasattr(model, "model_dump"): if hasattr(model, "model_dump"):
@@ -570,9 +578,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Create new instance of the same class # Create new instance of the same class
model_class = type(model) model_class = type(model)
return cast(T, model_class(**state_dict)) return cast("T", model_class(**state_dict))
msg = f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
raise TypeError( raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" msg,
) )
def _copy_state(self) -> T: def _copy_state(self) -> T:
@@ -583,7 +592,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return self._state return self._state
@property @property
def method_outputs(self) -> List[Any]: def method_outputs(self) -> list[Any]:
"""Returns the list of all outputs from executed methods.""" """Returns the list of all outputs from executed methods."""
return self._method_outputs return self._method_outputs
@@ -607,6 +616,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
flow = MyFlow() flow = MyFlow()
print(f"Current flow ID: {flow.flow_id}") # Safely get flow ID print(f"Current flow ID: {flow.flow_id}") # Safely get flow ID
``` ```
""" """
try: try:
if not hasattr(self, "_state"): if not hasattr(self, "_state"):
@@ -614,13 +624,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict): if isinstance(self._state, dict):
return str(self._state.get("id", "")) return str(self._state.get("id", ""))
elif isinstance(self._state, BaseModel): if isinstance(self._state, BaseModel):
return str(getattr(self._state, "id", "")) return str(getattr(self._state, "id", ""))
return "" return ""
except (AttributeError, TypeError): except (AttributeError, TypeError):
return "" # Safely handle any unexpected attribute access issues return "" # Safely handle any unexpected attribute access issues
def _initialize_state(self, inputs: Dict[str, Any]) -> None: def _initialize_state(self, inputs: dict[str, Any]) -> None:
"""Initialize or update flow state with new inputs. """Initialize or update flow state with new inputs.
Args: Args:
@@ -629,6 +639,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises: Raises:
ValueError: If validation fails for structured state ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary TypeError: If state is neither BaseModel nor dictionary
""" """
if isinstance(self._state, dict): if isinstance(self._state, dict):
# For dict states, preserve existing fields unless overridden # For dict states, preserve existing fields unless overridden
@@ -644,7 +655,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif isinstance(self._state, BaseModel): elif isinstance(self._state, BaseModel):
# For BaseModel states, preserve existing fields unless overridden # For BaseModel states, preserve existing fields unless overridden
try: try:
model = cast(BaseModel, self._state) model = cast("BaseModel", self._state)
# Get current state as dict # Get current state as dict
if hasattr(model, "model_dump"): if hasattr(model, "model_dump"):
current_state = model.model_dump() current_state = model.model_dump()
@@ -662,19 +673,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
model_class = type(model) model_class = type(model)
if hasattr(model_class, "model_validate"): if hasattr(model_class, "model_validate"):
# Pydantic v2 # Pydantic v2
self._state = cast(T, model_class.model_validate(new_state)) self._state = cast("T", model_class.model_validate(new_state))
elif hasattr(model_class, "parse_obj"): elif hasattr(model_class, "parse_obj"):
# Pydantic v1 # Pydantic v1
self._state = cast(T, model_class.parse_obj(new_state)) self._state = cast("T", model_class.parse_obj(new_state))
else: else:
# Fallback for other BaseModel implementations # Fallback for other BaseModel implementations
self._state = cast(T, model_class(**new_state)) self._state = cast("T", model_class(**new_state))
except ValidationError as e: except ValidationError as e:
raise ValueError(f"Invalid inputs for structured state: {e}") from e msg = f"Invalid inputs for structured state: {e}"
raise ValueError(msg) from e
else: else:
raise TypeError("State must be a BaseModel instance or a dictionary.") msg = "State must be a BaseModel instance or a dictionary."
raise TypeError(msg)
def _restore_state(self, stored_state: Dict[str, Any]) -> None: def _restore_state(self, stored_state: dict[str, Any]) -> None:
"""Restore flow state from persistence. """Restore flow state from persistence.
Args: Args:
@@ -683,11 +696,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises: Raises:
ValueError: If validation fails for structured state ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary TypeError: If state is neither BaseModel nor dictionary
""" """
# When restoring from persistence, use the stored ID # When restoring from persistence, use the stored ID
stored_id = stored_state.get("id") stored_id = stored_state.get("id")
if not stored_id: if not stored_id:
raise ValueError("Stored state must have an 'id' field") msg = "Stored state must have an 'id' field"
raise ValueError(msg)
if isinstance(self._state, dict): if isinstance(self._state, dict):
# For dict states, update all fields from stored state # For dict states, update all fields from stored state
@@ -695,22 +710,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._state.update(stored_state) self._state.update(stored_state)
elif isinstance(self._state, BaseModel): elif isinstance(self._state, BaseModel):
# For BaseModel states, create new instance with stored values # For BaseModel states, create new instance with stored values
model = cast(BaseModel, self._state) model = cast("BaseModel", self._state)
if hasattr(model, "model_validate"): if hasattr(model, "model_validate"):
# Pydantic v2 # Pydantic v2
self._state = cast(T, type(model).model_validate(stored_state)) self._state = cast("T", type(model).model_validate(stored_state))
elif hasattr(model, "parse_obj"): elif hasattr(model, "parse_obj"):
# Pydantic v1 # Pydantic v1
self._state = cast(T, type(model).parse_obj(stored_state)) self._state = cast("T", type(model).parse_obj(stored_state))
else: else:
# Fallback for other BaseModel implementations # Fallback for other BaseModel implementations
self._state = cast(T, type(model)(**stored_state)) self._state = cast("T", type(model)(**stored_state))
else: else:
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}") msg = f"State must be dict or BaseModel, got {type(self._state)}"
raise TypeError(msg)
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
""" """Start the flow execution in a synchronous context.
Start the flow execution in a synchronous context.
This method wraps kickoff_async so that all state initialization and event This method wraps kickoff_async so that all state initialization and event
emission is handled in the asynchronous method. emission is handled in the asynchronous method.
@@ -721,9 +736,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
return asyncio.run(run_flow()) return asyncio.run(run_flow())
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
""" """Start the flow execution asynchronously.
Start the flow execution asynchronously.
This method performs state restoration (if an 'id' is provided and persistence is available) This method performs state restoration (if an 'id' is provided and persistence is available)
and updates the flow state with any additional inputs. It then emits the FlowStartedEvent, and updates the flow state with any additional inputs. It then emits the FlowStartedEvent,
@@ -735,6 +749,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Returns: Returns:
The final output from the flow, which is the result of the last executed method. The final output from the flow, which is the result of the last executed method.
""" """
if inputs: if inputs:
# Override the id in the state if it exists in inputs # Override the id in the state if it exists in inputs
@@ -742,7 +757,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict): if isinstance(self._state, dict):
self._state["id"] = inputs["id"] self._state["id"] = inputs["id"]
elif isinstance(self._state, BaseModel): elif isinstance(self._state, BaseModel):
setattr(self._state, "id", inputs["id"]) self._state.id = inputs["id"]
# If persistence is enabled, attempt to restore the stored state using the provided id. # If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None: if "id" in inputs and self._persistence is not None:
@@ -756,7 +771,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._restore_state(stored_state) self._restore_state(stored_state)
else: else:
self._log_flow_event( self._log_flow_event(
f"No flow state found for UUID: {restore_uuid}", color="red" f"No flow state found for UUID: {restore_uuid}", color="red",
) )
# Update state with any additional inputs (ignoring the 'id' key) # Update state with any additional inputs (ignoring the 'id' key)
@@ -774,7 +789,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
), ),
) )
self._log_flow_event( self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold_magenta" f"Flow started with ID: {self.flow_id}", color="bold_magenta",
) )
if inputs is not None and "id" not in inputs: if inputs is not None and "id" not in inputs:
@@ -800,8 +815,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return final_output return final_output
async def _execute_start_method(self, start_method_name: str) -> None: async def _execute_start_method(self, start_method_name: str) -> None:
""" """Executes a flow's start method and its triggered listeners.
Executes a flow's start method and its triggered listeners.
This internal method handles the execution of methods marked with @start This internal method handles the execution of methods marked with @start
decorator and manages the subsequent chain of listener executions. decorator and manages the subsequent chain of listener executions.
@@ -816,14 +830,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Executes the start method and captures its result - Executes the start method and captures its result
- Triggers execution of any listeners waiting on this start method - Triggers execution of any listeners waiting on this start method
- Part of the flow's initialization sequence - Part of the flow's initialization sequence
""" """
result = await self._execute_method( result = await self._execute_method(
start_method_name, self._methods[start_method_name] start_method_name, self._methods[start_method_name],
) )
await self._execute_listeners(start_method_name, result) await self._execute_listeners(start_method_name, result)
async def _execute_method( async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any self, method_name: str, method: Callable, *args: Any, **kwargs: Any,
) -> Any: ) -> Any:
try: try:
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | ( dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
@@ -873,11 +888,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
error=e, error=e,
), ),
) )
raise e raise
async def _execute_listeners(self, trigger_method: str, result: Any) -> None: async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
""" """Executes all listeners and routers triggered by a method completion.
Executes all listeners and routers triggered by a method completion.
This internal method manages the execution flow by: This internal method manages the execution flow by:
1. First executing all triggered routers sequentially 1. First executing all triggered routers sequentially
@@ -897,6 +911,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Each router's result becomes a new trigger_method - Each router's result becomes a new trigger_method
- Normal listeners are executed in parallel for efficiency - Normal listeners are executed in parallel for efficiency
- Listeners can receive the trigger method's result as a parameter - Listeners can receive the trigger method's result as a parameter
""" """
# First, handle routers repeatedly until no router triggers anymore # First, handle routers repeatedly until no router triggers anymore
router_results = [] router_results = []
@@ -904,7 +919,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
while True: while True:
routers_triggered = self._find_triggered_methods( routers_triggered = self._find_triggered_methods(
current_trigger, router_only=True current_trigger, router_only=True,
) )
if not routers_triggered: if not routers_triggered:
break break
@@ -920,12 +935,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
) )
# Now execute normal listeners for all router results and the original trigger # Now execute normal listeners for all router results and the original trigger
all_triggers = [trigger_method] + router_results all_triggers = [trigger_method, *router_results]
for current_trigger in all_triggers: for current_trigger in all_triggers:
if current_trigger: # Skip None results if current_trigger: # Skip None results
listeners_triggered = self._find_triggered_methods( listeners_triggered = self._find_triggered_methods(
current_trigger, router_only=False current_trigger, router_only=False,
) )
if listeners_triggered: if listeners_triggered:
tasks = [ tasks = [
@@ -935,10 +950,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
def _find_triggered_methods( def _find_triggered_methods(
self, trigger_method: str, router_only: bool self, trigger_method: str, router_only: bool,
) -> List[str]: ) -> list[str]:
""" """Finds all methods that should be triggered based on conditions.
Finds all methods that should be triggered based on conditions.
This internal method evaluates both OR and AND conditions to determine This internal method evaluates both OR and AND conditions to determine
which methods should be executed next in the flow. which methods should be executed next in the flow.
@@ -963,6 +977,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
* AND: Triggers only when all conditions are met * AND: Triggers only when all conditions are met
- Maintains state for AND conditions using _pending_and_listeners - Maintains state for AND conditions using _pending_and_listeners
- Separates router and normal listener evaluation - Separates router and normal listener evaluation
""" """
triggered = [] triggered = []
for listener_name, (condition_type, methods) in self._listeners.items(): for listener_name, (condition_type, methods) in self._listeners.items():
@@ -992,8 +1007,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return triggered return triggered
async def _execute_single_listener(self, listener_name: str, result: Any) -> None: async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
""" """Executes a single listener method with proper event handling.
Executes a single listener method with proper event handling.
This internal method manages the execution of an individual listener, This internal method manages the execution of an individual listener,
including parameter inspection, event emission, and error handling. including parameter inspection, event emission, and error handling.
@@ -1018,6 +1032,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
------------- -------------
Catches and logs any exceptions during execution, preventing Catches and logs any exceptions during execution, preventing
individual listener failures from breaking the entire flow. individual listener failures from breaking the entire flow.
""" """
try: try:
method = self._methods[listener_name] method = self._methods[listener_name]
@@ -1028,7 +1043,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if method_params: if method_params:
listener_result = await self._execute_method( listener_result = await self._execute_method(
listener_name, method, result listener_name, method, result,
) )
else: else:
listener_result = await self._execute_method(listener_name, method) listener_result = await self._execute_method(listener_name, method)
@@ -1036,17 +1051,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Execute listeners (and possibly routers) of this listener # Execute listeners (and possibly routers) of this listener
await self._execute_listeners(listener_name, listener_result) await self._execute_listeners(listener_name, listener_result)
except Exception as e: except Exception:
print(
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
raise raise
def _log_flow_event( def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info" self, message: str, color: str = "yellow", level: str = "info",
) -> None: ) -> None:
"""Centralized logging method for flow events. """Centralized logging method for flow events.
@@ -1064,6 +1076,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Note: Note:
This method uses the Printer utility for colored console output This method uses the Printer utility for colored console output
and the standard logging module for log level support. and the standard logging module for log level support.
""" """
self._printer.print(message, color=color) self._printer.print(message, color=color)
if level == "info": if level == "info":

View File

@@ -1,5 +1,4 @@
import inspect import inspect
from typing import Optional
from pydantic import BaseModel, Field, InstanceOf, model_validator from pydantic import BaseModel, Field, InstanceOf, model_validator
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
inspecting the call stack. inspecting the call stack.
""" """
parent_flow: Optional[InstanceOf[Flow]] = Field( parent_flow: InstanceOf[Flow] | None = Field(
default=None, default=None,
description="The parent flow of the instance, if it was created inside a flow.", description="The parent flow of the instance, if it was created inside a flow.",
) )

View File

@@ -1,14 +1,13 @@
# flow_visualizer.py # flow_visualizer.py
import os import os
from pathlib import Path
from pyvis.network import Network from pyvis.network import Network
from crewai.flow.config import COLORS, NODE_STYLES from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.html_template_handler import HTMLTemplateHandler from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
from crewai.flow.path_utils import safe_path_join, validate_path_exists from crewai.flow.path_utils import safe_path_join
from crewai.flow.utils import calculate_node_levels from crewai.flow.utils import calculate_node_levels
from crewai.flow.visualization_utils import ( from crewai.flow.visualization_utils import (
add_edges, add_edges,
@@ -20,9 +19,8 @@ from crewai.flow.visualization_utils import (
class FlowPlot: class FlowPlot:
"""Handles the creation and rendering of flow visualization diagrams.""" """Handles the creation and rendering of flow visualization diagrams."""
def __init__(self, flow): def __init__(self, flow) -> None:
""" """Initialize FlowPlot with a flow object.
Initialize FlowPlot with a flow object.
Parameters Parameters
---------- ----------
@@ -33,21 +31,24 @@ class FlowPlot:
------ ------
ValueError ValueError
If flow object is invalid or missing required attributes. If flow object is invalid or missing required attributes.
""" """
if not hasattr(flow, '_methods'): if not hasattr(flow, "_methods"):
raise ValueError("Invalid flow object: missing '_methods' attribute") msg = "Invalid flow object: missing '_methods' attribute"
if not hasattr(flow, '_listeners'): raise ValueError(msg)
raise ValueError("Invalid flow object: missing '_listeners' attribute") if not hasattr(flow, "_listeners"):
if not hasattr(flow, '_start_methods'): msg = "Invalid flow object: missing '_listeners' attribute"
raise ValueError("Invalid flow object: missing '_start_methods' attribute") raise ValueError(msg)
if not hasattr(flow, "_start_methods"):
msg = "Invalid flow object: missing '_start_methods' attribute"
raise ValueError(msg)
self.flow = flow self.flow = flow
self.colors = COLORS self.colors = COLORS
self.node_styles = NODE_STYLES self.node_styles = NODE_STYLES
def plot(self, filename): def plot(self, filename) -> None:
""" """Generate and save an HTML visualization of the flow.
Generate and save an HTML visualization of the flow.
Parameters Parameters
---------- ----------
@@ -62,9 +63,11 @@ class FlowPlot:
If file operations fail or visualization cannot be generated. If file operations fail or visualization cannot be generated.
RuntimeError RuntimeError
If network visualization generation fails. If network visualization generation fails.
""" """
if not filename or not isinstance(filename, str): if not filename or not isinstance(filename, str):
raise ValueError("Filename must be a non-empty string") msg = "Filename must be a non-empty string"
raise ValueError(msg)
try: try:
# Initialize network # Initialize network
@@ -89,58 +92,63 @@ class FlowPlot:
"enabled": false "enabled": false
} }
} }
""" """,
) )
# Calculate levels for nodes # Calculate levels for nodes
try: try:
node_levels = calculate_node_levels(self.flow) node_levels = calculate_node_levels(self.flow)
except Exception as e: except Exception as e:
raise ValueError(f"Failed to calculate node levels: {str(e)}") msg = f"Failed to calculate node levels: {e!s}"
raise ValueError(msg)
# Compute positions # Compute positions
try: try:
node_positions = compute_positions(self.flow, node_levels) node_positions = compute_positions(self.flow, node_levels)
except Exception as e: except Exception as e:
raise ValueError(f"Failed to compute node positions: {str(e)}") msg = f"Failed to compute node positions: {e!s}"
raise ValueError(msg)
# Add nodes to the network # Add nodes to the network
try: try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles) add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to add nodes to network: {str(e)}") msg = f"Failed to add nodes to network: {e!s}"
raise RuntimeError(msg)
# Add edges to the network # Add edges to the network
try: try:
add_edges(net, self.flow, node_positions, self.colors) add_edges(net, self.flow, node_positions, self.colors)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to add edges to network: {str(e)}") msg = f"Failed to add edges to network: {e!s}"
raise RuntimeError(msg)
# Generate HTML # Generate HTML
try: try:
network_html = net.generate_html() network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html) final_html_content = self._generate_final_html(network_html)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to generate network visualization: {str(e)}") msg = f"Failed to generate network visualization: {e!s}"
raise RuntimeError(msg)
# Save the final HTML content to the file # Save the final HTML content to the file
try: try:
with open(f"{filename}.html", "w", encoding="utf-8") as f: with open(f"{filename}.html", "w", encoding="utf-8") as f:
f.write(final_html_content) f.write(final_html_content)
print(f"Plot saved as {filename}.html") except OSError as e:
except IOError as e: msg = f"Failed to save flow visualization to {filename}.html: {e!s}"
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}") raise OSError(msg)
except (ValueError, RuntimeError, IOError) as e: except (OSError, ValueError, RuntimeError):
raise e raise
except Exception as e: except Exception as e:
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}") msg = f"Unexpected error during flow visualization: {e!s}"
raise RuntimeError(msg)
finally: finally:
self._cleanup_pyvis_lib() self._cleanup_pyvis_lib()
def _generate_final_html(self, network_html): def _generate_final_html(self, network_html):
""" """Generate the final HTML content with network visualization and legend.
Generate the final HTML content with network visualization and legend.
Parameters Parameters
---------- ----------
@@ -158,9 +166,11 @@ class FlowPlot:
If template or logo files cannot be accessed. If template or logo files cannot be accessed.
ValueError ValueError
If network_html is invalid. If network_html is invalid.
""" """
if not network_html: if not network_html:
raise ValueError("Invalid network HTML content") msg = "Invalid network HTML content"
raise ValueError(msg)
try: try:
# Extract just the body content from the generated HTML # Extract just the body content from the generated HTML
@@ -169,9 +179,11 @@ class FlowPlot:
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir) logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
if not os.path.exists(template_path): if not os.path.exists(template_path):
raise IOError(f"Template file not found: {template_path}") msg = f"Template file not found: {template_path}"
raise OSError(msg)
if not os.path.exists(logo_path): if not os.path.exists(logo_path):
raise IOError(f"Logo file not found: {logo_path}") msg = f"Logo file not found: {logo_path}"
raise OSError(msg)
html_handler = HTMLTemplateHandler(template_path, logo_path) html_handler = HTMLTemplateHandler(template_path, logo_path)
network_body = html_handler.extract_body_content(network_html) network_body = html_handler.extract_body_content(network_html)
@@ -179,16 +191,15 @@ class FlowPlot:
# Generate the legend items HTML # Generate the legend items HTML
legend_items = get_legend_items(self.colors) legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items) legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html( return html_handler.generate_final_html(
network_body, legend_items_html network_body, legend_items_html,
) )
return final_html_content
except Exception as e: except Exception as e:
raise IOError(f"Failed to generate visualization HTML: {str(e)}") msg = f"Failed to generate visualization HTML: {e!s}"
raise OSError(msg)
def _cleanup_pyvis_lib(self): def _cleanup_pyvis_lib(self) -> None:
""" """Clean up the generated lib folder from pyvis.
Clean up the generated lib folder from pyvis.
This method safely removes the temporary lib directory created by pyvis This method safely removes the temporary lib directory created by pyvis
during network visualization generation. during network visualization generation.
@@ -198,15 +209,14 @@ class FlowPlot:
if os.path.exists(lib_folder) and os.path.isdir(lib_folder): if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil import shutil
shutil.rmtree(lib_folder) shutil.rmtree(lib_folder)
except ValueError as e: except ValueError:
print(f"Error validating lib folder path: {e}") pass
except Exception as e: except Exception:
print(f"Error cleaning up lib folder: {e}") pass
def plot_flow(flow, filename="flow_plot"): def plot_flow(flow, filename="flow_plot") -> None:
""" """Convenience function to create and save a flow visualization.
Convenience function to create and save a flow visualization.
Parameters Parameters
---------- ----------
@@ -221,6 +231,7 @@ def plot_flow(flow, filename="flow_plot"):
If flow object or filename is invalid. If flow object or filename is invalid.
IOError IOError
If file operations fail. If file operations fail.
""" """
visualizer = FlowPlot(flow) visualizer = FlowPlot(flow)
visualizer.plot(filename) visualizer.plot(filename)

View File

@@ -1,16 +1,14 @@
import base64 import base64
import re import re
from pathlib import Path
from crewai.flow.path_utils import safe_path_join, validate_path_exists from crewai.flow.path_utils import validate_path_exists
class HTMLTemplateHandler: class HTMLTemplateHandler:
"""Handles HTML template processing and generation for flow visualization diagrams.""" """Handles HTML template processing and generation for flow visualization diagrams."""
def __init__(self, template_path, logo_path): def __init__(self, template_path, logo_path) -> None:
""" """Initialize HTMLTemplateHandler with validated template and logo paths.
Initialize HTMLTemplateHandler with validated template and logo paths.
Parameters Parameters
---------- ----------
@@ -23,16 +21,18 @@ class HTMLTemplateHandler:
------ ------
ValueError ValueError
If template or logo paths are invalid or files don't exist. If template or logo paths are invalid or files don't exist.
""" """
try: try:
self.template_path = validate_path_exists(template_path, "file") self.template_path = validate_path_exists(template_path, "file")
self.logo_path = validate_path_exists(logo_path, "file") self.logo_path = validate_path_exists(logo_path, "file")
except ValueError as e: except ValueError as e:
raise ValueError(f"Invalid template or logo path: {e}") msg = f"Invalid template or logo path: {e}"
raise ValueError(msg)
def read_template(self): def read_template(self):
"""Read and return the HTML template file contents.""" """Read and return the HTML template file contents."""
with open(self.template_path, "r", encoding="utf-8") as f: with open(self.template_path, encoding="utf-8") as f:
return f.read() return f.read()
def encode_logo(self): def encode_logo(self):
@@ -81,13 +81,12 @@ class HTMLTemplateHandler:
final_html_content = html_template.replace("{{ title }}", title) final_html_content = html_template.replace("{{ title }}", title)
final_html_content = final_html_content.replace( final_html_content = final_html_content.replace(
"{{ network_content }}", network_body "{{ network_content }}", network_body,
) )
final_html_content = final_html_content.replace( final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64 "{{ logo_svg_base64 }}", logo_svg_base64,
) )
final_html_content = final_html_content.replace( return final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html "<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html,
) )
return final_html_content

View File

@@ -1,18 +1,14 @@
""" """Path utilities for secure file operations in CrewAI flow module.
Path utilities for secure file operations in CrewAI flow module.
This module provides utilities for secure path handling to prevent directory This module provides utilities for secure path handling to prevent directory
traversal attacks and ensure paths remain within allowed boundaries. traversal attacks and ensure paths remain within allowed boundaries.
""" """
import os
from pathlib import Path from pathlib import Path
from typing import List, Union
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str: def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
""" """Safely join path components and ensure the result is within allowed boundaries.
Safely join path components and ensure the result is within allowed boundaries.
Parameters Parameters
---------- ----------
@@ -31,15 +27,18 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
ValueError ValueError
If the resulting path would be outside the root directory If the resulting path would be outside the root directory
or if any path component is invalid. or if any path component is invalid.
""" """
if not parts: if not parts:
raise ValueError("No path components provided") msg = "No path components provided"
raise ValueError(msg)
try: try:
# Convert all parts to strings and clean them # Convert all parts to strings and clean them
clean_parts = [str(part).strip() for part in parts if part] clean_parts = [str(part).strip() for part in parts if part]
if not clean_parts: if not clean_parts:
raise ValueError("No valid path components provided") msg = "No valid path components provided"
raise ValueError(msg)
# Establish root directory # Establish root directory
root_path = Path(root).resolve() if root else Path.cwd() root_path = Path(root).resolve() if root else Path.cwd()
@@ -49,8 +48,9 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
# Check if the resolved path is within root # Check if the resolved path is within root
if not str(full_path).startswith(str(root_path)): if not str(full_path).startswith(str(root_path)):
msg = f"Invalid path: Potential directory traversal. Path must be within {root_path}"
raise ValueError( raise ValueError(
f"Invalid path: Potential directory traversal. Path must be within {root_path}" msg,
) )
return str(full_path) return str(full_path)
@@ -58,12 +58,12 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
except Exception as e: except Exception as e:
if isinstance(e, ValueError): if isinstance(e, ValueError):
raise raise
raise ValueError(f"Invalid path components: {str(e)}") msg = f"Invalid path components: {e!s}"
raise ValueError(msg)
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str: def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
""" """Validate that a path exists and is of the expected type.
Validate that a path exists and is of the expected type.
Parameters Parameters
---------- ----------
@@ -81,29 +81,33 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
------ ------
ValueError ValueError
If path doesn't exist or is not of expected type. If path doesn't exist or is not of expected type.
""" """
try: try:
path_obj = Path(path).resolve() path_obj = Path(path).resolve()
if not path_obj.exists(): if not path_obj.exists():
raise ValueError(f"Path does not exist: {path}") msg = f"Path does not exist: {path}"
raise ValueError(msg)
if file_type == "file" and not path_obj.is_file(): if file_type == "file" and not path_obj.is_file():
raise ValueError(f"Path is not a file: {path}") msg = f"Path is not a file: {path}"
elif file_type == "directory" and not path_obj.is_dir(): raise ValueError(msg)
raise ValueError(f"Path is not a directory: {path}") if file_type == "directory" and not path_obj.is_dir():
msg = f"Path is not a directory: {path}"
raise ValueError(msg)
return str(path_obj) return str(path_obj)
except Exception as e: except Exception as e:
if isinstance(e, ValueError): if isinstance(e, ValueError):
raise raise
raise ValueError(f"Invalid path: {str(e)}") msg = f"Invalid path: {e!s}"
raise ValueError(msg)
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]: def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
""" """Safely list files in a directory matching a pattern.
Safely list files in a directory matching a pattern.
Parameters Parameters
---------- ----------
@@ -121,15 +125,18 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
------ ------
ValueError ValueError
If directory is invalid or inaccessible. If directory is invalid or inaccessible.
""" """
try: try:
dir_path = Path(directory).resolve() dir_path = Path(directory).resolve()
if not dir_path.is_dir(): if not dir_path.is_dir():
raise ValueError(f"Not a directory: {directory}") msg = f"Not a directory: {directory}"
raise ValueError(msg)
return [str(p) for p in dir_path.glob(pattern) if p.is_file()] return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
except Exception as e: except Exception as e:
if isinstance(e, ValueError): if isinstance(e, ValueError):
raise raise
raise ValueError(f"Error listing files: {str(e)}") msg = f"Error listing files: {e!s}"
raise ValueError(msg)

View File

@@ -1,7 +1,7 @@
"""Base class for flow state persistence.""" """Base class for flow state persistence."""
import abc import abc
from typing import Any, Dict, Optional, Union from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@@ -22,14 +22,13 @@ class FlowPersistence(abc.ABC):
- Establishing connections - Establishing connections
- Setting up indexes - Setting up indexes
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def save_state( def save_state(
self, self,
flow_uuid: str, flow_uuid: str,
method_name: str, method_name: str,
state_data: Union[Dict[str, Any], BaseModel] state_data: dict[str, Any] | BaseModel,
) -> None: ) -> None:
"""Persist the flow state after method completion. """Persist the flow state after method completion.
@@ -37,11 +36,11 @@ class FlowPersistence(abc.ABC):
flow_uuid: Unique identifier for the flow instance flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model) state_data: Current state data (either dict or Pydantic model)
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID. """Load the most recent state for a given flow UUID.
Args: Args:
@@ -49,5 +48,5 @@ class FlowPersistence(abc.ABC):
Returns: Returns:
The most recent state as a dictionary, or None if no state exists The most recent state as a dictionary, or None if no state exists
""" """
pass

View File

@@ -1,5 +1,4 @@
""" """Decorators for flow state persistence.
Decorators for flow state persistence.
Example: Example:
```python ```python
@@ -19,18 +18,16 @@ Example:
# Asynchronous method implementation # Asynchronous method implementation
await some_async_operation() await some_async_operation()
``` ```
""" """
import asyncio import asyncio
import functools import functools
import logging import logging
from collections.abc import Callable
from typing import ( from typing import (
Any, Any,
Callable,
Optional,
Type,
TypeVar, TypeVar,
Union,
cast, cast,
) )
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
"save_state": "Saving flow state to memory for ID: {}", "save_state": "Saving flow state to memory for ID: {}",
"save_error": "Failed to persist state for method {}: {}", "save_error": "Failed to persist state for method {}: {}",
"state_missing": "Flow instance has no state", "state_missing": "Flow instance has no state",
"id_missing": "Flow state must have an 'id' field for persistence" "id_missing": "Flow state must have an 'id' field for persistence",
} }
@@ -74,20 +71,23 @@ class PersistenceDecorator:
ValueError: If flow has no state or state lacks an ID ValueError: If flow has no state or state lacks an ID
RuntimeError: If state persistence fails RuntimeError: If state persistence fails
AttributeError: If flow instance lacks required state attributes AttributeError: If flow instance lacks required state attributes
""" """
try: try:
state = getattr(flow_instance, 'state', None) state = getattr(flow_instance, "state", None)
if state is None: if state is None:
raise ValueError("Flow instance has no state") msg = "Flow instance has no state"
raise ValueError(msg)
flow_uuid: Optional[str] = None flow_uuid: str | None = None
if isinstance(state, dict): if isinstance(state, dict):
flow_uuid = state.get('id') flow_uuid = state.get("id")
elif isinstance(state, BaseModel): elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None) flow_uuid = getattr(state, "id", None)
if not flow_uuid: if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence") msg = "Flow state must have an 'id' field for persistence"
raise ValueError(msg)
# Log state saving only if verbose is True # Log state saving only if verbose is True
if verbose: if verbose:
@@ -103,21 +103,22 @@ class PersistenceDecorator:
except Exception as e: except Exception as e:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e)) error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
cls._printer.print(error_msg, color="red") cls._printer.print(error_msg, color="red")
logger.error(error_msg) logger.exception(error_msg)
raise RuntimeError(f"State persistence failed: {str(e)}") from e msg = f"State persistence failed: {e!s}"
raise RuntimeError(msg) from e
except AttributeError: except AttributeError:
error_msg = LOG_MESSAGES["state_missing"] error_msg = LOG_MESSAGES["state_missing"]
cls._printer.print(error_msg, color="red") cls._printer.print(error_msg, color="red")
logger.error(error_msg) logger.exception(error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
error_msg = LOG_MESSAGES["id_missing"] error_msg = LOG_MESSAGES["id_missing"]
cls._printer.print(error_msg, color="red") cls._printer.print(error_msg, color="red")
logger.error(error_msg) logger.exception(error_msg)
raise ValueError(error_msg) from e raise ValueError(error_msg) from e
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False): def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
"""Decorator to persist flow state. """Decorator to persist flow state.
This decorator can be applied at either the class level or method level. This decorator can be applied at either the class level or method level.
@@ -143,22 +144,23 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
@start() @start()
def begin(self): def begin(self):
pass pass
""" """
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]: def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
"""Decorator that handles both class and method decoration.""" """Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence() actual_persistence = persistence or SQLiteFlowPersistence()
if isinstance(target, type): if isinstance(target, type):
# Class decoration # Class decoration
original_init = getattr(target, "__init__") original_init = target.__init__
@functools.wraps(original_init) @functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None: def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
if 'persistence' not in kwargs: if "persistence" not in kwargs:
kwargs['persistence'] = actual_persistence kwargs["persistence"] = actual_persistence
original_init(self, *args, **kwargs) original_init(self, *args, **kwargs)
setattr(target, "__init__", new_init) target.__init__ = new_init
# Store original methods to preserve their decorators # Store original methods to preserve their decorators
original_methods = {} original_methods = {}
@@ -191,7 +193,7 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr): if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr)) setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True) wrapped.__is_flow_method__ = True
# Update the class with the wrapped method # Update the class with the wrapped method
setattr(target, name, wrapped) setattr(target, name, wrapped)
@@ -211,44 +213,42 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr): if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr)) setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True) wrapped.__is_flow_method__ = True
# Update the class with the wrapped method # Update the class with the wrapped method
setattr(target, name, wrapped) setattr(target, name, wrapped)
return target return target
else: # Method decoration
# Method decoration method = target
method = target method.__is_flow_method__ = True
setattr(method, "__is_flow_method__", True)
if asyncio.iscoroutinefunction(method): if asyncio.iscoroutinefunction(method):
@functools.wraps(method) @functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
method_coro = method(flow_instance, *args, **kwargs) method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro): if asyncio.iscoroutine(method_coro):
result = await method_coro result = await method_coro
else: else:
result = method_coro result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose) PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr): if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr)) setattr(method_async_wrapper, attr, getattr(method, attr))
setattr(method_async_wrapper, "__is_flow_method__", True) method_async_wrapper.__is_flow_method__ = True
return cast(Callable[..., T], method_async_wrapper) return cast("Callable[..., T]", method_async_wrapper)
else: @functools.wraps(method)
@functools.wraps(method) def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: result = method(flow_instance, *args, **kwargs)
result = method(flow_instance, *args, **kwargs) PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose) return result
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr): if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr)) setattr(method_sync_wrapper, attr, getattr(method, attr))
setattr(method_sync_wrapper, "__is_flow_method__", True) method_sync_wrapper.__is_flow_method__ = True
return cast(Callable[..., T], method_sync_wrapper) return cast("Callable[..., T]", method_sync_wrapper)
return decorator return decorator

View File

@@ -1,12 +1,10 @@
""" """SQLite-based implementation of flow state persistence."""
SQLite-based implementation of flow state persistence.
"""
import json import json
import sqlite3 import sqlite3
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@@ -23,7 +21,7 @@ class SQLiteFlowPersistence(FlowPersistence):
db_path: str db_path: str
def __init__(self, db_path: Optional[str] = None): def __init__(self, db_path: str | None = None) -> None:
"""Initialize SQLite persistence. """Initialize SQLite persistence.
Args: Args:
@@ -32,6 +30,7 @@ class SQLiteFlowPersistence(FlowPersistence):
Raises: Raises:
ValueError: If db_path is invalid ValueError: If db_path is invalid
""" """
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
@@ -39,7 +38,8 @@ class SQLiteFlowPersistence(FlowPersistence):
path = db_path or str(Path(db_storage_path()) / "flow_states.db") path = db_path or str(Path(db_storage_path()) / "flow_states.db")
if not path: if not path:
raise ValueError("Database path must be provided") msg = "Database path must be provided"
raise ValueError(msg)
self.db_path = path # Now mypy knows this is str self.db_path = path # Now mypy knows this is str
self.init_db() self.init_db()
@@ -56,21 +56,21 @@ class SQLiteFlowPersistence(FlowPersistence):
timestamp DATETIME NOT NULL, timestamp DATETIME NOT NULL,
state_json TEXT NOT NULL state_json TEXT NOT NULL
) )
""" """,
) )
# Add index for faster UUID lookups # Add index for faster UUID lookups
conn.execute( conn.execute(
""" """
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
ON flow_states(flow_uuid) ON flow_states(flow_uuid)
""" """,
) )
def save_state( def save_state(
self, self,
flow_uuid: str, flow_uuid: str,
method_name: str, method_name: str,
state_data: Union[Dict[str, Any], BaseModel], state_data: dict[str, Any] | BaseModel,
) -> None: ) -> None:
"""Save the current flow state to SQLite. """Save the current flow state to SQLite.
@@ -78,6 +78,7 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: Unique identifier for the flow instance flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model) state_data: Current state data (either dict or Pydantic model)
""" """
# Convert state_data to dict, handling both Pydantic and dict cases # Convert state_data to dict, handling both Pydantic and dict cases
if isinstance(state_data, BaseModel): if isinstance(state_data, BaseModel):
@@ -85,8 +86,9 @@ class SQLiteFlowPersistence(FlowPersistence):
elif isinstance(state_data, dict): elif isinstance(state_data, dict):
state_dict = state_data state_dict = state_data
else: else:
msg = f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
raise ValueError( raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" msg,
) )
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -107,7 +109,7 @@ class SQLiteFlowPersistence(FlowPersistence):
), ),
) )
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID. """Load the most recent state for a given flow UUID.
Args: Args:
@@ -115,6 +117,7 @@ class SQLiteFlowPersistence(FlowPersistence):
Returns: Returns:
The most recent state as a dictionary, or None if no state exists The most recent state as a dictionary, or None if no state exists
""" """
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute( cursor = conn.execute(

View File

@@ -1,33 +1,32 @@
""" """Utility functions for flow visualization and dependency analysis.
Utility functions for flow visualization and dependency analysis.
This module provides core functionality for analyzing and manipulating flow structures, This module provides core functionality for analyzing and manipulating flow structures,
including node level calculation, ancestor tracking, and return value analysis. including node level calculation, ancestor tracking, and return value analysis.
Functions in this module are primarily used by the visualization system to create Functions in this module are primarily used by the visualization system to create
accurate and informative flow diagrams. accurate and informative flow diagrams.
Example Example:
------- -------
>>> flow = Flow() >>> flow = Flow()
>>> node_levels = calculate_node_levels(flow) >>> node_levels = calculate_node_levels(flow)
>>> ancestors = build_ancestor_dict(flow) >>> ancestors = build_ancestor_dict(flow)
""" """
import ast import ast
import inspect import inspect
import textwrap import textwrap
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import Any, Deque, Dict, List, Optional, Set, Union from typing import Any
def get_possible_return_constants(function: Any) -> Optional[List[str]]: def get_possible_return_constants(function: Any) -> list[str] | None:
try: try:
source = inspect.getsource(function) source = inspect.getsource(function)
except OSError: except OSError:
# Can't get source code # Can't get source code
return None return None
except Exception as e: except Exception:
print(f"Error retrieving source code for function {function.__name__}: {e}")
return None return None
try: try:
@@ -35,24 +34,18 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
source = textwrap.dedent(source) source = textwrap.dedent(source)
# Parse the source code into an AST # Parse the source code into an AST
code_ast = ast.parse(source) code_ast = ast.parse(source)
except IndentationError as e: except IndentationError:
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None return None
except SyntaxError as e: except SyntaxError:
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None return None
except Exception as e: except Exception:
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None return None
return_values = set() return_values = set()
dict_definitions = {} dict_definitions = {}
class DictionaryAssignmentVisitor(ast.NodeVisitor): class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node): def visit_Assign(self, node) -> None:
# Check if this assignment is assigning a dictionary literal to a variable # Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1: if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0] target = node.targets[0]
@@ -69,10 +62,10 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
self.generic_visit(node) self.generic_visit(node)
class ReturnVisitor(ast.NodeVisitor): class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node): def visit_Return(self, node) -> None:
# Direct string return # Direct string return
if isinstance(node.value, ast.Constant) and isinstance( if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str node.value.value, str,
): ):
return_values.add(node.value.value) return_values.add(node.value.value)
# Dictionary-based return, like return paths[result] # Dictionary-based return, like return paths[result]
@@ -94,9 +87,8 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
return list(return_values) if return_values else None return list(return_values) if return_values else None
def calculate_node_levels(flow: Any) -> Dict[str, int]: def calculate_node_levels(flow: Any) -> dict[str, int]:
""" """Calculate the hierarchical level of each node in the flow.
Calculate the hierarchical level of each node in the flow.
Performs a breadth-first traversal of the flow graph to assign levels Performs a breadth-first traversal of the flow graph to assign levels
to nodes, starting with start methods at level 0. to nodes, starting with start methods at level 0.
@@ -117,11 +109,12 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
- Each subsequent connected node is assigned level = parent_level + 1 - Each subsequent connected node is assigned level = parent_level + 1
- Handles both OR and AND conditions for listeners - Handles both OR and AND conditions for listeners
- Processes router paths separately - Processes router paths separately
""" """
levels: Dict[str, int] = {} levels: dict[str, int] = {}
queue: Deque[str] = deque() queue: deque[str] = deque()
visited: Set[str] = set() visited: set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {} pending_and_listeners: dict[str, set[str]] = {}
# Make all start methods at level 0 # Make all start methods at level 0
for method_name, method in flow._methods.items(): for method_name, method in flow._methods.items():
@@ -172,9 +165,8 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
return levels return levels
def count_outgoing_edges(flow: Any) -> Dict[str, int]: def count_outgoing_edges(flow: Any) -> dict[str, int]:
""" """Count the number of outgoing edges for each method in the flow.
Count the number of outgoing edges for each method in the flow.
Parameters Parameters
---------- ----------
@@ -185,6 +177,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
------- -------
Dict[str, int] Dict[str, int]
Dictionary mapping method names to their outgoing edge count. Dictionary mapping method names to their outgoing edge count.
""" """
counts = {} counts = {}
for method_name in flow._methods: for method_name in flow._methods:
@@ -197,9 +190,8 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
return counts return counts
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
""" """Build a dictionary mapping each node to its ancestor nodes.
Build a dictionary mapping each node to its ancestor nodes.
Parameters Parameters
---------- ----------
@@ -210,9 +202,10 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
------- -------
Dict[str, Set[str]] Dict[str, Set[str]]
Dictionary mapping each node to a set of its ancestor nodes. Dictionary mapping each node to a set of its ancestor nodes.
""" """
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods} ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
visited: Set[str] = set() visited: set[str] = set()
for node in flow._methods: for node in flow._methods:
if node not in visited: if node not in visited:
dfs_ancestors(node, ancestors, visited, flow) dfs_ancestors(node, ancestors, visited, flow)
@@ -220,10 +213,9 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
def dfs_ancestors( def dfs_ancestors(
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any,
) -> None: ) -> None:
""" """Perform depth-first search to build ancestor relationships.
Perform depth-first search to build ancestor relationships.
Parameters Parameters
---------- ----------
@@ -240,6 +232,7 @@ def dfs_ancestors(
----- -----
This function modifies the ancestors dictionary in-place to build This function modifies the ancestors dictionary in-place to build
the complete ancestor graph. the complete ancestor graph.
""" """
if node in visited: if node in visited:
return return
@@ -265,10 +258,9 @@ def dfs_ancestors(
def is_ancestor( def is_ancestor(
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]] node: str, ancestor_candidate: str, ancestors: dict[str, set[str]],
) -> bool: ) -> bool:
""" """Check if one node is an ancestor of another.
Check if one node is an ancestor of another.
Parameters Parameters
---------- ----------
@@ -283,13 +275,13 @@ def is_ancestor(
------- -------
bool bool
True if ancestor_candidate is an ancestor of node, False otherwise. True if ancestor_candidate is an ancestor of node, False otherwise.
""" """
return ancestor_candidate in ancestors.get(node, set()) return ancestor_candidate in ancestors.get(node, set())
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
""" """Build a dictionary mapping parent nodes to their children.
Build a dictionary mapping parent nodes to their children.
Parameters Parameters
---------- ----------
@@ -306,8 +298,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
- Maps listeners to their trigger methods - Maps listeners to their trigger methods
- Maps router methods to their paths and listeners - Maps router methods to their paths and listeners
- Children lists are sorted for consistent ordering - Children lists are sorted for consistent ordering
""" """
parent_children: Dict[str, List[str]] = {} parent_children: dict[str, list[str]] = {}
# Map listeners to their trigger methods # Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items(): for listener_name, (_, trigger_methods) in flow._listeners.items():
@@ -332,10 +325,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
def get_child_index( def get_child_index(
parent: str, child: str, parent_children: Dict[str, List[str]] parent: str, child: str, parent_children: dict[str, list[str]],
) -> int: ) -> int:
""" """Get the index of a child node in its parent's sorted children list.
Get the index of a child node in its parent's sorted children list.
Parameters Parameters
---------- ----------
@@ -350,27 +342,25 @@ def get_child_index(
------- -------
int int
Zero-based index of the child in its parent's sorted children list. Zero-based index of the child in its parent's sorted children list.
""" """
children = parent_children.get(parent, []) children = parent_children.get(parent, [])
children.sort() children.sort()
return children.index(child) return children.index(child)
def process_router_paths(flow, current, current_level, levels, queue): def process_router_paths(flow, current, current_level, levels, queue) -> None:
""" """Handle the router connections for the current node."""
Handle the router connections for the current node.
"""
if current in flow._routers: if current in flow._routers:
paths = flow._router_paths.get(current, []) paths = flow._router_paths.get(current, [])
for path in paths: for path in paths:
for listener_name, ( for listener_name, (
condition_type, _condition_type,
trigger_methods, trigger_methods,
) in flow._listeners.items(): ) in flow._listeners.items():
if path in trigger_methods: if path in trigger_methods and (
if ( listener_name not in levels
listener_name not in levels or levels[listener_name] > current_level + 1
or levels[listener_name] > current_level + 1 ):
): levels[listener_name] = current_level + 1
levels[listener_name] = current_level + 1 queue.append(listener_name)
queue.append(listener_name)

View File

@@ -1,23 +1,23 @@
""" """Utilities for creating visual representations of flow structures.
Utilities for creating visual representations of flow structures.
This module provides functions for generating network visualizations of flows, This module provides functions for generating network visualizations of flows,
including node placement, edge creation, and visual styling. It handles the including node placement, edge creation, and visual styling. It handles the
conversion of flow structures into visual network graphs with appropriate conversion of flow structures into visual network graphs with appropriate
styling and layout. styling and layout.
Example Example:
------- -------
>>> flow = Flow() >>> flow = Flow()
>>> net = Network(directed=True) >>> net = Network(directed=True)
>>> node_positions = compute_positions(flow, node_levels) >>> node_positions = compute_positions(flow, node_levels)
>>> add_nodes_to_network(net, flow, node_positions, node_styles) >>> add_nodes_to_network(net, flow, node_positions, node_styles)
>>> add_edges(net, flow, node_positions, colors) >>> add_edges(net, flow, node_positions, colors)
""" """
import ast import ast
import inspect import inspect
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any
from .utils import ( from .utils import (
build_ancestor_dict, build_ancestor_dict,
@@ -28,8 +28,7 @@ from .utils import (
def method_calls_crew(method: Any) -> bool: def method_calls_crew(method: Any) -> bool:
""" """Check if the method contains a call to `.crew()`.
Check if the method contains a call to `.crew()`.
Parameters Parameters
---------- ----------
@@ -45,21 +44,22 @@ def method_calls_crew(method: Any) -> bool:
----- -----
Uses AST analysis to detect method calls, specifically looking for Uses AST analysis to detect method calls, specifically looking for
attribute access of 'crew'. attribute access of 'crew'.
""" """
try: try:
source = inspect.getsource(method) source = inspect.getsource(method)
source = inspect.cleandoc(source) source = inspect.cleandoc(source)
tree = ast.parse(source) tree = ast.parse(source)
except Exception as e: except Exception:
print(f"Could not parse method {method.__name__}: {e}")
return False return False
class CrewCallVisitor(ast.NodeVisitor): class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls.""" """AST visitor to detect .crew() method calls."""
def __init__(self):
def __init__(self) -> None:
self.found = False self.found = False
def visit_Call(self, node): def visit_Call(self, node) -> None:
if isinstance(node.func, ast.Attribute): if isinstance(node.func, ast.Attribute):
if node.func.attr == "crew": if node.func.attr == "crew":
self.found = True self.found = True
@@ -73,11 +73,10 @@ def method_calls_crew(method: Any) -> bool:
def add_nodes_to_network( def add_nodes_to_network(
net: Any, net: Any,
flow: Any, flow: Any,
node_positions: Dict[str, Tuple[float, float]], node_positions: dict[str, tuple[float, float]],
node_styles: Dict[str, Dict[str, Any]] node_styles: dict[str, dict[str, Any]],
) -> None: ) -> None:
""" """Add nodes to the network visualization with appropriate styling.
Add nodes to the network visualization with appropriate styling.
Parameters Parameters
---------- ----------
@@ -97,6 +96,7 @@ def add_nodes_to_network(
- Router methods - Router methods
- Crew methods - Crew methods
- Regular methods - Regular methods
""" """
def human_friendly_label(method_name): def human_friendly_label(method_name):
return method_name.replace("_", " ").title() return method_name.replace("_", " ").title()
@@ -123,7 +123,7 @@ def add_nodes_to_network(
"multi": "html", "multi": "html",
"color": node_style.get("font", {}).get("color", "#FFFFFF"), "color": node_style.get("font", {}).get("color", "#FFFFFF"),
}, },
} },
) )
net.add_node( net.add_node(
@@ -138,12 +138,11 @@ def add_nodes_to_network(
def compute_positions( def compute_positions(
flow: Any, flow: Any,
node_levels: Dict[str, int], node_levels: dict[str, int],
y_spacing: float = 150, y_spacing: float = 150,
x_spacing: float = 150 x_spacing: float = 150,
) -> Dict[str, Tuple[float, float]]: ) -> dict[str, tuple[float, float]]:
""" """Compute the (x, y) positions for each node in the flow graph.
Compute the (x, y) positions for each node in the flow graph.
Parameters Parameters
---------- ----------
@@ -160,9 +159,10 @@ def compute_positions(
------- -------
Dict[str, Tuple[float, float]] Dict[str, Tuple[float, float]]
Dictionary mapping node names to their (x, y) coordinates. Dictionary mapping node names to their (x, y) coordinates.
""" """
level_nodes: Dict[int, List[str]] = {} level_nodes: dict[int, list[str]] = {}
node_positions: Dict[str, Tuple[float, float]] = {} node_positions: dict[str, tuple[float, float]] = {}
for method_name, level in node_levels.items(): for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name) level_nodes.setdefault(level, []).append(method_name)
@@ -180,10 +180,10 @@ def compute_positions(
def add_edges( def add_edges(
net: Any, net: Any,
flow: Any, flow: Any,
node_positions: Dict[str, Tuple[float, float]], node_positions: dict[str, tuple[float, float]],
colors: Dict[str, str] colors: dict[str, str],
) -> None: ) -> None:
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
""" """
Add edges to the network visualization with appropriate styling. Add edges to the network visualization with appropriate styling.
@@ -245,7 +245,7 @@ def add_edges(
"color": edge_color, "color": edge_color,
"width": 2, "width": 2,
"arrows": "to", "arrows": "to",
"dashes": True if is_router_edge or is_and_condition else False, "dashes": bool(is_router_edge or is_and_condition),
"smooth": edge_smooth, "smooth": edge_smooth,
} }
@@ -261,9 +261,7 @@ def add_edges(
# If it's a known router edge and the method is known, don't warn. # If it's a known router edge and the method is known, don't warn.
# This means the path is legitimate, just not reflected as nodes here. # This means the path is legitimate, just not reflected as nodes here.
if not (is_router_edge and method_known): if not (is_router_edge and method_known):
print( pass
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
)
# Edges for router return paths # Edges for router return paths
for router_method_name, paths in flow._router_paths.items(): for router_method_name, paths in flow._router_paths.items():
@@ -278,7 +276,7 @@ def add_edges(
and listener_name in node_positions and listener_name in node_positions
): ):
is_cycle_edge = is_ancestor( is_cycle_edge = is_ancestor(
router_method_name, listener_name, ancestors router_method_name, listener_name, ancestors,
) )
parent_has_multiple_children = ( parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1 len(parent_children.get(router_method_name, [])) > 1
@@ -293,7 +291,7 @@ def add_edges(
dx = target_pos[0] - source_pos[0] dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index( index = get_child_index(
router_method_name, listener_name, parent_children router_method_name, listener_name, parent_children,
) )
edge_smooth = { edge_smooth = {
"type": smooth_type, "type": smooth_type,
@@ -316,6 +314,4 @@ def add_edges(
# Same check here: known router edge and known method? # Same check here: known router edge and known method?
method_known = listener_name in flow._methods method_known = listener_name in flow._methods
if not method_known: if not method_known:
print( pass
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
)

View File

@@ -1,55 +1,48 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List
import numpy as np import numpy as np
class BaseEmbedder(ABC): class BaseEmbedder(ABC):
""" """Abstract base class for text embedding models."""
Abstract base class for text embedding models
"""
@abstractmethod @abstractmethod
def embed_chunks(self, chunks: List[str]) -> np.ndarray: def embed_chunks(self, chunks: list[str]) -> np.ndarray:
""" """Generate embeddings for a list of text chunks.
Generate embeddings for a list of text chunks
Args: Args:
chunks: List of text chunks to embed chunks: List of text chunks to embed
Returns: Returns:
Array of embeddings Array of embeddings
""" """
pass
@abstractmethod @abstractmethod
def embed_texts(self, texts: List[str]) -> np.ndarray: def embed_texts(self, texts: list[str]) -> np.ndarray:
""" """Generate embeddings for a list of texts.
Generate embeddings for a list of texts
Args: Args:
texts: List of texts to embed texts: List of texts to embed
Returns: Returns:
Array of embeddings Array of embeddings
""" """
pass
@abstractmethod @abstractmethod
def embed_text(self, text: str) -> np.ndarray: def embed_text(self, text: str) -> np.ndarray:
""" """Generate embedding for a single text.
Generate embedding for a single text
Args: Args:
text: Text to embed text: Text to embed
Returns: Returns:
Embedding array Embedding array
""" """
pass
@property @property
@abstractmethod @abstractmethod
def dimension(self) -> int: def dimension(self) -> int:
"""Get the dimension of the embeddings""" """Get the dimension of the embeddings."""
pass

View File

@@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union
import numpy as np import numpy as np
@@ -19,75 +18,74 @@ except ImportError:
class FastEmbed(BaseEmbedder): class FastEmbed(BaseEmbedder):
""" """A wrapper class for text embedding models using FastEmbed."""
A wrapper class for text embedding models using FastEmbed
"""
def __init__( def __init__(
self, self,
model_name: str = "BAAI/bge-small-en-v1.5", model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[Union[str, Path]] = None, cache_dir: str | Path | None = None,
): ) -> None:
""" """Initialize the embedding model.
Initialize the embedding model
Args: Args:
model_name: Name of the model to use model_name: Name of the model to use
cache_dir: Directory to cache the model cache_dir: Directory to cache the model
gpu: Whether to use GPU acceleration gpu: Whether to use GPU acceleration
""" """
if not FASTEMBED_AVAILABLE: if not FASTEMBED_AVAILABLE:
raise ImportError( msg = (
"FastEmbed is not installed. Please install it with: " "FastEmbed is not installed. Please install it with: "
"uv pip install fastembed or uv pip install fastembed-gpu for GPU support" "uv pip install fastembed or uv pip install fastembed-gpu for GPU support"
) )
raise ImportError(
msg,
)
self.model = TextEmbedding( self.model = TextEmbedding(
model_name=model_name, model_name=model_name,
cache_dir=str(cache_dir) if cache_dir else None, cache_dir=str(cache_dir) if cache_dir else None,
) )
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]: def embed_chunks(self, chunks: list[str]) -> list[np.ndarray]:
""" """Generate embeddings for a list of text chunks.
Generate embeddings for a list of text chunks
Args: Args:
chunks: List of text chunks to embed chunks: List of text chunks to embed
Returns: Returns:
List of embeddings List of embeddings
"""
embeddings = list(self.model.embed(chunks))
return embeddings
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
""" """
Generate embeddings for a list of texts return list(self.model.embed(chunks))
def embed_texts(self, texts: list[str]) -> list[np.ndarray]:
"""Generate embeddings for a list of texts.
Args: Args:
texts: List of texts to embed texts: List of texts to embed
Returns: Returns:
List of embeddings List of embeddings
""" """
embeddings = list(self.model.embed(texts)) return list(self.model.embed(texts))
return embeddings
def embed_text(self, text: str) -> np.ndarray: def embed_text(self, text: str) -> np.ndarray:
""" """Generate embedding for a single text.
Generate embedding for a single text
Args: Args:
text: Text to embed text: Text to embed
Returns: Returns:
Embedding array Embedding array
""" """
return self.embed_texts([text])[0] return self.embed_texts([text])[0]
@property @property
def dimension(self) -> int: def dimension(self) -> int:
"""Get the dimension of the embeddings""" """Get the dimension of the embeddings."""
# Generate a test embedding to get dimensions # Generate a test embedding to get dimensions
test_embed = self.embed_text("test") test_embed = self.embed_text("test")
return len(test_embed) return len(test_embed)

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, List, Optional from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@@ -10,68 +10,70 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
class Knowledge(BaseModel): class Knowledge(BaseModel):
""" """Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args: Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None embedder: Optional[Dict[str, Any]] = None.
""" """
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: KnowledgeStorage | None = Field(default=None)
embedder: Optional[Dict[str, Any]] = None embedder: dict[str, Any] | None = None
collection_name: Optional[str] = None collection_name: str | None = None
def __init__( def __init__(
self, self,
collection_name: str, collection_name: str,
sources: List[BaseKnowledgeSource], sources: list[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None, embedder: dict[str, Any] | None = None,
storage: Optional[KnowledgeStorage] = None, storage: KnowledgeStorage | None = None,
**data, **data,
): ) -> None:
super().__init__(**data) super().__init__(**data)
if storage: if storage:
self.storage = storage self.storage = storage
else: else:
self.storage = KnowledgeStorage( self.storage = KnowledgeStorage(
embedder=embedder, collection_name=collection_name embedder=embedder, collection_name=collection_name,
) )
self.sources = sources self.sources = sources
self.storage.initialize_knowledge_storage() self.storage.initialize_knowledge_storage()
def query( def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35 self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """Query across all knowledge sources to find the most relevant information.
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks. Returns the top_k most relevant chunks.
Raises: Raises:
ValueError: If storage is not initialized. ValueError: If storage is not initialized.
""" """
if self.storage is None: if self.storage is None:
raise ValueError("Storage is not initialized.") msg = "Storage is not initialized."
raise ValueError(msg)
results = self.storage.search( return self.storage.search(
query, query,
limit=results_limit, limit=results_limit,
score_threshold=score_threshold, score_threshold=score_threshold,
) )
return results
def add_sources(self): def add_sources(self) -> None:
try: try:
for source in self.sources: for source in self.sources:
source.storage = self.storage source.storage = self.storage
source.add() source.add()
except Exception as e: except Exception:
raise e raise
def reset(self) -> None: def reset(self) -> None:
if self.storage: if self.storage:
self.storage.reset() self.storage.reset()
else: else:
raise ValueError("Storage is not initialized.") msg = "Storage is not initialized."
raise ValueError(msg)

View File

@@ -7,6 +7,7 @@ class KnowledgeConfig(BaseModel):
Args: Args:
results_limit (int): The number of relevant documents to return. results_limit (int): The number of relevant documents to return.
score_threshold (float): The minimum score for a document to be considered relevant. score_threshold (float): The minimum score for a document to be considered relevant.
""" """
results_limit: int = Field(default=3, description="The number of results to return") results_limit: int = Field(default=3, description="The number of results to return")

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union
from pydantic import Field, field_validator from pydantic import Field, field_validator
@@ -14,43 +13,43 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files.""" """Base class for knowledge sources that load content from files."""
_logger: Logger = Logger(verbose=True) _logger: Logger = Logger(verbose=True)
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field( file_path: Path | list[Path] | str | list[str] | None = Field(
default=None, default=None,
description="[Deprecated] The path to the file. Use file_paths instead.", description="[Deprecated] The path to the file. Use file_paths instead.",
) )
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field( file_paths: Path | list[Path] | str | list[str] | None = Field(
default_factory=list, description="The path to the file" default_factory=list, description="The path to the file",
) )
content: Dict[Path, str] = Field(init=False, default_factory=dict) content: dict[Path, str] = Field(init=False, default_factory=dict)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: KnowledgeStorage | None = Field(default=None)
safe_file_paths: List[Path] = Field(default_factory=list) safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before") @field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info): def validate_file_path(self, v, info):
"""Validate that at least one of file_path or file_paths is provided.""" """Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions # Single check if both are None, O(1) instead of nested conditions
if ( if (
v is None v is None
and info.data.get( and info.data.get(
"file_path" if info.field_name == "file_paths" else "file_paths" "file_path" if info.field_name == "file_paths" else "file_paths",
) )
is None is None
): ):
raise ValueError("Either file_path or file_paths must be provided") msg = "Either file_path or file_paths must be provided"
raise ValueError(msg)
return v return v
def model_post_init(self, _): def model_post_init(self, _) -> None:
"""Post-initialization method to load content.""" """Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths() self.safe_file_paths = self._process_file_paths()
self.validate_content() self.validate_content()
self.content = self.load_content() self.content = self.load_content()
@abstractmethod @abstractmethod
def load_content(self) -> Dict[Path, str]: def load_content(self) -> dict[Path, str]:
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory.""" """Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
pass
def validate_content(self): def validate_content(self) -> None:
"""Validate the paths.""" """Validate the paths."""
for path in self.safe_file_paths: for path in self.safe_file_paths:
if not path.exists(): if not path.exists():
@@ -59,7 +58,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.", f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
color="red", color="red",
) )
raise FileNotFoundError(f"File not found: {path}") msg = f"File not found: {path}"
raise FileNotFoundError(msg)
if not path.is_file(): if not path.is_file():
self._logger.log( self._logger.log(
"error", "error",
@@ -67,20 +67,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
color="red", color="red",
) )
def _save_documents(self): def _save_documents(self) -> None:
"""Save the documents to the storage.""" """Save the documents to the storage."""
if self.storage: if self.storage:
self.storage.save(self.chunks) self.storage.save(self.chunks)
else: else:
raise ValueError("No storage found to save documents.") msg = "No storage found to save documents."
raise ValueError(msg)
def convert_to_path(self, path: Union[Path, str]) -> Path: def convert_to_path(self, path: Path | str) -> Path:
"""Convert a path to a Path object.""" """Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
def _process_file_paths(self) -> List[Path]: def _process_file_paths(self) -> list[Path]:
"""Convert file_path to a list of Path objects.""" """Convert file_path to a list of Path objects."""
if hasattr(self, "file_path") and self.file_path is not None: if hasattr(self, "file_path") and self.file_path is not None:
self._logger.log( self._logger.log(
"warning", "warning",
@@ -90,10 +90,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
self.file_paths = self.file_path self.file_paths = self.file_path
if self.file_paths is None: if self.file_paths is None:
raise ValueError("Your source must be provided with a file_paths: []") msg = "Your source must be provided with a file_paths: []"
raise ValueError(msg)
# Convert single path to list # Convert single path to list
path_list: List[Union[Path, str]] = ( path_list: list[Path | str] = (
[self.file_paths] [self.file_paths]
if isinstance(self.file_paths, (str, Path)) if isinstance(self.file_paths, (str, Path))
else list(self.file_paths) else list(self.file_paths)
@@ -102,8 +103,9 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
) )
if not path_list: if not path_list:
msg = "file_path/file_paths must be a Path, str, or a list of these types"
raise ValueError( raise ValueError(
"file_path/file_paths must be a Path, str, or a list of these types" msg,
) )
return [self.convert_to_path(path) for path in path_list] return [self.convert_to_path(path) for path in path_list]

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any
import numpy as np import numpy as np
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@@ -12,41 +12,39 @@ class BaseKnowledgeSource(BaseModel, ABC):
chunk_size: int = 4000 chunk_size: int = 4000
chunk_overlap: int = 200 chunk_overlap: int = 200
chunks: List[str] = Field(default_factory=list) chunks: list[str] = Field(default_factory=list)
chunk_embeddings: List[np.ndarray] = Field(default_factory=list) chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: KnowledgeStorage | None = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: Optional[str] = Field(default=None) collection_name: str | None = Field(default=None)
@abstractmethod @abstractmethod
def validate_content(self) -> Any: def validate_content(self) -> Any:
"""Load and preprocess content from the source.""" """Load and preprocess content from the source."""
pass
@abstractmethod @abstractmethod
def add(self) -> None: def add(self) -> None:
"""Process content, chunk it, compute embeddings, and save them.""" """Process content, chunk it, compute embeddings, and save them."""
pass
def get_embeddings(self) -> List[np.ndarray]: def get_embeddings(self) -> list[np.ndarray]:
"""Return the list of embeddings for the chunks.""" """Return the list of embeddings for the chunks."""
return self.chunk_embeddings return self.chunk_embeddings
def _chunk_text(self, text: str) -> List[str]: def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks.""" """Utility method to split text into chunks."""
return [ return [
text[i : i + self.chunk_size] text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap) for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
] ]
def _save_documents(self): def _save_documents(self) -> None:
""" """Save the documents to the storage.
Save the documents to the storage.
This method should be called after the chunks and embeddings are generated. This method should be called after the chunks and embeddings are generated.
""" """
if self.storage: if self.storage:
self.storage.save(self.chunks) self.storage.save(self.chunks)
else: else:
raise ValueError("No storage found to save documents.") msg = "No storage found to save documents."
raise ValueError(msg)

View File

@@ -1,5 +1,6 @@
from collections.abc import Iterator
from pathlib import Path from pathlib import Path
from typing import Iterator, List, Optional, Union from typing import TYPE_CHECKING
from urllib.parse import urlparse from urllib.parse import urlparse
try: try:
@@ -7,7 +8,6 @@ try:
from docling.document_converter import DocumentConverter from docling.document_converter import DocumentConverter
from docling.exceptions import ConversionError from docling.exceptions import ConversionError
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.types.doc.document import DoclingDocument
DOCLING_AVAILABLE = True DOCLING_AVAILABLE = True
except ImportError: except ImportError:
@@ -19,27 +19,33 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
if TYPE_CHECKING:
from docling_core.types.doc.document import DoclingDocument
class CrewDoclingSource(BaseKnowledgeSource): class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to markdown or json """Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth. This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
if not DOCLING_AVAILABLE: if not DOCLING_AVAILABLE:
raise ImportError( msg = (
"The docling package is required to use CrewDoclingSource. " "The docling package is required to use CrewDoclingSource. "
"Please install it using: uv add docling" "Please install it using: uv add docling"
) )
raise ImportError(
msg,
)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
_logger: Logger = Logger(verbose=True) _logger: Logger = Logger(verbose=True)
file_path: Optional[List[Union[Path, str]]] = Field(default=None) file_path: list[Path | str] | None = Field(default=None)
file_paths: List[Union[Path, str]] = Field(default_factory=list) file_paths: list[Path | str] = Field(default_factory=list)
chunks: List[str] = Field(default_factory=list) chunks: list[str] = Field(default_factory=list)
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list) safe_file_paths: list[Path | str] = Field(default_factory=list)
content: List["DoclingDocument"] = Field(default_factory=list) content: list["DoclingDocument"] = Field(default_factory=list)
document_converter: "DocumentConverter" = Field( document_converter: "DocumentConverter" = Field(
default_factory=lambda: DocumentConverter( default_factory=lambda: DocumentConverter(
allowed_formats=[ allowed_formats=[
@@ -51,8 +57,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
InputFormat.IMAGE, InputFormat.IMAGE,
InputFormat.XLSX, InputFormat.XLSX,
InputFormat.PPTX, InputFormat.PPTX,
] ],
) ),
) )
def model_post_init(self, _) -> None: def model_post_init(self, _) -> None:
@@ -66,7 +72,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.safe_file_paths = self.validate_content() self.safe_file_paths = self.validate_content()
self.content = self._load_content() self.content = self._load_content()
def _load_content(self) -> List["DoclingDocument"]: def _load_content(self) -> list["DoclingDocument"]:
try: try:
return self._convert_source_to_docling_documents() return self._convert_source_to_docling_documents()
except ConversionError as e: except ConversionError as e:
@@ -75,10 +81,10 @@ class CrewDoclingSource(BaseKnowledgeSource):
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}", f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
"red", "red",
) )
raise e raise
except Exception as e: except Exception as e:
self._logger.log("error", f"Error loading content: {e}") self._logger.log("error", f"Error loading content: {e}")
raise e raise
def add(self) -> None: def add(self) -> None:
if self.content is None: if self.content is None:
@@ -88,7 +94,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.chunks.extend(list(new_chunks_iterable)) self.chunks.extend(list(new_chunks_iterable))
self._save_documents() self._save_documents()
def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]: def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths) conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter] return [result.document for result in conv_results_iter]
@@ -97,8 +103,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
for chunk in chunker.chunk(doc): for chunk in chunker.chunk(doc):
yield chunk.text yield chunk.text
def validate_content(self) -> List[Union[Path, str]]: def validate_content(self) -> list[Path | str]:
processed_paths: List[Union[Path, str]] = [] processed_paths: list[Path | str] = []
for path in self.file_paths: for path in self.file_paths:
if isinstance(path, str): if isinstance(path, str):
if path.startswith(("http://", "https://")): if path.startswith(("http://", "https://")):
@@ -106,15 +112,18 @@ class CrewDoclingSource(BaseKnowledgeSource):
if self._validate_url(path): if self._validate_url(path):
processed_paths.append(path) processed_paths.append(path)
else: else:
raise ValueError(f"Invalid URL format: {path}") msg = f"Invalid URL format: {path}"
raise ValueError(msg)
except Exception as e: except Exception as e:
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}") msg = f"Invalid URL: {path}. Error: {e!s}"
raise ValueError(msg)
else: else:
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path) local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
if local_path.exists(): if local_path.exists():
processed_paths.append(local_path) processed_paths.append(local_path)
else: else:
raise FileNotFoundError(f"File not found: {local_path}") msg = f"File not found: {local_path}"
raise FileNotFoundError(msg)
else: else:
# this is an instance of Path # this is an instance of Path
processed_paths.append(path) processed_paths.append(path)
@@ -128,7 +137,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
result.scheme in ("http", "https"), result.scheme in ("http", "https"),
result.netloc, result.netloc,
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
] ],
) )
except Exception: except Exception:
return False return False

View File

@@ -1,6 +1,5 @@
import csv import csv
from pathlib import Path from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -8,11 +7,11 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class CSVKnowledgeSource(BaseFileKnowledgeSource): class CSVKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries CSV file content using embeddings.""" """A knowledge source that stores and queries CSV file content using embeddings."""
def load_content(self) -> Dict[Path, str]: def load_content(self) -> dict[Path, str]:
"""Load and preprocess CSV file content.""" """Load and preprocess CSV file content."""
content_dict = {} content_dict = {}
for file_path in self.safe_file_paths: for file_path in self.safe_file_paths:
with open(file_path, "r", encoding="utf-8") as csvfile: with open(file_path, encoding="utf-8") as csvfile:
reader = csv.reader(csvfile) reader = csv.reader(csvfile)
content = "" content = ""
for row in reader: for row in reader:
@@ -21,8 +20,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
return content_dict return content_dict
def add(self) -> None: def add(self) -> None:
""" """Add CSV file content to the knowledge source, chunk it, compute embeddings,
Add CSV file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings. and save the embeddings.
""" """
content_str = ( content_str = (
@@ -32,7 +30,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks) self.chunks.extend(new_chunks)
self._save_documents() self._save_documents()
def _chunk_text(self, text: str) -> List[str]: def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks.""" """Utility method to split text into chunks."""
return [ return [
text[i : i + self.chunk_size] text[i : i + self.chunk_size]

View File

@@ -1,6 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Dict, Iterator, List, Optional, Union
from urllib.parse import urlparse
from pydantic import Field, field_validator from pydantic import Field, field_validator
@@ -16,34 +14,34 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
_logger: Logger = Logger(verbose=True) _logger: Logger = Logger(verbose=True)
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field( file_path: Path | list[Path] | str | list[str] | None = Field(
default=None, default=None,
description="[Deprecated] The path to the file. Use file_paths instead.", description="[Deprecated] The path to the file. Use file_paths instead.",
) )
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field( file_paths: Path | list[Path] | str | list[str] | None = Field(
default_factory=list, description="The path to the file" default_factory=list, description="The path to the file",
) )
chunks: List[str] = Field(default_factory=list) chunks: list[str] = Field(default_factory=list)
content: Dict[Path, Dict[str, str]] = Field(default_factory=dict) content: dict[Path, dict[str, str]] = Field(default_factory=dict)
safe_file_paths: List[Path] = Field(default_factory=list) safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before") @field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info): def validate_file_path(self, v, info):
"""Validate that at least one of file_path or file_paths is provided.""" """Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions # Single check if both are None, O(1) instead of nested conditions
if ( if (
v is None v is None
and info.data.get( and info.data.get(
"file_path" if info.field_name == "file_paths" else "file_paths" "file_path" if info.field_name == "file_paths" else "file_paths",
) )
is None is None
): ):
raise ValueError("Either file_path or file_paths must be provided") msg = "Either file_path or file_paths must be provided"
raise ValueError(msg)
return v return v
def _process_file_paths(self) -> List[Path]: def _process_file_paths(self) -> list[Path]:
"""Convert file_path to a list of Path objects.""" """Convert file_path to a list of Path objects."""
if hasattr(self, "file_path") and self.file_path is not None: if hasattr(self, "file_path") and self.file_path is not None:
self._logger.log( self._logger.log(
"warning", "warning",
@@ -53,10 +51,11 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.file_paths = self.file_path self.file_paths = self.file_path
if self.file_paths is None: if self.file_paths is None:
raise ValueError("Your source must be provided with a file_paths: []") msg = "Your source must be provided with a file_paths: []"
raise ValueError(msg)
# Convert single path to list # Convert single path to list
path_list: List[Union[Path, str]] = ( path_list: list[Path | str] = (
[self.file_paths] [self.file_paths]
if isinstance(self.file_paths, (str, Path)) if isinstance(self.file_paths, (str, Path))
else list(self.file_paths) else list(self.file_paths)
@@ -65,13 +64,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
) )
if not path_list: if not path_list:
msg = "file_path/file_paths must be a Path, str, or a list of these types"
raise ValueError( raise ValueError(
"file_path/file_paths must be a Path, str, or a list of these types" msg,
) )
return [self.convert_to_path(path) for path in path_list] return [self.convert_to_path(path) for path in path_list]
def validate_content(self): def validate_content(self) -> None:
"""Validate the paths.""" """Validate the paths."""
for path in self.safe_file_paths: for path in self.safe_file_paths:
if not path.exists(): if not path.exists():
@@ -80,7 +80,8 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.", f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
color="red", color="red",
) )
raise FileNotFoundError(f"File not found: {path}") msg = f"File not found: {path}"
raise FileNotFoundError(msg)
if not path.is_file(): if not path.is_file():
self._logger.log( self._logger.log(
"error", "error",
@@ -100,7 +101,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.validate_content() self.validate_content()
self.content = self._load_content() self.content = self._load_content()
def _load_content(self) -> Dict[Path, Dict[str, str]]: def _load_content(self) -> dict[Path, dict[str, str]]:
"""Load and preprocess Excel file content from multiple sheets. """Load and preprocess Excel file content from multiple sheets.
Each sheet's content is converted to CSV format and stored. Each sheet's content is converted to CSV format and stored.
@@ -111,6 +112,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
Raises: Raises:
ImportError: If required dependencies are missing. ImportError: If required dependencies are missing.
FileNotFoundError: If the specified Excel file cannot be opened. FileNotFoundError: If the specified Excel file cannot be opened.
""" """
pd = self._import_dependencies() pd = self._import_dependencies()
content_dict = {} content_dict = {}
@@ -119,14 +121,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
with pd.ExcelFile(file_path) as xl: with pd.ExcelFile(file_path) as xl:
sheet_dict = { sheet_dict = {
str(sheet_name): str( str(sheet_name): str(
pd.read_excel(xl, sheet_name).to_csv(index=False) pd.read_excel(xl, sheet_name).to_csv(index=False),
) )
for sheet_name in xl.sheet_names for sheet_name in xl.sheet_names
} }
content_dict[file_path] = sheet_dict content_dict[file_path] = sheet_dict
return content_dict return content_dict
def convert_to_path(self, path: Union[Path, str]) -> Path: def convert_to_path(self, path: Path | str) -> Path:
"""Convert a path to a Path object.""" """Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
@@ -138,13 +140,13 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
return pd return pd
except ImportError as e: except ImportError as e:
missing_package = str(e).split()[-1] missing_package = str(e).split()[-1]
msg = f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
raise ImportError( raise ImportError(
f"{missing_package} is not installed. Please install it with: pip install {missing_package}" msg,
) )
def add(self) -> None: def add(self) -> None:
""" """Add Excel file content to the knowledge source, chunk it, compute embeddings,
Add Excel file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings. and save the embeddings.
""" """
# Convert dictionary values to a single string if content is a dictionary # Convert dictionary values to a single string if content is a dictionary
@@ -161,7 +163,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks) self.chunks.extend(new_chunks)
self._save_documents() self._save_documents()
def _chunk_text(self, text: str) -> List[str]: def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks.""" """Utility method to split text into chunks."""
return [ return [
text[i : i + self.chunk_size] text[i : i + self.chunk_size]

View File

@@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -8,12 +8,12 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class JSONKnowledgeSource(BaseFileKnowledgeSource): class JSONKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries JSON file content using embeddings.""" """A knowledge source that stores and queries JSON file content using embeddings."""
def load_content(self) -> Dict[Path, str]: def load_content(self) -> dict[Path, str]:
"""Load and preprocess JSON file content.""" """Load and preprocess JSON file content."""
content: Dict[Path, str] = {} content: dict[Path, str] = {}
for path in self.safe_file_paths: for path in self.safe_file_paths:
path = self.convert_to_path(path) path = self.convert_to_path(path)
with open(path, "r", encoding="utf-8") as json_file: with open(path, encoding="utf-8") as json_file:
data = json.load(json_file) data = json.load(json_file)
content[path] = self._json_to_text(data) content[path] = self._json_to_text(data)
return content return content
@@ -29,12 +29,11 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
for item in data: for item in data:
text += f"{indent}- {self._json_to_text(item, level + 1)}\n" text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
else: else:
text += f"{str(data)}" text += f"{data!s}"
return text return text
def add(self) -> None: def add(self) -> None:
""" """Add JSON file content to the knowledge source, chunk it, compute embeddings,
Add JSON file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings. and save the embeddings.
""" """
content_str = ( content_str = (
@@ -44,7 +43,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks) self.chunks.extend(new_chunks)
self._save_documents() self._save_documents()
def _chunk_text(self, text: str) -> List[str]: def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks.""" """Utility method to split text into chunks."""
return [ return [
text[i : i + self.chunk_size] text[i : i + self.chunk_size]

View File

@@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class PDFKnowledgeSource(BaseFileKnowledgeSource): class PDFKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries PDF file content using embeddings.""" """A knowledge source that stores and queries PDF file content using embeddings."""
def load_content(self) -> Dict[Path, str]: def load_content(self) -> dict[Path, str]:
"""Load and preprocess PDF file content.""" """Load and preprocess PDF file content."""
pdfplumber = self._import_pdfplumber() pdfplumber = self._import_pdfplumber()
@@ -31,21 +30,21 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
return pdfplumber return pdfplumber
except ImportError: except ImportError:
msg = "pdfplumber is not installed. Please install it with: pip install pdfplumber"
raise ImportError( raise ImportError(
"pdfplumber is not installed. Please install it with: pip install pdfplumber" msg,
) )
def add(self) -> None: def add(self) -> None:
""" """Add PDF file content to the knowledge source, chunk it, compute embeddings,
Add PDF file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings. and save the embeddings.
""" """
for _, text in self.content.items(): for text in self.content.values():
new_chunks = self._chunk_text(text) new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks) self.chunks.extend(new_chunks)
self._save_documents() self._save_documents()
def _chunk_text(self, text: str) -> List[str]: def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks.""" """Utility method to split text into chunks."""
return [ return [
text[i : i + self.chunk_size] text[i : i + self.chunk_size]

View File

@@ -1,4 +1,3 @@
from typing import List, Optional
from pydantic import Field from pydantic import Field
@@ -9,16 +8,17 @@ class StringKnowledgeSource(BaseKnowledgeSource):
"""A knowledge source that stores and queries plain text content using embeddings.""" """A knowledge source that stores and queries plain text content using embeddings."""
content: str = Field(...) content: str = Field(...)
collection_name: Optional[str] = Field(default=None) collection_name: str | None = Field(default=None)
def model_post_init(self, _): def model_post_init(self, _) -> None:
"""Post-initialization method to validate content.""" """Post-initialization method to validate content."""
self.validate_content() self.validate_content()
def validate_content(self): def validate_content(self) -> None:
"""Validate string content.""" """Validate string content."""
if not isinstance(self.content, str): if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content") msg = "StringKnowledgeSource only accepts string content"
raise ValueError(msg)
def add(self) -> None: def add(self) -> None:
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them.""" """Add string content to the knowledge source, chunk it, compute embeddings, and save them."""
@@ -26,7 +26,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks) self.chunks.extend(new_chunks)
self._save_documents() self._save_documents()
def _chunk_text(self, text: str) -> List[str]: def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks.""" """Utility method to split text into chunks."""
return [ return [
text[i : i + self.chunk_size] text[i : i + self.chunk_size]

View File

@@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -7,26 +6,25 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class TextFileKnowledgeSource(BaseFileKnowledgeSource): class TextFileKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries text file content using embeddings.""" """A knowledge source that stores and queries text file content using embeddings."""
def load_content(self) -> Dict[Path, str]: def load_content(self) -> dict[Path, str]:
"""Load and preprocess text file content.""" """Load and preprocess text file content."""
content = {} content = {}
for path in self.safe_file_paths: for path in self.safe_file_paths:
path = self.convert_to_path(path) path = self.convert_to_path(path)
with open(path, "r", encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
content[path] = f.read() content[path] = f.read()
return content return content
def add(self) -> None: def add(self) -> None:
""" """Add text file content to the knowledge source, chunk it, compute embeddings,
Add text file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings. and save the embeddings.
""" """
for _, text in self.content.items(): for text in self.content.values():
new_chunks = self._chunk_text(text) new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks) self.chunks.extend(new_chunks)
self._save_documents() self._save_documents()
def _chunk_text(self, text: str) -> List[str]: def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks.""" """Utility method to split text into chunks."""
return [ return [
text[i : i + self.chunk_size] text[i : i + self.chunk_size]

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any
class BaseKnowledgeStorage(ABC): class BaseKnowledgeStorage(ABC):
@@ -8,22 +8,19 @@ class BaseKnowledgeStorage(ABC):
@abstractmethod @abstractmethod
def search( def search(
self, self,
query: List[str], query: list[str],
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: dict | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Search for documents in the knowledge base.""" """Search for documents in the knowledge base."""
pass
@abstractmethod @abstractmethod
def save( def save(
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]] self, documents: list[str], metadata: dict[str, Any] | list[dict[str, Any]],
) -> None: ) -> None:
"""Save documents to the knowledge base.""" """Save documents to the knowledge base."""
pass
@abstractmethod @abstractmethod
def reset(self) -> None: def reset(self) -> None:
"""Reset the knowledge base.""" """Reset the knowledge base."""
pass

View File

@@ -4,12 +4,11 @@ import io
import logging import logging
import os import os
import shutil import shutil
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any
import chromadb import chromadb
import chromadb.errors import chromadb.errors
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany
from chromadb.config import Settings from chromadb.config import Settings
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
@@ -19,6 +18,9 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
if TYPE_CHECKING:
from chromadb.api.types import OneOrMany
@contextlib.contextmanager @contextlib.contextmanager
def suppress_logging( def suppress_logging(
@@ -38,30 +40,29 @@ def suppress_logging(
class KnowledgeStorage(BaseKnowledgeStorage): class KnowledgeStorage(BaseKnowledgeStorage):
""" """Extends Storage to handle embeddings for memory entries, improving
Extends Storage to handle embeddings for memory entries, improving
search efficiency. search efficiency.
""" """
collection: Optional[chromadb.Collection] = None collection: chromadb.Collection | None = None
collection_name: Optional[str] = "knowledge" collection_name: str | None = "knowledge"
app: Optional[ClientAPI] = None app: ClientAPI | None = None
def __init__( def __init__(
self, self,
embedder: Optional[Dict[str, Any]] = None, embedder: dict[str, Any] | None = None,
collection_name: Optional[str] = None, collection_name: str | None = None,
): ) -> None:
self.collection_name = collection_name self.collection_name = collection_name
self._set_embedder_config(embedder) self._set_embedder_config(embedder)
def search( def search(
self, self,
query: List[str], query: list[str],
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: dict | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Dict[str, Any]]: ) -> list[dict[str, Any]]:
with suppress_logging(): with suppress_logging():
if self.collection: if self.collection:
fetched = self.collection.query( fetched = self.collection.query(
@@ -80,10 +81,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
if result["score"] >= score_threshold: if result["score"] >= score_threshold:
results.append(result) results.append(result)
return results return results
else: msg = "Collection not initialized"
raise Exception("Collection not initialized") raise Exception(msg)
def initialize_knowledge_storage(self): def initialize_knowledge_storage(self) -> None:
base_path = os.path.join(db_storage_path(), "knowledge") base_path = os.path.join(db_storage_path(), "knowledge")
chroma_client = chromadb.PersistentClient( chroma_client = chromadb.PersistentClient(
path=base_path, path=base_path,
@@ -104,11 +105,13 @@ class KnowledgeStorage(BaseKnowledgeStorage):
embedding_function=self.embedder, embedding_function=self.embedder,
) )
else: else:
raise Exception("Vector Database Client not initialized") msg = "Vector Database Client not initialized"
raise Exception(msg)
except Exception: except Exception:
raise Exception("Failed to create or get collection") msg = "Failed to create or get collection"
raise Exception(msg)
def reset(self): def reset(self) -> None:
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY) base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
if not self.app: if not self.app:
self.app = chromadb.PersistentClient( self.app = chromadb.PersistentClient(
@@ -123,11 +126,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def save( def save(
self, self,
documents: List[str], documents: list[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
): ) -> None:
if not self.collection: if not self.collection:
raise Exception("Collection not initialized") msg = "Collection not initialized"
raise Exception(msg)
try: try:
# Create a dictionary to store unique documents # Create a dictionary to store unique documents
@@ -156,7 +160,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
filtered_ids.append(doc_id) filtered_ids.append(doc_id)
# If we have no metadata at all, set it to None # If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = ( final_metadata: OneOrMany[chromadb.Metadata] | None = (
None if all(m is None for m in filtered_metadata) else filtered_metadata None if all(m is None for m in filtered_metadata) else filtered_metadata
) )
@@ -171,10 +175,13 @@ class KnowledgeStorage(BaseKnowledgeStorage):
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`", "Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red", "red",
) )
raise ValueError( msg = (
"Embedding dimension mismatch. Make sure you're using the same embedding model " "Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection." "across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`" "Try resetting the collection using `crewai reset-memories -a`"
)
raise ValueError(
msg,
) from e ) from e
except Exception as e: except Exception as e:
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red") Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
@@ -186,15 +193,16 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) )
return OpenAIEmbeddingFunction( return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
) )
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None: def _set_embedder_config(self, embedder: dict[str, Any] | None = None) -> None:
"""Set the embedding configuration for the knowledge storage. """Set the embedding configuration for the knowledge storage.
Args: Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function. If None or empty, defaults to the default embedding function.
""" """
self.embedder = ( self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder) EmbeddingConfigurator().configure_embedder(embedder)

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, List from typing import Any
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str: def extract_knowledge_context(knowledge_snippets: list[dict[str, Any]]) -> str:
"""Extract knowledge from the task prompt.""" """Extract knowledge from the task prompt."""
valid_snippets = [ valid_snippets = [
result["context"] result["context"]

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
import uuid import uuid
from datetime import datetime from collections.abc import Callable
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast from typing import Any, cast
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
@@ -35,7 +35,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args, render_text_description_and_args,
show_agent_logs, show_agent_logs,
) )
from crewai.utilities.converter import convert_to_model, generate_model_description from crewai.utilities.converter import generate_model_description
from crewai.utilities.events.agent_events import ( from crewai.utilities.events.agent_events import (
LiteAgentExecutionCompletedEvent, LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent, LiteAgentExecutionErrorEvent,
@@ -60,15 +60,15 @@ class LiteAgentOutput(BaseModel):
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
raw: str = Field(description="Raw output of the agent", default="") raw: str = Field(description="Raw output of the agent", default="")
pydantic: Optional[BaseModel] = Field( pydantic: BaseModel | None = Field(
description="Pydantic output of the agent", default=None description="Pydantic output of the agent", default=None,
) )
agent_role: str = Field(description="Role of the agent that produced this output") agent_role: str = Field(description="Role of the agent that produced this output")
usage_metrics: Optional[Dict[str, Any]] = Field( usage_metrics: dict[str, Any] | None = Field(
description="Token usage metrics for this execution", default=None description="Token usage metrics for this execution", default=None,
) )
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Convert pydantic_output to a dictionary.""" """Convert pydantic_output to a dictionary."""
if self.pydantic: if self.pydantic:
return self.pydantic.model_dump() return self.pydantic.model_dump()
@@ -82,8 +82,7 @@ class LiteAgentOutput(BaseModel):
class LiteAgent(FlowTrackable, BaseModel): class LiteAgent(FlowTrackable, BaseModel):
""" """A lightweight agent that can process messages and use tools.
A lightweight agent that can process messages and use tools.
This agent is simpler than the full Agent class, focusing on direct execution This agent is simpler than the full Agent class, focusing on direct execution
rather than task delegation. It's designed to be used for simple interactions rather than task delegation. It's designed to be used for simple interactions
@@ -99,6 +98,7 @@ class LiteAgent(FlowTrackable, BaseModel):
max_iterations: Maximum number of iterations for tool usage. max_iterations: Maximum number of iterations for tool usage.
max_execution_time: Maximum execution time in seconds. max_execution_time: Maximum execution time in seconds.
response_format: Optional Pydantic model for structured output. response_format: Optional Pydantic model for structured output.
""" """
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
@@ -107,19 +107,19 @@ class LiteAgent(FlowTrackable, BaseModel):
role: str = Field(description="Role of the agent") role: str = Field(description="Role of the agent")
goal: str = Field(description="Goal of the agent") goal: str = Field(description="Goal of the agent")
backstory: str = Field(description="Backstory of the agent") backstory: str = Field(description="Backstory of the agent")
llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( llm: str | InstanceOf[LLM] | Any | None = Field(
default=None, description="Language model that will run the agent" default=None, description="Language model that will run the agent",
) )
tools: List[BaseTool] = Field( tools: list[BaseTool] = Field(
default_factory=list, description="Tools at agent's disposal" default_factory=list, description="Tools at agent's disposal",
) )
# Execution Control Properties # Execution Control Properties
max_iterations: int = Field( max_iterations: int = Field(
default=15, description="Maximum number of iterations for tool usage" default=15, description="Maximum number of iterations for tool usage",
) )
max_execution_time: Optional[int] = Field( max_execution_time: int | None = Field(
default=None, description="Maximum execution time in seconds" default=None, description="Maximum execution time in seconds",
) )
respect_context_window: bool = Field( respect_context_window: bool = Field(
default=True, default=True,
@@ -129,38 +129,38 @@ class LiteAgent(FlowTrackable, BaseModel):
default=True, default=True,
description="Whether to use stop words to prevent the LLM from using tools", description="Whether to use stop words to prevent the LLM from using tools",
) )
request_within_rpm_limit: Optional[Callable[[], bool]] = Field( request_within_rpm_limit: Callable[[], bool] | None = Field(
default=None, default=None,
description="Callback to check if the request is within the RPM limit", description="Callback to check if the request is within the RPM limit",
) )
i18n: I18N = Field(default=I18N(), description="Internationalization settings.") i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
# Output and Formatting Properties # Output and Formatting Properties
response_format: Optional[Type[BaseModel]] = Field( response_format: type[BaseModel] | None = Field(
default=None, description="Pydantic model for structured output" default=None, description="Pydantic model for structured output",
) )
verbose: bool = Field( verbose: bool = Field(
default=False, description="Whether to print execution details" default=False, description="Whether to print execution details",
) )
callbacks: List[Callable] = Field( callbacks: list[Callable] = Field(
default=[], description="Callbacks to be used for the agent" default=[], description="Callbacks to be used for the agent",
) )
# State and Results # State and Results
tools_results: List[Dict[str, Any]] = Field( tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent." default=[], description="Results of the tools used by the agent.",
) )
# Reference of Agent # Reference of Agent
original_agent: Optional[BaseAgent] = Field( original_agent: BaseAgent | None = Field(
default=None, description="Reference to the agent that created this LiteAgent" default=None, description="Reference to the agent that created this LiteAgent",
) )
# Private Attributes # Private Attributes
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list) _parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess) _token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler) _cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4())) _key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list) _messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
_iterations: int = PrivateAttr(default=0) _iterations: int = PrivateAttr(default=0)
_printer: Printer = PrivateAttr(default_factory=Printer) _printer: Printer = PrivateAttr(default_factory=Printer)
@@ -169,7 +169,8 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Set up the LLM and other components after initialization.""" """Set up the LLM and other components after initialization."""
self.llm = create_llm(self.llm) self.llm = create_llm(self.llm)
if not isinstance(self.llm, LLM): if not isinstance(self.llm, LLM):
raise ValueError("Unable to create LLM instance") msg = "Unable to create LLM instance"
raise ValueError(msg)
# Initialize callbacks # Initialize callbacks
token_callback = TokenCalcHandler(token_cost_process=self._token_process) token_callback = TokenCalcHandler(token_cost_process=self._token_process)
@@ -194,9 +195,8 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Return the original role for compatibility with tool interfaces.""" """Return the original role for compatibility with tool interfaces."""
return self.role return self.role
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput: def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
""" """Execute the agent with the given messages.
Execute the agent with the given messages.
Args: Args:
messages: Either a string query or a list of message dictionaries. messages: Either a string query or a list of message dictionaries.
@@ -205,6 +205,7 @@ class LiteAgent(FlowTrackable, BaseModel):
Returns: Returns:
LiteAgentOutput: The result of the agent execution. LiteAgentOutput: The result of the agent execution.
""" """
# Create agent info for event emission # Create agent info for event emission
agent_info = { agent_info = {
@@ -235,18 +236,18 @@ class LiteAgent(FlowTrackable, BaseModel):
# Execute the agent using invoke loop # Execute the agent using invoke loop
agent_finish = self._invoke_loop() agent_finish = self._invoke_loop()
formatted_result: Optional[BaseModel] = None formatted_result: BaseModel | None = None
if self.response_format: if self.response_format:
try: try:
# Cast to BaseModel to ensure type safety # Cast to BaseModel to ensure type safety
result = self.response_format.model_validate_json( result = self.response_format.model_validate_json(
agent_finish.output agent_finish.output,
) )
if isinstance(result, BaseModel): if isinstance(result, BaseModel):
formatted_result = result formatted_result = result
except Exception as e: except Exception as e:
self._printer.print( self._printer.print(
content=f"Failed to parse output into response format: {str(e)}", content=f"Failed to parse output into response format: {e!s}",
color="yellow", color="yellow",
) )
@@ -286,13 +287,12 @@ class LiteAgent(FlowTrackable, BaseModel):
error=str(e), error=str(e),
), ),
) )
raise e raise
async def kickoff_async( async def kickoff_async(
self, messages: Union[str, List[Dict[str, str]]] self, messages: str | list[dict[str, str]],
) -> LiteAgentOutput: ) -> LiteAgentOutput:
""" """Execute the agent asynchronously with the given messages.
Execute the agent asynchronously with the given messages.
Args: Args:
messages: Either a string query or a list of message dictionaries. messages: Either a string query or a list of message dictionaries.
@@ -301,6 +301,7 @@ class LiteAgent(FlowTrackable, BaseModel):
Returns: Returns:
LiteAgentOutput: The result of the agent execution. LiteAgentOutput: The result of the agent execution.
""" """
return await asyncio.to_thread(self.kickoff, messages) return await asyncio.to_thread(self.kickoff, messages)
@@ -319,7 +320,7 @@ class LiteAgent(FlowTrackable, BaseModel):
else: else:
# Use the prompt template for agents without tools # Use the prompt template for agents without tools
base_prompt = self.i18n.slice( base_prompt = self.i18n.slice(
"lite_agent_system_prompt_without_tools" "lite_agent_system_prompt_without_tools",
).format( ).format(
role=self.role, role=self.role,
backstory=self.backstory, backstory=self.backstory,
@@ -330,14 +331,14 @@ class LiteAgent(FlowTrackable, BaseModel):
if self.response_format: if self.response_format:
schema = generate_model_description(self.response_format) schema = generate_model_description(self.response_format)
base_prompt += self.i18n.slice("lite_agent_response_format").format( base_prompt += self.i18n.slice("lite_agent_response_format").format(
response_format=schema response_format=schema,
) )
return base_prompt return base_prompt
def _format_messages( def _format_messages(
self, messages: Union[str, List[Dict[str, str]]] self, messages: str | list[dict[str, str]],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
"""Format messages for the LLM.""" """Format messages for the LLM."""
if isinstance(messages, str): if isinstance(messages, str):
messages = [{"role": "user", "content": messages}] messages = [{"role": "user", "content": messages}]
@@ -353,11 +354,11 @@ class LiteAgent(FlowTrackable, BaseModel):
return formatted_messages return formatted_messages
def _invoke_loop(self) -> AgentFinish: def _invoke_loop(self) -> AgentFinish:
""" """Run the agent's thought process until it reaches a conclusion or max iterations.
Run the agent's thought process until it reaches a conclusion or max iterations.
Returns: Returns:
AgentFinish: The final result of the agent execution. AgentFinish: The final result of the agent execution.
""" """
# Execute the agent loop # Execute the agent loop
formatted_answer = None formatted_answer = None
@@ -369,7 +370,7 @@ class LiteAgent(FlowTrackable, BaseModel):
printer=self._printer, printer=self._printer,
i18n=self.i18n, i18n=self.i18n,
messages=self._messages, messages=self._messages,
llm=cast(LLM, self.llm), llm=cast("LLM", self.llm),
callbacks=self._callbacks, callbacks=self._callbacks,
) )
@@ -387,7 +388,7 @@ class LiteAgent(FlowTrackable, BaseModel):
try: try:
answer = get_llm_response( answer = get_llm_response(
llm=cast(LLM, self.llm), llm=cast("LLM", self.llm),
messages=self._messages, messages=self._messages,
callbacks=self._callbacks, callbacks=self._callbacks,
printer=self._printer, printer=self._printer,
@@ -407,7 +408,7 @@ class LiteAgent(FlowTrackable, BaseModel):
self, self,
event=LLMCallFailedEvent(error=str(e)), event=LLMCallFailedEvent(error=str(e)),
) )
raise e raise
formatted_answer = process_llm_response(answer, self.use_stop_words) formatted_answer = process_llm_response(answer, self.use_stop_words)
@@ -421,8 +422,8 @@ class LiteAgent(FlowTrackable, BaseModel):
agent_role=self.role, agent_role=self.role,
agent=self.original_agent, agent=self.original_agent,
) )
except Exception as e: except Exception:
raise e raise
formatted_answer = handle_agent_action_core( formatted_answer = handle_agent_action_core(
formatted_answer=formatted_answer, formatted_answer=formatted_answer,
@@ -443,20 +444,19 @@ class LiteAgent(FlowTrackable, BaseModel):
except Exception as e: except Exception as e:
if e.__class__.__module__.startswith("litellm"): if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors # Do not retry on litellm errors
raise e raise
if is_context_length_exceeded(e): if is_context_length_exceeded(e):
handle_context_length( handle_context_length(
respect_context_window=self.respect_context_window, respect_context_window=self.respect_context_window,
printer=self._printer, printer=self._printer,
messages=self._messages, messages=self._messages,
llm=cast(LLM, self.llm), llm=cast("LLM", self.llm),
callbacks=self._callbacks, callbacks=self._callbacks,
i18n=self.i18n, i18n=self.i18n,
) )
continue continue
else: handle_unknown_error(self._printer, e)
handle_unknown_error(self._printer, e) raise
raise e
finally: finally:
self._iterations += 1 self._iterations += 1
@@ -465,7 +465,7 @@ class LiteAgent(FlowTrackable, BaseModel):
self._show_logs(formatted_answer) self._show_logs(formatted_answer)
return formatted_answer return formatted_answer
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]): def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
"""Show logs for the agent's execution.""" """Show logs for the agent's execution."""
show_agent_logs( show_agent_logs(
printer=self._printer, printer=self._printer,

View File

@@ -6,17 +6,10 @@ import threading
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from types import SimpleNamespace
from typing import ( from typing import (
Any, Any,
DefaultDict,
Dict,
List,
Literal, Literal,
Optional,
Type,
TypedDict, TypedDict,
Union,
cast, cast,
) )
@@ -31,7 +24,6 @@ from crewai.utilities.events.llm_events import (
LLMCallType, LLMCallType,
LLMStreamChunkEvent, LLMStreamChunkEvent,
) )
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning) warnings.simplefilter("ignore", UserWarning)
@@ -55,7 +47,7 @@ load_dotenv()
class FilteredStream: class FilteredStream:
def __init__(self, original_stream): def __init__(self, original_stream) -> None:
self._original_stream = original_stream self._original_stream = original_stream
self._lock = threading.Lock() self._lock = threading.Lock()
@@ -210,7 +202,7 @@ def suppress_warnings():
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
warnings.filterwarnings( warnings.filterwarnings(
"ignore", message="open_text is deprecated*", category=DeprecationWarning "ignore", message="open_text is deprecated*", category=DeprecationWarning,
) )
# Redirect stdout and stderr # Redirect stdout and stderr
@@ -226,14 +218,14 @@ def suppress_warnings():
class Delta(TypedDict): class Delta(TypedDict):
content: Optional[str] content: str | None
role: Optional[str] role: str | None
class StreamingChoices(TypedDict): class StreamingChoices(TypedDict):
delta: Delta delta: Delta
index: int index: int
finish_reason: Optional[str] finish_reason: str | None
class FunctionArgs(BaseModel): class FunctionArgs(BaseModel):
@@ -249,29 +241,31 @@ class LLM(BaseLLM):
def __init__( def __init__(
self, self,
model: str, model: str,
timeout: Optional[Union[float, int]] = None, timeout: float | None = None,
temperature: Optional[float] = None, temperature: float | None = None,
top_p: Optional[float] = None, top_p: float | None = None,
n: Optional[int] = None, n: int | None = None,
stop: Optional[Union[str, List[str]]] = None, stop: str | list[str] | None = None,
max_completion_tokens: Optional[int] = None, max_completion_tokens: int | None = None,
max_tokens: Optional[int] = None, max_tokens: int | None = None,
presence_penalty: Optional[float] = None, presence_penalty: float | None = None,
frequency_penalty: Optional[float] = None, frequency_penalty: float | None = None,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: dict[int, float] | None = None,
response_format: Optional[Type[BaseModel]] = None, response_format: type[BaseModel] | None = None,
seed: Optional[int] = None, seed: int | None = None,
logprobs: Optional[int] = None, logprobs: int | None = None,
top_logprobs: Optional[int] = None, top_logprobs: int | None = None,
base_url: Optional[str] = None, base_url: str | None = None,
api_base: Optional[str] = None, api_base: str | None = None,
api_version: Optional[str] = None, api_version: str | None = None,
api_key: Optional[str] = None, api_key: str | None = None,
callbacks: List[Any] = [], callbacks: list[Any] | None = None,
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False, stream: bool = False,
**kwargs, **kwargs,
): ) -> None:
if callbacks is None:
callbacks = []
self.model = model self.model = model
self.timeout = timeout self.timeout = timeout
self.temperature = temperature self.temperature = temperature
@@ -301,7 +295,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str] # Normalize self.stop to always be a List[str]
if stop is None: if stop is None:
self.stop: List[str] = [] self.stop: list[str] = []
elif isinstance(stop, str): elif isinstance(stop, str):
self.stop = [stop] self.stop = [stop]
else: else:
@@ -318,15 +312,16 @@ class LLM(BaseLLM):
Returns: Returns:
bool: True if the model is from Anthropic, False otherwise. bool: True if the model is from Anthropic, False otherwise.
""" """
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/") ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _prepare_completion_params( def _prepare_completion_params(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: str | list[dict[str, str]],
tools: Optional[List[dict]] = None, tools: list[dict] | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Prepare parameters for the completion call. """Prepare parameters for the completion call.
Args: Args:
@@ -337,6 +332,7 @@ class LLM(BaseLLM):
Returns: Returns:
Dict[str, Any]: Parameters for the completion call Dict[str, Any]: Parameters for the completion call
""" """
# --- 1) Format messages according to provider requirements # --- 1) Format messages according to provider requirements
if isinstance(messages, str): if isinstance(messages, str):
@@ -375,9 +371,9 @@ class LLM(BaseLLM):
def _handle_streaming_response( def _handle_streaming_response(
self, self,
params: Dict[str, Any], params: dict[str, Any],
callbacks: Optional[List[Any]] = None, callbacks: list[Any] | None = None,
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
) -> str: ) -> str:
"""Handle a streaming response from the LLM. """Handle a streaming response from the LLM.
@@ -391,6 +387,7 @@ class LLM(BaseLLM):
Raises: Raises:
Exception: If no content is received from the streaming response Exception: If no content is received from the streaming response
""" """
# --- 1) Initialize response tracking # --- 1) Initialize response tracking
full_response = "" full_response = ""
@@ -399,8 +396,8 @@ class LLM(BaseLLM):
usage_info = None usage_info = None
tool_calls = None tool_calls = None
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict( accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs AccumulatedToolArgs,
) )
# --- 2) Make sure stream is set to True and include usage metrics # --- 2) Make sure stream is set to True and include usage metrics
@@ -424,16 +421,16 @@ class LLM(BaseLLM):
choices = chunk["choices"] choices = chunk["choices"]
elif hasattr(chunk, "choices"): elif hasattr(chunk, "choices"):
# Check if choices is not a type but an actual attribute with value # Check if choices is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "choices"), type): if not isinstance(chunk.choices, type):
choices = getattr(chunk, "choices") choices = chunk.choices
# Try to extract usage information if available # Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk: if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"] usage_info = chunk["usage"]
elif hasattr(chunk, "usage"): elif hasattr(chunk, "usage"):
# Check if usage is not a type but an actual attribute with value # Check if usage is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "usage"), type): if not isinstance(chunk.usage, type):
usage_info = getattr(chunk, "usage") usage_info = chunk.usage
if choices and len(choices) > 0: if choices and len(choices) > 0:
choice = choices[0] choice = choices[0]
@@ -443,7 +440,7 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "delta" in choice: if isinstance(choice, dict) and "delta" in choice:
delta = choice["delta"] delta = choice["delta"]
elif hasattr(choice, "delta"): elif hasattr(choice, "delta"):
delta = getattr(choice, "delta") delta = choice.delta
# Extract content from delta # Extract content from delta
if delta: if delta:
@@ -453,7 +450,7 @@ class LLM(BaseLLM):
chunk_content = delta["content"] chunk_content = delta["content"]
# Handle object format # Handle object format
elif hasattr(delta, "content"): elif hasattr(delta, "content"):
chunk_content = getattr(delta, "content") chunk_content = delta.content
# Handle case where content might be None or empty # Handle case where content might be None or empty
if chunk_content is None and isinstance(delta, dict): if chunk_content is None and isinstance(delta, dict):
@@ -491,21 +488,21 @@ class LLM(BaseLLM):
# --- 4) Fallback to non-streaming if no content received # --- 4) Fallback to non-streaming if no content received
if not full_response.strip() and chunk_count == 0: if not full_response.strip() and chunk_count == 0:
logging.warning( logging.warning(
"No chunks received in streaming response, falling back to non-streaming" "No chunks received in streaming response, falling back to non-streaming",
) )
non_streaming_params = params.copy() non_streaming_params = params.copy()
non_streaming_params["stream"] = False non_streaming_params["stream"] = False
non_streaming_params.pop( non_streaming_params.pop(
"stream_options", None "stream_options", None,
) # Remove stream_options for non-streaming call ) # Remove stream_options for non-streaming call
return self._handle_non_streaming_response( return self._handle_non_streaming_response(
non_streaming_params, callbacks, available_functions non_streaming_params, callbacks, available_functions,
) )
# --- 5) Handle empty response with chunks # --- 5) Handle empty response with chunks
if not full_response.strip() and chunk_count > 0: if not full_response.strip() and chunk_count > 0:
logging.warning( logging.warning(
f"Received {chunk_count} chunks but no content was extracted" f"Received {chunk_count} chunks but no content was extracted",
) )
if last_chunk is not None: if last_chunk is not None:
try: try:
@@ -514,8 +511,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk: if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"] choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"): elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type): if not isinstance(last_chunk.choices, type):
choices = getattr(last_chunk, "choices") choices = last_chunk.choices
if choices and len(choices) > 0: if choices and len(choices) > 0:
choice = choices[0] choice = choices[0]
@@ -525,30 +522,31 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice: if isinstance(choice, dict) and "message" in choice:
message = choice["message"] message = choice["message"]
elif hasattr(choice, "message"): elif hasattr(choice, "message"):
message = getattr(choice, "message") message = choice.message
if message: if message:
content = None content = None
if isinstance(message, dict) and "content" in message: if isinstance(message, dict) and "content" in message:
content = message["content"] content = message["content"]
elif hasattr(message, "content"): elif hasattr(message, "content"):
content = getattr(message, "content") content = message.content
if content: if content:
full_response = content full_response = content
logging.info( logging.info(
f"Extracted content from last chunk message: {full_response}" f"Extracted content from last chunk message: {full_response}",
) )
except Exception as e: except Exception as e:
logging.debug(f"Error extracting content from last chunk: {e}") logging.debug(f"Error extracting content from last chunk: {e}")
logging.debug( logging.debug(
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}" f"Last chunk format: {type(last_chunk)}, content: {last_chunk}",
) )
# --- 6) If still empty, raise an error instead of using a default response # --- 6) If still empty, raise an error instead of using a default response
if not full_response.strip() and len(accumulated_tool_args) == 0: if not full_response.strip() and len(accumulated_tool_args) == 0:
msg = "No content received from streaming response. Received empty chunks or failed to extract content."
raise Exception( raise Exception(
"No content received from streaming response. Received empty chunks or failed to extract content." msg,
) )
# --- 7) Check for tool calls in the final response # --- 7) Check for tool calls in the final response
@@ -559,8 +557,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk: if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"] choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"): elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type): if not isinstance(last_chunk.choices, type):
choices = getattr(last_chunk, "choices") choices = last_chunk.choices
if choices and len(choices) > 0: if choices and len(choices) > 0:
choice = choices[0] choice = choices[0]
@@ -569,13 +567,13 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice: if isinstance(choice, dict) and "message" in choice:
message = choice["message"] message = choice["message"]
elif hasattr(choice, "message"): elif hasattr(choice, "message"):
message = getattr(choice, "message") message = choice.message
if message: if message:
if isinstance(message, dict) and "tool_calls" in message: if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message["tool_calls"] tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"): elif hasattr(message, "tool_calls"):
tool_calls = getattr(message, "tool_calls") tool_calls = message.tool_calls
except Exception as e: except Exception as e:
logging.debug(f"Error checking for tool calls: {e}") logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly # --- 8) If no tool calls or no available functions, return the text response directly
@@ -605,9 +603,9 @@ class LLM(BaseLLM):
# decide whether to summarize the content or abort based on the respect_context_window flag. # decide whether to summarize the content or abort based on the respect_context_window flag.
raise LLMContextLengthExceededException(str(e)) raise LLMContextLengthExceededException(str(e))
except Exception as e: except Exception as e:
logging.error(f"Error in streaming response: {str(e)}") logging.exception(f"Error in streaming response: {e!s}")
if full_response.strip(): if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}") logging.warning(f"Returning partial response despite error: {e!s}")
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL) self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response return full_response
@@ -617,13 +615,14 @@ class LLM(BaseLLM):
self, self,
event=LLMCallFailedEvent(error=str(e)), event=LLMCallFailedEvent(error=str(e)),
) )
raise Exception(f"Failed to get streaming response: {str(e)}") msg = f"Failed to get streaming response: {e!s}"
raise Exception(msg)
def _handle_streaming_tool_calls( def _handle_streaming_tool_calls(
self, self,
tool_calls: List[ChatCompletionDeltaToolCall], tool_calls: list[ChatCompletionDeltaToolCall],
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs], accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
) -> None | str: ) -> None | str:
for tool_call in tool_calls: for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index] current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -662,9 +661,9 @@ class LLM(BaseLLM):
def _handle_streaming_callbacks( def _handle_streaming_callbacks(
self, self,
callbacks: Optional[List[Any]], callbacks: list[Any] | None,
usage_info: Optional[Dict[str, Any]], usage_info: dict[str, Any] | None,
last_chunk: Optional[Any], last_chunk: Any | None,
) -> None: ) -> None:
"""Handle callbacks with usage info for streaming responses. """Handle callbacks with usage info for streaming responses.
@@ -672,6 +671,7 @@ class LLM(BaseLLM):
callbacks: Optional list of callback functions callbacks: Optional list of callback functions
usage_info: Usage information collected during streaming usage_info: Usage information collected during streaming
last_chunk: The last chunk received from the streaming response last_chunk: The last chunk received from the streaming response
""" """
if callbacks and len(callbacks) > 0: if callbacks and len(callbacks) > 0:
for callback in callbacks: for callback in callbacks:
@@ -688,9 +688,9 @@ class LLM(BaseLLM):
usage_info = last_chunk["usage"] usage_info = last_chunk["usage"]
elif hasattr(last_chunk, "usage"): elif hasattr(last_chunk, "usage"):
if not isinstance( if not isinstance(
getattr(last_chunk, "usage"), type last_chunk.usage, type,
): ):
usage_info = getattr(last_chunk, "usage") usage_info = last_chunk.usage
except Exception as e: except Exception as e:
logging.debug(f"Error extracting usage info: {e}") logging.debug(f"Error extracting usage info: {e}")
@@ -704,9 +704,9 @@ class LLM(BaseLLM):
def _handle_non_streaming_response( def _handle_non_streaming_response(
self, self,
params: Dict[str, Any], params: dict[str, Any],
callbacks: Optional[List[Any]] = None, callbacks: list[Any] | None = None,
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
) -> str: ) -> str:
"""Handle a non-streaming response from the LLM. """Handle a non-streaming response from the LLM.
@@ -717,6 +717,7 @@ class LLM(BaseLLM):
Returns: Returns:
str: The response text str: The response text
""" """
# --- 1) Make the completion call # --- 1) Make the completion call
try: try:
@@ -731,7 +732,7 @@ class LLM(BaseLLM):
raise LLMContextLengthExceededException(str(e)) raise LLMContextLengthExceededException(str(e))
# --- 2) Extract response message and content # --- 2) Extract response message and content
response_message = cast(Choices, cast(ModelResponse, response).choices)[ response_message = cast("Choices", cast("ModelResponse", response).choices)[
0 0
].message ].message
text_response = response_message.content or "" text_response = response_message.content or ""
@@ -768,9 +769,9 @@ class LLM(BaseLLM):
def _handle_tool_call( def _handle_tool_call(
self, self,
tool_calls: List[Any], tool_calls: list[Any],
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
) -> Optional[str]: ) -> str | None:
"""Handle a tool call from the LLM. """Handle a tool call from the LLM.
Args: Args:
@@ -779,6 +780,7 @@ class LLM(BaseLLM):
Returns: Returns:
Optional[str]: The result of the tool call, or None if no tool call was made Optional[str]: The result of the tool call, or None if no tool call was made
""" """
# --- 1) Validate tool calls and available functions # --- 1) Validate tool calls and available functions
if not tool_calls or not available_functions: if not tool_calls or not available_functions:
@@ -805,23 +807,23 @@ class LLM(BaseLLM):
except Exception as e: except Exception as e:
# --- 3.4) Handle execution errors # --- 3.4) Handle execution errors
fn = available_functions.get( fn = available_functions.get(
function_name, lambda: None function_name, lambda: None,
) # Ensure fn is always a callable ) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}") logging.exception(f"Error executing function '{function_name}': {e}")
assert hasattr(crewai_event_bus, "emit") assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"), event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
) )
return None return None
def call( def call(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: str | list[dict[str, str]],
tools: Optional[List[dict]] = None, tools: list[dict] | None = None,
callbacks: Optional[List[Any]] = None, callbacks: list[Any] | None = None,
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
) -> Union[str, Any]: ) -> str | Any:
"""High-level LLM call method. """High-level LLM call method.
Args: Args:
@@ -844,6 +846,7 @@ class LLM(BaseLLM):
TypeError: If messages format is invalid TypeError: If messages format is invalid
ValueError: If response format is not supported ValueError: If response format is not supported
LLMContextLengthExceededException: If input exceeds model's context limit LLMContextLengthExceededException: If input exceeds model's context limit
""" """
# --- 1) Emit call started event # --- 1) Emit call started event
assert hasattr(crewai_event_bus, "emit") assert hasattr(crewai_event_bus, "emit")
@@ -882,12 +885,11 @@ class LLM(BaseLLM):
# --- 7) Make the completion call and handle response # --- 7) Make the completion call and handle response
if self.stream: if self.stream:
return self._handle_streaming_response( return self._handle_streaming_response(
params, callbacks, available_functions params, callbacks, available_functions,
)
else:
return self._handle_non_streaming_response(
params, callbacks, available_functions
) )
return self._handle_non_streaming_response(
params, callbacks, available_functions,
)
except LLMContextLengthExceededException: except LLMContextLengthExceededException:
# Re-raise LLMContextLengthExceededException as it should be handled # Re-raise LLMContextLengthExceededException as it should be handled
@@ -900,15 +902,16 @@ class LLM(BaseLLM):
self, self,
event=LLMCallFailedEvent(error=str(e)), event=LLMCallFailedEvent(error=str(e)),
) )
logging.error(f"LiteLLM call failed: {str(e)}") logging.exception(f"LiteLLM call failed: {e!s}")
raise raise
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType): def _handle_emit_call_events(self, response: Any, call_type: LLMCallType) -> None:
"""Handle the events for the LLM call. """Handle the events for the LLM call.
Args: Args:
response (str): The response from the LLM call. response (str): The response from the LLM call.
call_type (str): The type of call, either "tool_call" or "llm_call". call_type (str): The type of call, either "tool_call" or "llm_call".
""" """
assert hasattr(crewai_event_bus, "emit") assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit( crewai_event_bus.emit(
@@ -917,8 +920,8 @@ class LLM(BaseLLM):
) )
def _format_messages_for_provider( def _format_messages_for_provider(
self, messages: List[Dict[str, str]] self, messages: list[dict[str, str]],
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
"""Format messages according to provider requirements. """Format messages according to provider requirements.
Args: Args:
@@ -931,15 +934,18 @@ class LLM(BaseLLM):
Raises: Raises:
TypeError: If messages is None or contains invalid message format. TypeError: If messages is None or contains invalid message format.
""" """
if messages is None: if messages is None:
raise TypeError("Messages cannot be None") msg = "Messages cannot be None"
raise TypeError(msg)
# Validate message format first # Validate message format first
for msg in messages: for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
msg = "Invalid message format. Each message must be a dict with 'role' and 'content' keys"
raise TypeError( raise TypeError(
"Invalid message format. Each message must be a dict with 'role' and 'content' keys" msg,
) )
# Handle O1 models specially # Handle O1 models specially
@@ -949,7 +955,7 @@ class LLM(BaseLLM):
# Convert system messages to assistant messages # Convert system messages to assistant messages
if msg["role"] == "system": if msg["role"] == "system":
formatted_messages.append( formatted_messages.append(
{"role": "assistant", "content": msg["content"]} {"role": "assistant", "content": msg["content"]},
) )
else: else:
formatted_messages.append(msg) formatted_messages.append(msg)
@@ -977,9 +983,8 @@ class LLM(BaseLLM):
return messages return messages
def _get_custom_llm_provider(self) -> Optional[str]: def _get_custom_llm_provider(self) -> str | None:
""" """Derives the custom_llm_provider from the model string.
Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter". - For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
- If the model is "gemini/gemini-1.5-pro", returns "gemini". - If the model is "gemini/gemini-1.5-pro", returns "gemini".
- If there is no '/', defaults to "openai". - If there is no '/', defaults to "openai".
@@ -989,8 +994,7 @@ class LLM(BaseLLM):
return None return None
def _validate_call_params(self) -> None: def _validate_call_params(self) -> None:
""" """Validate parameters before making a call. Currently this only checks if
Validate parameters before making a call. Currently this only checks if
a response_format is provided and whether the model supports it. a response_format is provided and whether the model supports it.
The custom_llm_provider is dynamically determined from the model: The custom_llm_provider is dynamically determined from the model:
- E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter" - E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter"
@@ -1002,19 +1006,22 @@ class LLM(BaseLLM):
model=self.model, model=self.model,
custom_llm_provider=provider, custom_llm_provider=provider,
): ):
raise ValueError( msg = (
f"The model {self.model} does not support response_format for provider '{provider}'. " f"The model {self.model} does not support response_format for provider '{provider}'. "
"Please remove response_format or use a supported model." "Please remove response_format or use a supported model."
) )
raise ValueError(
msg,
)
def supports_function_calling(self) -> bool: def supports_function_calling(self) -> bool:
try: try:
provider = self._get_custom_llm_provider() provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling( return litellm.utils.supports_function_calling(
self.model, custom_llm_provider=provider self.model, custom_llm_provider=provider,
) )
except Exception as e: except Exception as e:
logging.error(f"Failed to check function calling support: {str(e)}") logging.exception(f"Failed to check function calling support: {e!s}")
return False return False
def supports_stop_words(self) -> bool: def supports_stop_words(self) -> bool:
@@ -1022,16 +1029,16 @@ class LLM(BaseLLM):
params = get_supported_openai_params(model=self.model) params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params return params is not None and "stop" in params
except Exception as e: except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}") logging.exception(f"Failed to get supported params: {e!s}")
return False return False
def get_context_window_size(self) -> int: def get_context_window_size(self) -> int:
""" """Returns the context window size, using 75% of the maximum to avoid
Returns the context window size, using 75% of the maximum to avoid
cutting off messages mid-thread. cutting off messages mid-thread.
Raises: Raises:
ValueError: If a model's context window size is outside valid bounds (1024-2097152) ValueError: If a model's context window size is outside valid bounds (1024-2097152)
""" """
if self.context_window_size != 0: if self.context_window_size != 0:
return self.context_window_size return self.context_window_size
@@ -1042,21 +1049,21 @@ class LLM(BaseLLM):
# Validate all context window sizes # Validate all context window sizes
for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < MIN_CONTEXT or value > MAX_CONTEXT: if value < MIN_CONTEXT or value > MAX_CONTEXT:
msg = f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
raise ValueError( raise ValueError(
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}" msg,
) )
self.context_window_size = int( self.context_window_size = int(
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO,
) )
for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if self.model.startswith(key): if self.model.startswith(key):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size return self.context_window_size
def set_callbacks(self, callbacks: List[Any]): def set_callbacks(self, callbacks: list[Any]) -> None:
""" """Attempt to keep a single set of callbacks in litellm by removing old
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones. duplicates and adding new ones.
""" """
with suppress_warnings(): with suppress_warnings():
@@ -1071,9 +1078,8 @@ class LLM(BaseLLM):
litellm.callbacks = callbacks litellm.callbacks = callbacks
def set_env_callbacks(self): def set_env_callbacks(self) -> None:
""" """Sets the success and failure callbacks for the LiteLLM library from environment variables.
Sets the success and failure callbacks for the LiteLLM library from environment variables.
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS` This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
environment variables, which should contain comma-separated lists of callback names. environment variables, which should contain comma-separated lists of callback names.
@@ -1089,6 +1095,7 @@ class LLM(BaseLLM):
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
`litellm.failure_callback` to ["langfuse"]. `litellm.failure_callback` to ["langfuse"].
""" """
with suppress_warnings(): with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "") success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any
class BaseLLM(ABC): class BaseLLM(ABC):
@@ -17,17 +17,18 @@ class BaseLLM(ABC):
Attributes: Attributes:
stop (list): A list of stop sequences that the LLM should use to stop generation. stop (list): A list of stop sequences that the LLM should use to stop generation.
This is used by the CrewAgentExecutor and other components. This is used by the CrewAgentExecutor and other components.
""" """
model: str model: str
temperature: Optional[float] = None temperature: float | None = None
stop: Optional[List[str]] = None stop: list[str] | None = None
def __init__( def __init__(
self, self,
model: str, model: str,
temperature: Optional[float] = None, temperature: float | None = None,
): ) -> None:
"""Initialize the BaseLLM with default attributes. """Initialize the BaseLLM with default attributes.
This constructor sets default values for attributes that are expected This constructor sets default values for attributes that are expected
@@ -43,11 +44,11 @@ class BaseLLM(ABC):
@abstractmethod @abstractmethod
def call( def call(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: str | list[dict[str, str]],
tools: Optional[List[dict]] = None, tools: list[dict] | None = None,
callbacks: Optional[List[Any]] = None, callbacks: list[Any] | None = None,
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
) -> Union[str, Any]: ) -> str | Any:
"""Call the LLM with the given messages. """Call the LLM with the given messages.
Args: Args:
@@ -70,14 +71,15 @@ class BaseLLM(ABC):
ValueError: If the messages format is invalid. ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out. TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons. RuntimeError: If the LLM request fails for other reasons.
""" """
pass
def supports_stop_words(self) -> bool: def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words. """Check if the LLM supports stop words.
Returns: Returns:
bool: True if the LLM supports stop words, False otherwise. bool: True if the LLM supports stop words, False otherwise.
""" """
return True # Default implementation assumes support for stop words return True # Default implementation assumes support for stop words
@@ -86,6 +88,7 @@ class BaseLLM(ABC):
Returns: Returns:
int: The number of tokens/characters the model can handle. int: The number of tokens/characters the model can handle.
""" """
# Default implementation - subclasses should override with model-specific values # Default implementation - subclasses should override with model-specific values
return 4096 return 4096

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union from typing import Any
import aisuite as ai import aisuite as ai
@@ -6,17 +6,17 @@ from crewai.llms.base_llm import BaseLLM
class AISuiteLLM(BaseLLM): class AISuiteLLM(BaseLLM):
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs): def __init__(self, model: str, temperature: float | None = None, **kwargs) -> None:
super().__init__(model, temperature, **kwargs) super().__init__(model, temperature, **kwargs)
self.client = ai.Client() self.client = ai.Client()
def call( def call(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: str | list[dict[str, str]],
tools: Optional[List[dict]] = None, tools: list[dict] | None = None,
callbacks: Optional[List[Any]] = None, callbacks: list[Any] | None = None,
available_functions: Optional[Dict[str, Any]] = None, available_functions: dict[str, Any] | None = None,
) -> Union[str, Any]: ) -> str | Any:
completion_params = self._prepare_completion_params(messages, tools) completion_params = self._prepare_completion_params(messages, tools)
response = self.client.chat.completions.create(**completion_params) response = self.client.chat.completions.create(**completion_params)
@@ -24,9 +24,9 @@ class AISuiteLLM(BaseLLM):
def _prepare_completion_params( def _prepare_completion_params(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: str | list[dict[str, str]],
tools: Optional[List[dict]] = None, tools: list[dict] | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional from typing import Any
from crewai.memory import ( from crewai.memory import (
EntityMemory, EntityMemory,
@@ -12,13 +12,13 @@ from crewai.memory import (
class ContextualMemory: class ContextualMemory:
def __init__( def __init__(
self, self,
memory_config: Optional[Dict[str, Any]], memory_config: dict[str, Any] | None,
stm: ShortTermMemory, stm: ShortTermMemory,
ltm: LongTermMemory, ltm: LongTermMemory,
em: EntityMemory, em: EntityMemory,
um: UserMemory, um: UserMemory,
exm: ExternalMemory, exm: ExternalMemory,
): ) -> None:
if memory_config is not None: if memory_config is not None:
self.memory_provider = memory_config.get("provider") self.memory_provider = memory_config.get("provider")
else: else:
@@ -30,8 +30,7 @@ class ContextualMemory:
self.exm = exm self.exm = exm
def build_context_for_task(self, task, context) -> str: def build_context_for_task(self, task, context) -> str:
""" """Automatically builds a minimal, highly relevant set of contextual information
Automatically builds a minimal, highly relevant set of contextual information
for a given task. for a given task.
""" """
query = f"{task.description} {context}".strip() query = f"{task.description} {context}".strip()
@@ -49,11 +48,9 @@ class ContextualMemory:
return "\n".join(filter(None, context)) return "\n".join(filter(None, context))
def _fetch_stm_context(self, query) -> str: def _fetch_stm_context(self, query) -> str:
""" """Fetches recent relevant insights from STM related to the task's description and expected_output,
Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
""" """
if self.stm is None: if self.stm is None:
return "" return ""
@@ -62,16 +59,14 @@ class ContextualMemory:
[ [
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}" f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
for result in stm_results for result in stm_results
] ],
) )
return f"Recent Insights:\n{formatted_results}" if stm_results else "" return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]: def _fetch_ltm_context(self, task) -> str | None:
""" """Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
""" """
if self.ltm is None: if self.ltm is None:
return "" return ""
@@ -90,8 +85,7 @@ class ContextualMemory:
return f"Historical Data:\n{formatted_results}" if ltm_results else "" return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str: def _fetch_entity_context(self, query) -> str:
""" """Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
""" """
if self.em is None: if self.em is None:
@@ -102,19 +96,20 @@ class ContextualMemory:
[ [
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}" f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
for result in em_results for result in em_results
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" ], # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
) )
return f"Entities:\n{formatted_results}" if em_results else "" return f"Entities:\n{formatted_results}" if em_results else ""
def _fetch_user_context(self, query: str) -> str: def _fetch_user_context(self, query: str) -> str:
""" """Fetches and formats relevant user information from User Memory.
Fetches and formats relevant user information from User Memory.
Args: Args:
query (str): The search query to find relevant user memories. query (str): The search query to find relevant user memories.
Returns: Returns:
str: Formatted user memories as bullet points, or an empty string if none found. str: Formatted user memories as bullet points, or an empty string if none found.
"""
"""
if self.um is None: if self.um is None:
return "" return ""
@@ -128,12 +123,14 @@ class ContextualMemory:
return f"User memories/preferences:\n{formatted_memories}" return f"User memories/preferences:\n{formatted_memories}"
def _fetch_external_context(self, query: str) -> str: def _fetch_external_context(self, query: str) -> str:
""" """Fetches and formats relevant information from External Memory.
Fetches and formats relevant information from External Memory.
Args: Args:
query (str): The search query to find relevant information. query (str): The search query to find relevant information.
Returns: Returns:
str: Formatted information as bullet points, or an empty string if none found. str: Formatted information as bullet points, or an empty string if none found.
""" """
if self.exm is None: if self.exm is None:
return "" return ""

View File

@@ -1,4 +1,3 @@
from typing import Optional
from pydantic import PrivateAttr from pydantic import PrivateAttr
@@ -8,15 +7,14 @@ from crewai.memory.storage.rag_storage import RAGStorage
class EntityMemory(Memory): class EntityMemory(Memory):
""" """EntityMemory class for managing structured information about entities
EntityMemory class for managing structured information about entities
and their relationships using SQLite storage. and their relationships using SQLite storage.
Inherits from the Memory class. Inherits from the Memory class.
""" """
_memory_provider: Optional[str] = PrivateAttr() _memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider") memory_provider = crew.memory_config.get("provider")
else: else:
@@ -26,8 +24,9 @@ class EntityMemory(Memory):
try: try:
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError: except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError( raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`." msg,
) )
storage = Mem0Storage(type="entities", crew=crew) storage = Mem0Storage(type="entities", crew=crew)
else: else:
@@ -63,4 +62,5 @@ class EntityMemory(Memory):
try: try:
self.storage.reset() self.storage.reset()
except Exception as e: except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}") msg = f"An error occurred while resetting the entity memory: {e}"
raise Exception(msg)

View File

@@ -5,7 +5,7 @@ class EntityMemoryItem:
type: str, type: str,
description: str, description: str,
relationships: str, relationships: str,
): ) -> None:
self.name = name self.name = name
self.type = type self.type = type
self.description = description self.description = description

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any
from crewai.memory.external.external_memory_item import ExternalMemoryItem from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
@@ -9,41 +9,44 @@ if TYPE_CHECKING:
class ExternalMemory(Memory): class ExternalMemory(Memory):
def __init__(self, storage: Optional[Storage] = None, **data: Any): def __init__(self, storage: Storage | None = None, **data: Any) -> None:
super().__init__(storage=storage, **data) super().__init__(storage=storage, **data)
@staticmethod @staticmethod
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage": def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config) return Mem0Storage(type="external", crew=crew, config=config)
@staticmethod @staticmethod
def external_supported_storages() -> Dict[str, Any]: def external_supported_storages() -> dict[str, Any]:
return { return {
"mem0": ExternalMemory._configure_mem0, "mem0": ExternalMemory._configure_mem0,
} }
@staticmethod @staticmethod
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage: def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
if not embedder_config: if not embedder_config:
raise ValueError("embedder_config is required") msg = "embedder_config is required"
raise ValueError(msg)
if "provider" not in embedder_config: if "provider" not in embedder_config:
raise ValueError("embedder_config must include a 'provider' key") msg = "embedder_config must include a 'provider' key"
raise ValueError(msg)
provider = embedder_config["provider"] provider = embedder_config["provider"]
supported_storages = ExternalMemory.external_supported_storages() supported_storages = ExternalMemory.external_supported_storages()
if provider not in supported_storages: if provider not in supported_storages:
raise ValueError(f"Provider {provider} not supported") msg = f"Provider {provider} not supported"
raise ValueError(msg)
return supported_storages[provider](crew, embedder_config.get("config", {})) return supported_storages[provider](crew, embedder_config.get("config", {}))
def save( def save(
self, self,
value: Any, value: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
agent: Optional[str] = None, agent: str | None = None,
) -> None: ) -> None:
"""Saves a value into the external storage.""" """Saves a value into the external storage."""
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent) item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)

View File

@@ -1,13 +1,13 @@
from typing import Any, Dict, Optional from typing import Any
class ExternalMemoryItem: class ExternalMemoryItem:
def __init__( def __init__(
self, self,
value: Any, value: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
agent: Optional[str] = None, agent: str | None = None,
): ) -> None:
self.value = value self.value = value
self.metadata = metadata self.metadata = metadata
self.agent = agent self.agent = agent

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List from typing import Any
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
@@ -6,15 +6,14 @@ from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
class LongTermMemory(Memory): class LongTermMemory(Memory):
""" """LongTermMemory class for managing cross runs data related to overall crew's
LongTermMemory class for managing cross runs data related to overall crew's
execution and performance. execution and performance.
Inherits from the Memory class and utilizes an instance of a class that Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with adheres to the Storage for data storage, specifically working with
LongTermMemoryItem instances. LongTermMemoryItem instances.
""" """
def __init__(self, storage=None, path=None): def __init__(self, storage=None, path=None) -> None:
if not storage: if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage) super().__init__(storage=storage)
@@ -29,7 +28,7 @@ class LongTermMemory(Memory):
datetime=item.datetime, datetime=item.datetime,
) )
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" def search(self, task: str, latest_n: int = 3) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load" return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None: def reset(self) -> None:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union from typing import Any
class LongTermMemoryItem: class LongTermMemoryItem:
@@ -8,9 +8,9 @@ class LongTermMemoryItem:
task: str, task: str,
expected_output: str, expected_output: str,
datetime: str, datetime: str,
quality: Optional[Union[int, float]] = None, quality: float | None = None,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
): ) -> None:
self.task = task self.task = task
self.agent = agent self.agent = agent
self.quality = quality self.quality = quality

View File

@@ -1,26 +1,24 @@
from typing import Any, Dict, List, Optional from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
class Memory(BaseModel): class Memory(BaseModel):
""" """Base class for memory, now supporting agent tags and generic metadata."""
Base class for memory, now supporting agent tags and generic metadata.
"""
embedder_config: Optional[Dict[str, Any]] = None embedder_config: dict[str, Any] | None = None
crew: Optional[Any] = None crew: Any | None = None
storage: Any storage: Any
def __init__(self, storage: Any, **data: Any): def __init__(self, storage: Any, **data: Any) -> None:
super().__init__(storage=storage, **data) super().__init__(storage=storage, **data)
def save( def save(
self, self,
value: Any, value: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
agent: Optional[str] = None, agent: str | None = None,
) -> None: ) -> None:
metadata = metadata or {} metadata = metadata or {}
if agent: if agent:
@@ -33,9 +31,9 @@ class Memory(BaseModel):
query: str, query: str,
limit: int = 3, limit: int = 3,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> list[Any]:
return self.storage.search( return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold query=query, limit=limit, score_threshold=score_threshold,
) )
def set_crew(self, crew: Any) -> "Memory": def set_crew(self, crew: Any) -> "Memory":

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional from typing import Any
from pydantic import PrivateAttr from pydantic import PrivateAttr
@@ -8,17 +8,16 @@ from crewai.memory.storage.rag_storage import RAGStorage
class ShortTermMemory(Memory): class ShortTermMemory(Memory):
""" """ShortTermMemory class for managing transient data related to immediate tasks
ShortTermMemory class for managing transient data related to immediate tasks
and interactions. and interactions.
Inherits from the Memory class and utilizes an instance of a class that Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with adheres to the Storage for data storage, specifically working with
MemoryItem instances. MemoryItem instances.
""" """
_memory_provider: Optional[str] = PrivateAttr() _memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider") memory_provider = crew.memory_config.get("provider")
else: else:
@@ -28,8 +27,9 @@ class ShortTermMemory(Memory):
try: try:
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError: except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError( raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`." msg,
) )
storage = Mem0Storage(type="short_term", crew=crew) storage = Mem0Storage(type="short_term", crew=crew)
else: else:
@@ -49,8 +49,8 @@ class ShortTermMemory(Memory):
def save( def save(
self, self,
value: Any, value: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
agent: Optional[str] = None, agent: str | None = None,
) -> None: ) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent) item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
if self._memory_provider == "mem0": if self._memory_provider == "mem0":
@@ -65,13 +65,14 @@ class ShortTermMemory(Memory):
score_threshold: float = 0.35, score_threshold: float = 0.35,
): ):
return self.storage.search( return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold query=query, limit=limit, score_threshold=score_threshold,
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters ) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def reset(self) -> None: def reset(self) -> None:
try: try:
self.storage.reset() self.storage.reset()
except Exception as e: except Exception as e:
msg = f"An error occurred while resetting the short-term memory: {e}"
raise Exception( raise Exception(
f"An error occurred while resetting the short-term memory: {e}" msg,
) )

View File

@@ -1,13 +1,13 @@
from typing import Any, Dict, Optional from typing import Any
class ShortTermMemoryItem: class ShortTermMemoryItem:
def __init__( def __init__(
self, self,
data: Any, data: Any,
agent: Optional[str] = None, agent: str | None = None,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
): ) -> None:
self.data = data self.data = data
self.agent = agent self.agent = agent
self.metadata = metadata if metadata is not None else {} self.metadata = metadata if metadata is not None else {}

View File

@@ -1,11 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import Any
class BaseRAGStorage(ABC): class BaseRAGStorage(ABC):
""" """Base class for RAG-based Storage implementations."""
Base class for RAG-based Storage implementations.
"""
app: Any | None = None app: Any | None = None
@@ -13,9 +11,9 @@ class BaseRAGStorage(ABC):
self, self,
type: str, type: str,
allow_reset: bool = True, allow_reset: bool = True,
embedder_config: Optional[Dict[str, Any]] = None, embedder_config: dict[str, Any] | None = None,
crew: Any = None, crew: Any = None,
): ) -> None:
self.type = type self.type = type
self.allow_reset = allow_reset self.allow_reset = allow_reset
self.embedder_config = embedder_config self.embedder_config = embedder_config
@@ -25,52 +23,44 @@ class BaseRAGStorage(ABC):
def _initialize_agents(self) -> str: def _initialize_agents(self) -> str:
if self.crew: if self.crew:
return "_".join( return "_".join(
[self._sanitize_role(agent.role) for agent in self.crew.agents] [self._sanitize_role(agent.role) for agent in self.crew.agents],
) )
return "" return ""
@abstractmethod @abstractmethod
def _sanitize_role(self, role: str) -> str: def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names.""" """Sanitizes agent roles to ensure valid directory names."""
pass
@abstractmethod @abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value with metadata to the storage.""" """Save a value with metadata to the storage."""
pass
@abstractmethod @abstractmethod
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: dict | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> list[Any]:
"""Search for entries in the storage.""" """Search for entries in the storage."""
pass
@abstractmethod @abstractmethod
def reset(self) -> None: def reset(self) -> None:
"""Reset the storage.""" """Reset the storage."""
pass
@abstractmethod @abstractmethod
def _generate_embedding( def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None self, text: str, metadata: dict[str, Any] | None = None,
) -> Any: ) -> Any:
"""Generate an embedding for the given text and metadata.""" """Generate an embedding for the given text and metadata."""
pass
@abstractmethod @abstractmethod
def _initialize_app(self): def _initialize_app(self):
"""Initialize the vector db.""" """Initialize the vector db."""
pass
def setup_config(self, config: Dict[str, Any]): def setup_config(self, config: dict[str, Any]) -> None:
"""Setup the config of the storage.""" """Setup the config of the storage."""
pass
def initialize_client(self): def initialize_client(self) -> None:
"""Initialize the client of the storage. This should setup the app and the db collection""" """Initialize the client of the storage. This should setup the app and the db collection."""
pass

View File

@@ -1,15 +1,15 @@
from typing import Any, Dict, List from typing import Any
class Storage: class Storage:
"""Abstract base class defining the storage interface""" """Abstract base class defining the storage interface."""
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: dict[str, Any]) -> None:
pass pass
def search( def search(
self, query: str, limit: int, score_threshold: float self, query: str, limit: int, score_threshold: float,
) -> Dict[str, Any] | List[Any]: ) -> dict[str, Any] | list[Any]:
return {} return {}
def reset(self) -> None: def reset(self) -> None:

View File

@@ -2,7 +2,7 @@ import json
import logging import logging
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any
from crewai.task import Task from crewai.task import Task
from crewai.utilities import Printer from crewai.utilities import Printer
@@ -14,12 +14,10 @@ logger = logging.getLogger(__name__)
class KickoffTaskOutputsSQLiteStorage: class KickoffTaskOutputsSQLiteStorage:
""" """An updated SQLite storage class for kickoff task outputs storage."""
An updated SQLite storage class for kickoff task outputs storage.
"""
def __init__( def __init__(
self, db_path: Optional[str] = None self, db_path: str | None = None,
) -> None: ) -> None:
if db_path is None: if db_path is None:
# Get the parent directory of the default db path and create our db file there # Get the parent directory of the default db path and create our db file there
@@ -37,6 +35,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises: Raises:
DatabaseOperationError: If database initialization fails due to SQLite errors. DatabaseOperationError: If database initialization fails due to SQLite errors.
""" """
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -52,22 +51,22 @@ class KickoffTaskOutputsSQLiteStorage:
was_replayed BOOLEAN, was_replayed BOOLEAN,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
) )
""" """,
) )
conn.commit() conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
logger.error(error_msg) logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e)
def add( def add(
self, self,
task: Task, task: Task,
output: Dict[str, Any], output: dict[str, Any],
task_index: int, task_index: int,
was_replayed: bool = False, was_replayed: bool = False,
inputs: Dict[str, Any] = {}, inputs: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Add a new task output record to the database. """Add a new task output record to the database.
@@ -80,7 +79,10 @@ class KickoffTaskOutputsSQLiteStorage:
Raises: Raises:
DatabaseOperationError: If saving the task output fails due to SQLite errors. DatabaseOperationError: If saving the task output fails due to SQLite errors.
""" """
if inputs is None:
inputs = {}
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.execute("BEGIN TRANSACTION") conn.execute("BEGIN TRANSACTION")
@@ -103,7 +105,7 @@ class KickoffTaskOutputsSQLiteStorage:
conn.commit() conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
logger.error(error_msg) logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e)
def update( def update(
@@ -123,6 +125,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises: Raises:
DatabaseOperationError: If updating the task output fails due to SQLite errors. DatabaseOperationError: If updating the task output fails due to SQLite errors.
""" """
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -136,7 +139,7 @@ class KickoffTaskOutputsSQLiteStorage:
values.append( values.append(
json.dumps(value, cls=CrewJSONEncoder) json.dumps(value, cls=CrewJSONEncoder)
if isinstance(value, dict) if isinstance(value, dict)
else value else value,
) )
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
@@ -149,10 +152,10 @@ class KickoffTaskOutputsSQLiteStorage:
logger.warning(f"No row found with task_index {task_index}. No update performed.") logger.warning(f"No row found with task_index {task_index}. No update performed.")
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
logger.error(error_msg) logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e)
def load(self) -> List[Dict[str, Any]]: def load(self) -> list[dict[str, Any]]:
"""Load all task output records from the database. """Load all task output records from the database.
Returns: Returns:
@@ -162,6 +165,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises: Raises:
DatabaseOperationError: If loading task outputs fails due to SQLite errors. DatabaseOperationError: If loading task outputs fails due to SQLite errors.
""" """
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -190,7 +194,7 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
logger.error(error_msg) logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e)
def delete_all(self) -> None: def delete_all(self) -> None:
@@ -201,6 +205,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises: Raises:
DatabaseOperationError: If deleting task outputs fails due to SQLite errors. DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
""" """
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -210,5 +215,5 @@ class KickoffTaskOutputsSQLiteStorage:
conn.commit() conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e) error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
logger.error(error_msg) logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e) raise DatabaseOperationError(error_msg, e)

View File

@@ -1,19 +1,17 @@
import json import json
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any
from crewai.utilities import Printer from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
class LTMSQLiteStorage: class LTMSQLiteStorage:
""" """An updated SQLite storage class for LTM data storage."""
An updated SQLite storage class for LTM data storage.
"""
def __init__( def __init__(
self, db_path: Optional[str] = None self, db_path: str | None = None,
) -> None: ) -> None:
if db_path is None: if db_path is None:
# Get the parent directory of the default db path and create our db file there # Get the parent directory of the default db path and create our db file there
@@ -24,10 +22,8 @@ class LTMSQLiteStorage:
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._initialize_db() self._initialize_db()
def _initialize_db(self): def _initialize_db(self) -> None:
""" """Initializes the SQLite database and creates LTM table."""
Initializes the SQLite database and creates LTM table
"""
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor() cursor = conn.cursor()
@@ -40,7 +36,7 @@ class LTMSQLiteStorage:
datetime TEXT, datetime TEXT,
score REAL score REAL
) )
""" """,
) )
conn.commit() conn.commit()
@@ -53,9 +49,9 @@ class LTMSQLiteStorage:
def save( def save(
self, self,
task_description: str, task_description: str,
metadata: Dict[str, Any], metadata: dict[str, Any],
datetime: str, datetime: str,
score: Union[int, float], score: float,
) -> None: ) -> None:
"""Saves data to the LTM table with error handling.""" """Saves data to the LTM table with error handling."""
try: try:
@@ -76,8 +72,8 @@ class LTMSQLiteStorage:
) )
def load( def load(
self, task_description: str, latest_n: int self, task_description: str, latest_n: int,
) -> Optional[List[Dict[str, Any]]]: ) -> list[dict[str, Any]] | None:
"""Queries the LTM table by task description with error handling.""" """Queries the LTM table by task description with error handling."""
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@@ -125,4 +121,3 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}", content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red", color="red",
) )
return None

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, List from typing import Any
from mem0 import Memory, MemoryClient from mem0 import Memory, MemoryClient
@@ -7,17 +7,15 @@ from crewai.memory.storage.interface import Storage
class Mem0Storage(Storage): class Mem0Storage(Storage):
""" """Extends Storage to handle embedding and searching across entities using Mem0."""
Extends Storage to handle embedding and searching across entities using Mem0.
"""
def __init__(self, type, crew=None, config=None): def __init__(self, type, crew=None, config=None) -> None:
super().__init__() super().__init__()
supported_types = ["user", "short_term", "long_term", "entities", "external"] supported_types = ["user", "short_term", "long_term", "entities", "external"]
if type not in supported_types: if type not in supported_types:
raise ValueError( raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: " f"Invalid type '{type}' for Mem0Storage. Must be one of: "
+ ", ".join(supported_types) + ", ".join(supported_types),
) )
self.memory_type = type self.memory_type = type
@@ -29,7 +27,8 @@ class Mem0Storage(Storage):
# User ID is required for user memory type "user" since it's used as a unique identifier for the user. # User ID is required for user memory type "user" since it's used as a unique identifier for the user.
user_id = self._get_user_id() user_id = self._get_user_id()
if type == "user" and not user_id: if type == "user" and not user_id:
raise ValueError("User ID is required for user memory type") msg = "User ID is required for user memory type"
raise ValueError(msg)
# API key in memory config overrides the environment variable # API key in memory config overrides the environment variable
config = self._get_config() config = self._get_config()
@@ -42,23 +41,20 @@ class Mem0Storage(Storage):
if mem0_api_key: if mem0_api_key:
if mem0_org_id and mem0_project_id: if mem0_org_id and mem0_project_id:
self.memory = MemoryClient( self.memory = MemoryClient(
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id,
) )
else: else:
self.memory = MemoryClient(api_key=mem0_api_key) self.memory = MemoryClient(api_key=mem0_api_key)
elif mem0_local_config and len(mem0_local_config):
self.memory = Memory.from_config(mem0_local_config)
else: else:
if mem0_local_config and len(mem0_local_config): self.memory = Memory()
self.memory = Memory.from_config(mem0_local_config)
else:
self.memory = Memory()
def _sanitize_role(self, role: str) -> str: def _sanitize_role(self, role: str) -> str:
""" """Sanitizes agent roles to ensure valid directory names."""
Sanitizes agent roles to ensure valid directory names.
"""
return role.replace("\n", "").replace(" ", "_").replace("/", "_") return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: dict[str, Any]) -> None:
user_id = self._get_user_id() user_id = self._get_user_id()
agent_name = self._get_agent_name() agent_name = self._get_agent_name()
params = None params = None
@@ -97,7 +93,7 @@ class Mem0Storage(Storage):
query: str, query: str,
limit: int = 3, limit: int = 3,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> list[Any]:
params = {"query": query, "limit": limit, "output_format": "v1.1"} params = {"query": query, "limit": limit, "output_format": "v1.1"}
if user_id := self._get_user_id(): if user_id := self._get_user_id():
params["user_id"] = user_id params["user_id"] = user_id
@@ -133,12 +129,11 @@ class Mem0Storage(Storage):
agents = self.crew.agents agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents] agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents) return "_".join(agents)
return agents
def _get_config(self) -> Dict[str, Any]: def _get_config(self) -> dict[str, Any]:
return self.config or getattr(self, "memory_config", {}).get("config", {}) or {} return self.config or getattr(self, "memory_config", {}).get("config", {}) or {}
def reset(self): def reset(self) -> None:
if self.memory: if self.memory:
self.memory.reset() self.memory.reset()

View File

@@ -4,7 +4,7 @@ import logging
import os import os
import shutil import shutil
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
@@ -32,16 +32,15 @@ def suppress_logging(
class RAGStorage(BaseRAGStorage): class RAGStorage(BaseRAGStorage):
""" """Extends Storage to handle embeddings for memory entries, improving
Extends Storage to handle embeddings for memory entries, improving
search efficiency. search efficiency.
""" """
app: ClientAPI | None = None app: ClientAPI | None = None
def __init__( def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None self, type, allow_reset=True, embedder_config=None, crew=None, path=None,
): ) -> None:
super().__init__(type, allow_reset, embedder_config, crew) super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else [] agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents] agents = [self._sanitize_role(agent.role) for agent in agents]
@@ -55,11 +54,11 @@ class RAGStorage(BaseRAGStorage):
self.path = path self.path = path
self._initialize_app() self._initialize_app()
def _set_embedder_config(self): def _set_embedder_config(self) -> None:
configurator = EmbeddingConfigurator() configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config) self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self): def _initialize_app(self) -> None:
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
@@ -73,48 +72,44 @@ class RAGStorage(BaseRAGStorage):
try: try:
self.collection = self.app.get_collection( self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config name=self.type, embedding_function=self.embedder_config,
) )
except Exception: except Exception:
self.collection = self.app.create_collection( self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config name=self.type, embedding_function=self.embedder_config,
) )
def _sanitize_role(self, role: str) -> str: def _sanitize_role(self, role: str) -> str:
""" """Sanitizes agent roles to ensure valid directory names."""
Sanitizes agent roles to ensure valid directory names.
"""
return role.replace("\n", "").replace(" ", "_").replace("/", "_") return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _build_storage_file_name(self, type: str, file_name: str) -> str: def _build_storage_file_name(self, type: str, file_name: str) -> str:
""" """Ensures file name does not exceed max allowed by OS."""
Ensures file name does not exceed max allowed by OS
"""
base_path = f"{db_storage_path()}/{type}" base_path = f"{db_storage_path()}/{type}"
if len(file_name) > MAX_FILE_NAME_LENGTH: if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning( logging.warning(
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters." f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters.",
) )
file_name = file_name[:MAX_FILE_NAME_LENGTH] file_name = file_name[:MAX_FILE_NAME_LENGTH]
return f"{base_path}/{file_name}" return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"): if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app() self._initialize_app()
try: try:
self._generate_embedding(value, metadata) self._generate_embedding(value, metadata)
except Exception as e: except Exception as e:
logging.error(f"Error during {self.type} save: {str(e)}") logging.exception(f"Error during {self.type} save: {e!s}")
def search( def search(
self, self,
query: str, query: str,
limit: int = 3, limit: int = 3,
filter: Optional[dict] = None, filter: dict | None = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> list[Any]:
if not hasattr(self, "app"): if not hasattr(self, "app"):
self._initialize_app() self._initialize_app()
@@ -135,10 +130,10 @@ class RAGStorage(BaseRAGStorage):
return results return results
except Exception as e: except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}") logging.exception(f"Error during {self.type} search: {e!s}")
return [] return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"): if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app() self._initialize_app()
@@ -160,8 +155,9 @@ class RAGStorage(BaseRAGStorage):
# Ignore this specific error # Ignore this specific error
pass pass
else: else:
msg = f"An error occurred while resetting the {self.type} memory: {e}"
raise Exception( raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}" msg,
) )
def _create_default_embedding_function(self): def _create_default_embedding_function(self):
@@ -170,5 +166,5 @@ class RAGStorage(BaseRAGStorage):
) )
return OpenAIEmbeddingFunction( return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
) )

View File

@@ -1,18 +1,17 @@
import warnings import warnings
from typing import Any, Dict, Optional from typing import Any
from crewai.memory.memory import Memory from crewai.memory.memory import Memory
class UserMemory(Memory): class UserMemory(Memory):
""" """UserMemory class for handling user memory storage and retrieval.
UserMemory class for handling user memory storage and retrieval.
Inherits from the Memory class and utilizes an instance of a class that Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with adheres to the Storage for data storage, specifically working with
MemoryItem instances. MemoryItem instances.
""" """
def __init__(self, crew=None): def __init__(self, crew=None) -> None:
warnings.warn( warnings.warn(
"UserMemory is deprecated and will be removed in a future version. " "UserMemory is deprecated and will be removed in a future version. "
"Please use ExternalMemory instead.", "Please use ExternalMemory instead.",
@@ -22,8 +21,9 @@ class UserMemory(Memory):
try: try:
from crewai.memory.storage.mem0_storage import Mem0Storage from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError: except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError( raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`." msg,
) )
storage = Mem0Storage(type="user", crew=crew) storage = Mem0Storage(type="user", crew=crew)
super().__init__(storage) super().__init__(storage)
@@ -31,8 +31,8 @@ class UserMemory(Memory):
def save( def save(
self, self,
value, value,
metadata: Optional[Dict[str, Any]] = None, metadata: dict[str, Any] | None = None,
agent: Optional[str] = None, agent: str | None = None,
) -> None: ) -> None:
# TODO: Change this function since we want to take care of the case where we save memories for the usr # TODO: Change this function since we want to take care of the case where we save memories for the usr
data = f"Remember the details about the user: {value}" data = f"Remember the details about the user: {value}"
@@ -44,15 +44,15 @@ class UserMemory(Memory):
limit: int = 3, limit: int = 3,
score_threshold: float = 0.35, score_threshold: float = 0.35,
): ):
results = self.storage.search( return self.storage.search(
query=query, query=query,
limit=limit, limit=limit,
score_threshold=score_threshold, score_threshold=score_threshold,
) )
return results
def reset(self) -> None: def reset(self) -> None:
try: try:
self.storage.reset() self.storage.reset()
except Exception as e: except Exception as e:
raise Exception(f"An error occurred while resetting the user memory: {e}") msg = f"An error occurred while resetting the user memory: {e}"
raise Exception(msg)

View File

@@ -1,8 +1,8 @@
from typing import Any, Dict, Optional from typing import Any
class UserMemoryItem: class UserMemoryItem:
def __init__(self, data: Any, user: str, metadata: Optional[Dict[str, Any]] = None): def __init__(self, data: Any, user: str, metadata: dict[str, Any] | None = None) -> None:
self.data = data self.data = data
self.user = user self.user = user
self.metadata = metadata if metadata is not None else {} self.metadata = metadata if metadata is not None else {}

View File

@@ -2,9 +2,7 @@ from enum import Enum
class Process(str, Enum): class Process(str, Enum):
""" """Class representing the different processes that can be used to tackle tasks."""
Class representing the different processes that can be used to tackle tasks
"""
sequential = "sequential" sequential = "sequential"
hierarchical = "hierarchical" hierarchical = "hierarchical"

View File

@@ -1,5 +1,5 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Callable
from crewai import Crew from crewai import Crew
from crewai.project.utils import memoize from crewai.project.utils import memoize
@@ -36,15 +36,13 @@ def task(func):
def agent(func): def agent(func):
"""Marks a method as a crew agent.""" """Marks a method as a crew agent."""
func.is_agent = True func.is_agent = True
func = memoize(func) return memoize(func)
return func
def llm(func): def llm(func):
"""Marks a method as an LLM provider.""" """Marks a method as an LLM provider."""
func.is_llm = True func.is_llm = True
func = memoize(func) return memoize(func)
return func
def output_json(cls): def output_json(cls):
@@ -91,7 +89,7 @@ def crew(func) -> Callable[..., Crew]:
agents = self._original_agents.items() agents = self._original_agents.items()
# Instantiate tasks in order # Instantiate tasks in order
for task_name, task_method in tasks: for _task_name, task_method in tasks:
task_instance = task_method(self) task_instance = task_method(self)
instantiated_tasks.append(task_instance) instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None) agent_instance = getattr(task_instance, "agent", None)
@@ -100,7 +98,7 @@ def crew(func) -> Callable[..., Crew]:
agent_roles.add(agent_instance.role) agent_roles.add(agent_instance.role)
# Instantiate agents not included by tasks # Instantiate agents not included by tasks
for agent_name, agent_method in agents: for _agent_name, agent_method in agents:
agent_instance = agent_method(self) agent_instance = agent_method(self)
if agent_instance.role not in agent_roles: if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance) instantiated_agents.append(agent_instance)
@@ -117,9 +115,9 @@ def crew(func) -> Callable[..., Crew]:
return wrapper return wrapper
for _, callback in self._before_kickoff.items(): for callback in self._before_kickoff.values():
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self)) crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
for _, callback in self._after_kickoff.items(): for callback in self._after_kickoff.values():
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self)) crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
return crew return crew

View File

@@ -1,7 +1,8 @@
import inspect import inspect
import logging import logging
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, TypeVar, cast from typing import Any, TypeVar, cast
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -23,11 +24,11 @@ def CrewBase(cls: T) -> T:
base_directory = Path(inspect.getfile(cls)).parent base_directory = Path(inspect.getfile(cls)).parent
original_agents_config_path = getattr( original_agents_config_path = getattr(
cls, "agents_config", "config/agents.yaml" cls, "agents_config", "config/agents.yaml",
) )
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml") original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.load_configurations() self.load_configurations()
self.map_all_agent_variables() self.map_all_agent_variables()
@@ -49,22 +50,22 @@ def CrewBase(cls: T) -> T:
} }
# Store specific function types # Store specific function types
self._original_tasks = self._filter_functions( self._original_tasks = self._filter_functions(
self._original_functions, "is_task" self._original_functions, "is_task",
) )
self._original_agents = self._filter_functions( self._original_agents = self._filter_functions(
self._original_functions, "is_agent" self._original_functions, "is_agent",
) )
self._before_kickoff = self._filter_functions( self._before_kickoff = self._filter_functions(
self._original_functions, "is_before_kickoff" self._original_functions, "is_before_kickoff",
) )
self._after_kickoff = self._filter_functions( self._after_kickoff = self._filter_functions(
self._original_functions, "is_after_kickoff" self._original_functions, "is_after_kickoff",
) )
self._kickoff = self._filter_functions( self._kickoff = self._filter_functions(
self._original_functions, "is_kickoff" self._original_functions, "is_kickoff",
) )
def load_configurations(self): def load_configurations(self) -> None:
"""Load agent and task configurations from YAML files.""" """Load agent and task configurations from YAML files."""
if isinstance(self.original_agents_config_path, str): if isinstance(self.original_agents_config_path, str):
agents_config_path = ( agents_config_path = (
@@ -75,12 +76,12 @@ def CrewBase(cls: T) -> T:
except FileNotFoundError: except FileNotFoundError:
logging.warning( logging.warning(
f"Agent config file not found at {agents_config_path}. " f"Agent config file not found at {agents_config_path}. "
"Proceeding with empty agent configurations." "Proceeding with empty agent configurations.",
) )
self.agents_config = {} self.agents_config = {}
else: else:
logging.warning( logging.warning(
"No agent configuration path provided. Proceeding with empty agent configurations." "No agent configuration path provided. Proceeding with empty agent configurations.",
) )
self.agents_config = {} self.agents_config = {}
@@ -93,22 +94,21 @@ def CrewBase(cls: T) -> T:
except FileNotFoundError: except FileNotFoundError:
logging.warning( logging.warning(
f"Task config file not found at {tasks_config_path}. " f"Task config file not found at {tasks_config_path}. "
"Proceeding with empty task configurations." "Proceeding with empty task configurations.",
) )
self.tasks_config = {} self.tasks_config = {}
else: else:
logging.warning( logging.warning(
"No task configuration path provided. Proceeding with empty task configurations." "No task configuration path provided. Proceeding with empty task configurations.",
) )
self.tasks_config = {} self.tasks_config = {}
@staticmethod @staticmethod
def load_yaml(config_path: Path): def load_yaml(config_path: Path):
try: try:
with open(config_path, "r", encoding="utf-8") as file: with open(config_path, encoding="utf-8") as file:
return yaml.safe_load(file) return yaml.safe_load(file)
except FileNotFoundError: except FileNotFoundError:
print(f"File not found: {config_path}")
raise raise
def _get_all_functions(self): def _get_all_functions(self):
@@ -119,8 +119,8 @@ def CrewBase(cls: T) -> T:
} }
def _filter_functions( def _filter_functions(
self, functions: Dict[str, Callable], attribute: str self, functions: dict[str, Callable], attribute: str,
) -> Dict[str, Callable]: ) -> dict[str, Callable]:
return { return {
name: func name: func
for name, func in functions.items() for name, func in functions.items()
@@ -132,7 +132,7 @@ def CrewBase(cls: T) -> T:
llms = self._filter_functions(all_functions, "is_llm") llms = self._filter_functions(all_functions, "is_llm")
tool_functions = self._filter_functions(all_functions, "is_tool") tool_functions = self._filter_functions(all_functions, "is_tool")
cache_handler_functions = self._filter_functions( cache_handler_functions = self._filter_functions(
all_functions, "is_cache_handler" all_functions, "is_cache_handler",
) )
callbacks = self._filter_functions(all_functions, "is_callback") callbacks = self._filter_functions(all_functions, "is_callback")
@@ -149,11 +149,11 @@ def CrewBase(cls: T) -> T:
def _map_agent_variables( def _map_agent_variables(
self, self,
agent_name: str, agent_name: str,
agent_info: Dict[str, Any], agent_info: dict[str, Any],
llms: Dict[str, Callable], llms: dict[str, Callable],
tool_functions: Dict[str, Callable], tool_functions: dict[str, Callable],
cache_handler_functions: Dict[str, Callable], cache_handler_functions: dict[str, Callable],
callbacks: Dict[str, Callable], callbacks: dict[str, Callable],
) -> None: ) -> None:
if llm := agent_info.get("llm"): if llm := agent_info.get("llm"):
try: try:
@@ -187,12 +187,12 @@ def CrewBase(cls: T) -> T:
agents = self._filter_functions(all_functions, "is_agent") agents = self._filter_functions(all_functions, "is_agent")
tasks = self._filter_functions(all_functions, "is_task") tasks = self._filter_functions(all_functions, "is_task")
output_json_functions = self._filter_functions( output_json_functions = self._filter_functions(
all_functions, "is_output_json" all_functions, "is_output_json",
) )
tool_functions = self._filter_functions(all_functions, "is_tool") tool_functions = self._filter_functions(all_functions, "is_tool")
callback_functions = self._filter_functions(all_functions, "is_callback") callback_functions = self._filter_functions(all_functions, "is_callback")
output_pydantic_functions = self._filter_functions( output_pydantic_functions = self._filter_functions(
all_functions, "is_output_pydantic" all_functions, "is_output_pydantic",
) )
for task_name, task_info in self.tasks_config.items(): for task_name, task_info in self.tasks_config.items():
@@ -210,13 +210,13 @@ def CrewBase(cls: T) -> T:
def _map_task_variables( def _map_task_variables(
self, self,
task_name: str, task_name: str,
task_info: Dict[str, Any], task_info: dict[str, Any],
agents: Dict[str, Callable], agents: dict[str, Callable],
tasks: Dict[str, Callable], tasks: dict[str, Callable],
output_json_functions: Dict[str, Callable], output_json_functions: dict[str, Callable],
tool_functions: Dict[str, Callable], tool_functions: dict[str, Callable],
callback_functions: Dict[str, Callable], callback_functions: dict[str, Callable],
output_pydantic_functions: Dict[str, Callable], output_pydantic_functions: dict[str, Callable],
) -> None: ) -> None:
if context_list := task_info.get("context"): if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [ self.tasks_config[task_name]["context"] = [
@@ -253,4 +253,4 @@ def CrewBase(cls: T) -> T:
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")" WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")" WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
return cast(T, WrappedClass) return cast("T", WrappedClass)

View File

@@ -1,5 +1,4 @@
""" """Fingerprint Module.
Fingerprint Module
This module provides functionality for generating and validating unique identifiers This module provides functionality for generating and validating unique identifiers
for CrewAI agents. These identifiers are used for tracking, auditing, and security. for CrewAI agents. These identifiers are used for tracking, auditing, and security.
@@ -7,14 +6,13 @@ for CrewAI agents. These identifiers are used for tracking, auditing, and securi
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
class Fingerprint(BaseModel): class Fingerprint(BaseModel):
""" """A class for generating and managing unique identifiers for agents.
A class for generating and managing unique identifiers for agents.
Each agent has dual identifiers: Each agent has dual identifiers:
- Human-readable ID: For debugging and reference (derived from role if not specified) - Human-readable ID: For debugging and reference (derived from role if not specified)
@@ -24,48 +22,54 @@ class Fingerprint(BaseModel):
uuid_str (str): String representation of the UUID for this fingerprint, auto-generated uuid_str (str): String representation of the UUID for this fingerprint, auto-generated
created_at (datetime): When this fingerprint was created, auto-generated created_at (datetime): When this fingerprint was created, auto-generated
metadata (Dict[str, Any]): Additional metadata associated with this fingerprint metadata (Dict[str, Any]): Additional metadata associated with this fingerprint
""" """
uuid_str: str = Field(default_factory=lambda: str(uuid.uuid4()), description="String representation of the UUID") uuid_str: str = Field(default_factory=lambda: str(uuid.uuid4()), description="String representation of the UUID")
created_at: datetime = Field(default_factory=datetime.now, description="When this fingerprint was created") created_at: datetime = Field(default_factory=datetime.now, description="When this fingerprint was created")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this fingerprint") metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this fingerprint")
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@field_validator('metadata') @field_validator("metadata")
@classmethod @classmethod
def validate_metadata(cls, v): def validate_metadata(cls, v):
"""Validate that metadata is a dictionary with string keys and valid values.""" """Validate that metadata is a dictionary with string keys and valid values."""
if not isinstance(v, dict): if not isinstance(v, dict):
raise ValueError("Metadata must be a dictionary") msg = "Metadata must be a dictionary"
raise ValueError(msg)
# Validate that all keys are strings # Validate that all keys are strings
for key, value in v.items(): for key, value in v.items():
if not isinstance(key, str): if not isinstance(key, str):
raise ValueError(f"Metadata keys must be strings, got {type(key)}") msg = f"Metadata keys must be strings, got {type(key)}"
raise ValueError(msg)
# Validate nested dictionaries (prevent deeply nested structures) # Validate nested dictionaries (prevent deeply nested structures)
if isinstance(value, dict): if isinstance(value, dict):
# Check for nested dictionaries (limit depth to 1) # Check for nested dictionaries (limit depth to 1)
for nested_key, nested_value in value.items(): for nested_key, nested_value in value.items():
if not isinstance(nested_key, str): if not isinstance(nested_key, str):
raise ValueError(f"Nested metadata keys must be strings, got {type(nested_key)}") msg = f"Nested metadata keys must be strings, got {type(nested_key)}"
raise ValueError(msg)
if isinstance(nested_value, dict): if isinstance(nested_value, dict):
raise ValueError("Metadata can only be nested one level deep") msg = "Metadata can only be nested one level deep"
raise ValueError(msg)
# Check for maximum metadata size (prevent DoS) # Check for maximum metadata size (prevent DoS)
if len(str(v)) > 10000: # Limit metadata size to 10KB if len(str(v)) > 10000: # Limit metadata size to 10KB
raise ValueError("Metadata size exceeds maximum allowed (10KB)") msg = "Metadata size exceeds maximum allowed (10KB)"
raise ValueError(msg)
return v return v
def __init__(self, **data): def __init__(self, **data) -> None:
"""Initialize a Fingerprint with auto-generated uuid_str and created_at.""" """Initialize a Fingerprint with auto-generated uuid_str and created_at."""
# Remove uuid_str and created_at from data to ensure they're auto-generated # Remove uuid_str and created_at from data to ensure they're auto-generated
if 'uuid_str' in data: if "uuid_str" in data:
data.pop('uuid_str') data.pop("uuid_str")
if 'created_at' in data: if "created_at" in data:
data.pop('created_at') data.pop("created_at")
# Call the parent constructor with the modified data # Call the parent constructor with the modified data
super().__init__(**data) super().__init__(**data)
@@ -77,32 +81,33 @@ class Fingerprint(BaseModel):
@classmethod @classmethod
def _generate_uuid(cls, seed: str) -> str: def _generate_uuid(cls, seed: str) -> str:
""" """Generate a deterministic UUID based on a seed string.
Generate a deterministic UUID based on a seed string.
Args: Args:
seed (str): The seed string to use for UUID generation seed (str): The seed string to use for UUID generation
Returns: Returns:
str: A string representation of the UUID consistently generated from the seed str: A string representation of the UUID consistently generated from the seed
""" """
if not isinstance(seed, str): if not isinstance(seed, str):
raise ValueError("Seed must be a string") msg = "Seed must be a string"
raise ValueError(msg)
if not seed.strip(): if not seed.strip():
raise ValueError("Seed cannot be empty or whitespace") msg = "Seed cannot be empty or whitespace"
raise ValueError(msg)
# Create a deterministic UUID using v5 (SHA-1) # Create a deterministic UUID using v5 (SHA-1)
# Custom namespace for CrewAI to enhance security # Custom namespace for CrewAI to enhance security
# Using a unique namespace specific to CrewAI to reduce collision risks # Using a unique namespace specific to CrewAI to reduce collision risks
CREW_AI_NAMESPACE = uuid.UUID('f47ac10b-58cc-4372-a567-0e02b2c3d479') CREW_AI_NAMESPACE = uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
return str(uuid.uuid5(CREW_AI_NAMESPACE, seed)) return str(uuid.uuid5(CREW_AI_NAMESPACE, seed))
@classmethod @classmethod
def generate(cls, seed: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> 'Fingerprint': def generate(cls, seed: str | None = None, metadata: dict[str, Any] | None = None) -> "Fingerprint":
""" """Static factory method to create a new Fingerprint.
Static factory method to create a new Fingerprint.
Args: Args:
seed (Optional[str]): A string to use as seed for the UUID generation. seed (Optional[str]): A string to use as seed for the UUID generation.
@@ -111,11 +116,12 @@ class Fingerprint(BaseModel):
Returns: Returns:
Fingerprint: A new Fingerprint instance Fingerprint: A new Fingerprint instance
""" """
fingerprint = cls(metadata=metadata or {}) fingerprint = cls(metadata=metadata or {})
if seed: if seed:
# For seed-based generation, we need to manually set the uuid_str after creation # For seed-based generation, we need to manually set the uuid_str after creation
object.__setattr__(fingerprint, 'uuid_str', cls._generate_uuid(seed)) object.__setattr__(fingerprint, "uuid_str", cls._generate_uuid(seed))
return fingerprint return fingerprint
def __str__(self) -> str: def __str__(self) -> str:
@@ -132,29 +138,29 @@ class Fingerprint(BaseModel):
"""Hash of the fingerprint (based on UUID).""" """Hash of the fingerprint (based on UUID)."""
return hash(self.uuid_str) return hash(self.uuid_str)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
""" """Convert the fingerprint to a dictionary representation.
Convert the fingerprint to a dictionary representation.
Returns: Returns:
Dict[str, Any]: Dictionary representation of the fingerprint Dict[str, Any]: Dictionary representation of the fingerprint
""" """
return { return {
"uuid_str": self.uuid_str, "uuid_str": self.uuid_str,
"created_at": self.created_at.isoformat(), "created_at": self.created_at.isoformat(),
"metadata": self.metadata "metadata": self.metadata,
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Fingerprint': def from_dict(cls, data: dict[str, Any]) -> "Fingerprint":
""" """Create a Fingerprint from a dictionary representation.
Create a Fingerprint from a dictionary representation.
Args: Args:
data (Dict[str, Any]): Dictionary representation of a fingerprint data (Dict[str, Any]): Dictionary representation of a fingerprint
Returns: Returns:
Fingerprint: A new Fingerprint instance Fingerprint: A new Fingerprint instance
""" """
if not data: if not data:
return cls() return cls()
@@ -163,8 +169,8 @@ class Fingerprint(BaseModel):
# For consistency with existing stored fingerprints, we need to manually set these # For consistency with existing stored fingerprints, we need to manually set these
if "uuid_str" in data: if "uuid_str" in data:
object.__setattr__(fingerprint, 'uuid_str', data["uuid_str"]) object.__setattr__(fingerprint, "uuid_str", data["uuid_str"])
if "created_at" in data and isinstance(data["created_at"], str): if "created_at" in data and isinstance(data["created_at"], str):
object.__setattr__(fingerprint, 'created_at', datetime.fromisoformat(data["created_at"])) object.__setattr__(fingerprint, "created_at", datetime.fromisoformat(data["created_at"]))
return fingerprint return fingerprint

View File

@@ -1,5 +1,4 @@
""" """Security Configuration Module.
Security Configuration Module
This module provides configuration for CrewAI security features, including: This module provides configuration for CrewAI security features, including:
- Authentication settings - Authentication settings
@@ -10,7 +9,7 @@ The SecurityConfig class is the primary interface for managing security settings
in CrewAI applications. in CrewAI applications.
""" """
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
@@ -18,8 +17,7 @@ from crewai.security.fingerprint import Fingerprint
class SecurityConfig(BaseModel): class SecurityConfig(BaseModel):
""" """Configuration for CrewAI security features.
Configuration for CrewAI security features.
This class manages security settings for CrewAI agents, including: This class manages security settings for CrewAI agents, including:
- Authentication credentials *TODO* - Authentication credentials *TODO*
@@ -30,82 +28,83 @@ class SecurityConfig(BaseModel):
Attributes: Attributes:
version (str): Version of the security configuration version (str): Version of the security configuration
fingerprint (Fingerprint): The unique fingerprint automatically generated for the component fingerprint (Fingerprint): The unique fingerprint automatically generated for the component
""" """
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True arbitrary_types_allowed=True,
# Note: Cannot use frozen=True as existing tests modify the fingerprint property # Note: Cannot use frozen=True as existing tests modify the fingerprint property
) )
version: str = Field( version: str = Field(
default="1.0.0", default="1.0.0",
description="Version of the security configuration" description="Version of the security configuration",
) )
fingerprint: Fingerprint = Field( fingerprint: Fingerprint = Field(
default_factory=Fingerprint, default_factory=Fingerprint,
description="Unique identifier for the component" description="Unique identifier for the component",
) )
def is_compatible(self, min_version: str) -> bool: def is_compatible(self, min_version: str) -> bool:
""" """Check if this security configuration is compatible with the minimum required version.
Check if this security configuration is compatible with the minimum required version.
Args: Args:
min_version (str): Minimum required version in semver format (e.g., "1.0.0") min_version (str): Minimum required version in semver format (e.g., "1.0.0")
Returns: Returns:
bool: True if this configuration is compatible, False otherwise bool: True if this configuration is compatible, False otherwise
""" """
# Simple version comparison (can be enhanced with packaging.version if needed) # Simple version comparison (can be enhanced with packaging.version if needed)
current = [int(x) for x in self.version.split(".")] current = [int(x) for x in self.version.split(".")]
minimum = [int(x) for x in min_version.split(".")] minimum = [int(x) for x in min_version.split(".")]
# Compare major, minor, patch versions # Compare major, minor, patch versions
for c, m in zip(current, minimum): for c, m in zip(current, minimum, strict=False):
if c > m: if c > m:
return True return True
if c < m: if c < m:
return False return False
return True return True
@model_validator(mode='before') @model_validator(mode="before")
@classmethod @classmethod
def validate_fingerprint(cls, values): def validate_fingerprint(cls, values):
"""Ensure fingerprint is properly initialized.""" """Ensure fingerprint is properly initialized."""
if isinstance(values, dict): if isinstance(values, dict):
# Handle case where fingerprint is not provided or is None # Handle case where fingerprint is not provided or is None
if 'fingerprint' not in values or values['fingerprint'] is None: if "fingerprint" not in values or values["fingerprint"] is None:
values['fingerprint'] = Fingerprint() values["fingerprint"] = Fingerprint()
# Handle case where fingerprint is a string (seed) # Handle case where fingerprint is a string (seed)
elif isinstance(values['fingerprint'], str): elif isinstance(values["fingerprint"], str):
if not values['fingerprint'].strip(): if not values["fingerprint"].strip():
raise ValueError("Fingerprint seed cannot be empty") msg = "Fingerprint seed cannot be empty"
values['fingerprint'] = Fingerprint.generate(seed=values['fingerprint']) raise ValueError(msg)
values["fingerprint"] = Fingerprint.generate(seed=values["fingerprint"])
return values return values
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
""" """Convert the security config to a dictionary.
Convert the security config to a dictionary.
Returns: Returns:
Dict[str, Any]: Dictionary representation of the security config Dict[str, Any]: Dictionary representation of the security config
""" """
result = { return {
"fingerprint": self.fingerprint.to_dict() "fingerprint": self.fingerprint.to_dict(),
} }
return result
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'SecurityConfig': def from_dict(cls, data: dict[str, Any]) -> "SecurityConfig":
""" """Create a SecurityConfig from a dictionary.
Create a SecurityConfig from a dictionary.
Args: Args:
data (Dict[str, Any]): Dictionary representation of a security config data (Dict[str, Any]): Dictionary representation of a security config
Returns: Returns:
SecurityConfig: A new SecurityConfig instance SecurityConfig: A new SecurityConfig instance
""" """
# Make a copy to avoid modifying the original # Make a copy to avoid modifying the original
data_copy = data.copy() data_copy = data.copy()

View File

@@ -2,23 +2,16 @@ import datetime
import inspect import inspect
import json import json
import logging import logging
import re
import threading import threading
import uuid import uuid
from collections.abc import Callable
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy from copy import copy
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
Callable,
ClassVar, ClassVar,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union, Union,
get_args, get_args,
get_origin, get_origin,
@@ -71,6 +64,7 @@ class Task(BaseModel):
output_pydantic: Pydantic model for task output. output_pydantic: Pydantic model for task output.
security_config: Security configuration including fingerprinting. security_config: Security configuration including fingerprinting.
tools: List of tools/resources limited for task execution. tools: List of tools/resources limited for task execution.
""" """
__hash__ = object.__hash__ # type: ignore __hash__ = object.__hash__ # type: ignore
@@ -79,46 +73,46 @@ class Task(BaseModel):
tools_errors: int = 0 tools_errors: int = 0
delegations: int = 0 delegations: int = 0
i18n: I18N = I18N() i18n: I18N = I18N()
name: Optional[str] = Field(default=None) name: str | None = Field(default=None)
prompt_context: Optional[str] = None prompt_context: str | None = None
description: str = Field(description="Description of the actual task.") description: str = Field(description="Description of the actual task.")
expected_output: str = Field( expected_output: str = Field(
description="Clear definition of expected output for the task." description="Clear definition of expected output for the task.",
) )
config: Optional[Dict[str, Any]] = Field( config: dict[str, Any] | None = Field(
description="Configuration for the agent", description="Configuration for the agent",
default=None, default=None,
) )
callback: Optional[Any] = Field( callback: Any | None = Field(
description="Callback to be executed after the task is completed.", default=None description="Callback to be executed after the task is completed.", default=None,
) )
agent: Optional[BaseAgent] = Field( agent: BaseAgent | None = Field(
description="Agent responsible for execution the task.", default=None description="Agent responsible for execution the task.", default=None,
) )
context: Optional[List["Task"]] = Field( context: list["Task"] | None = Field(
description="Other tasks that will have their output used as context for this task.", description="Other tasks that will have their output used as context for this task.",
default=None, default=None,
) )
async_execution: Optional[bool] = Field( async_execution: bool | None = Field(
description="Whether the task should be executed asynchronously or not.", description="Whether the task should be executed asynchronously or not.",
default=False, default=False,
) )
output_json: Optional[Type[BaseModel]] = Field( output_json: type[BaseModel] | None = Field(
description="A Pydantic model to be used to create a JSON output.", description="A Pydantic model to be used to create a JSON output.",
default=None, default=None,
) )
output_pydantic: Optional[Type[BaseModel]] = Field( output_pydantic: type[BaseModel] | None = Field(
description="A Pydantic model to be used to create a Pydantic output.", description="A Pydantic model to be used to create a Pydantic output.",
default=None, default=None,
) )
output_file: Optional[str] = Field( output_file: str | None = Field(
description="A file path to be used to create a file output.", description="A file path to be used to create a file output.",
default=None, default=None,
) )
output: Optional[TaskOutput] = Field( output: TaskOutput | None = Field(
description="Task output, it's final result after being executed", default=None description="Task output, it's final result after being executed", default=None,
) )
tools: Optional[List[BaseTool]] = Field( tools: list[BaseTool] | None = Field(
default_factory=list, default_factory=list,
description="Tools the agent is limited to use for this task.", description="Tools the agent is limited to use for this task.",
) )
@@ -131,37 +125,36 @@ class Task(BaseModel):
frozen=True, frozen=True,
description="Unique identifier for the object, not set by user.", description="Unique identifier for the object, not set by user.",
) )
human_input: Optional[bool] = Field( human_input: bool | None = Field(
description="Whether the task should have a human review the final answer of the agent", description="Whether the task should have a human review the final answer of the agent",
default=False, default=False,
) )
converter_cls: Optional[Type[Converter]] = Field( converter_cls: type[Converter] | None = Field(
description="A converter class used to export structured output", description="A converter class used to export structured output",
default=None, default=None,
) )
processed_by_agents: Set[str] = Field(default_factory=set) processed_by_agents: set[str] = Field(default_factory=set)
guardrail: Optional[Union[Callable[[TaskOutput], Tuple[bool, Any]], str]] = Field( guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str | None = Field(
default=None, default=None,
description="Function or string description of a guardrail to validate task output before proceeding to next task", description="Function or string description of a guardrail to validate task output before proceeding to next task",
) )
max_retries: int = Field( max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails" default=3, description="Maximum number of retries when guardrail fails",
) )
retry_count: int = Field(default=0, description="Current number of retries") retry_count: int = Field(default=0, description="Current number of retries")
start_time: Optional[datetime.datetime] = Field( start_time: datetime.datetime | None = Field(
default=None, description="Start time of the task execution" default=None, description="Start time of the task execution",
) )
end_time: Optional[datetime.datetime] = Field( end_time: datetime.datetime | None = Field(
default=None, description="End time of the task execution" default=None, description="End time of the task execution",
) )
@field_validator("guardrail") @field_validator("guardrail")
@classmethod @classmethod
def validate_guardrail_function( def validate_guardrail_function(
cls, v: Optional[str | Callable] cls, v: str | Callable | None,
) -> Optional[str | Callable]: ) -> str | Callable | None:
""" """If v is a callable, validate that the guardrail function has the correct signature and behavior.
If v is a callable, validate that the guardrail function has the correct signature and behavior.
If v is a string, return it as is. If v is a string, return it as is.
While type hints provide static checking, this validator ensures runtime safety by: While type hints provide static checking, this validator ensures runtime safety by:
@@ -183,6 +176,7 @@ class Task(BaseModel):
Raises: Raises:
ValueError: If the function signature is invalid or return annotation ValueError: If the function signature is invalid or return annotation
doesn't match Tuple[bool, Any] doesn't match Tuple[bool, Any]
""" """
if v is not None and callable(v): if v is not None and callable(v):
sig = inspect.signature(v) sig = inspect.signature(v)
@@ -192,7 +186,8 @@ class Task(BaseModel):
if param.default is inspect.Parameter.empty if param.default is inspect.Parameter.empty
] ]
if len(positional_args) != 1: if len(positional_args) != 1:
raise ValueError("Guardrail function must accept exactly one parameter") msg = "Guardrail function must accept exactly one parameter"
raise ValueError(msg)
# Check return annotation if present, but don't require it # Check return annotation if present, but don't require it
return_annotation = sig.return_annotation return_annotation = sig.return_annotation
@@ -210,16 +205,17 @@ class Task(BaseModel):
or return_annotation_args[1] == Union[str, TaskOutput] or return_annotation_args[1] == Union[str, TaskOutput]
) )
): ):
msg = "If return type is annotated, it must be Tuple[bool, Any]"
raise ValueError( raise ValueError(
"If return type is annotated, it must be Tuple[bool, Any]" msg,
) )
return v return v
_guardrail: Optional[Callable] = PrivateAttr(default=None) _guardrail: Callable | None = PrivateAttr(default=None)
_original_description: Optional[str] = PrivateAttr(default=None) _original_description: str | None = PrivateAttr(default=None)
_original_expected_output: Optional[str] = PrivateAttr(default=None) _original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: Optional[str] = PrivateAttr(default=None) _original_output_file: str | None = PrivateAttr(default=None)
_thread: Optional[threading.Thread] = PrivateAttr(default=None) _thread: threading.Thread | None = PrivateAttr(default=None)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -231,8 +227,9 @@ class Task(BaseModel):
required_fields = ["description", "expected_output"] required_fields = ["description", "expected_output"]
for field in required_fields: for field in required_fields:
if getattr(self, field) is None: if getattr(self, field) is None:
msg = f"{field} must be provided either directly or through config"
raise ValueError( raise ValueError(
f"{field} must be provided either directly or through config" msg,
) )
return self return self
@@ -245,22 +242,23 @@ class Task(BaseModel):
assert self.agent is not None assert self.agent is not None
self._guardrail = LLMGuardrail( self._guardrail = LLMGuardrail(
description=self.guardrail, llm=self.agent.llm description=self.guardrail, llm=self.agent.llm,
) )
return self return self
@field_validator("id", mode="before") @field_validator("id", mode="before")
@classmethod @classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: def _deny_user_set_id(cls, v: UUID4 | None) -> None:
if v: if v:
msg = "may_not_set_field"
raise PydanticCustomError( raise PydanticCustomError(
"may_not_set_field", "This field is not to be set by the user.", {} msg, "This field is not to be set by the user.", {},
) )
@field_validator("output_file") @field_validator("output_file")
@classmethod @classmethod
def output_file_validation(cls, value: Optional[str]) -> Optional[str]: def output_file_validation(cls, value: str | None) -> str | None:
"""Validate the output file path. """Validate the output file path.
Args: Args:
@@ -274,26 +272,30 @@ class Task(BaseModel):
Raises: Raises:
ValueError: If the path contains invalid characters, path traversal attempts, ValueError: If the path contains invalid characters, path traversal attempts,
or other security concerns. or other security concerns.
""" """
if value is None: if value is None:
return None return None
# Basic security checks # Basic security checks
if ".." in value: if ".." in value:
msg = "Path traversal attempts are not allowed in output_file paths"
raise ValueError( raise ValueError(
"Path traversal attempts are not allowed in output_file paths" msg,
) )
# Check for shell expansion first # Check for shell expansion first
if value.startswith("~") or value.startswith("$"): if value.startswith(("~", "$")):
msg = "Shell expansion characters are not allowed in output_file paths"
raise ValueError( raise ValueError(
"Shell expansion characters are not allowed in output_file paths" msg,
) )
# Then check other shell special characters # Then check other shell special characters
if any(char in value for char in ["|", ">", "<", "&", ";"]): if any(char in value for char in ["|", ">", "<", "&", ";"]):
msg = "Shell special characters are not allowed in output_file paths"
raise ValueError( raise ValueError(
"Shell special characters are not allowed in output_file paths" msg,
) )
# Don't strip leading slash if it's a template path with variables # Don't strip leading slash if it's a template path with variables
@@ -302,7 +304,8 @@ class Task(BaseModel):
template_vars = [part.split("}")[0] for part in value.split("{")[1:]] template_vars = [part.split("}")[0] for part in value.split("{")[1:]]
for var in template_vars: for var in template_vars:
if not var.isidentifier(): if not var.isidentifier():
raise ValueError(f"Invalid template variable name: {var}") msg = f"Invalid template variable name: {var}"
raise ValueError(msg)
return value return value
# Strip leading slash for regular paths # Strip leading slash for regular paths
@@ -330,8 +333,9 @@ class Task(BaseModel):
"""Check if an output type is set.""" """Check if an output type is set."""
output_types = [self.output_json, self.output_pydantic] output_types = [self.output_json, self.output_pydantic]
if len([type for type in output_types if type]) > 1: if len([type for type in output_types if type]) > 1:
msg = "output_type"
raise PydanticCustomError( raise PydanticCustomError(
"output_type", msg,
"Only one output type can be set, either output_pydantic or output_json.", "Only one output type can be set, either output_pydantic or output_json.",
{}, {},
) )
@@ -339,9 +343,9 @@ class Task(BaseModel):
def execute_sync( def execute_sync(
self, self,
agent: Optional[BaseAgent] = None, agent: BaseAgent | None = None,
context: Optional[str] = None, context: str | None = None,
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
) -> TaskOutput: ) -> TaskOutput:
"""Execute the task synchronously.""" """Execute the task synchronously."""
return self._execute_core(agent, context, tools) return self._execute_core(agent, context, tools)
@@ -363,8 +367,8 @@ class Task(BaseModel):
def execute_async( def execute_async(
self, self,
agent: BaseAgent | None = None, agent: BaseAgent | None = None,
context: Optional[str] = None, context: str | None = None,
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
) -> Future[TaskOutput]: ) -> Future[TaskOutput]:
"""Execute the task asynchronously.""" """Execute the task asynchronously."""
future: Future[TaskOutput] = Future() future: Future[TaskOutput] = Future()
@@ -377,9 +381,9 @@ class Task(BaseModel):
def _execute_task_async( def _execute_task_async(
self, self,
agent: Optional[BaseAgent], agent: BaseAgent | None,
context: Optional[str], context: str | None,
tools: Optional[List[Any]], tools: list[Any] | None,
future: Future[TaskOutput], future: Future[TaskOutput],
) -> None: ) -> None:
"""Execute the task asynchronously with context handling.""" """Execute the task asynchronously with context handling."""
@@ -388,17 +392,18 @@ class Task(BaseModel):
def _execute_core( def _execute_core(
self, self,
agent: Optional[BaseAgent], agent: BaseAgent | None,
context: Optional[str], context: str | None,
tools: Optional[List[Any]], tools: list[Any] | None,
) -> TaskOutput: ) -> TaskOutput:
"""Run the core execution logic of the task.""" """Run the core execution logic of the task."""
try: try:
agent = agent or self.agent agent = agent or self.agent
self.agent = agent self.agent = agent
if not agent: if not agent:
msg = f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
raise Exception( raise Exception(
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical." msg,
) )
self.start_time = datetime.datetime.now() self.start_time = datetime.datetime.now()
@@ -430,10 +435,13 @@ class Task(BaseModel):
guardrail_result = self._process_guardrail(task_output) guardrail_result = self._process_guardrail(task_output)
if not guardrail_result.success: if not guardrail_result.success:
if self.retry_count >= self.max_retries: if self.retry_count >= self.max_retries:
raise Exception( msg = (
f"Task failed guardrail validation after {self.max_retries} retries. " f"Task failed guardrail validation after {self.max_retries} retries. "
f"Last error: {guardrail_result.error}" f"Last error: {guardrail_result.error}"
) )
raise Exception(
msg,
)
self.retry_count += 1 self.retry_count += 1
context = self.i18n.errors("validation_error").format( context = self.i18n.errors("validation_error").format(
@@ -448,14 +456,15 @@ class Task(BaseModel):
return self._execute_core(agent, context, tools) return self._execute_core(agent, context, tools)
if guardrail_result.result is None: if guardrail_result.result is None:
msg = "Task guardrail returned None as result. This is not allowed."
raise Exception( raise Exception(
"Task guardrail returned None as result. This is not allowed." msg,
) )
if isinstance(guardrail_result.result, str): if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output( pydantic_output, json_output = self._export_output(
guardrail_result.result guardrail_result.result,
) )
task_output.pydantic = pydantic_output task_output.pydantic = pydantic_output
task_output.json_dict = json_output task_output.json_dict = json_output
@@ -482,13 +491,13 @@ class Task(BaseModel):
) )
self._save_file(content) self._save_file(content)
crewai_event_bus.emit( crewai_event_bus.emit(
self, TaskCompletedEvent(output=task_output, task=self) self, TaskCompletedEvent(output=task_output, task=self),
) )
return task_output return task_output
except Exception as e: except Exception as e:
self.end_time = datetime.datetime.now() self.end_time = datetime.datetime.now()
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
raise e # Re-raise the exception after emitting the event raise # Re-raise the exception after emitting the event
def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult: def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult:
assert self._guardrail is not None assert self._guardrail is not None
@@ -504,7 +513,7 @@ class Task(BaseModel):
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,
LLMGuardrailStartedEvent( LLMGuardrailStartedEvent(
guardrail=self._guardrail, retry_count=self.retry_count guardrail=self._guardrail, retry_count=self.retry_count,
), ),
) )
@@ -526,17 +535,18 @@ class Task(BaseModel):
Returns: Returns:
Prompt of the task. Prompt of the task.
""" """
tasks_slices = [self.description] tasks_slices = [self.description]
output = self.i18n.slice("expected_output").format( output = self.i18n.slice("expected_output").format(
expected_output=self.expected_output expected_output=self.expected_output,
) )
tasks_slices = [self.description, output] tasks_slices = [self.description, output]
return "\n".join(tasks_slices) return "\n".join(tasks_slices)
def interpolate_inputs_and_add_conversation_history( def interpolate_inputs_and_add_conversation_history(
self, inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] self, inputs: dict[str, str | int | float | dict[str, Any] | list[Any]],
) -> None: ) -> None:
"""Interpolate inputs into the task description, expected output, and output file path. """Interpolate inputs into the task description, expected output, and output file path.
Add conversation history if present. Add conversation history if present.
@@ -547,6 +557,7 @@ class Task(BaseModel):
Raises: Raises:
ValueError: If a required template variable is missing from inputs. ValueError: If a required template variable is missing from inputs.
""" """
if self._original_description is None: if self._original_description is None:
self._original_description = self.description self._original_description = self.description
@@ -560,43 +571,46 @@ class Task(BaseModel):
try: try:
self.description = interpolate_only( self.description = interpolate_only(
input_string=self._original_description, inputs=inputs input_string=self._original_description, inputs=inputs,
) )
except KeyError as e: except KeyError as e:
msg = f"Missing required template variable '{e.args[0]}' in description"
raise ValueError( raise ValueError(
f"Missing required template variable '{e.args[0]}' in description" msg,
) from e ) from e
except ValueError as e: except ValueError as e:
raise ValueError(f"Error interpolating description: {str(e)}") from e msg = f"Error interpolating description: {e!s}"
raise ValueError(msg) from e
try: try:
self.expected_output = interpolate_only( self.expected_output = interpolate_only(
input_string=self._original_expected_output, inputs=inputs input_string=self._original_expected_output, inputs=inputs,
) )
except (KeyError, ValueError) as e: except (KeyError, ValueError) as e:
raise ValueError(f"Error interpolating expected_output: {str(e)}") from e msg = f"Error interpolating expected_output: {e!s}"
raise ValueError(msg) from e
if self.output_file is not None: if self.output_file is not None:
try: try:
self.output_file = interpolate_only( self.output_file = interpolate_only(
input_string=self._original_output_file, inputs=inputs input_string=self._original_output_file, inputs=inputs,
) )
except (KeyError, ValueError) as e: except (KeyError, ValueError) as e:
msg = f"Error interpolating output_file path: {e!s}"
raise ValueError( raise ValueError(
f"Error interpolating output_file path: {str(e)}" msg,
) from e ) from e
if "crew_chat_messages" in inputs and inputs["crew_chat_messages"]: if inputs.get("crew_chat_messages"):
conversation_instruction = self.i18n.slice( conversation_instruction = self.i18n.slice(
"conversation_history_instruction" "conversation_history_instruction",
) )
crew_chat_messages_json = str(inputs["crew_chat_messages"]) crew_chat_messages_json = str(inputs["crew_chat_messages"])
try: try:
crew_chat_messages = json.loads(crew_chat_messages_json) crew_chat_messages = json.loads(crew_chat_messages_json)
except json.JSONDecodeError as e: except json.JSONDecodeError:
print("An error occurred while parsing crew chat messages:", e)
raise raise
conversation_history = "\n".join( conversation_history = "\n".join(
@@ -613,14 +627,14 @@ class Task(BaseModel):
"""Increment the tools errors counter.""" """Increment the tools errors counter."""
self.tools_errors += 1 self.tools_errors += 1
def increment_delegations(self, agent_name: Optional[str]) -> None: def increment_delegations(self, agent_name: str | None) -> None:
"""Increment the delegations counter.""" """Increment the delegations counter."""
if agent_name: if agent_name:
self.processed_by_agents.add(agent_name) self.processed_by_agents.add(agent_name)
self.delegations += 1 self.delegations += 1
def copy( def copy(
self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"] self, agents: list["BaseAgent"], task_mapping: dict[str, "Task"],
) -> "Task": ) -> "Task":
"""Creates a deep copy of the Task while preserving its original class type. """Creates a deep copy of the Task while preserving its original class type.
@@ -630,6 +644,7 @@ class Task(BaseModel):
Returns: Returns:
A copy of the task with the same class type as the original. A copy of the task with the same class type as the original.
""" """
exclude = { exclude = {
"id", "id",
@@ -653,20 +668,19 @@ class Task(BaseModel):
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
cloned_tools = copy(self.tools) if self.tools else [] cloned_tools = copy(self.tools) if self.tools else []
copied_task = self.__class__( return self.__class__(
**copied_data, **copied_data,
context=cloned_context, context=cloned_context,
agent=cloned_agent, agent=cloned_agent,
tools=cloned_tools, tools=cloned_tools,
) )
return copied_task
def _export_output( def _export_output(
self, result: str self, result: str,
) -> Tuple[Optional[BaseModel], Optional[Dict[str, Any]]]: ) -> tuple[BaseModel | None, dict[str, Any] | None]:
pydantic_output: Optional[BaseModel] = None pydantic_output: BaseModel | None = None
json_output: Optional[Dict[str, Any]] = None json_output: dict[str, Any] | None = None
if self.output_pydantic or self.output_json: if self.output_pydantic or self.output_json:
model_output = convert_to_model( model_output = convert_to_model(
@@ -696,7 +710,7 @@ class Task(BaseModel):
return OutputFormat.PYDANTIC return OutputFormat.PYDANTIC
return OutputFormat.RAW return OutputFormat.RAW
def _save_file(self, result: Union[Dict, str, Any]) -> None: def _save_file(self, result: dict | str | Any) -> None:
"""Save task output to a file. """Save task output to a file.
Note: Note:
@@ -713,9 +727,11 @@ class Task(BaseModel):
RuntimeError: If there is an error writing to the file. For cross-platform RuntimeError: If there is an error writing to the file. For cross-platform
compatibility, especially on Windows, use FileWriterTool from crewai_tools compatibility, especially on Windows, use FileWriterTool from crewai_tools
package. package.
""" """
if self.output_file is None: if self.output_file is None:
raise ValueError("output_file is not set.") msg = "output_file is not set."
raise ValueError(msg)
FILEWRITER_RECOMMENDATION = ( FILEWRITER_RECOMMENDATION = (
"For cross-platform file writing, especially on Windows, " "For cross-platform file writing, especially on Windows, "
@@ -736,15 +752,14 @@ class Task(BaseModel):
json.dump(result, file, ensure_ascii=False, indent=2) json.dump(result, file, ensure_ascii=False, indent=2)
else: else:
file.write(str(result)) file.write(str(result))
except (OSError, IOError) as e: except OSError as e:
raise RuntimeError( raise RuntimeError(
"\n".join( "\n".join(
[f"Failed to save output file: {e}", FILEWRITER_RECOMMENDATION] [f"Failed to save output file: {e}", FILEWRITER_RECOMMENDATION],
) ),
) )
return None
def __repr__(self): def __repr__(self) -> str:
return f"Task(description={self.description}, expected_output={self.expected_output})" return f"Task(description={self.description}, expected_output={self.expected_output})"
@property @property
@@ -753,5 +768,6 @@ class Task(BaseModel):
Returns: Returns:
Fingerprint: The fingerprint of the task Fingerprint: The fingerprint of the task
""" """
return self.security_config.fingerprint return self.security_config.fingerprint

Some files were not shown because too many files have changed in this diff Show More