Compare commits

...

7 Commits

Author SHA1 Message Date
Devin AI
90980e8190 Update mypy configuration to handle union types better
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-12 13:45:33 +00:00
Devin AI
f6cdfc1099 Update ruff configuration to ignore additional Path and docstring rules
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-12 13:44:20 +00:00
Devin AI
39c4ed33bb Update ruff configuration to ignore additional linting rules
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-12 13:41:08 +00:00
Devin AI
1860026d61 Update ruff configuration to ignore test-specific linting rules
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-12 13:36:52 +00:00
Devin AI
46621113af Apply automatic linting fixes to tests directory
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-12 13:31:07 +00:00
Devin AI
ad1ea46bbb Apply automatic linting fixes to src directory
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-12 13:30:50 +00:00
Devin AI
807dfe0558 Fix linting issues by updating ruff configuration and adding linting test
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-12 13:30:28 +00:00
225 changed files with 5018 additions and 5019 deletions

View File

@@ -2,3 +2,50 @@ exclude = [
"templates",
"__init__.py",
]
[lint]
select = ["ALL"]
ignore = [
"D100", # Missing docstring in public module
"D101", # Missing docstring in public class
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D106", # Missing docstring in public nested class
"D107", # Missing docstring in __init__
"D205", # 1 blank line required between summary line and description
"ANN001", # Missing type annotation for function argument
"ANN002", # Missing type annotation for *args
"ANN003", # Missing type annotation for **kwargs
"ANN201", # Missing return type annotation for public function
"ANN202", # Missing return type annotation for private function
"ANN204", # Missing return type annotation for special method
"ANN205", # Missing return type annotation for staticmethod
"ANN206", # Missing return type annotation for classmethod
"E501", # Line too long
"PT011", # pytest.raises() without match parameter
"PT012", # pytest.raises() block should contain a single simple statement
"SIM117", # Use a single `with` statement with multiple contexts
"PLR2004", # Magic value used in comparison
"B017", # Do not assert blind exception
]
[lint.per-file-ignores]
"tests/*" = [
"S101", # Allow assert in tests
"SLF001", # Allow private member access in tests
"DTZ001", # Allow datetime without tzinfo in tests
"PTH107", # Allow os.remove instead of Path.unlink in tests
"PTH118", # Allow os.path.join() in tests
"PTH120", # Allow os.path.dirname() in tests
"PTH123", # Allow open() instead of Path.open() in tests
"PTH202", # Allow os.path.getsize in tests
"PT012", # Allow multiple statements in pytest.raises() block in tests
"SIM117", # Allow nested with statements in tests
"PLR2004", # Allow magic values in tests
"B017", # Allow asserting blind exceptions in tests
]
[lint.isort]
known-first-party = ["crewai"]

View File

@@ -94,8 +94,10 @@ crewai = "crewai.cli.cli:crewai"
[tool.mypy]
ignore_missing_imports = true
disable_error_code = 'import-untyped'
disable_error_code = 'import-untyped,union-attr'
exclude = ["cli/templates"]
implicit_optional = true
strict_optional = false
[tool.bandit]
exclude_dirs = ["src/crewai/cli/templates"]

View File

@@ -1,6 +1,7 @@
import shutil
import subprocess
from typing import Any, Dict, List, Literal, Optional, Sequence, Type, Union
from collections.abc import Sequence
from typing import Any, Literal
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -67,40 +68,41 @@ class Agent(BaseAgent):
step_callback: Callback to be executed after each step of the agent execution.
knowledge_sources: Knowledge sources for the agent.
embedder: Embedder configuration for the agent.
"""
_times_executed: int = PrivateAttr(default=0)
max_execution_time: Optional[int] = Field(
max_execution_time: int | None = Field(
default=None,
description="Maximum execution time for an agent to execute a task",
)
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
step_callback: Optional[Any] = Field(
step_callback: Any | None = Field(
default=None,
description="Callback to be executed after each step of the agent execution.",
)
use_system_prompt: Optional[bool] = Field(
use_system_prompt: bool | None = Field(
default=True,
description="Use system prompt for the agent.",
)
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
description="Language model that will run the agent.", default=None
llm: str | InstanceOf[BaseLLM] | Any = Field(
description="Language model that will run the agent.", default=None,
)
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
description="Language model that will run the agent.", default=None,
)
system_template: Optional[str] = Field(
default=None, description="System format for the agent."
system_template: str | None = Field(
default=None, description="System format for the agent.",
)
prompt_template: Optional[str] = Field(
default=None, description="Prompt format for the agent."
prompt_template: str | None = Field(
default=None, description="Prompt format for the agent.",
)
response_template: Optional[str] = Field(
default=None, description="Response format for the agent."
response_template: str | None = Field(
default=None, description="Response format for the agent.",
)
allow_code_execution: Optional[bool] = Field(
default=False, description="Enable code execution for the agent."
allow_code_execution: bool | None = Field(
default=False, description="Enable code execution for the agent.",
)
respect_context_window: bool = Field(
default=True,
@@ -118,19 +120,19 @@ class Agent(BaseAgent):
default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
)
embedder: Optional[Dict[str, Any]] = Field(
embedder: dict[str, Any] | None = Field(
default=None,
description="Embedder configuration for the agent.",
)
agent_knowledge_context: Optional[str] = Field(
agent_knowledge_context: str | None = Field(
default=None,
description="Knowledge context for the agent.",
)
crew_knowledge_context: Optional[str] = Field(
crew_knowledge_context: str | None = Field(
default=None,
description="Knowledge context for the crew.",
)
knowledge_search_query: Optional[str] = Field(
knowledge_search_query: str | None = Field(
default=None,
description="Knowledge search query for the agent dynamically generated by the agent.",
)
@@ -141,7 +143,7 @@ class Agent(BaseAgent):
self.llm = create_llm(self.llm)
if self.function_calling_llm and not isinstance(
self.function_calling_llm, BaseLLM
self.function_calling_llm, BaseLLM,
):
self.function_calling_llm = create_llm(self.function_calling_llm)
@@ -153,12 +155,12 @@ class Agent(BaseAgent):
return self
def _setup_agent_executor(self):
def _setup_agent_executor(self) -> None:
if not self.cache_handler:
self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler)
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
try:
if self.embedder is None and crew_embedder:
self.embedder = crew_embedder
@@ -174,7 +176,8 @@ class Agent(BaseAgent):
storage=self.knowledge_storage or None,
)
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
msg = f"Invalid Knowledge Configuration: {e!s}"
raise ValueError(msg)
def _is_any_available_memory(self) -> bool:
"""Check if any memory is available."""
@@ -196,8 +199,8 @@ class Agent(BaseAgent):
def execute_task(
self,
task: Task,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task with the agent.
@@ -213,6 +216,7 @@ class Agent(BaseAgent):
TimeoutError: If execution exceeds the maximum execution time.
ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons.
"""
if self.tools_handler:
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
@@ -228,18 +232,18 @@ class Agent(BaseAgent):
# schema = json.dumps(task.output_json, indent=2)
schema = generate_model_description(task.output_json)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
"formatted_task_instructions",
).format(output_format=schema)
elif task.output_pydantic:
schema = generate_model_description(task.output_pydantic)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
"formatted_task_instructions",
).format(output_format=schema)
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context
task=task_prompt, context=context,
)
if self._is_any_available_memory():
@@ -267,25 +271,25 @@ class Agent(BaseAgent):
)
try:
self.knowledge_search_query = self._get_knowledge_search_query(
task_prompt
task_prompt,
)
if self.knowledge_search_query:
agent_knowledge_snippets = self.knowledge.query(
[self.knowledge_search_query], **knowledge_config
[self.knowledge_search_query], **knowledge_config,
)
if agent_knowledge_snippets:
self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
agent_knowledge_snippets,
)
if self.agent_knowledge_context:
task_prompt += self.agent_knowledge_context
if self.crew:
knowledge_snippets = self.crew.query_knowledge(
[self.knowledge_search_query], **knowledge_config
[self.knowledge_search_query], **knowledge_config,
)
if knowledge_snippets:
self.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
knowledge_snippets,
)
if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context
@@ -342,11 +346,12 @@ class Agent(BaseAgent):
not isinstance(self.max_execution_time, int)
or self.max_execution_time <= 0
):
msg = "Max Execution time must be a positive integer greater than zero"
raise ValueError(
"Max Execution time must be a positive integer greater than zero"
msg,
)
result = self._execute_with_timeout(
task_prompt, task, self.max_execution_time
task_prompt, task, self.max_execution_time,
)
else:
result = self._execute_without_timeout(task_prompt, task)
@@ -361,7 +366,7 @@ class Agent(BaseAgent):
error=str(e),
),
)
raise e
raise
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
@@ -373,7 +378,7 @@ class Agent(BaseAgent):
error=str(e),
),
)
raise e
raise
self._times_executed += 1
if self._times_executed > self.max_retry_limit:
crewai_event_bus.emit(
@@ -384,7 +389,7 @@ class Agent(BaseAgent):
error=str(e),
),
)
raise e
raise
result = self.execute_task(task, context, tools)
if self.max_rpm and self._rpm_controller:
@@ -416,24 +421,27 @@ class Agent(BaseAgent):
Raises:
TimeoutError: If execution exceeds the timeout.
RuntimeError: If execution fails for other reasons.
"""
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
self._execute_without_timeout, task_prompt=task_prompt, task=task
self._execute_without_timeout, task_prompt=task_prompt, task=task,
)
try:
return future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
future.cancel()
msg = f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
raise TimeoutError(
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
msg,
)
except Exception as e:
future.cancel()
raise RuntimeError(f"Task execution failed: {str(e)}")
msg = f"Task execution failed: {e!s}"
raise RuntimeError(msg)
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
"""Execute a task without a timeout.
@@ -444,6 +452,7 @@ class Agent(BaseAgent):
Returns:
The output of the agent.
"""
return self.agent_executor.invoke(
{
@@ -451,18 +460,19 @@ class Agent(BaseAgent):
"tool_names": self.agent_executor.tools_names,
"tools": self.agent_executor.tools_description,
"ask_for_human_input": task.human_input,
}
},
)["output"]
def create_agent_executor(
self, tools: Optional[List[BaseTool]] = None, task=None
self, tools: list[BaseTool] | None = None, task=None,
) -> None:
"""Create an agent executor for the agent.
Returns:
An instance of the CrewAgentExecutor class.
"""
raw_tools: List[BaseTool] = tools or self.tools or []
raw_tools: list[BaseTool] = tools or self.tools or []
parsed_tools = parse_tools(raw_tools)
prompt = Prompts(
@@ -479,7 +489,7 @@ class Agent(BaseAgent):
if self.response_template:
stop_words.append(
self.response_template.split("{{ .Response }}")[1].strip()
self.response_template.split("{{ .Response }}")[1].strip(),
)
self.agent_executor = CrewAgentExecutor(
@@ -504,10 +514,9 @@ class Agent(BaseAgent):
callbacks=[TokenCalcHandler(self._token_process)],
)
def get_delegation_tools(self, agents: List[BaseAgent]):
def get_delegation_tools(self, agents: list[BaseAgent]):
agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools()
return tools
return agent_tools.tools()
def get_multimodal_tools(self) -> Sequence[BaseTool]:
from crewai.tools.agent_tools.add_image_tool import AddImageTool
@@ -523,7 +532,7 @@ class Agent(BaseAgent):
return [CodeInterpreterTool(unsafe_mode=unsafe_mode)]
except ModuleNotFoundError:
self._logger.log(
"info", "Coding tools not available. Install crewai_tools. "
"info", "Coding tools not available. Install crewai_tools. ",
)
def get_output_converter(self, llm, text, model, instructions):
@@ -555,7 +564,7 @@ class Agent(BaseAgent):
)
return task_prompt
def _render_text_description(self, tools: List[Any]) -> str:
def _render_text_description(self, tools: list[Any]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:
@@ -565,48 +574,48 @@ class Agent(BaseAgent):
search: This tool is used for search
calculator: This tool is used for math
"""
description = "\n".join(
return "\n".join(
[
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
for tool in tools
]
],
)
return description
def _validate_docker_installation(self) -> None:
"""Check if Docker is installed and running."""
if not shutil.which("docker"):
msg = f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}"
raise RuntimeError(
f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}"
msg,
)
try:
subprocess.run(
["docker", "info"],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
capture_output=True,
)
except subprocess.CalledProcessError:
msg = f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
raise RuntimeError(
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
msg,
)
def __repr__(self):
def __repr__(self) -> str:
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
@property
def fingerprint(self) -> Fingerprint:
"""
Get the agent's fingerprint.
"""Get the agent's fingerprint.
Returns:
Fingerprint: The agent's fingerprint
"""
return self.security_config.fingerprint
def set_fingerprint(self, fingerprint: Fingerprint):
def set_fingerprint(self, fingerprint: Fingerprint) -> None:
self.security_config.fingerprint = fingerprint
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
@@ -619,7 +628,7 @@ class Agent(BaseAgent):
),
)
query = self.i18n.slice("knowledge_search_query").format(
task_prompt=task_prompt
task_prompt=task_prompt,
)
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
if not isinstance(self.llm, BaseLLM):
@@ -644,7 +653,7 @@ class Agent(BaseAgent):
"content": rewriter_prompt,
},
{"role": "user", "content": query},
]
],
)
crewai_event_bus.emit(
self,
@@ -666,11 +675,10 @@ class Agent(BaseAgent):
def kickoff(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
) -> LiteAgentOutput:
"""
Execute the agent with the given messages using a LiteAgent instance.
"""Execute the agent with the given messages using a LiteAgent instance.
This method is useful when you want to use the Agent configuration but
with the simpler and more direct execution flow of LiteAgent.
@@ -683,6 +691,7 @@ class Agent(BaseAgent):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
lite_agent = LiteAgent(
role=self.role,
@@ -703,11 +712,10 @@ class Agent(BaseAgent):
async def kickoff_async(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages using a LiteAgent instance.
"""Execute the agent asynchronously with the given messages using a LiteAgent instance.
This is the async version of the kickoff method.
@@ -719,6 +727,7 @@ class Agent(BaseAgent):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
lite_agent = LiteAgent(
role=self.role,

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import PrivateAttr
@@ -16,27 +16,27 @@ class BaseAgentAdapter(BaseAgent, ABC):
"""
adapted_structured_output: bool = False
_agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_agent_config: dict[str, Any] | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}
def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any):
def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any) -> None:
super().__init__(adapted_agent=True, **kwargs)
self._agent_config = agent_config
@abstractmethod
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure and adapt tools for the specific agent implementation.
Args:
tools: Optional list of BaseTool instances to be configured
"""
pass
def configure_structured_output(self, structured_output: Any) -> None:
"""Configure the structured output for the specific agent implementation.
Args:
structured_output: The structured output to be configured
"""
pass

View File

@@ -8,7 +8,7 @@ class BaseConverterAdapter(ABC):
converter adapters must implement for converting structured output.
"""
def __init__(self, agent_adapter):
def __init__(self, agent_adapter) -> None:
self.agent_adapter = agent_adapter
@abstractmethod
@@ -16,14 +16,11 @@ class BaseConverterAdapter(ABC):
"""Configure agents to return structured output.
Must support json and pydantic output.
"""
pass
@abstractmethod
def enhance_system_prompt(self, base_prompt: str) -> str:
"""Enhance the system prompt with structured output instructions."""
pass
@abstractmethod
def post_process_result(self, result: str) -> str:
"""Post-process the result to ensure it matches the expected format: string."""
pass

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any
from crewai.tools.base_tool import BaseTool
@@ -12,23 +12,23 @@ class BaseToolAdapter(ABC):
different frameworks and platforms.
"""
original_tools: List[BaseTool]
converted_tools: List[Any]
original_tools: list[BaseTool]
converted_tools: list[Any]
def __init__(self, tools: Optional[List[BaseTool]] = None):
def __init__(self, tools: list[BaseTool] | None = None) -> None:
self.original_tools = tools or []
self.converted_tools = []
@abstractmethod
def configure_tools(self, tools: List[BaseTool]) -> None:
def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure and convert tools for the specific implementation.
Args:
tools: List of BaseTool instances to be configured and converted
"""
pass
def tools(self) -> List[Any]:
"""
def tools(self) -> list[Any]:
"""Return all converted tools."""
return self.converted_tools

View File

@@ -1,4 +1,4 @@
from typing import Any, AsyncIterable, Dict, List, Optional
from typing import Any
from pydantic import Field, PrivateAttr
@@ -52,16 +52,17 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
role: str,
goal: str,
backstory: str,
tools: Optional[List[BaseTool]] = None,
tools: list[BaseTool] | None = None,
llm: Any = None,
max_iterations: int = 10,
agent_config: Optional[Dict[str, Any]] = None,
agent_config: dict[str, Any] | None = None,
**kwargs,
):
) -> None:
"""Initialize the LangGraph agent adapter."""
if not LANGGRAPH_AVAILABLE:
msg = "LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
raise ImportError(
"LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
msg,
)
super().__init__(
role=role,
@@ -82,7 +83,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
try:
self._memory = MemorySaver()
converted_tools: List[Any] = self._tool_adapter.tools()
converted_tools: list[Any] = self._tool_adapter.tools()
if self._agent_config:
self._graph = create_react_agent(
model=self.llm,
@@ -101,18 +102,18 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
except ImportError as e:
self._logger.log(
"error", f"Failed to import LangGraph dependencies: {str(e)}"
"error", f"Failed to import LangGraph dependencies: {e!s}",
)
raise
except Exception as e:
self._logger.log("error", f"Error setting up LangGraph agent: {str(e)}")
self._logger.log("error", f"Error setting up LangGraph agent: {e!s}")
raise
def _build_system_prompt(self) -> str:
"""Build a system prompt for the LangGraph agent."""
base_prompt = f"""
You are {self.role}.
Your goal is: {self.goal}
Your backstory: {self.backstory}
@@ -124,8 +125,8 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
def execute_task(
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task using the LangGraph workflow."""
self.create_agent_executor(tools)
@@ -137,7 +138,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context
task=task_prompt, context=context,
)
crewai_event_bus.emit(
@@ -159,7 +160,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
"messages": [
("system", self._build_system_prompt()),
("user", task_prompt),
]
],
},
config,
)
@@ -180,14 +181,14 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(
agent=self, task=task, output=final_answer
agent=self, task=task, output=final_answer,
),
)
return final_answer
except Exception as e:
self._logger.log("error", f"Error executing LangGraph task: {str(e)}")
self._logger.log("error", f"Error executing LangGraph task: {e!s}")
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -198,11 +199,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
)
raise
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
"""Configure the LangGraph agent for execution."""
self.configure_tools(tools)
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure tools for the LangGraph agent."""
if tools:
all_tools = list(self.tools or []) + list(tools or [])
@@ -210,13 +211,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
available_tools = self._tool_adapter.tools()
self._graph.tools = available_tools
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
"""Implement delegation tools support for LangGraph."""
agent_tools = AgentTools(agents=agents)
return agent_tools.tools()
def get_output_converter(
self, llm: Any, text: str, model: Any, instructions: str
self, llm: Any, text: str, model: Any, instructions: str,
) -> Any:
"""Convert output format if needed."""
return Converter(llm=llm, text=text, model=model, instructions=instructions)

View File

@@ -1,29 +1,25 @@
import inspect
from typing import Any, List, Optional
from typing import Any
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
from crewai.tools.base_tool import BaseTool
class LangGraphToolAdapter(BaseToolAdapter):
"""Adapts CrewAI tools to LangGraph agent tool compatible format"""
"""Adapts CrewAI tools to LangGraph agent tool compatible format."""
def __init__(self, tools: Optional[List[BaseTool]] = None):
def __init__(self, tools: list[BaseTool] | None = None) -> None:
self.original_tools = tools or []
self.converted_tools = []
def configure_tools(self, tools: List[BaseTool]) -> None:
"""
Configure and convert CrewAI tools to LangGraph-compatible format.
def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure and convert CrewAI tools to LangGraph-compatible format.
LangGraph expects tools in langchain_core.tools format.
"""
from langchain_core.tools import BaseTool, StructuredTool
converted_tools = []
if self.original_tools:
all_tools = tools + self.original_tools
else:
all_tools = tools
all_tools = tools + self.original_tools if self.original_tools else tools
for tool in all_tools:
if isinstance(tool, BaseTool):
converted_tools.append(tool)
@@ -57,5 +53,5 @@ class LangGraphToolAdapter(BaseToolAdapter):
self.converted_tools = converted_tools
def tools(self) -> List[Any]:
def tools(self) -> list[Any]:
return self.converted_tools or []

View File

@@ -5,10 +5,10 @@ from crewai.utilities.converter import generate_model_description
class LangGraphConverterAdapter(BaseConverterAdapter):
"""Adapter for handling structured output conversion in LangGraph agents"""
"""Adapter for handling structured output conversion in LangGraph agents."""
def __init__(self, agent_adapter):
"""Initialize the converter adapter with a reference to the agent adapter"""
def __init__(self, agent_adapter) -> None:
"""Initialize the converter adapter with a reference to the agent adapter."""
self.agent_adapter = agent_adapter
self._output_format = None
self._schema = None
@@ -32,7 +32,7 @@ class LangGraphConverterAdapter(BaseConverterAdapter):
self._system_prompt_appendix = self._generate_system_prompt_appendix()
def _generate_system_prompt_appendix(self) -> str:
"""Generate an appendix for the system prompt to enforce structured output"""
"""Generate an appendix for the system prompt to enforce structured output."""
if not self._output_format or not self._schema:
return ""
@@ -41,19 +41,19 @@ Important: Your final answer MUST be provided in the following structured format
{self._schema}
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
The output should be raw JSON that exactly matches the specified schema.
"""
def enhance_system_prompt(self, original_prompt: str) -> str:
"""Add structured output instructions to the system prompt if needed"""
"""Add structured output instructions to the system prompt if needed."""
if not self._system_prompt_appendix:
return original_prompt
return f"{original_prompt}\n{self._system_prompt_appendix}"
def post_process_result(self, result: str) -> str:
"""Post-process the result to ensure it matches the expected format"""
"""Post-process the result to ensure it matches the expected format."""
if not self._output_format:
return result

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any
from pydantic import Field, PrivateAttr
@@ -29,13 +29,13 @@ except ImportError:
class OpenAIAgentAdapter(BaseAgentAdapter):
"""Adapter for OpenAI Assistants"""
"""Adapter for OpenAI Assistants."""
model_config = {"arbitrary_types_allowed": True}
_openai_agent: "OpenAIAgent" = PrivateAttr()
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
_active_thread: Optional[str] = PrivateAttr(default=None)
_active_thread: str | None = PrivateAttr(default=None)
function_calling_llm: Any = Field(default=None)
step_callback: Any = Field(default=None)
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
@@ -44,35 +44,35 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
def __init__(
self,
model: str = "gpt-4o-mini",
tools: Optional[List[BaseTool]] = None,
agent_config: Optional[dict] = None,
tools: list[BaseTool] | None = None,
agent_config: dict | None = None,
**kwargs,
):
) -> None:
if not OPENAI_AVAILABLE:
msg = "OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
raise ImportError(
"OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
msg,
)
else:
role = kwargs.pop("role", None)
goal = kwargs.pop("goal", None)
backstory = kwargs.pop("backstory", None)
super().__init__(
role=role,
goal=goal,
backstory=backstory,
tools=tools,
agent_config=agent_config,
**kwargs,
)
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
self.llm = model
self._converter_adapter = OpenAIConverterAdapter(self)
role = kwargs.pop("role", None)
goal = kwargs.pop("goal", None)
backstory = kwargs.pop("backstory", None)
super().__init__(
role=role,
goal=goal,
backstory=backstory,
tools=tools,
agent_config=agent_config,
**kwargs,
)
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
self.llm = model
self._converter_adapter = OpenAIConverterAdapter(self)
def _build_system_prompt(self) -> str:
"""Build a system prompt for the OpenAI agent."""
base_prompt = f"""
You are {self.role}.
Your goal is: {self.goal}
Your backstory: {self.backstory}
@@ -84,10 +84,10 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
def execute_task(
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task using the OpenAI Assistant"""
"""Execute a task using the OpenAI Assistant."""
self._converter_adapter.configure_structured_output(task)
self.create_agent_executor(tools)
@@ -98,7 +98,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
task_prompt = task.prompt()
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context
task=task_prompt, context=context,
)
crewai_event_bus.emit(
self,
@@ -114,13 +114,13 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(
agent=self, task=task, output=final_answer
agent=self, task=task, output=final_answer,
),
)
return final_answer
except Exception as e:
self._logger.log("error", f"Error executing OpenAI task: {str(e)}")
self._logger.log("error", f"Error executing OpenAI task: {e!s}")
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -131,9 +131,8 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
)
raise
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
"""
Configure the OpenAI agent for execution.
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
"""Configure the OpenAI agent for execution.
While OpenAI handles execution differently through Runner,
we can use this method to set up tools and configurations.
"""
@@ -152,27 +151,27 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
self.agent_executor = Runner
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure tools for the OpenAI Assistant"""
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure tools for the OpenAI Assistant."""
if tools:
self._tool_adapter.configure_tools(tools)
if self._tool_adapter.converted_tools:
self._openai_agent.tools = self._tool_adapter.converted_tools
def handle_execution_result(self, result: Any) -> str:
"""Process OpenAI Assistant execution result converting any structured output to a string"""
"""Process OpenAI Assistant execution result converting any structured output to a string."""
return self._converter_adapter.post_process_result(result.final_output)
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
"""Implement delegation tools support"""
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
"""Implement delegation tools support."""
agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools()
return tools
return agent_tools.tools()
def configure_structured_output(self, task) -> None:
"""Configure the structured output for the specific agent implementation.
Args:
structured_output: The structured output to be configured
"""
self._converter_adapter.configure_structured_output(task)

View File

@@ -1,5 +1,5 @@
import inspect
from typing import Any, List, Optional
from typing import Any
from agents import FunctionTool, Tool
@@ -8,42 +8,36 @@ from crewai.tools import BaseTool
class OpenAIAgentToolAdapter(BaseToolAdapter):
"""Adapter for OpenAI Assistant tools"""
"""Adapter for OpenAI Assistant tools."""
def __init__(self, tools: Optional[List[BaseTool]] = None):
def __init__(self, tools: list[BaseTool] | None = None) -> None:
self.original_tools = tools or []
def configure_tools(self, tools: List[BaseTool]) -> None:
"""Configure tools for the OpenAI Assistant"""
if self.original_tools:
all_tools = tools + self.original_tools
else:
all_tools = tools
def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure tools for the OpenAI Assistant."""
all_tools = tools + self.original_tools if self.original_tools else tools
if all_tools:
self.converted_tools = self._convert_tools_to_openai_format(all_tools)
def _convert_tools_to_openai_format(
self, tools: Optional[List[BaseTool]]
) -> List[Tool]:
"""Convert CrewAI tools to OpenAI Assistant tool format"""
self, tools: list[BaseTool] | None,
) -> list[Tool]:
"""Convert CrewAI tools to OpenAI Assistant tool format."""
if not tools:
return []
def sanitize_tool_name(name: str) -> str:
"""Convert tool name to match OpenAI's required pattern"""
"""Convert tool name to match OpenAI's required pattern."""
import re
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
return sanitized
return re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
def create_tool_wrapper(tool: BaseTool):
"""Create a wrapper function that handles the OpenAI function tool interface"""
"""Create a wrapper function that handles the OpenAI function tool interface."""
async def wrapper(context_wrapper: Any, arguments: Any) -> Any:
# Get the parameter name from the schema
param_name = list(
tool.args_schema.model_json_schema()["properties"].keys()
)[0]
param_name = next(iter(tool.args_schema.model_json_schema()["properties"].keys()))
# Handle different argument types
if isinstance(arguments, dict):

View File

@@ -7,8 +7,7 @@ from crewai.utilities.i18n import I18N
class OpenAIConverterAdapter(BaseConverterAdapter):
"""
Adapter for handling structured output conversion in OpenAI agents.
"""Adapter for handling structured output conversion in OpenAI agents.
This adapter enhances the OpenAI agent to handle structured output formats
and post-processes the results when needed.
@@ -17,21 +16,22 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
_output_format: The expected output format (json, pydantic, or None)
_schema: The schema description for the expected output
_output_model: The Pydantic model for the output
"""
def __init__(self, agent_adapter):
"""Initialize the converter adapter with a reference to the agent adapter"""
def __init__(self, agent_adapter) -> None:
"""Initialize the converter adapter with a reference to the agent adapter."""
self.agent_adapter = agent_adapter
self._output_format = None
self._schema = None
self._output_model = None
def configure_structured_output(self, task) -> None:
"""
Configure the structured output for OpenAI agent based on task requirements.
"""Configure the structured output for OpenAI agent based on task requirements.
Args:
task: The task containing output format requirements
"""
# Reset configuration
self._output_format = None
@@ -55,14 +55,14 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
self._output_model = task.output_pydantic
def enhance_system_prompt(self, base_prompt: str) -> str:
"""
Enhance the base system prompt with structured output requirements if needed.
"""Enhance the base system prompt with structured output requirements if needed.
Args:
base_prompt: The original system prompt
Returns:
Enhanced system prompt with output format instructions if needed
"""
if not self._output_format:
return base_prompt
@@ -76,8 +76,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
return f"{base_prompt}\n\n{output_schema}"
def post_process_result(self, result: str) -> str:
"""
Post-process the result to ensure it matches the expected format.
"""Post-process the result to ensure it matches the expected format.
This method attempts to extract valid JSON from the result if necessary.
@@ -86,6 +85,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
Returns:
Processed result conforming to the expected output format
"""
if not self._output_format:
return result

View File

@@ -1,8 +1,9 @@
import uuid
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, TypeVar
from pydantic import (
UUID4,
@@ -14,6 +15,7 @@ from pydantic import (
model_validator,
)
from pydantic_core import PydanticCustomError
from typing_extensions import Self
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
@@ -25,7 +27,6 @@ from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.utilities import I18N, Logger, RPMController
from crewai.utilities.config import process_config
from crewai.utilities.converter import Converter
from crewai.utilities.string_utils import interpolate_only
T = TypeVar("T", bound="BaseAgent")
@@ -77,30 +78,31 @@ class BaseAgent(ABC, BaseModel):
Set the rpm controller for the agent.
set_private_attrs() -> "BaseAgent":
Set private attributes.
"""
__hash__ = object.__hash__ # type: ignore
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
_rpm_controller: RPMController | None = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
_original_role: Optional[str] = PrivateAttr(default=None)
_original_goal: Optional[str] = PrivateAttr(default=None)
_original_backstory: Optional[str] = PrivateAttr(default=None)
_original_role: str | None = PrivateAttr(default=None)
_original_goal: str | None = PrivateAttr(default=None)
_original_backstory: str | None = PrivateAttr(default=None)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
config: Optional[Dict[str, Any]] = Field(
description="Configuration for the agent", default=None, exclude=True
config: dict[str, Any] | None = Field(
description="Configuration for the agent", default=None, exclude=True,
)
cache: bool = Field(
default=True, description="Whether the agent should use a cache for tool usage."
default=True, description="Whether the agent should use a cache for tool usage.",
)
verbose: bool = Field(
default=False, description="Verbose mode for the Agent Execution"
default=False, description="Verbose mode for the Agent Execution",
)
max_rpm: Optional[int] = Field(
max_rpm: int | None = Field(
default=None,
description="Maximum number of requests per minute for the agent execution to be respected.",
)
@@ -108,41 +110,41 @@ class BaseAgent(ABC, BaseModel):
default=False,
description="Enable agent to delegate and ask questions among each other.",
)
tools: Optional[List[BaseTool]] = Field(
default_factory=list, description="Tools at agents' disposal"
tools: list[BaseTool] | None = Field(
default_factory=list, description="Tools at agents' disposal",
)
max_iter: int = Field(
default=25, description="Maximum iterations for an agent to execute a task"
default=25, description="Maximum iterations for an agent to execute a task",
)
agent_executor: InstanceOf = Field(
default=None, description="An instance of the CrewAgentExecutor class."
default=None, description="An instance of the CrewAgentExecutor class.",
)
llm: Any = Field(
default=None, description="Language model that will run the agent."
default=None, description="Language model that will run the agent.",
)
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
default=None, description="An instance of the CacheHandler class."
cache_handler: InstanceOf[CacheHandler] | None = Field(
default=None, description="An instance of the CacheHandler class.",
)
tools_handler: InstanceOf[ToolsHandler] = Field(
default_factory=ToolsHandler,
description="An instance of the ToolsHandler class.",
)
tools_results: List[Dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent.",
)
max_tokens: Optional[int] = Field(
default=None, description="Maximum number of tokens for the agent's execution."
max_tokens: int | None = Field(
default=None, description="Maximum number of tokens for the agent's execution.",
)
knowledge: Optional[Knowledge] = Field(
default=None, description="Knowledge for the agent."
knowledge: Knowledge | None = Field(
default=None, description="Knowledge for the agent.",
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
default=None,
description="Knowledge sources for the agent.",
)
knowledge_storage: Optional[Any] = Field(
knowledge_storage: Any | None = Field(
default=None,
description="Custom knowledge storage for the agent.",
)
@@ -150,13 +152,13 @@ class BaseAgent(ABC, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the agent, including fingerprinting.",
)
callbacks: List[Callable] = Field(
default=[], description="Callbacks to be used for the agent"
callbacks: list[Callable] = Field(
default=[], description="Callbacks to be used for the agent",
)
adapted_agent: bool = Field(
default=False, description="Whether the agent is adapted"
default=False, description="Whether the agent is adapted",
)
knowledge_config: Optional[KnowledgeConfig] = Field(
knowledge_config: KnowledgeConfig | None = Field(
default=None,
description="Knowledge configuration for the agent such as limits and threshold",
)
@@ -168,7 +170,7 @@ class BaseAgent(ABC, BaseModel):
@field_validator("tools")
@classmethod
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
@@ -188,11 +190,14 @@ class BaseAgent(ABC, BaseModel):
# Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool))
else:
raise ValueError(
msg = (
f"Invalid tool type: {type(tool)}. "
"Tool must be an instance of BaseTool or "
"an object with 'name', 'func', and 'description' attributes."
)
raise ValueError(
msg,
)
return processed_tools
@model_validator(mode="after")
@@ -200,15 +205,16 @@ class BaseAgent(ABC, BaseModel):
# Validate required fields
for field in ["role", "goal", "backstory"]:
if getattr(self, field) is None:
msg = f"{field} must be provided either directly or through config"
raise ValueError(
f"{field} must be provided either directly or through config"
msg,
)
# Set private attributes
self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller:
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger
max_rpm=self.max_rpm, logger=self._logger,
)
if not self._token_process:
self._token_process = TokenProcess()
@@ -221,10 +227,11 @@ class BaseAgent(ABC, BaseModel):
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
if v:
msg = "may_not_set_field"
raise PydanticCustomError(
"may_not_set_field", "This field is not to be set by the user.", {}
msg, "This field is not to be set by the user.", {},
)
@model_validator(mode="after")
@@ -233,7 +240,7 @@ class BaseAgent(ABC, BaseModel):
self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller:
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger
max_rpm=self.max_rpm, logger=self._logger,
)
if not self._token_process:
self._token_process = TokenProcess()
@@ -252,8 +259,8 @@ class BaseAgent(ABC, BaseModel):
def execute_task(
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
pass
@@ -262,11 +269,10 @@ class BaseAgent(ABC, BaseModel):
pass
@abstractmethod
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
"""Set the task tools that init BaseAgenTools class."""
pass
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
"""Create a deep copy of the Agent."""
exclude = {
"id",
@@ -309,7 +315,7 @@ class BaseAgent(ABC, BaseModel):
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = type(self)(
return type(self)(
**copied_data,
llm=existing_llm,
tools=self.tools,
@@ -318,9 +324,8 @@ class BaseAgent(ABC, BaseModel):
knowledge_storage=copied_knowledge_storage,
)
return copied_agent
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolate inputs into the agent description and backstory."""
if self._original_role is None:
self._original_role = self.role
@@ -331,13 +336,13 @@ class BaseAgent(ABC, BaseModel):
if inputs:
self.role = interpolate_only(
input_string=self._original_role, inputs=inputs
input_string=self._original_role, inputs=inputs,
)
self.goal = interpolate_only(
input_string=self._original_goal, inputs=inputs
input_string=self._original_goal, inputs=inputs,
)
self.backstory = interpolate_only(
input_string=self._original_backstory, inputs=inputs
input_string=self._original_backstory, inputs=inputs,
)
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
@@ -345,6 +350,7 @@ class BaseAgent(ABC, BaseModel):
Args:
cache_handler: An instance of the CacheHandler class.
"""
self.tools_handler = ToolsHandler()
if self.cache:
@@ -357,10 +363,11 @@ class BaseAgent(ABC, BaseModel):
Args:
rpm_controller: An instance of the RPMController class.
"""
if not self._rpm_controller:
self._rpm_controller = rpm_controller
self.create_agent_executor()
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
pass

View File

@@ -1,3 +1,4 @@
import contextlib
import time
from typing import TYPE_CHECKING
@@ -43,8 +44,7 @@ class CrewAgentExecutorMixin:
},
agent=self.agent.role,
)
except Exception as e:
print(f"Failed to add to short term memory: {e}")
except Exception:
pass
def _create_external_memory(self, output) -> None:
@@ -56,7 +56,7 @@ class CrewAgentExecutorMixin:
and hasattr(self.crew, "_external_memory")
and self.crew._external_memory
):
try:
with contextlib.suppress(Exception):
self.crew._external_memory.save(
value=output.text,
metadata={
@@ -64,9 +64,6 @@ class CrewAgentExecutorMixin:
},
agent=self.agent.role,
)
except Exception as e:
print(f"Failed to add to external memory: {e}")
pass
def _create_long_term_memory(self, output) -> None:
"""Create and save long-term and entity memory items based on evaluation."""
@@ -103,15 +100,13 @@ class CrewAgentExecutorMixin:
type=entity.type,
description=entity.description,
relationships="\n".join(
[f"- {r}" for r in entity.relationships]
[f"- {r}" for r in entity.relationships],
),
)
self.crew._entity_memory.save(entity_memory)
except AttributeError as e:
print(f"Missing attributes for long term memory: {e}")
except AttributeError:
pass
except Exception as e:
print(f"Failed to add to long term memory: {e}")
except Exception:
pass
elif (
self.crew
@@ -126,7 +121,7 @@ class CrewAgentExecutorMixin:
def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging."""
self._printer.print(
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m",
)
# Training mode prompt (single iteration)

View File

@@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field
class OutputConverter(BaseModel, ABC):
"""
Abstract base class for converting task results into structured formats.
"""Abstract base class for converting task results into structured formats.
This class provides a framework for converting unstructured text into
either Pydantic models or JSON, tailored for specific agent requirements.
@@ -19,6 +18,7 @@ class OutputConverter(BaseModel, ABC):
model (Any): The target model for structuring the output.
instructions (str): Specific instructions for the conversion process.
max_attempts (int): Maximum number of conversion attempts (default: 3).
"""
text: str = Field(description="Text to be converted.")
@@ -33,9 +33,7 @@ class OutputConverter(BaseModel, ABC):
@abstractmethod
def to_pydantic(self, current_attempt=1) -> BaseModel:
"""Convert text to pydantic."""
pass
@abstractmethod
def to_json(self, current_attempt=1) -> dict:
"""Convert text to json."""
pass

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, PrivateAttr
@@ -6,10 +6,10 @@ from pydantic import BaseModel, PrivateAttr
class CacheHandler(BaseModel):
"""Callback handler for tool usage."""
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
_cache: dict[str, Any] = PrivateAttr(default_factory=dict)
def add(self, tool, input, output):
def add(self, tool, input, output) -> None:
self._cache[f"{tool}-{input}"] = output
def read(self, tool, input) -> Optional[str]:
def read(self, tool, input) -> str | None:
return self._cache.get(f"{tool}-{input}")

View File

@@ -1,6 +1,5 @@
import json
import re
from typing import Any, Callable, Dict, List, Optional, Union
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
@@ -10,8 +9,6 @@ from crewai.agents.parser import (
OutputParserException,
)
from crewai.agents.tools_handler import ToolsHandler
from crewai.llm import BaseLLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer
@@ -34,6 +31,10 @@ from crewai.utilities.logger import Logger
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.training_handler import CrewTrainingHandler
if TYPE_CHECKING:
from crewai.llm import BaseLLM
from crewai.tools.base_tool import BaseTool
class CrewAgentExecutor(CrewAgentExecutorMixin):
_logger: Logger = Logger()
@@ -46,18 +47,22 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
agent: BaseAgent,
prompt: dict[str, str],
max_iter: int,
tools: List[CrewStructuredTool],
tools: list[CrewStructuredTool],
tools_names: str,
stop_words: List[str],
stop_words: list[str],
tools_description: str,
tools_handler: ToolsHandler,
step_callback: Any = None,
original_tools: List[Any] = [],
original_tools: list[Any] | None = None,
function_calling_llm: Any = None,
respect_context_window: bool = False,
request_within_rpm_limit: Optional[Callable[[], bool]] = None,
callbacks: List[Any] = [],
):
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
) -> None:
if callbacks is None:
callbacks = []
if original_tools is None:
original_tools = []
self._i18n: I18N = I18N()
self.llm: BaseLLM = llm
self.task = task
@@ -79,10 +84,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.ask_for_human_input = False
self.messages: List[Dict[str, str]] = []
self.messages: list[dict[str, str]] = []
self.iterations = 0
self.log_error_after = 3
self.tool_name_to_tool_map: Dict[str, Union[CrewStructuredTool, BaseTool]] = {
self.tool_name_to_tool_map: dict[str, CrewStructuredTool | BaseTool] = {
tool.name: tool for tool in self.tools
}
existing_stop = self.llm.stop or []
@@ -90,11 +95,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
set(
existing_stop + self.stop
if isinstance(existing_stop, list)
else self.stop
)
else self.stop,
),
)
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
def invoke(self, inputs: dict[str, str]) -> dict[str, Any]:
if "system" in self.prompt:
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
user_prompt = self._format_prompt(self.prompt.get("user", ""), inputs)
@@ -120,9 +125,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
handle_unknown_error(self._printer, e)
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
raise e
else:
raise e
raise
raise
if self.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer)
@@ -133,8 +137,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return {"output": formatted_answer.output}
def _invoke_loop(self) -> AgentFinish:
"""
Main loop to invoke the agent's thought process until it reaches a conclusion
"""Main loop to invoke the agent's thought process until it reaches a conclusion
or the maximum number of iterations is reached.
"""
formatted_answer = None
@@ -170,8 +173,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
):
fingerprint_context = {
"agent_fingerprint": str(
self.agent.security_config.fingerprint
)
self.agent.security_config.fingerprint,
),
}
tool_result = execute_tool_and_check_finality(
@@ -187,7 +190,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
function_calling_llm=self.function_calling_llm,
)
formatted_answer = self._handle_agent_action(
formatted_answer, tool_result
formatted_answer, tool_result,
)
self._invoke_step_callback(formatted_answer)
@@ -205,7 +208,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
raise e
raise
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
@@ -216,9 +219,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
i18n=self._i18n,
)
continue
else:
handle_unknown_error(self._printer, e)
raise e
handle_unknown_error(self._printer, e)
raise
finally:
self.iterations += 1
@@ -231,8 +233,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return formatted_answer
def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult
) -> Union[AgentAction, AgentFinish]:
self, formatted_answer: AgentAction, tool_result: ToolResult,
) -> AgentAction | AgentFinish:
"""Handle the AgentAction, execute tools, and process the results."""
# Special case for add_image_tool
add_image_tool = self._i18n.tools("add_image")
@@ -261,24 +263,26 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"""Append a message to the message list with the given role."""
self.messages.append(format_message_for_llm(text, role=role))
def _show_start_logs(self):
def _show_start_logs(self) -> None:
"""Show logs for the start of agent execution."""
if self.agent is None:
raise ValueError("Agent cannot be None")
msg = "Agent cannot be None"
raise ValueError(msg)
show_agent_logs(
printer=self._printer,
agent_role=self.agent.role,
task_description=(
getattr(self.task, "description") if self.task else "Not Found"
self.task.description if self.task else "Not Found"
),
verbose=self.agent.verbose
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
)
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
"""Show logs for the agent's execution."""
if self.agent is None:
raise ValueError("Agent cannot be None")
msg = "Agent cannot be None"
raise ValueError(msg)
show_agent_logs(
printer=self._printer,
agent_role=self.agent.role,
@@ -300,11 +304,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
summary = self.llm.call(
[
format_message_for_llm(
self._i18n.slice("summarizer_system_message"), role="system"
self._i18n.slice("summarizer_system_message"), role="system",
),
format_message_for_llm(
self._i18n.slice("summarize_instruction").format(
group=group["content"]
group=group["content"],
),
),
],
@@ -316,12 +320,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages = [
format_message_for_llm(
self._i18n.slice("summary").format(merged_summary=merged_summary)
)
self._i18n.slice("summary").format(merged_summary=merged_summary),
),
]
def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: Optional[str] = None
self, result: AgentFinish, human_feedback: str | None = None,
) -> None:
"""Handle the process of saving training data."""
agent_id = str(self.agent.id) # type: ignore
@@ -348,29 +352,27 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"initial_output": result.output,
"human_feedback": human_feedback,
}
# Save improved output
elif train_iteration in agent_training_data:
agent_training_data[train_iteration]["improved_output"] = result.output
else:
# Save improved output
if train_iteration in agent_training_data:
agent_training_data[train_iteration]["improved_output"] = result.output
else:
self._printer.print(
content=(
f"No existing training data for agent {agent_id} and iteration "
f"{train_iteration}. Cannot save improved output."
),
color="red",
)
return
self._printer.print(
content=(
f"No existing training data for agent {agent_id} and iteration "
f"{train_iteration}. Cannot save improved output."
),
color="red",
)
return
# Update the training data and save
training_data[agent_id] = agent_training_data
training_handler.save(training_data)
def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str:
def _format_prompt(self, prompt: str, inputs: dict[str, str]) -> str:
prompt = prompt.replace("{input}", inputs["input"])
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
prompt = prompt.replace("{tools}", inputs["tools"])
return prompt
return prompt.replace("{tools}", inputs["tools"])
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
"""Handle human feedback with different flows for training vs regular use.
@@ -380,6 +382,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
AgentFinish: The final answer after processing feedback
"""
human_feedback = self._ask_human_input(formatted_answer.output)
@@ -393,14 +396,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return bool(self.crew and self.crew._train)
def _handle_training_feedback(
self, initial_answer: AgentFinish, feedback: str
self, initial_answer: AgentFinish, feedback: str,
) -> AgentFinish:
"""Process feedback for training scenarios with single iteration."""
self._handle_crew_training_output(initial_answer, feedback)
self.messages.append(
format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback)
)
self._i18n.slice("feedback_instructions").format(feedback=feedback),
),
)
improved_answer = self._invoke_loop()
self._handle_crew_training_output(improved_answer)
@@ -408,7 +411,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return improved_answer
def _handle_regular_feedback(
self, current_answer: AgentFinish, initial_feedback: str
self, current_answer: AgentFinish, initial_feedback: str,
) -> AgentFinish:
"""Process feedback for regular use with potential multiple iterations."""
feedback = initial_feedback
@@ -428,8 +431,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"""Process a single feedback iteration."""
self.messages.append(
format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback)
)
self._i18n.slice("feedback_instructions").format(feedback=feedback),
),
)
return self._invoke_loop()

View File

@@ -1,5 +1,5 @@
import re
from typing import Any, Optional, Union
from typing import Any
from json_repair import repair_json
@@ -18,7 +18,7 @@ class AgentAction:
text: str
result: str
def __init__(self, thought: str, tool: str, tool_input: str, text: str):
def __init__(self, thought: str, tool: str, tool_input: str, text: str) -> None:
self.thought = thought
self.tool = tool
self.tool_input = tool_input
@@ -30,7 +30,7 @@ class AgentFinish:
output: str
text: str
def __init__(self, thought: str, output: str, text: str):
def __init__(self, thought: str, output: str, text: str) -> None:
self.thought = thought
self.output = output
self.text = text
@@ -39,7 +39,7 @@ class AgentFinish:
class OutputParserException(Exception):
error: str
def __init__(self, error: str):
def __init__(self, error: str) -> None:
self.error = error
@@ -67,24 +67,24 @@ class CrewAgentParser:
_i18n: I18N = I18N()
agent: Any = None
def __init__(self, agent: Optional[Any] = None):
def __init__(self, agent: Any | None = None) -> None:
self.agent = agent
@staticmethod
def parse_text(text: str) -> Union[AgentAction, AgentFinish]:
"""
Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
def parse_text(text: str) -> AgentAction | AgentFinish:
"""Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
Args:
text: The text to parse.
Returns:
Either an AgentAction or AgentFinish based on the parsed content.
"""
parser = CrewAgentParser()
return parser.parse(text)
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
def parse(self, text: str) -> AgentAction | AgentFinish:
thought = self._extract_thought(text)
includes_answer = FINAL_ANSWER_ACTION in text
regex = (
@@ -102,7 +102,7 @@ class CrewAgentParser:
final_answer = final_answer[:-3].rstrip()
return AgentFinish(thought, final_answer, text)
elif action_match:
if action_match:
action = action_match.group(1)
clean_action = self._clean_action(action)
@@ -114,21 +114,21 @@ class CrewAgentParser:
return AgentAction(thought, clean_action, safe_tool_input, text)
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
msg = f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}"
raise OutputParserException(
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}",
msg,
)
elif not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
if not re.search(
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL,
):
raise OutputParserException(
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
)
else:
format = self._i18n.slice("format_without_tools")
error = f"{format}"
raise OutputParserException(
error,
)
format = self._i18n.slice("format_without_tools")
error = f"{format}"
raise OutputParserException(
error,
)
def _extract_thought(self, text: str) -> str:
thought_index = text.find("\nAction")
@@ -138,8 +138,7 @@ class CrewAgentParser:
return ""
thought = text[:thought_index].strip()
# Remove any triple backticks from the thought string
thought = thought.replace("```", "").strip()
return thought
return thought.replace("```", "").strip()
def _clean_action(self, text: str) -> str:
"""Clean action string by removing non-essential formatting characters."""

View File

@@ -1,7 +1,8 @@
from typing import Any, Optional, Union
from typing import Any
from crewai.tools.cache_tools.cache_tools import CacheTools
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
from ..tools.cache_tools.cache_tools import CacheTools
from ..tools.tool_calling import InstructorToolCalling, ToolCalling
from .cache.cache_handler import CacheHandler
@@ -9,16 +10,16 @@ class ToolsHandler:
"""Callback handler for tool usage."""
last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling")
cache: Optional[CacheHandler]
cache: CacheHandler | None
def __init__(self, cache: Optional[CacheHandler] = None):
def __init__(self, cache: CacheHandler | None = None) -> None:
"""Initialize the callback handler."""
self.cache = cache
self.last_used_tool = {} # type: ignore # BUG?: same as above
def on_tool_use(
self,
calling: Union[ToolCalling, InstructorToolCalling],
calling: ToolCalling | InstructorToolCalling,
output: str,
should_cache: bool = True,
) -> Any:

View File

@@ -9,9 +9,9 @@ def add_crew_to_flow(crew_name: str) -> None:
"""Add a new crew to the current flow."""
# Check if pyproject.toml exists in the current directory
if not Path("pyproject.toml").exists():
print("This command must be run from the root of a flow project.")
msg = "This command must be run from the root of a flow project."
raise click.ClickException(
"This command must be run from the root of a flow project."
msg,
)
# Determine the flow folder based on the current directory
@@ -19,8 +19,8 @@ def add_crew_to_flow(crew_name: str) -> None:
crews_folder = flow_folder / "src" / flow_folder.name / "crews"
if not crews_folder.exists():
print("Crews folder does not exist in the current flow.")
raise click.ClickException("Crews folder does not exist in the current flow.")
msg = "Crews folder does not exist in the current flow."
raise click.ClickException(msg)
# Create the crew within the flow's crews directory
create_embedded_crew(crew_name, parent_folder=crews_folder)
@@ -39,7 +39,7 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
if crew_folder.exists():
if not click.confirm(
f"Crew {folder_name} already exists. Do you want to override it?"
f"Crew {folder_name} already exists. Do you want to override it?",
):
click.secho("Operation cancelled.", fg="yellow")
return
@@ -66,5 +66,5 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
copy_template(src_file, dst_file, crew_name, class_name, folder_name)
click.secho(
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True,
)

View File

@@ -1,6 +1,6 @@
import time
import webbrowser
from typing import Any, Dict
from typing import Any
import requests
from rich.console import Console
@@ -17,38 +17,37 @@ class AuthenticationCommand:
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
def __init__(self):
def __init__(self) -> None:
self.token_manager = TokenManager()
def signup(self) -> None:
"""Sign up to CrewAI+"""
"""Sign up to CrewAI+."""
console.print("Signing Up to CrewAI+ \n", style="bold blue")
device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data)
return self._poll_for_token(device_code_data)
def _get_device_code(self) -> Dict[str, Any]:
def _get_device_code(self) -> dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
"client_id": AUTH0_CLIENT_ID,
"scope": "openid",
"audience": AUTH0_AUDIENCE,
}
response = requests.post(
url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20
url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20,
)
response.raise_for_status()
return response.json()
def _display_auth_instructions(self, device_code_data: Dict[str, str]) -> None:
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
"""Display the authentication instructions to the user."""
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
"""Poll the server for the token."""
token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
@@ -81,7 +80,7 @@ class AuthenticationCommand:
)
console.print(
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n"
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n",
)
return
@@ -92,5 +91,5 @@ class AuthenticationCommand:
attempts += 1
console.print(
"Timeout: Failed to get the token. Please try again.", style="bold red"
"Timeout: Failed to get the token. Please try again.", style="bold red",
)

View File

@@ -5,5 +5,5 @@ def get_auth_token() -> str:
"""Get the authentication token."""
access_token = TokenManager().get_token()
if not access_token:
raise Exception()
raise Exception
return access_token

View File

@@ -3,7 +3,6 @@ import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from auth0.authentication.token_verifier import (
AsymmetricSignatureVerifier,
@@ -15,8 +14,7 @@ from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
def validate_token(id_token: str) -> None:
"""
Verify the token and its precedence
"""Verify the token and its precedence.
:param id_token:
"""
@@ -24,15 +22,14 @@ def validate_token(id_token: str) -> None:
issuer = f"https://{AUTH0_DOMAIN}/"
signature_verifier = AsymmetricSignatureVerifier(jwks_url)
token_verifier = TokenVerifier(
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID,
)
token_verifier.verify(id_token)
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
"""Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
"""
@@ -41,8 +38,7 @@ class TokenManager:
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key.
"""Get or create the encryption key.
:return: The encryption key.
"""
@@ -57,8 +53,7 @@ class TokenManager:
return new_key
def save_tokens(self, access_token: str, expires_in: int) -> None:
"""
Save the access token and its expiration time.
"""Save the access token and its expiration time.
:param access_token: The access token to save.
:param expires_in: The expiration time of the access token in seconds.
@@ -71,9 +66,8 @@ class TokenManager:
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]:
"""
Get the access token if it is valid and not expired.
def get_token(self) -> str | None:
"""Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
"""
@@ -89,8 +83,7 @@ class TokenManager:
return data["access_token"]
def get_secure_storage_path(self) -> Path:
"""
Get the secure storage path based on the operating system.
"""Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
@@ -112,8 +105,7 @@ class TokenManager:
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
"""Save the content to a secure file.
:param filename: The name of the file.
:param content: The content to save.
@@ -127,9 +119,8 @@ class TokenManager:
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]:
"""
Read the content of a secure file.
def read_secure_file(self, filename: str) -> bytes | None:
"""Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.

View File

@@ -1,6 +1,4 @@
import os
from importlib.metadata import version as get_version
from typing import Optional, Tuple
import click
@@ -28,7 +26,7 @@ from .update_crew import update_crew
@click.group()
@click.version_option(get_version("crewai"))
def crewai():
def crewai() -> None:
"""Top-level command group for crewai."""
@@ -37,7 +35,7 @@ def crewai():
@click.argument("name")
@click.option("--provider", type=str, help="The provider to use for the crew")
@click.option("--skip_provider", is_flag=True, help="Skip provider validation")
def create(type, name, provider, skip_provider=False):
def create(type, name, provider, skip_provider=False) -> None:
"""Create a new crew, or flow."""
if type == "crew":
create_crew(name, provider, skip_provider)
@@ -49,9 +47,9 @@ def create(type, name, provider, skip_provider=False):
@crewai.command()
@click.option(
"--tools", is_flag=True, help="Show the installed version of crewai tools"
"--tools", is_flag=True, help="Show the installed version of crewai tools",
)
def version(tools):
def version(tools) -> None:
"""Show the installed version of crewai."""
try:
crewai_version = get_version("crewai")
@@ -82,7 +80,7 @@ def version(tools):
default="trained_agents_data.pkl",
help="Path to a custom file for training",
)
def train(n_iterations: int, filename: str):
def train(n_iterations: int, filename: str) -> None:
"""Train the crew."""
click.echo(f"Training the Crew for {n_iterations} iterations")
train_crew(n_iterations, filename)
@@ -96,11 +94,11 @@ def train(n_iterations: int, filename: str):
help="Replay the crew from this task ID, including all subsequent tasks.",
)
def replay(task_id: str) -> None:
"""
Replay the crew execution from a specific task.
"""Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
try:
click.echo(f"Replaying the crew from task {task_id}")
@@ -111,16 +109,14 @@ def replay(task_id: str) -> None:
@crewai.command()
def log_tasks_outputs() -> None:
"""
Retrieve your latest crew.kickoff() task outputs.
"""
"""Retrieve your latest crew.kickoff() task outputs."""
try:
storage = KickoffTaskOutputsSQLiteStorage()
tasks = storage.load()
if not tasks:
click.echo(
"No task outputs found. Only crew kickoff task outputs are logged."
"No task outputs found. Only crew kickoff task outputs are logged.",
)
return
@@ -153,13 +149,11 @@ def reset_memories(
kickoff_outputs: bool,
all: bool,
) -> None:
"""
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.
"""
"""Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved."""
try:
if not all and not (long or short or entities or knowledge or kickoff_outputs):
click.echo(
"Please specify at least one memory type to reset using the appropriate flags."
"Please specify at least one memory type to reset using the appropriate flags.",
)
return
reset_memories_command(long, short, entities, knowledge, kickoff_outputs, all)
@@ -182,71 +176,69 @@ def reset_memories(
default="gpt-4o-mini",
help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.",
)
def test(n_iterations: int, model: str):
def test(n_iterations: int, model: str) -> None:
"""Test the crew and evaluate the results."""
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
evaluate_crew(n_iterations, model)
@crewai.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
context_settings={
"ignore_unknown_options": True,
"allow_extra_args": True,
},
)
@click.pass_context
def install(context):
def install(context) -> None:
"""Install the Crew."""
install_crew(context.args)
@crewai.command()
def run():
def run() -> None:
"""Run the Crew."""
run_crew()
@crewai.command()
def update():
def update() -> None:
"""Update the pyproject.toml of the Crew project to use uv."""
update_crew()
@crewai.command()
def signup():
def signup() -> None:
"""Sign Up/Login to CrewAI+."""
AuthenticationCommand().signup()
@crewai.command()
def login():
def login() -> None:
"""Sign Up/Login to CrewAI+."""
AuthenticationCommand().signup()
# DEPLOY CREWAI+ COMMANDS
@crewai.group()
def deploy():
def deploy() -> None:
"""Deploy the Crew CLI group."""
pass
@crewai.group()
def tool():
def tool() -> None:
"""Tool Repository related commands."""
pass
@deploy.command(name="create")
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
def deploy_create(yes: bool):
def deploy_create(yes: bool) -> None:
"""Create a Crew deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.create_crew(yes)
@deploy.command(name="list")
def deploy_list():
def deploy_list() -> None:
"""List all deployments."""
deploy_cmd = DeployCommand()
deploy_cmd.list_crews()
@@ -254,7 +246,7 @@ def deploy_list():
@deploy.command(name="push")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_push(uuid: Optional[str]):
def deploy_push(uuid: str | None) -> None:
"""Deploy the Crew."""
deploy_cmd = DeployCommand()
deploy_cmd.deploy(uuid=uuid)
@@ -262,7 +254,7 @@ def deploy_push(uuid: Optional[str]):
@deploy.command(name="status")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deply_status(uuid: Optional[str]):
def deply_status(uuid: str | None) -> None:
"""Get the status of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_status(uuid=uuid)
@@ -270,7 +262,7 @@ def deply_status(uuid: Optional[str]):
@deploy.command(name="logs")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_logs(uuid: Optional[str]):
def deploy_logs(uuid: str | None) -> None:
"""Get the logs of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_logs(uuid=uuid)
@@ -278,7 +270,7 @@ def deploy_logs(uuid: Optional[str]):
@deploy.command(name="remove")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_remove(uuid: Optional[str]):
def deploy_remove(uuid: str | None) -> None:
"""Remove a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.remove_crew(uuid=uuid)
@@ -286,14 +278,14 @@ def deploy_remove(uuid: Optional[str]):
@tool.command(name="create")
@click.argument("handle")
def tool_create(handle: str):
def tool_create(handle: str) -> None:
tool_cmd = ToolCommand()
tool_cmd.create(handle)
@tool.command(name="install")
@click.argument("handle")
def tool_install(handle: str):
def tool_install(handle: str) -> None:
tool_cmd = ToolCommand()
tool_cmd.login()
tool_cmd.install(handle)
@@ -309,27 +301,26 @@ def tool_install(handle: str):
)
@click.option("--public", "is_public", flag_value=True, default=False)
@click.option("--private", "is_public", flag_value=False)
def tool_publish(is_public: bool, force: bool):
def tool_publish(is_public: bool, force: bool) -> None:
tool_cmd = ToolCommand()
tool_cmd.login()
tool_cmd.publish(is_public, force)
@crewai.group()
def flow():
def flow() -> None:
"""Flow related commands."""
pass
@flow.command(name="kickoff")
def flow_run():
def flow_run() -> None:
"""Kickoff the Flow."""
click.echo("Running the Flow")
kickoff_flow()
@flow.command(name="plot")
def flow_plot():
def flow_plot() -> None:
"""Plot the Flow."""
click.echo("Plotting the Flow")
plot_flow()
@@ -337,20 +328,19 @@ def flow_plot():
@flow.command(name="add-crew")
@click.argument("crew_name")
def flow_add_crew(crew_name):
def flow_add_crew(crew_name) -> None:
"""Add a crew to an existing flow."""
click.echo(f"Adding crew {crew_name} to the flow")
add_crew_to_flow(crew_name)
@crewai.command()
def chat():
"""
Start a conversation with the Crew, collecting user-supplied inputs,
def chat() -> None:
"""Start a conversation with the Crew, collecting user-supplied inputs,
and using the Chat LLM to generate responses.
"""
click.secho(
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
)
run_chat()

View File

@@ -10,13 +10,13 @@ console = Console()
class BaseCommand:
def __init__(self):
def __init__(self) -> None:
self._telemetry = Telemetry()
self._telemetry.set_tracer()
class PlusAPIMixin:
def __init__(self, telemetry):
def __init__(self, telemetry) -> None:
try:
telemetry.set_tracer()
self.plus_api_client = PlusAPI(api_key=get_auth_token())
@@ -30,11 +30,11 @@ class PlusAPIMixin:
raise SystemExit
def _validate_response(self, response: requests.Response) -> None:
"""
Handle and display error messages from API responses.
"""Handle and display error messages from API responses.
Args:
response (requests.Response): The response from the Plus API
"""
try:
json_response = response.json()
@@ -55,13 +55,13 @@ class PlusAPIMixin:
for field, messages in json_response.items():
for message in messages:
console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}"
f"* [bold red]{field.capitalize()}[/bold red] {message}",
)
raise SystemExit
if not response.ok:
console.print(
"Request to Enterprise API failed. Details:", style="bold red"
"Request to Enterprise API failed. Details:", style="bold red",
)
details = (
json_response.get("error")

View File

@@ -1,6 +1,5 @@
import json
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
@@ -8,16 +7,16 @@ DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
class Settings(BaseModel):
tool_repository_username: Optional[str] = Field(
None, description="Username for interacting with the Tool Repository"
tool_repository_username: str | None = Field(
None, description="Username for interacting with the Tool Repository",
)
tool_repository_password: Optional[str] = Field(
None, description="Password for interacting with the Tool Repository"
tool_repository_password: str | None = Field(
None, description="Password for interacting with the Tool Repository",
)
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True)
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
"""Load Settings from config path"""
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data) -> None:
"""Load Settings from config path."""
config_path.parent.mkdir(parents=True, exist_ok=True)
file_data = {}
@@ -32,7 +31,7 @@ class Settings(BaseModel):
super().__init__(config_path=config_path, **merged_data)
def dump(self) -> None:
"""Save current settings to settings.json"""
"""Save current settings to settings.json."""
if self.config_path.is_file():
with self.config_path.open("r") as f:
existing_data = json.load(f)

View File

@@ -3,31 +3,31 @@ ENV_VARS = {
{
"prompt": "Enter your OPENAI API key (press Enter to skip)",
"key_name": "OPENAI_API_KEY",
}
},
],
"anthropic": [
{
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
"key_name": "ANTHROPIC_API_KEY",
}
},
],
"gemini": [
{
"prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)",
"key_name": "GEMINI_API_KEY",
}
},
],
"nvidia_nim": [
{
"prompt": "Enter your NVIDIA API key (press Enter to skip)",
"key_name": "NVIDIA_NIM_API_KEY",
}
},
],
"groq": [
{
"prompt": "Enter your GROQ API key (press Enter to skip)",
"key_name": "GROQ_API_KEY",
}
},
],
"watson": [
{
@@ -47,7 +47,7 @@ ENV_VARS = {
{
"default": True,
"API_BASE": "http://localhost:11434",
}
},
],
"bedrock": [
{
@@ -101,7 +101,7 @@ ENV_VARS = {
{
"prompt": "Enter your SambaNovaCloud API key (press Enter to skip)",
"key_name": "SAMBANOVA_API_KEY",
}
},
],
}

View File

@@ -24,7 +24,7 @@ def create_folder_structure(name, parent_folder=None):
if folder_path.exists():
if not click.confirm(
f"Folder {folder_name} already exists. Do you want to override it?"
f"Folder {folder_name} already exists. Do you want to override it?",
):
click.secho("Operation cancelled.", fg="yellow")
sys.exit(0)
@@ -48,7 +48,7 @@ def create_folder_structure(name, parent_folder=None):
return folder_path, folder_name, class_name
def copy_template_files(folder_path, name, class_name, parent_folder):
def copy_template_files(folder_path, name, class_name, parent_folder) -> None:
package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew"
@@ -89,7 +89,7 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
copy_template(src_file, dst_file, name, class_name, folder_path.name)
def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
def create_crew(name, provider=None, skip_provider=False, parent_folder=None) -> None:
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path)
if not skip_provider:
@@ -109,7 +109,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
if existing_provider:
if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?",
):
click.secho("Keeping existing provider configuration.", fg="yellow")
return
@@ -126,11 +126,11 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
if selected_provider: # Valid selection
break
click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red"
"No provider selected. Please try again or press 'q' to exit.", fg="red",
)
# Check if the selected provider has predefined models
if selected_provider in MODELS and MODELS[selected_provider]:
if MODELS.get(selected_provider):
while True:
selected_model = select_model(selected_provider, provider_models)
if selected_model is None: # User typed 'q'
@@ -167,7 +167,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
click.secho("API keys and model saved to .env file", fg="green")
else:
click.secho(
"No API keys provided. Skipping .env file creation.", fg="yellow"
"No API keys provided. Skipping .env file creation.", fg="yellow",
)
click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")

View File

@@ -5,7 +5,7 @@ import click
from crewai.telemetry import Telemetry
def create_flow(name):
def create_flow(name) -> None:
"""Create a new flow."""
folder_name = name.replace(" ", "_").replace("-", "_").lower()
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
@@ -43,12 +43,12 @@ def create_flow(name):
"poem_crew",
]
def process_file(src_file, dst_file):
def process_file(src_file, dst_file) -> None:
if src_file.suffix in [".pyc", ".pyo", ".pyd"]:
return
try:
with open(src_file, "r", encoding="utf-8") as file:
with open(src_file, encoding="utf-8") as file:
content = file.read()
except Exception as e:
click.secho(f"Error processing file {src_file}: {e}", fg="red")

View File

@@ -5,7 +5,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any
import click
import tomli
@@ -22,10 +22,9 @@ MIN_REQUIRED_VERSION = "0.98.0"
def check_conversational_crews_version(
crewai_version: str, pyproject_data: dict
crewai_version: str, pyproject_data: dict,
) -> bool:
"""
Check if the installed crewAI version supports conversational crews.
"""Check if the installed crewAI version supports conversational crews.
Args:
crewai_version: The current version of crewAI.
@@ -33,6 +32,7 @@ def check_conversational_crews_version(
Returns:
bool: True if version check passes, False otherwise.
"""
try:
if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION):
@@ -48,9 +48,8 @@ def check_conversational_crews_version(
return True
def run_chat():
"""
Runs an interactive chat loop using the Crew's chat LLM with function calling.
def run_chat() -> None:
"""Runs an interactive chat loop using the Crew's chat LLM with function calling.
Incorporates crew_name, crew_description, and input fields to build a tool schema.
Exits if crew_name or crew_description are missing.
"""
@@ -84,7 +83,7 @@ def run_chat():
# Call the LLM to generate the introductory message
introductory_message = chat_llm.call(
messages=[{"role": "system", "content": system_message}]
messages=[{"role": "system", "content": system_message}],
)
finally:
# Stop loading indicator
@@ -108,15 +107,13 @@ def run_chat():
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
def show_loading(event: threading.Event):
def show_loading(event: threading.Event) -> None:
"""Display animated loading dots while processing."""
while not event.is_set():
print(".", end="", flush=True)
time.sleep(1)
print()
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
"""Initializes the chat LLM and handles exceptions."""
try:
return create_llm(crew.chat_llm)
@@ -157,7 +154,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
)
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
"""Creates a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs):
@@ -166,7 +163,7 @@ def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
return run_crew_tool_with_messages
def flush_input():
def flush_input() -> None:
"""Flush any pending input from the user."""
if platform.system() == "Windows":
# Windows platform
@@ -181,7 +178,7 @@ def flush_input():
termios.tcflush(sys.stdin, termios.TCIFLUSH)
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions) -> None:
"""Main chat loop for interacting with the user."""
while True:
try:
@@ -190,7 +187,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
user_input = get_user_input()
handle_user_input(
user_input, chat_llm, messages, crew_tool_schema, available_functions
user_input, chat_llm, messages, crew_tool_schema, available_functions,
)
except KeyboardInterrupt:
@@ -221,9 +218,9 @@ def get_user_input() -> str:
def handle_user_input(
user_input: str,
chat_llm: LLM,
messages: List[Dict[str, str]],
crew_tool_schema: Dict[str, Any],
available_functions: Dict[str, Any],
messages: list[dict[str, str]],
crew_tool_schema: dict[str, Any],
available_functions: dict[str, Any],
) -> None:
if user_input.strip().lower() == "exit":
click.echo("Exiting chat. Goodbye!")
@@ -251,8 +248,7 @@ def handle_user_input(
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
"""
Dynamically build a Littellm 'function' schema for the given crew.
"""Dynamically build a Littellm 'function' schema for the given crew.
crew_name: The name of the crew (used for the function 'name').
crew_inputs: A ChatInputs object containing crew_description
@@ -281,9 +277,8 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
}
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
"""
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
"""Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
Args:
crew (Crew): The crew instance to run.
@@ -295,6 +290,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
Raises:
SystemExit: Exits the chat if an error occurs during crew execution.
"""
try:
# Serialize 'messages' to JSON string before adding to kwargs
@@ -304,9 +300,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
crew_output = crew.kickoff(inputs=kwargs)
# Convert CrewOutput to a string to send back to the user
result = str(crew_output)
return str(crew_output)
return result
except Exception as e:
# Exit the chat and show the error message
click.secho("An error occurred while running the crew:", fg="red")
@@ -314,12 +309,12 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
sys.exit(1)
def load_crew_and_name() -> Tuple[Crew, str]:
"""
Loads the crew by importing the crew class from the user's project.
def load_crew_and_name() -> tuple[Crew, str]:
"""Loads the crew by importing the crew class from the user's project.
Returns:
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
"""
# Get the current working directory
cwd = Path.cwd()
@@ -327,7 +322,8 @@ def load_crew_and_name() -> Tuple[Crew, str]:
# Path to the pyproject.toml file
pyproject_path = cwd / "pyproject.toml"
if not pyproject_path.exists():
raise FileNotFoundError("pyproject.toml not found in the current directory.")
msg = "pyproject.toml not found in the current directory."
raise FileNotFoundError(msg)
# Load the pyproject.toml file using 'tomli'
with pyproject_path.open("rb") as f:
@@ -351,14 +347,16 @@ def load_crew_and_name() -> Tuple[Crew, str]:
try:
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
except ImportError as e:
raise ImportError(f"Failed to import crew module {crew_module_name}: {e}")
msg = f"Failed to import crew module {crew_module_name}: {e}"
raise ImportError(msg)
# Get the crew class from the module
try:
crew_class = getattr(crew_module, crew_class_name)
except AttributeError:
msg = f"Crew class {crew_class_name} not found in module {crew_module_name}"
raise AttributeError(
f"Crew class {crew_class_name} not found in module {crew_module_name}"
msg,
)
# Instantiate the crew
@@ -367,8 +365,7 @@ def load_crew_and_name() -> Tuple[Crew, str]:
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
"""
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
"""Generates the ChatInputs required for the crew by analyzing the tasks and agents.
Args:
crew (Crew): The crew object containing tasks and agents.
@@ -377,6 +374,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
Returns:
ChatInputs: An object containing the crew's name, description, and input fields.
"""
# Extract placeholders from tasks and agents
required_inputs = fetch_required_inputs(crew)
@@ -391,22 +389,22 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
crew_description = generate_crew_description_with_ai(crew, chat_llm)
return ChatInputs(
crew_name=crew_name, crew_description=crew_description, inputs=input_fields
crew_name=crew_name, crew_description=crew_description, inputs=input_fields,
)
def fetch_required_inputs(crew: Crew) -> Set[str]:
"""
Extracts placeholders from the crew's tasks and agents.
def fetch_required_inputs(crew: Crew) -> set[str]:
"""Extracts placeholders from the crew's tasks and agents.
Args:
crew (Crew): The crew object.
Returns:
Set[str]: A set of placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: Set[str] = set()
required_inputs: set[str] = set()
# Scan tasks
for task in crew.tasks:
@@ -422,8 +420,7 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
"""
Generates an input description using AI based on the context of the crew.
"""Generates an input description using AI based on the context of the crew.
Args:
input_name (str): The name of the input placeholder.
@@ -432,6 +429,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
Returns:
str: A concise description of the input.
"""
# Gather context from tasks and agents where the input is used
context_texts = []
@@ -444,10 +442,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
):
# Replace placeholders with input names
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or ""
lambda m: m.group(1), task.description or "",
)
expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output or ""
lambda m: m.group(1), task.expected_output or "",
)
context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}")
@@ -461,7 +459,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub(
lambda m: m.group(1), agent.backstory or ""
lambda m: m.group(1), agent.backstory or "",
)
context_texts.append(f"Agent Role: {agent_role}")
context_texts.append(f"Agent Goal: {agent_goal}")
@@ -470,7 +468,8 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
context = "\n".join(context_texts)
if not context:
# If no context is found for the input, raise an exception as per instruction
raise ValueError(f"No context found for input '{input_name}'.")
msg = f"No context found for input '{input_name}'."
raise ValueError(msg)
prompt = (
f"Based on the following context, write a concise description (15 words or less) of the input '{input_name}'.\n"
@@ -479,14 +478,12 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
f"{context}"
)
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
description = response.strip()
return response.strip()
return description
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
"""
Generates a brief description of the crew using AI.
"""Generates a brief description of the crew using AI.
Args:
crew (Crew): The crew object.
@@ -494,6 +491,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
Returns:
str: A concise description of the crew's purpose (15 words or less).
"""
# Gather context from tasks and agents
context_texts = []
@@ -502,10 +500,10 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
for task in crew.tasks:
# Replace placeholders with input names
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description or ""
lambda m: m.group(1), task.description or "",
)
expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output or ""
lambda m: m.group(1), task.expected_output or "",
)
context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}")
@@ -514,7 +512,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
agent_backstory = placeholder_pattern.sub(
lambda m: m.group(1), agent.backstory or ""
lambda m: m.group(1), agent.backstory or "",
)
context_texts.append(f"Agent Role: {agent_role}")
context_texts.append(f"Agent Goal: {agent_goal}")
@@ -522,7 +520,8 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
context = "\n".join(context_texts)
if not context:
raise ValueError("No context found for generating crew description.")
msg = "No context found for generating crew description."
raise ValueError(msg)
prompt = (
"Based on the following context, write a concise, action-oriented description (15 words or less) of the crew's purpose.\n"
@@ -531,6 +530,5 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
f"{context}"
)
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
crew_description = response.strip()
return response.strip()
return crew_description

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any
from rich.console import Console
@@ -10,34 +10,27 @@ console = Console()
class DeployCommand(BaseCommand, PlusAPIMixin):
"""
A class to handle deployment-related operations for CrewAI projects.
"""
def __init__(self):
"""
Initialize the DeployCommand with project name and API client.
"""
"""A class to handle deployment-related operations for CrewAI projects."""
def __init__(self) -> None:
"""Initialize the DeployCommand with project name and API client."""
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
self.project_name = get_project_name(require=True)
def _standard_no_param_error_message(self) -> None:
"""
Display a standard error message when no UUID or project name is available.
"""
"""Display a standard error message when no UUID or project name is available."""
console.print(
"No UUID provided, project pyproject.toml not found or with error.",
style="bold red",
)
def _display_deployment_info(self, json_response: Dict[str, Any]) -> None:
"""
Display deployment information.
def _display_deployment_info(self, json_response: dict[str, Any]) -> None:
"""Display deployment information.
Args:
json_response (Dict[str, Any]): The deployment information to display.
"""
console.print("Deploying the crew...\n", style="bold blue")
for key, value in json_response.items():
@@ -47,24 +40,24 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
console.print(" or")
console.print(f"crewai deploy status --uuid \"{json_response['uuid']}\"")
def _display_logs(self, log_messages: List[Dict[str, Any]]) -> None:
"""
Display log messages.
def _display_logs(self, log_messages: list[dict[str, Any]]) -> None:
"""Display log messages.
Args:
log_messages (List[Dict[str, Any]]): The log messages to display.
"""
for log_message in log_messages:
console.print(
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}"
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}",
)
def deploy(self, uuid: Optional[str] = None) -> None:
"""
Deploy a crew using either UUID or project name.
def deploy(self, uuid: str | None = None) -> None:
"""Deploy a crew using either UUID or project name.
Args:
uuid (Optional[str]): The UUID of the crew to deploy.
"""
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
console.print("Starting deployment...", style="bold blue")
@@ -80,9 +73,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._display_deployment_info(response.json())
def create_crew(self, confirm: bool = False) -> None:
"""
Create a new crew deployment.
"""
"""Create a new crew deployment."""
self._create_crew_deployment_span = (
self._telemetry.create_crew_deployment_span()
)
@@ -110,29 +101,28 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._display_creation_success(response.json())
def _confirm_input(
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
self, env_vars: dict[str, str], remote_repo_url: str, confirm: bool,
) -> None:
"""
Confirm input parameters with the user.
"""Confirm input parameters with the user.
Args:
env_vars (Dict[str, str]): Environment variables.
remote_repo_url (str): Remote repository URL.
confirm (bool): Whether to confirm input.
"""
if not confirm:
input(f"Press Enter to continue with the following Env vars: {env_vars}")
input(
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n"
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n",
)
def _create_payload(
self,
env_vars: Dict[str, str],
env_vars: dict[str, str],
remote_repo_url: str,
) -> Dict[str, Any]:
"""
Create the payload for crew creation.
) -> dict[str, Any]:
"""Create the payload for crew creation.
Args:
remote_repo_url (str): Remote repository URL.
@@ -140,25 +130,26 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
Returns:
Dict[str, Any]: The payload for crew creation.
"""
return {
"deploy": {
"name": self.project_name,
"repo_clone_url": remote_repo_url,
"env": env_vars,
}
},
}
def _display_creation_success(self, json_response: Dict[str, Any]) -> None:
"""
Display success message after crew creation.
def _display_creation_success(self, json_response: dict[str, Any]) -> None:
"""Display success message after crew creation.
Args:
json_response (Dict[str, Any]): The response containing crew information.
"""
console.print("Deployment created successfully!\n", style="bold green")
console.print(
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green"
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green",
)
console.print(f"Status: {json_response['status']}", style="bold green")
console.print("\nTo (re)deploy the crew, run:")
@@ -167,9 +158,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
console.print(f"crewai deploy push --uuid {json_response['uuid']}")
def list_crews(self) -> None:
"""
List all available crews.
"""
"""List all available crews."""
console.print("Listing all Crews\n", style="bold blue")
response = self.plus_api_client.list_crews()
@@ -179,31 +168,29 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
else:
self._display_no_crews_message()
def _display_crews(self, crews_data: List[Dict[str, Any]]) -> None:
"""
Display the list of crews.
def _display_crews(self, crews_data: list[dict[str, Any]]) -> None:
"""Display the list of crews.
Args:
crews_data (List[Dict[str, Any]]): List of crew data to display.
"""
for crew_data in crews_data:
console.print(
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]"
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]",
)
def _display_no_crews_message(self) -> None:
"""
Display a message when no crews are available.
"""
"""Display a message when no crews are available."""
console.print("You don't have any Crews yet. Let's create one!", style="yellow")
console.print(" crewai create crew <crew_name>", style="green")
def get_crew_status(self, uuid: Optional[str] = None) -> None:
"""
Get the status of a crew.
def get_crew_status(self, uuid: str | None = None) -> None:
"""Get the status of a crew.
Args:
uuid (Optional[str]): The UUID of the crew to check.
"""
console.print("Fetching deployment status...", style="bold blue")
if uuid:
@@ -217,23 +204,23 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._validate_response(response)
self._display_crew_status(response.json())
def _display_crew_status(self, status_data: Dict[str, str]) -> None:
"""
Display the status of a crew.
def _display_crew_status(self, status_data: dict[str, str]) -> None:
"""Display the status of a crew.
Args:
status_data (Dict[str, str]): The status data to display.
"""
console.print(f"Name:\t {status_data['name']}")
console.print(f"Status:\t {status_data['status']}")
def get_crew_logs(self, uuid: Optional[str], log_type: str = "deployment") -> None:
"""
Get logs for a crew.
def get_crew_logs(self, uuid: str | None, log_type: str = "deployment") -> None:
"""Get logs for a crew.
Args:
uuid (Optional[str]): The UUID of the crew to get logs for.
log_type (str): The type of logs to retrieve (default: "deployment").
"""
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
console.print(f"Fetching {log_type} logs...", style="bold blue")
@@ -249,12 +236,12 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._validate_response(response)
self._display_logs(response.json())
def remove_crew(self, uuid: Optional[str]) -> None:
"""
Remove a crew deployment.
def remove_crew(self, uuid: str | None) -> None:
"""Remove a crew deployment.
Args:
uuid (Optional[str]): The UUID of the crew to remove.
"""
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
console.print("Removing deployment...", style="bold blue")
@@ -269,9 +256,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
if response.status_code == 204:
console.print(
f"Crew '{self.project_name}' removed successfully.", style="green"
f"Crew '{self.project_name}' removed successfully.", style="green",
)
else:
console.print(
f"Failed to remove crew '{self.project_name}'", style="bold red"
f"Failed to remove crew '{self.project_name}'", style="bold red",
)

View File

@@ -4,18 +4,19 @@ import click
def evaluate_crew(n_iterations: int, model: str) -> None:
"""
Test and Evaluate the crew by running a command in the UV environment.
"""Test and Evaluate the crew by running a command in the UV environment.
Args:
n_iterations (int): The number of iterations to test the crew.
model (str): The model to test the crew with.
"""
command = ["uv", "run", "test", str(n_iterations), model]
try:
if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.")
msg = "The number of iterations must be a positive integer."
raise ValueError(msg)
result = subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -1,16 +1,18 @@
import subprocess
from functools import lru_cache
from functools import cache
class Repository:
def __init__(self, path="."):
def __init__(self, path=".") -> None:
self.path = path
if not self.is_git_installed():
raise ValueError("Git is not installed or not found in your PATH.")
msg = "Git is not installed or not found in your PATH."
raise ValueError(msg)
if not self.is_git_repo():
raise ValueError(f"{self.path} is not a Git repository.")
msg = f"{self.path} is not a Git repository."
raise ValueError(msg)
self.fetch()
@@ -18,7 +20,7 @@ class Repository:
"""Check if Git is installed and available in the system."""
try:
subprocess.run(
["git", "--version"], capture_output=True, check=True, text=True
["git", "--version"], capture_output=True, check=True, text=True,
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
@@ -36,7 +38,7 @@ class Repository:
encoding="utf-8",
).strip()
@lru_cache(maxsize=None)
@cache
def is_git_repo(self) -> bool:
"""Check if the current directory is a git repository."""
try:
@@ -62,10 +64,7 @@ class Repository:
def is_synced(self) -> bool:
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
return False
else:
return True
return not (self.has_uncommitted_changes() or self.is_ahead_or_behind())
def origin_url(self) -> str | None:
"""Get the Git repository's remote URL."""

View File

@@ -8,11 +8,9 @@ import click
# so if you expect this to support more things you will need to replicate it there
# ask @joaomdmoura if you are unsure
def install_crew(proxy_options: list[str]) -> None:
"""
Install the crew by running the UV command to lock and install.
"""
"""Install the crew by running the UV command to lock and install."""
try:
command = ["uv", "sync"] + proxy_options
command = ["uv", "sync", *proxy_options]
subprocess.run(command, check=True, capture_output=False, text=True)
except subprocess.CalledProcessError as e:

View File

@@ -4,9 +4,7 @@ import click
def kickoff_flow() -> None:
"""
Kickoff the flow by running a command in the UV environment.
"""
"""Kickoff the flow by running a command in the UV environment."""
command = ["uv", "run", "kickoff"]
try:

View File

@@ -4,9 +4,7 @@ import click
def plot_flow() -> None:
"""
Plot the flow by running a command in the UV environment.
"""
"""Plot the flow by running a command in the UV environment."""
command = ["uv", "run", "plot"]
try:

View File

@@ -1,5 +1,4 @@
from os import getenv
from typing import Optional
from urllib.parse import urljoin
import requests
@@ -8,9 +7,7 @@ from crewai.cli.version import get_crewai_version
class PlusAPI:
"""
This class exposes methods for working with the CrewAI+ API.
"""
"""This class exposes methods for working with the CrewAI+ API."""
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
@@ -42,7 +39,7 @@ class PlusAPI:
handle: str,
is_public: bool,
version: str,
description: Optional[str],
description: str | None,
encoded_file: str,
):
params = {
@@ -56,7 +53,7 @@ class PlusAPI:
def deploy_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy",
)
def deploy_by_uuid(self, uuid: str) -> requests.Response:
@@ -64,29 +61,29 @@ class PlusAPI:
def crew_status_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status",
)
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment"
self, project_name: str, log_type: str = "deployment",
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}",
)
def crew_by_uuid(
self, uuid: str, log_type: str = "deployment"
self, uuid: str, log_type: str = "deployment",
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}",
)
def delete_crew_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}",
)
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:

View File

@@ -10,8 +10,7 @@ from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
def select_choice(prompt_message, choices):
"""
Presents a list of choices to the user and prompts them to select one.
"""Presents a list of choices to the user and prompts them to select one.
Args:
- prompt_message (str): The message to display to the user before presenting the choices.
@@ -19,11 +18,11 @@ def select_choice(prompt_message, choices):
Returns:
- str: The selected choice from the list, or None if the user chooses to quit.
"""
"""
provider_models = get_provider_data()
if not provider_models:
return
return None
click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan")
@@ -31,7 +30,7 @@ def select_choice(prompt_message, choices):
while True:
choice = click.prompt(
"Enter the number of your choice or 'q' to quit", type=str
"Enter the number of your choice or 'q' to quit", type=str,
)
if choice.lower() == "q":
@@ -51,8 +50,7 @@ def select_choice(prompt_message, choices):
def select_provider(provider_models):
"""
Presents a list of providers to the user and prompts them to select one.
"""Presents a list of providers to the user and prompts them to select one.
Args:
- provider_models (dict): A dictionary of provider models.
@@ -60,12 +58,13 @@ def select_provider(provider_models):
Returns:
- str: The selected provider
- None: If user explicitly quits
"""
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice(
"Select a provider to set up:", predefined_providers + ["other"]
"Select a provider to set up:", [*predefined_providers, "other"],
)
if provider is None: # User typed 'q'
return None
@@ -79,8 +78,7 @@ def select_provider(provider_models):
def select_model(provider, provider_models):
"""
Presents a list of models for a given provider to the user and prompts them to select one.
"""Presents a list of models for a given provider to the user and prompts them to select one.
Args:
- provider (str): The provider for which to select a model.
@@ -88,6 +86,7 @@ def select_model(provider, provider_models):
Returns:
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
@@ -100,15 +99,13 @@ def select_model(provider, provider_models):
click.secho(f"No models available for provider '{provider}'.", fg="red")
return None
selected_model = select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models
return select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models,
)
return selected_model
def load_provider_data(cache_file, cache_expiry):
"""
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
"""Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
Args:
- cache_file (Path): The path to the cache file.
@@ -116,6 +113,7 @@ def load_provider_data(cache_file, cache_expiry):
Returns:
- dict or None: The loaded provider data or None if the operation fails.
"""
current_time = time.time()
if (
@@ -126,7 +124,7 @@ def load_provider_data(cache_file, cache_expiry):
if data:
return data
click.secho(
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
"Cache is corrupted. Fetching provider data from the web...", fg="yellow",
)
else:
click.secho(
@@ -137,31 +135,31 @@ def load_provider_data(cache_file, cache_expiry):
def read_cache_file(cache_file):
"""
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
"""Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
Args:
- cache_file (Path): The path to the cache file.
Returns:
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
"""
try:
with open(cache_file, "r") as f:
with open(cache_file) as f:
return json.load(f)
except json.JSONDecodeError:
return None
def fetch_provider_data(cache_file):
"""
Fetches provider data from a specified URL and caches it to a file.
"""Fetches provider data from a specified URL and caches it to a file.
Args:
- cache_file (Path): The path to the cache file.
Returns:
- dict or None: The fetched provider data or None if the operation fails.
"""
try:
response = requests.get(JSON_URL, stream=True, timeout=60)
@@ -178,20 +176,20 @@ def fetch_provider_data(cache_file):
def download_data(response):
"""
Downloads data from a given HTTP response and returns the JSON content.
"""Downloads data from a given HTTP response and returns the JSON content.
Args:
- response (requests.Response): The HTTP response object.
Returns:
- dict: The JSON content of the response.
"""
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
data_chunks = []
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
length=total_size, label="Downloading", show_pos=True,
) as progress_bar:
for chunk in response.iter_content(block_size):
if chunk:
@@ -202,11 +200,11 @@ def download_data(response):
def get_provider_data():
"""
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
"""Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
"""
cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True)

View File

@@ -4,11 +4,11 @@ import click
def replay_task_command(task_id: str) -> None:
"""
Replay the crew execution from a specific task.
"""Replay the crew execution from a specific task.
Args:
task_id (str): The ID of the task to replay from.
"""
command = ["uv", "run", "replay", task_id]

View File

@@ -13,8 +13,7 @@ def reset_memories_command(
kickoff_outputs,
all,
) -> None:
"""
Reset the crew memories.
"""Reset the crew memories.
Args:
long (bool): Whether to reset the long-term memory.
@@ -23,49 +22,50 @@ def reset_memories_command(
kickoff_outputs (bool): Whether to reset the latest kickoff task outputs.
all (bool): Whether to reset all memories.
knowledge (bool): Whether to reset the knowledge.
"""
"""
try:
if not any([long, short, entity, kickoff_outputs, knowledge, all]):
click.echo(
"No memory type specified. Please specify at least one type to reset."
"No memory type specified. Please specify at least one type to reset.",
)
return
crews = get_crews()
if not crews:
raise ValueError("No crew found.")
msg = "No crew found."
raise ValueError(msg)
for crew in crews:
if all:
crew.reset_memories(command_type="all")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed."
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed.",
)
continue
if long:
crew.reset_memories(command_type="long")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset."
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset.",
)
if short:
crew.reset_memories(command_type="short")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset."
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset.",
)
if entity:
crew.reset_memories(command_type="entity")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset."
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset.",
)
if kickoff_outputs:
crew.reset_memories(command_type="kickoff_outputs")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset."
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset.",
)
if knowledge:
crew.reset_memories(command_type="knowledge")
click.echo(
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset."
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset.",
)
except subprocess.CalledProcessError as e:

View File

@@ -1,6 +1,5 @@
import subprocess
from enum import Enum
from typing import List, Optional
import click
from packaging import version
@@ -15,8 +14,7 @@ class CrewType(Enum):
def run_crew() -> None:
"""
Run the crew or flow by running a command in the UV environment.
"""Run the crew or flow by running a command in the UV environment.
Starting from version 0.103.0, this command can be used to run both
standard crews and flows. For flows, it detects the type from pyproject.toml
@@ -48,11 +46,11 @@ def run_crew() -> None:
def execute_command(crew_type: CrewType) -> None:
"""
Execute the appropriate command based on crew type.
"""Execute the appropriate command based on crew type.
Args:
crew_type: The type of crew to run
"""
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
@@ -67,12 +65,12 @@ def execute_command(crew_type: CrewType) -> None:
def handle_error(error: subprocess.CalledProcessError, crew_type: CrewType) -> None:
"""
Handle subprocess errors with appropriate messaging.
"""Handle subprocess errors with appropriate messaging.
Args:
error: The subprocess error that occurred
crew_type: The type of crew that was being run
"""
entity_type = "flow" if crew_type == CrewType.FLOW else "crew"
click.echo(f"An error occurred while running the {entity_type}: {error}", err=True)

View File

@@ -22,15 +22,13 @@ console = Console()
class ToolCommand(BaseCommand, PlusAPIMixin):
"""
A class to handle tool repository related operations for CrewAI projects.
"""
"""A class to handle tool repository related operations for CrewAI projects."""
def __init__(self):
def __init__(self) -> None:
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
def create(self, handle: str):
def create(self, handle: str) -> None:
self._ensure_not_in_project()
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
@@ -40,8 +38,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
if project_root.exists():
click.secho(f"Folder {folder_name} already exists.", fg="red")
raise SystemExit
else:
os.makedirs(project_root)
os.makedirs(project_root)
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
@@ -56,12 +53,12 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
self.login()
subprocess.run(["git", "init"], check=True)
console.print(
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]"
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]",
)
finally:
os.chdir(old_directory)
def publish(self, is_public: bool, force: bool = False):
def publish(self, is_public: bool, force: bool = False) -> None:
if not git.Repository().is_synced() and not force:
console.print(
"[bold red]Failed to publish tool.[/bold red]\n"
@@ -69,9 +66,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
"* [bold]Commit[/bold] your changes.\n"
"* [bold]Push[/bold] to sync with the remote.\n"
"* [bold]Pull[/bold] the latest changes from the remote.\n"
"\nOnce your repository is up-to-date, retry publishing the tool."
"\nOnce your repository is up-to-date, retry publishing the tool.",
)
raise SystemExit()
raise SystemExit
project_name = get_project_name(require=True)
assert isinstance(project_name, str)
@@ -90,7 +87,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
)
tarball_filename = next(
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None,
)
if not tarball_filename:
console.print(
@@ -123,7 +120,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
style="bold green",
)
def install(self, handle: str):
def install(self, handle: str) -> None:
get_response = self.plus_api_client.get_tool(handle)
if get_response.status_code == 404:
@@ -132,9 +129,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
style="bold red",
)
raise SystemExit
elif get_response.status_code != 200:
if get_response.status_code != 200:
console.print(
"Failed to get tool details. Please try again later.", style="bold red"
"Failed to get tool details. Please try again later.", style="bold red",
)
raise SystemExit
@@ -142,7 +139,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
console.print(f"Successfully installed {handle}", style="bold green")
def login(self):
def login(self) -> None:
login_response = self.plus_api_client.login_to_tool_repository()
if login_response.status_code != 200:
@@ -164,10 +161,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
settings.dump()
console.print(
"Successfully authenticated to the tool repository.", style="bold green"
"Successfully authenticated to the tool repository.", style="bold green",
)
def _add_package(self, tool_details):
def _add_package(self, tool_details) -> None:
tool_handle = tool_details["handle"]
repository_handle = tool_details["repository"]["handle"]
repository_url = tool_details["repository"]["url"]
@@ -192,16 +189,16 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
click.echo(add_package_result.stderr, err=True)
raise SystemExit
def _ensure_not_in_project(self):
def _ensure_not_in_project(self) -> None:
if os.path.isfile("./pyproject.toml"):
console.print(
"[bold red]Oops! It looks like you're inside a project.[/bold red]"
"[bold red]Oops! It looks like you're inside a project.[/bold red]",
)
console.print(
"You can't create a new tool while inside an existing project."
"You can't create a new tool while inside an existing project.",
)
console.print(
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again."
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again.",
)
raise SystemExit
@@ -211,10 +208,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or ""
settings.tool_repository_username or "",
)
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or ""
settings.tool_repository_password or "",
)
return env

View File

@@ -4,20 +4,22 @@ import click
def train_crew(n_iterations: int, filename: str) -> None:
"""
Train the crew by running a command in the UV environment.
"""Train the crew by running a command in the UV environment.
Args:
n_iterations (int): The number of iterations to train the crew.
"""
command = ["uv", "run", "train", str(n_iterations), filename]
try:
if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.")
msg = "The number of iterations must be a positive integer."
raise ValueError(msg)
if not filename.endswith(".pkl"):
raise ValueError("The filename must not end with .pkl")
msg = "The filename must not end with .pkl"
raise ValueError(msg)
result = subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -11,9 +11,8 @@ def update_crew() -> None:
migrate_pyproject("pyproject.toml", "pyproject.toml")
def migrate_pyproject(input_file, output_file):
"""
Migrate the pyproject.toml to the new format.
def migrate_pyproject(input_file, output_file) -> None:
"""Migrate the pyproject.toml to the new format.
This function is used to migrate the pyproject.toml to the new format.
And it will be used to migrate the pyproject.toml to the new format when uv is used.
@@ -81,7 +80,7 @@ def migrate_pyproject(input_file, output_file):
# Extract the module name from any existing script
existing_scripts = new_pyproject["project"]["scripts"]
module_name = next(
(value.split(".")[0] for value in existing_scripts.values() if "." in value)
value.split(".")[0] for value in existing_scripts.values() if "." in value
)
new_pyproject["project"]["scripts"]["run_crew"] = f"{module_name}.main:run"
@@ -93,22 +92,19 @@ def migrate_pyproject(input_file, output_file):
# Backup the old pyproject.toml
backup_file = "pyproject-old.toml"
shutil.copy2(input_file, backup_file)
print(f"Original pyproject.toml backed up as {backup_file}")
# Rename the poetry.lock file
lock_file = "poetry.lock"
lock_backup = "poetry-old.lock"
if os.path.exists(lock_file):
os.rename(lock_file, lock_backup)
print(f"Original poetry.lock renamed to {lock_backup}")
else:
print("No poetry.lock file found to rename.")
pass
# Write the new pyproject.toml
with open(output_file, "wb") as f:
tomli_w.dump(new_pyproject, f)
print(f"Migration complete. New pyproject.toml written to {output_file}")
def parse_version(version: str) -> str:

View File

@@ -3,7 +3,7 @@ import shutil
import sys
from functools import reduce
from inspect import isfunction, ismethod
from typing import Any, Dict, List, get_type_hints
from typing import Any, get_type_hints
import click
import tomli
@@ -19,9 +19,9 @@ if sys.version_info >= (3, 11):
console = Console()
def copy_template(src, dst, name, class_name, folder_name):
def copy_template(src, dst, name, class_name, folder_name) -> None:
"""Copy a file from src to dst."""
with open(src, "r") as file:
with open(src) as file:
content = file.read()
# Interpolate the content
@@ -39,8 +39,7 @@ def copy_template(src, dst, name, class_name, folder_name):
def read_toml(file_path: str = "pyproject.toml"):
"""Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f:
toml_dict = tomli.load(f)
return toml_dict
return tomli.load(f)
def parse_toml(content):
@@ -50,59 +49,56 @@ def parse_toml(content):
def get_project_name(
pyproject_path: str = "pyproject.toml", require: bool = False
pyproject_path: str = "pyproject.toml", require: bool = False,
) -> str | None:
"""Get the project name from the pyproject.toml file."""
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
def get_project_version(
pyproject_path: str = "pyproject.toml", require: bool = False
pyproject_path: str = "pyproject.toml", require: bool = False,
) -> str | None:
"""Get the project version from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "version"], require=require
pyproject_path, ["project", "version"], require=require,
)
def get_project_description(
pyproject_path: str = "pyproject.toml", require: bool = False
pyproject_path: str = "pyproject.toml", require: bool = False,
) -> str | None:
"""Get the project description from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["project", "description"], require=require
pyproject_path, ["project", "description"], require=require,
)
def _get_project_attribute(
pyproject_path: str, keys: List[str], require: bool
pyproject_path: str, keys: list[str], require: bool,
) -> Any | None:
"""Get an attribute from the pyproject.toml file."""
attribute = None
try:
with open(pyproject_path, "r") as f:
with open(pyproject_path) as f:
pyproject_content = parse_toml(f.read())
dependencies = (
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
)
if not any(True for dep in dependencies if "crewai" in dep):
raise Exception("crewai is not in the dependencies.")
msg = "crewai is not in the dependencies."
raise Exception(msg)
attribute = _get_nested_value(pyproject_content, keys)
except FileNotFoundError:
print(f"Error: {pyproject_path} not found.")
pass
except KeyError:
print(f"Error: {pyproject_path} is not a valid pyproject.toml file.")
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
print(
f"Error: {pyproject_path} is not a valid TOML file."
if sys.version_info >= (3, 11)
else f"Error reading the pyproject.toml file: {e}"
)
except Exception as e:
print(f"Error reading the pyproject.toml file: {e}")
pass
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception: # type: ignore
pass
except Exception:
pass
if require and not attribute:
console.print(
@@ -114,7 +110,7 @@ def _get_project_attribute(
return attribute
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
@@ -122,7 +118,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
"""Fetch the environment variables from a .env file and return them as a dictionary."""
try:
# Read the .env file
with open(env_file_path, "r") as f:
with open(env_file_path) as f:
env_content = f.read()
# Parse the .env file content to a dictionary
@@ -135,14 +131,14 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
return env_dict
except FileNotFoundError:
print(f"Error: {env_file_path} not found.")
except Exception as e:
print(f"Error reading the .env file: {e}")
pass
except Exception:
pass
return {}
def tree_copy(source, destination):
def tree_copy(source, destination) -> None:
"""Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source):
source_item = os.path.join(source, item)
@@ -153,7 +149,7 @@ def tree_copy(source, destination):
shutil.copy2(source_item, destination_item)
def tree_find_and_replace(directory, find, replace):
def tree_find_and_replace(directory, find, replace) -> None:
"""Recursively searches through a directory, replacing a target string in
both file contents and filenames with a specified replacement string.
"""
@@ -161,7 +157,7 @@ def tree_find_and_replace(directory, find, replace):
for filename in files:
filepath = os.path.join(path, filename)
with open(filepath, "r") as file:
with open(filepath) as file:
contents = file.read()
with open(filepath, "w") as file:
file.write(contents.replace(find, replace))
@@ -180,19 +176,19 @@ def tree_find_and_replace(directory, find, replace):
def load_env_vars(folder_path):
"""
Loads environment variables from a .env file in the specified folder path.
"""Loads environment variables from a .env file in the specified folder path.
Args:
- folder_path (Path): The path to the folder containing the .env file.
Returns:
- dict: A dictionary of environment variables.
"""
env_file_path = folder_path / ".env"
env_vars = {}
if env_file_path.exists():
with open(env_file_path, "r") as file:
with open(env_file_path) as file:
for line in file:
key, _, value = line.strip().partition("=")
if key and value:
@@ -201,8 +197,7 @@ def load_env_vars(folder_path):
def update_env_vars(env_vars, provider, model):
"""
Updates environment variables with the API key for the selected provider and model.
"""Updates environment variables with the API key for the selected provider and model.
Args:
- env_vars (dict): Environment variables dictionary.
@@ -211,6 +206,7 @@ def update_env_vars(env_vars, provider, model):
Returns:
- None
"""
api_key_var = ENV_VARS.get(
provider,
@@ -218,14 +214,14 @@ def update_env_vars(env_vars, provider, model):
click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
)
),
],
)[0]
if api_key_var not in env_vars:
try:
env_vars[api_key_var] = click.prompt(
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True,
)
except click.exceptions.Abort:
click.secho("Operation aborted by the user.", fg="red")
@@ -238,13 +234,13 @@ def update_env_vars(env_vars, provider, model):
return env_vars
def write_env_file(folder_path, env_vars):
"""
Writes environment variables to a .env file in the specified folder.
def write_env_file(folder_path, env_vars) -> None:
"""Writes environment variables to a .env file in the specified folder.
Args:
- folder_path (Path): The path to the folder where the .env file will be written.
- env_vars (dict): A dictionary of environment variables to write.
"""
env_file_path = folder_path / ".env"
with open(env_file_path, "w") as file:
@@ -263,7 +259,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
crew_os_path = os.path.join(root, crew_path)
try:
spec = importlib.util.spec_from_file_location(
"crew_module", crew_os_path
"crew_module", crew_os_path,
)
if not spec or not spec.loader:
continue
@@ -277,19 +273,16 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
try:
crew_instances.extend(fetch_crews(module_attr))
except Exception as e:
print(f"Error processing attribute {attr_name}: {e}")
except Exception:
continue
except Exception as exec_error:
print(f"Error executing module: {exec_error}")
except Exception:
import traceback
print(f"Traceback: {traceback.format_exc()}")
except (ImportError, AttributeError) as e:
if require:
console.print(
f"Error importing crew from {crew_path}: {str(e)}",
f"Error importing crew from {crew_path}: {e!s}",
style="bold red",
)
continue
@@ -303,7 +296,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
except Exception as e:
if require:
console.print(
f"Unexpected error while loading crew: {str(e)}", style="bold red"
f"Unexpected error while loading crew: {e!s}", style="bold red",
)
raise SystemExit
return crew_instances
@@ -317,13 +310,12 @@ def get_crew_instance(module_attr) -> Crew | None:
):
return module_attr().crew()
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
module_attr
module_attr,
).get("return") is Crew:
return module_attr()
elif isinstance(module_attr, Crew):
if isinstance(module_attr, Crew):
return module_attr
else:
return None
return None
def fetch_crews(module_attr) -> list[Crew]:

View File

@@ -2,5 +2,5 @@ import importlib.metadata
def get_crewai_version() -> str:
"""Get the version number of CrewAI running the CLI"""
"""Get the version number of CrewAI running the CLI."""
return importlib.metadata.version("crewai")

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@@ -12,27 +12,28 @@ class CrewOutput(BaseModel):
"""Class that represents the result of a crew."""
raw: str = Field(description="Raw output of crew", default="")
pydantic: Optional[BaseModel] = Field(
description="Pydantic output of Crew", default=None
pydantic: BaseModel | None = Field(
description="Pydantic output of Crew", default=None,
)
json_dict: Optional[Dict[str, Any]] = Field(
description="JSON dict output of Crew", default=None
json_dict: dict[str, Any] | None = Field(
description="JSON dict output of Crew", default=None,
)
tasks_output: list[TaskOutput] = Field(
description="Output of each task", default=[]
description="Output of each task", default=[],
)
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
@property
def json(self) -> Optional[str]:
def json(self) -> str | None:
if self.tasks_output[-1].output_format != OutputFormat.JSON:
msg = "No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
raise ValueError(
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
msg,
)
return json.dumps(self.json_dict)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary."""
output_dict = {}
if self.json_dict:
@@ -44,12 +45,12 @@ class CrewOutput(BaseModel):
def __getitem__(self, key):
if self.pydantic and hasattr(self.pydantic, key):
return getattr(self.pydantic, key)
elif self.json_dict and key in self.json_dict:
if self.json_dict and key in self.json_dict:
return self.json_dict[key]
else:
raise KeyError(f"Key '{key}' not found in CrewOutput.")
msg = f"Key '{key}' not found in CrewOutput."
raise KeyError(msg)
def __str__(self):
def __str__(self) -> str:
if self.pydantic:
return str(self.pydantic)
if self.json_dict:

View File

@@ -2,17 +2,11 @@ import asyncio
import copy
import inspect
import logging
from collections.abc import Callable
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Set,
Type,
TypeVar,
Union,
cast,
)
from uuid import uuid4
@@ -48,14 +42,14 @@ class FlowState(BaseModel):
# Type variables with explicit bounds
T = TypeVar(
"T", bound=Union[Dict[str, Any], BaseModel]
"T", bound=dict[str, Any] | BaseModel,
) # Generic flow state type parameter
StateT = TypeVar(
"StateT", bound=Union[Dict[str, Any], BaseModel]
"StateT", bound=dict[str, Any] | BaseModel,
) # State validation type parameter
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
"""Ensure state matches expected type with proper validation.
Args:
@@ -68,6 +62,7 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
Raises:
TypeError: If state doesn't match expected type
ValueError: If state validation fails
"""
"""Ensure state matches expected type with proper validation.
@@ -84,20 +79,22 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
"""
if expected_type is dict:
if not isinstance(state, dict):
raise TypeError(f"Expected dict, got {type(state).__name__}")
return cast(StateT, state)
msg = f"Expected dict, got {type(state).__name__}"
raise TypeError(msg)
return cast("StateT", state)
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
if not isinstance(state, expected_type):
msg = f"Expected {expected_type.__name__}, got {type(state).__name__}"
raise TypeError(
f"Expected {expected_type.__name__}, got {type(state).__name__}"
msg,
)
return cast(StateT, state)
raise TypeError(f"Invalid expected_type: {expected_type}")
return cast("StateT", state)
msg = f"Invalid expected_type: {expected_type}"
raise TypeError(msg)
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
"""
Marks a method as a flow's starting point.
def start(condition: str | dict | Callable | None = None) -> Callable:
"""Marks a method as a flow's starting point.
This decorator designates a method as an entry point for the flow execution.
It can optionally specify conditions that trigger the start based on other
@@ -135,6 +132,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
>>> @start(and_("method1", "method2")) # Start after multiple methods
>>> def complex_start(self):
... pass
"""
def decorator(func):
@@ -154,17 +152,17 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
msg = "Condition must be a method, string, or a result of or_() or and_()"
raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()"
msg,
)
return func
return decorator
def listen(condition: Union[str, dict, Callable]) -> Callable:
"""
Creates a listener that executes when specified conditions are met.
def listen(condition: str | dict | Callable) -> Callable:
"""Creates a listener that executes when specified conditions are met.
This decorator sets up a method to execute in response to other method
executions in the flow. It supports both simple and complex triggering
@@ -197,6 +195,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
>>> @listen(or_("success", "failure")) # Listen to multiple methods
>>> def handle_completion(self):
... pass
"""
def decorator(func):
@@ -214,17 +213,17 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
msg = "Condition must be a method, string, or a result of or_() or and_()"
raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()"
msg,
)
return func
return decorator
def router(condition: Union[str, dict, Callable]) -> Callable:
"""
Creates a routing method that directs flow execution based on conditions.
def router(condition: str | dict | Callable) -> Callable:
"""Creates a routing method that directs flow execution based on conditions.
This decorator marks a method as a router, which can dynamically determine
the next steps in the flow based on its return value. Routers are triggered
@@ -262,6 +261,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
... if all([self.state.valid, self.state.processed]):
... return CONTINUE
... return STOP
"""
def decorator(func):
@@ -280,17 +280,17 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
msg = "Condition must be a method, string, or a result of or_() or and_()"
raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()"
msg,
)
return func
return decorator
def or_(*conditions: Union[str, dict, Callable]) -> dict:
"""
Combines multiple conditions with OR logic for flow control.
def or_(*conditions: str | dict | Callable) -> dict:
"""Combines multiple conditions with OR logic for flow control.
Creates a condition that is satisfied when any of the specified conditions
are met. This is used with @start, @listen, or @router decorators to create
@@ -320,6 +320,7 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
>>> @listen(or_("success", "timeout"))
>>> def handle_completion(self):
... pass
"""
methods = []
for condition in conditions:
@@ -330,13 +331,13 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition)))
else:
raise ValueError("Invalid condition in or_()")
msg = "Invalid condition in or_()"
raise ValueError(msg)
return {"type": "OR", "methods": methods}
def and_(*conditions: Union[str, dict, Callable]) -> dict:
"""
Combines multiple conditions with AND logic for flow control.
def and_(*conditions: str | dict | Callable) -> dict:
"""Combines multiple conditions with AND logic for flow control.
Creates a condition that is satisfied only when all specified conditions
are met. This is used with @start, @listen, or @router decorators to create
@@ -366,6 +367,7 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict:
>>> @listen(and_("validated", "processed"))
>>> def handle_complete_data(self):
... pass
"""
methods = []
for condition in conditions:
@@ -376,7 +378,8 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict:
elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition)))
else:
raise ValueError("Invalid condition in and_()")
msg = "Invalid condition in and_()"
raise ValueError(msg)
return {"type": "AND", "methods": methods}
@@ -416,10 +419,10 @@ class FlowMeta(type):
if possible_returns:
router_paths[attr_name] = possible_returns
setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
setattr(cls, "_routers", routers)
setattr(cls, "_router_paths", router_paths)
cls._start_methods = start_methods
cls._listeners = listeners
cls._routers = routers
cls._router_paths = router_paths
return cls
@@ -427,17 +430,18 @@ class FlowMeta(type):
class Flow(Generic[T], metaclass=FlowMeta):
"""Base class for all flows.
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel.
"""
_printer = Printer()
_start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Set[str] = set()
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None
_start_methods: list[str] = []
_listeners: dict[str, tuple[str, list[str]]] = {}
_routers: set[str] = set()
_router_paths: dict[str, list[str]] = {}
initial_state: type[T] | T | None = None
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]:
class _FlowGeneric(cls): # type: ignore
_initial_state_T = item # type: ignore
@@ -446,7 +450,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
def __init__(
self,
persistence: Optional[FlowPersistence] = None,
persistence: FlowPersistence | None = None,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
@@ -454,13 +458,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
Args:
persistence: Optional persistence backend for storing flow states
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
self._methods: Dict[str, Callable] = {}
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
self._persistence: Optional[FlowPersistence] = persistence
self._methods: dict[str, Callable] = {}
self._method_execution_counts: dict[str, int] = {}
self._pending_and_listeners: dict[str, set[str]] = {}
self._method_outputs: list[Any] = [] # List to store all method outputs
self._persistence: FlowPersistence | None = persistence
# Initialize state with initial values
self._state = self._create_initial_state()
@@ -502,58 +507,61 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises:
ValueError: If structured state model lacks 'id' field
TypeError: If state is neither BaseModel nor dictionary
"""
# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_T"):
state_type = getattr(self, "_initial_state_T")
state_type = self._initial_state_T
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance without id, then set it
instance = state_type()
if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4()))
return cast(T, instance)
elif issubclass(state_type, BaseModel):
instance.id = str(uuid4())
return cast("T", instance)
if issubclass(state_type, BaseModel):
# Create a new type that includes the ID field
class StateWithId(state_type, FlowState): # type: ignore
pass
instance = StateWithId()
if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4()))
return cast(T, instance)
elif state_type is dict:
return cast(T, {"id": str(uuid4())})
instance.id = str(uuid4())
return cast("T", instance)
if state_type is dict:
return cast("T", {"id": str(uuid4())})
# Handle case where no initial state is provided
if self.initial_state is None:
return cast(T, {"id": str(uuid4())})
return cast("T", {"id": str(uuid4())})
# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state()) # Uses model defaults
elif issubclass(self.initial_state, BaseModel):
return cast("T", self.initial_state()) # Uses model defaults
if issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state()) # Uses model defaults
elif self.initial_state is dict:
return cast(T, {"id": str(uuid4())})
msg = "Flow state model must have an 'id' field"
raise ValueError(msg)
return cast("T", self.initial_state()) # Uses model defaults
if self.initial_state is dict:
return cast("T", {"id": str(uuid4())})
# Handle dictionary instance case
if isinstance(self.initial_state, dict):
new_state = dict(self.initial_state) # Copy to avoid mutations
if "id" not in new_state:
new_state["id"] = str(uuid4())
return cast(T, new_state)
return cast("T", new_state)
# Handle BaseModel instance case
if isinstance(self.initial_state, BaseModel):
model = cast(BaseModel, self.initial_state)
model = cast("BaseModel", self.initial_state)
if not hasattr(model, "id"):
raise ValueError("Flow state model must have an 'id' field")
msg = "Flow state model must have an 'id' field"
raise ValueError(msg)
# Create new instance with same values to avoid mutations
if hasattr(model, "model_dump"):
@@ -570,9 +578,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Create new instance of the same class
model_class = type(model)
return cast(T, model_class(**state_dict))
return cast("T", model_class(**state_dict))
msg = f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
msg,
)
def _copy_state(self) -> T:
@@ -583,7 +592,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return self._state
@property
def method_outputs(self) -> List[Any]:
def method_outputs(self) -> list[Any]:
"""Returns the list of all outputs from executed methods."""
return self._method_outputs
@@ -607,6 +616,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
flow = MyFlow()
print(f"Current flow ID: {flow.flow_id}") # Safely get flow ID
```
"""
try:
if not hasattr(self, "_state"):
@@ -614,13 +624,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict):
return str(self._state.get("id", ""))
elif isinstance(self._state, BaseModel):
if isinstance(self._state, BaseModel):
return str(getattr(self._state, "id", ""))
return ""
except (AttributeError, TypeError):
return "" # Safely handle any unexpected attribute access issues
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
def _initialize_state(self, inputs: dict[str, Any]) -> None:
"""Initialize or update flow state with new inputs.
Args:
@@ -629,6 +639,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises:
ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary
"""
if isinstance(self._state, dict):
# For dict states, preserve existing fields unless overridden
@@ -644,7 +655,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif isinstance(self._state, BaseModel):
# For BaseModel states, preserve existing fields unless overridden
try:
model = cast(BaseModel, self._state)
model = cast("BaseModel", self._state)
# Get current state as dict
if hasattr(model, "model_dump"):
current_state = model.model_dump()
@@ -662,19 +673,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
model_class = type(model)
if hasattr(model_class, "model_validate"):
# Pydantic v2
self._state = cast(T, model_class.model_validate(new_state))
self._state = cast("T", model_class.model_validate(new_state))
elif hasattr(model_class, "parse_obj"):
# Pydantic v1
self._state = cast(T, model_class.parse_obj(new_state))
self._state = cast("T", model_class.parse_obj(new_state))
else:
# Fallback for other BaseModel implementations
self._state = cast(T, model_class(**new_state))
self._state = cast("T", model_class(**new_state))
except ValidationError as e:
raise ValueError(f"Invalid inputs for structured state: {e}") from e
msg = f"Invalid inputs for structured state: {e}"
raise ValueError(msg) from e
else:
raise TypeError("State must be a BaseModel instance or a dictionary.")
msg = "State must be a BaseModel instance or a dictionary."
raise TypeError(msg)
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
def _restore_state(self, stored_state: dict[str, Any]) -> None:
"""Restore flow state from persistence.
Args:
@@ -683,11 +696,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises:
ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary
"""
# When restoring from persistence, use the stored ID
stored_id = stored_state.get("id")
if not stored_id:
raise ValueError("Stored state must have an 'id' field")
msg = "Stored state must have an 'id' field"
raise ValueError(msg)
if isinstance(self._state, dict):
# For dict states, update all fields from stored state
@@ -695,22 +710,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._state.update(stored_state)
elif isinstance(self._state, BaseModel):
# For BaseModel states, create new instance with stored values
model = cast(BaseModel, self._state)
model = cast("BaseModel", self._state)
if hasattr(model, "model_validate"):
# Pydantic v2
self._state = cast(T, type(model).model_validate(stored_state))
self._state = cast("T", type(model).model_validate(stored_state))
elif hasattr(model, "parse_obj"):
# Pydantic v1
self._state = cast(T, type(model).parse_obj(stored_state))
self._state = cast("T", type(model).parse_obj(stored_state))
else:
# Fallback for other BaseModel implementations
self._state = cast(T, type(model)(**stored_state))
self._state = cast("T", type(model)(**stored_state))
else:
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
msg = f"State must be dict or BaseModel, got {type(self._state)}"
raise TypeError(msg)
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Start the flow execution in a synchronous context.
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
"""Start the flow execution in a synchronous context.
This method wraps kickoff_async so that all state initialization and event
emission is handled in the asynchronous method.
@@ -721,9 +736,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
return asyncio.run(run_flow())
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Start the flow execution asynchronously.
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
"""Start the flow execution asynchronously.
This method performs state restoration (if an 'id' is provided and persistence is available)
and updates the flow state with any additional inputs. It then emits the FlowStartedEvent,
@@ -735,6 +749,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Returns:
The final output from the flow, which is the result of the last executed method.
"""
if inputs:
# Override the id in the state if it exists in inputs
@@ -742,7 +757,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict):
self._state["id"] = inputs["id"]
elif isinstance(self._state, BaseModel):
setattr(self._state, "id", inputs["id"])
self._state.id = inputs["id"]
# If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None:
@@ -756,7 +771,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._restore_state(stored_state)
else:
self._log_flow_event(
f"No flow state found for UUID: {restore_uuid}", color="red"
f"No flow state found for UUID: {restore_uuid}", color="red",
)
# Update state with any additional inputs (ignoring the 'id' key)
@@ -774,7 +789,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
),
)
self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
f"Flow started with ID: {self.flow_id}", color="bold_magenta",
)
if inputs is not None and "id" not in inputs:
@@ -800,8 +815,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return final_output
async def _execute_start_method(self, start_method_name: str) -> None:
"""
Executes a flow's start method and its triggered listeners.
"""Executes a flow's start method and its triggered listeners.
This internal method handles the execution of methods marked with @start
decorator and manages the subsequent chain of listener executions.
@@ -816,14 +830,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Executes the start method and captures its result
- Triggers execution of any listeners waiting on this start method
- Part of the flow's initialization sequence
"""
result = await self._execute_method(
start_method_name, self._methods[start_method_name]
start_method_name, self._methods[start_method_name],
)
await self._execute_listeners(start_method_name, result)
async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
self, method_name: str, method: Callable, *args: Any, **kwargs: Any,
) -> Any:
try:
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
@@ -873,11 +888,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
error=e,
),
)
raise e
raise
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
"""
Executes all listeners and routers triggered by a method completion.
"""Executes all listeners and routers triggered by a method completion.
This internal method manages the execution flow by:
1. First executing all triggered routers sequentially
@@ -897,6 +911,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Each router's result becomes a new trigger_method
- Normal listeners are executed in parallel for efficiency
- Listeners can receive the trigger method's result as a parameter
"""
# First, handle routers repeatedly until no router triggers anymore
router_results = []
@@ -904,7 +919,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
while True:
routers_triggered = self._find_triggered_methods(
current_trigger, router_only=True
current_trigger, router_only=True,
)
if not routers_triggered:
break
@@ -920,12 +935,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
# Now execute normal listeners for all router results and the original trigger
all_triggers = [trigger_method] + router_results
all_triggers = [trigger_method, *router_results]
for current_trigger in all_triggers:
if current_trigger: # Skip None results
listeners_triggered = self._find_triggered_methods(
current_trigger, router_only=False
current_trigger, router_only=False,
)
if listeners_triggered:
tasks = [
@@ -935,10 +950,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
await asyncio.gather(*tasks)
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
"""
Finds all methods that should be triggered based on conditions.
self, trigger_method: str, router_only: bool,
) -> list[str]:
"""Finds all methods that should be triggered based on conditions.
This internal method evaluates both OR and AND conditions to determine
which methods should be executed next in the flow.
@@ -963,6 +977,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
* AND: Triggers only when all conditions are met
- Maintains state for AND conditions using _pending_and_listeners
- Separates router and normal listener evaluation
"""
triggered = []
for listener_name, (condition_type, methods) in self._listeners.items():
@@ -992,8 +1007,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
return triggered
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
"""
Executes a single listener method with proper event handling.
"""Executes a single listener method with proper event handling.
This internal method manages the execution of an individual listener,
including parameter inspection, event emission, and error handling.
@@ -1018,6 +1032,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
-------------
Catches and logs any exceptions during execution, preventing
individual listener failures from breaking the entire flow.
"""
try:
method = self._methods[listener_name]
@@ -1028,7 +1043,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if method_params:
listener_result = await self._execute_method(
listener_name, method, result
listener_name, method, result,
)
else:
listener_result = await self._execute_method(listener_name, method)
@@ -1036,17 +1051,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Execute listeners (and possibly routers) of this listener
await self._execute_listeners(listener_name, listener_result)
except Exception as e:
print(
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
)
except Exception:
import traceback
traceback.print_exc()
raise
def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info"
self, message: str, color: str = "yellow", level: str = "info",
) -> None:
"""Centralized logging method for flow events.
@@ -1064,6 +1076,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Note:
This method uses the Printer utility for colored console output
and the standard logging module for log level support.
"""
self._printer.print(message, color=color)
if level == "info":

View File

@@ -1,5 +1,4 @@
import inspect
from typing import Optional
from pydantic import BaseModel, Field, InstanceOf, model_validator
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
inspecting the call stack.
"""
parent_flow: Optional[InstanceOf[Flow]] = Field(
parent_flow: InstanceOf[Flow] | None = Field(
default=None,
description="The parent flow of the instance, if it was created inside a flow.",
)

View File

@@ -1,14 +1,13 @@
# flow_visualizer.py
import os
from pathlib import Path
from pyvis.network import Network
from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
from crewai.flow.path_utils import safe_path_join, validate_path_exists
from crewai.flow.path_utils import safe_path_join
from crewai.flow.utils import calculate_node_levels
from crewai.flow.visualization_utils import (
add_edges,
@@ -20,9 +19,8 @@ from crewai.flow.visualization_utils import (
class FlowPlot:
"""Handles the creation and rendering of flow visualization diagrams."""
def __init__(self, flow):
"""
Initialize FlowPlot with a flow object.
def __init__(self, flow) -> None:
"""Initialize FlowPlot with a flow object.
Parameters
----------
@@ -33,21 +31,24 @@ class FlowPlot:
------
ValueError
If flow object is invalid or missing required attributes.
"""
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not hasattr(flow, '_listeners'):
raise ValueError("Invalid flow object: missing '_listeners' attribute")
if not hasattr(flow, '_start_methods'):
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
if not hasattr(flow, "_methods"):
msg = "Invalid flow object: missing '_methods' attribute"
raise ValueError(msg)
if not hasattr(flow, "_listeners"):
msg = "Invalid flow object: missing '_listeners' attribute"
raise ValueError(msg)
if not hasattr(flow, "_start_methods"):
msg = "Invalid flow object: missing '_start_methods' attribute"
raise ValueError(msg)
self.flow = flow
self.colors = COLORS
self.node_styles = NODE_STYLES
def plot(self, filename):
"""
Generate and save an HTML visualization of the flow.
def plot(self, filename) -> None:
"""Generate and save an HTML visualization of the flow.
Parameters
----------
@@ -62,10 +63,12 @@ class FlowPlot:
If file operations fail or visualization cannot be generated.
RuntimeError
If network visualization generation fails.
"""
if not filename or not isinstance(filename, str):
raise ValueError("Filename must be a non-empty string")
msg = "Filename must be a non-empty string"
raise ValueError(msg)
try:
# Initialize network
net = Network(
@@ -89,58 +92,63 @@ class FlowPlot:
"enabled": false
}
}
"""
""",
)
# Calculate levels for nodes
try:
node_levels = calculate_node_levels(self.flow)
except Exception as e:
raise ValueError(f"Failed to calculate node levels: {str(e)}")
msg = f"Failed to calculate node levels: {e!s}"
raise ValueError(msg)
# Compute positions
try:
node_positions = compute_positions(self.flow, node_levels)
except Exception as e:
raise ValueError(f"Failed to compute node positions: {str(e)}")
msg = f"Failed to compute node positions: {e!s}"
raise ValueError(msg)
# Add nodes to the network
try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
except Exception as e:
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
msg = f"Failed to add nodes to network: {e!s}"
raise RuntimeError(msg)
# Add edges to the network
try:
add_edges(net, self.flow, node_positions, self.colors)
except Exception as e:
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
msg = f"Failed to add edges to network: {e!s}"
raise RuntimeError(msg)
# Generate HTML
try:
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
except Exception as e:
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
msg = f"Failed to generate network visualization: {e!s}"
raise RuntimeError(msg)
# Save the final HTML content to the file
try:
with open(f"{filename}.html", "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Plot saved as {filename}.html")
except IOError as e:
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
except OSError as e:
msg = f"Failed to save flow visualization to {filename}.html: {e!s}"
raise OSError(msg)
except (ValueError, RuntimeError, IOError) as e:
raise e
except (OSError, ValueError, RuntimeError):
raise
except Exception as e:
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
msg = f"Unexpected error during flow visualization: {e!s}"
raise RuntimeError(msg)
finally:
self._cleanup_pyvis_lib()
def _generate_final_html(self, network_html):
"""
Generate the final HTML content with network visualization and legend.
"""Generate the final HTML content with network visualization and legend.
Parameters
----------
@@ -158,9 +166,11 @@ class FlowPlot:
If template or logo files cannot be accessed.
ValueError
If network_html is invalid.
"""
if not network_html:
raise ValueError("Invalid network HTML content")
msg = "Invalid network HTML content"
raise ValueError(msg)
try:
# Extract just the body content from the generated HTML
@@ -169,9 +179,11 @@ class FlowPlot:
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
if not os.path.exists(template_path):
raise IOError(f"Template file not found: {template_path}")
msg = f"Template file not found: {template_path}"
raise OSError(msg)
if not os.path.exists(logo_path):
raise IOError(f"Logo file not found: {logo_path}")
msg = f"Logo file not found: {logo_path}"
raise OSError(msg)
html_handler = HTMLTemplateHandler(template_path, logo_path)
network_body = html_handler.extract_body_content(network_html)
@@ -179,16 +191,15 @@ class FlowPlot:
# Generate the legend items HTML
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html(
network_body, legend_items_html
return html_handler.generate_final_html(
network_body, legend_items_html,
)
return final_html_content
except Exception as e:
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
msg = f"Failed to generate visualization HTML: {e!s}"
raise OSError(msg)
def _cleanup_pyvis_lib(self):
"""
Clean up the generated lib folder from pyvis.
def _cleanup_pyvis_lib(self) -> None:
"""Clean up the generated lib folder from pyvis.
This method safely removes the temporary lib directory created by pyvis
during network visualization generation.
@@ -198,15 +209,14 @@ class FlowPlot:
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
shutil.rmtree(lib_folder)
except ValueError as e:
print(f"Error validating lib folder path: {e}")
except Exception as e:
print(f"Error cleaning up lib folder: {e}")
except ValueError:
pass
except Exception:
pass
def plot_flow(flow, filename="flow_plot"):
"""
Convenience function to create and save a flow visualization.
def plot_flow(flow, filename="flow_plot") -> None:
"""Convenience function to create and save a flow visualization.
Parameters
----------
@@ -221,6 +231,7 @@ def plot_flow(flow, filename="flow_plot"):
If flow object or filename is invalid.
IOError
If file operations fail.
"""
visualizer = FlowPlot(flow)
visualizer.plot(filename)

View File

@@ -1,16 +1,14 @@
import base64
import re
from pathlib import Path
from crewai.flow.path_utils import safe_path_join, validate_path_exists
from crewai.flow.path_utils import validate_path_exists
class HTMLTemplateHandler:
"""Handles HTML template processing and generation for flow visualization diagrams."""
def __init__(self, template_path, logo_path):
"""
Initialize HTMLTemplateHandler with validated template and logo paths.
def __init__(self, template_path, logo_path) -> None:
"""Initialize HTMLTemplateHandler with validated template and logo paths.
Parameters
----------
@@ -23,16 +21,18 @@ class HTMLTemplateHandler:
------
ValueError
If template or logo paths are invalid or files don't exist.
"""
try:
self.template_path = validate_path_exists(template_path, "file")
self.logo_path = validate_path_exists(logo_path, "file")
except ValueError as e:
raise ValueError(f"Invalid template or logo path: {e}")
msg = f"Invalid template or logo path: {e}"
raise ValueError(msg)
def read_template(self):
"""Read and return the HTML template file contents."""
with open(self.template_path, "r", encoding="utf-8") as f:
with open(self.template_path, encoding="utf-8") as f:
return f.read()
def encode_logo(self):
@@ -81,13 +81,12 @@ class HTMLTemplateHandler:
final_html_content = html_template.replace("{{ title }}", title)
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body
"{{ network_content }}", network_body,
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
"{{ logo_svg_base64 }}", logo_svg_base64,
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
return final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html,
)
return final_html_content

View File

@@ -1,18 +1,14 @@
"""
Path utilities for secure file operations in CrewAI flow module.
"""Path utilities for secure file operations in CrewAI flow module.
This module provides utilities for secure path handling to prevent directory
traversal attacks and ensure paths remain within allowed boundaries.
"""
import os
from pathlib import Path
from typing import List, Union
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
"""
Safely join path components and ensure the result is within allowed boundaries.
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
"""Safely join path components and ensure the result is within allowed boundaries.
Parameters
----------
@@ -31,39 +27,43 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
ValueError
If the resulting path would be outside the root directory
or if any path component is invalid.
"""
if not parts:
raise ValueError("No path components provided")
msg = "No path components provided"
raise ValueError(msg)
try:
# Convert all parts to strings and clean them
clean_parts = [str(part).strip() for part in parts if part]
if not clean_parts:
raise ValueError("No valid path components provided")
msg = "No valid path components provided"
raise ValueError(msg)
# Establish root directory
root_path = Path(root).resolve() if root else Path.cwd()
# Join and resolve the full path
full_path = Path(root_path, *clean_parts).resolve()
# Check if the resolved path is within root
if not str(full_path).startswith(str(root_path)):
msg = f"Invalid path: Potential directory traversal. Path must be within {root_path}"
raise ValueError(
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
msg,
)
return str(full_path)
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid path components: {str(e)}")
msg = f"Invalid path components: {e!s}"
raise ValueError(msg)
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
"""
Validate that a path exists and is of the expected type.
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
"""Validate that a path exists and is of the expected type.
Parameters
----------
@@ -81,29 +81,33 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
------
ValueError
If path doesn't exist or is not of expected type.
"""
try:
path_obj = Path(path).resolve()
if not path_obj.exists():
raise ValueError(f"Path does not exist: {path}")
msg = f"Path does not exist: {path}"
raise ValueError(msg)
if file_type == "file" and not path_obj.is_file():
raise ValueError(f"Path is not a file: {path}")
elif file_type == "directory" and not path_obj.is_dir():
raise ValueError(f"Path is not a directory: {path}")
msg = f"Path is not a file: {path}"
raise ValueError(msg)
if file_type == "directory" and not path_obj.is_dir():
msg = f"Path is not a directory: {path}"
raise ValueError(msg)
return str(path_obj)
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid path: {str(e)}")
msg = f"Invalid path: {e!s}"
raise ValueError(msg)
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
"""
Safely list files in a directory matching a pattern.
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
"""Safely list files in a directory matching a pattern.
Parameters
----------
@@ -121,15 +125,18 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
------
ValueError
If directory is invalid or inaccessible.
"""
try:
dir_path = Path(directory).resolve()
if not dir_path.is_dir():
raise ValueError(f"Not a directory: {directory}")
msg = f"Not a directory: {directory}"
raise ValueError(msg)
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Error listing files: {str(e)}")
msg = f"Error listing files: {e!s}"
raise ValueError(msg)

View File

@@ -1,53 +1,52 @@
"""Base class for flow state persistence."""
import abc
from typing import Any, Dict, Optional, Union
from typing import Any
from pydantic import BaseModel
class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence.
This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
"""
@abc.abstractmethod
def init_db(self) -> None:
"""Initialize the persistence backend.
This method should handle any necessary setup, such as:
- Creating tables
- Establishing connections
- Setting up indexes
"""
pass
@abc.abstractmethod
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel]
state_data: dict[str, Any] | BaseModel,
) -> None:
"""Persist the flow state after method completion.
Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
pass
@abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
Args:
flow_uuid: Unique identifier for the flow instance
Returns:
The most recent state as a dictionary, or None if no state exists
"""
pass

View File

@@ -1,5 +1,4 @@
"""
Decorators for flow state persistence.
"""Decorators for flow state persistence.
Example:
```python
@@ -19,18 +18,16 @@ Example:
# Asynchronous method implementation
await some_async_operation()
```
"""
import asyncio
import functools
import logging
from collections.abc import Callable
from typing import (
Any,
Callable,
Optional,
Type,
TypeVar,
Union,
cast,
)
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
"save_state": "Saving flow state to memory for ID: {}",
"save_error": "Failed to persist state for method {}: {}",
"state_missing": "Flow instance has no state",
"id_missing": "Flow state must have an 'id' field for persistence"
"id_missing": "Flow state must have an 'id' field for persistence",
}
@@ -74,20 +71,23 @@ class PersistenceDecorator:
ValueError: If flow has no state or state lacks an ID
RuntimeError: If state persistence fails
AttributeError: If flow instance lacks required state attributes
"""
try:
state = getattr(flow_instance, 'state', None)
state = getattr(flow_instance, "state", None)
if state is None:
raise ValueError("Flow instance has no state")
msg = "Flow instance has no state"
raise ValueError(msg)
flow_uuid: Optional[str] = None
flow_uuid: str | None = None
if isinstance(state, dict):
flow_uuid = state.get('id')
flow_uuid = state.get("id")
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None)
flow_uuid = getattr(state, "id", None)
if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence")
msg = "Flow state must have an 'id' field for persistence"
raise ValueError(msg)
# Log state saving only if verbose is True
if verbose:
@@ -103,21 +103,22 @@ class PersistenceDecorator:
except Exception as e:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise RuntimeError(f"State persistence failed: {str(e)}") from e
logger.exception(error_msg)
msg = f"State persistence failed: {e!s}"
raise RuntimeError(msg) from e
except AttributeError:
error_msg = LOG_MESSAGES["state_missing"]
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
logger.exception(error_msg)
raise ValueError(error_msg)
except (TypeError, ValueError) as e:
error_msg = LOG_MESSAGES["id_missing"]
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
logger.exception(error_msg)
raise ValueError(error_msg) from e
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
"""Decorator to persist flow state.
This decorator can be applied at either the class level or method level.
@@ -143,22 +144,23 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
@start()
def begin(self):
pass
"""
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()
if isinstance(target, type):
# Class decoration
original_init = getattr(target, "__init__")
original_init = target.__init__
@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
if 'persistence' not in kwargs:
kwargs['persistence'] = actual_persistence
if "persistence" not in kwargs:
kwargs["persistence"] = actual_persistence
original_init(self, *args, **kwargs)
setattr(target, "__init__", new_init)
target.__init__ = new_init
# Store original methods to preserve their decorators
original_methods = {}
@@ -191,7 +193,7 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)
wrapped.__is_flow_method__ = True
# Update the class with the wrapped method
setattr(target, name, wrapped)
@@ -211,44 +213,42 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)
wrapped.__is_flow_method__ = True
# Update the class with the wrapped method
setattr(target, name, wrapped)
return target
else:
# Method decoration
method = target
setattr(method, "__is_flow_method__", True)
# Method decoration
method = target
method.__is_flow_method__ = True
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
setattr(method_async_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_async_wrapper)
else:
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
method_async_wrapper.__is_flow_method__ = True
return cast("Callable[..., T]", method_async_wrapper)
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
setattr(method_sync_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_sync_wrapper)
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
method_sync_wrapper.__is_flow_method__ = True
return cast("Callable[..., T]", method_sync_wrapper)
return decorator

View File

@@ -1,12 +1,10 @@
"""
SQLite-based implementation of flow state persistence.
"""
"""SQLite-based implementation of flow state persistence."""
import json
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any
from pydantic import BaseModel
@@ -23,7 +21,7 @@ class SQLiteFlowPersistence(FlowPersistence):
db_path: str
def __init__(self, db_path: Optional[str] = None):
def __init__(self, db_path: str | None = None) -> None:
"""Initialize SQLite persistence.
Args:
@@ -32,6 +30,7 @@ class SQLiteFlowPersistence(FlowPersistence):
Raises:
ValueError: If db_path is invalid
"""
from crewai.utilities.paths import db_storage_path
@@ -39,7 +38,8 @@ class SQLiteFlowPersistence(FlowPersistence):
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
if not path:
raise ValueError("Database path must be provided")
msg = "Database path must be provided"
raise ValueError(msg)
self.db_path = path # Now mypy knows this is str
self.init_db()
@@ -56,21 +56,21 @@ class SQLiteFlowPersistence(FlowPersistence):
timestamp DATETIME NOT NULL,
state_json TEXT NOT NULL
)
"""
""",
)
# Add index for faster UUID lookups
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
ON flow_states(flow_uuid)
"""
""",
)
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel],
state_data: dict[str, Any] | BaseModel,
) -> None:
"""Save the current flow state to SQLite.
@@ -78,6 +78,7 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
# Convert state_data to dict, handling both Pydantic and dict cases
if isinstance(state_data, BaseModel):
@@ -85,8 +86,9 @@ class SQLiteFlowPersistence(FlowPersistence):
elif isinstance(state_data, dict):
state_dict = state_data
else:
msg = f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
msg,
)
with sqlite3.connect(self.db_path) as conn:
@@ -107,7 +109,7 @@ class SQLiteFlowPersistence(FlowPersistence):
),
)
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
Args:
@@ -115,6 +117,7 @@ class SQLiteFlowPersistence(FlowPersistence):
Returns:
The most recent state as a dictionary, or None if no state exists
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(

View File

@@ -1,33 +1,32 @@
"""
Utility functions for flow visualization and dependency analysis.
"""Utility functions for flow visualization and dependency analysis.
This module provides core functionality for analyzing and manipulating flow structures,
including node level calculation, ancestor tracking, and return value analysis.
Functions in this module are primarily used by the visualization system to create
accurate and informative flow diagrams.
Example
Example:
-------
>>> flow = Flow()
>>> node_levels = calculate_node_levels(flow)
>>> ancestors = build_ancestor_dict(flow)
"""
import ast
import inspect
import textwrap
from collections import defaultdict, deque
from typing import Any, Deque, Dict, List, Optional, Set, Union
from typing import Any
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
def get_possible_return_constants(function: Any) -> list[str] | None:
try:
source = inspect.getsource(function)
except OSError:
# Can't get source code
return None
except Exception as e:
print(f"Error retrieving source code for function {function.__name__}: {e}")
except Exception:
return None
try:
@@ -35,24 +34,18 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
source = textwrap.dedent(source)
# Parse the source code into an AST
code_ast = ast.parse(source)
except IndentationError as e:
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
except IndentationError:
return None
except SyntaxError as e:
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
except SyntaxError:
return None
except Exception as e:
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
except Exception:
return None
return_values = set()
dict_definitions = {}
class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node):
def visit_Assign(self, node) -> None:
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
@@ -69,10 +62,10 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
self.generic_visit(node)
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
def visit_Return(self, node) -> None:
# Direct string return
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str
node.value.value, str,
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
@@ -94,9 +87,8 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
return list(return_values) if return_values else None
def calculate_node_levels(flow: Any) -> Dict[str, int]:
"""
Calculate the hierarchical level of each node in the flow.
def calculate_node_levels(flow: Any) -> dict[str, int]:
"""Calculate the hierarchical level of each node in the flow.
Performs a breadth-first traversal of the flow graph to assign levels
to nodes, starting with start methods at level 0.
@@ -117,11 +109,12 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
- Each subsequent connected node is assigned level = parent_level + 1
- Handles both OR and AND conditions for listeners
- Processes router paths separately
"""
levels: Dict[str, int] = {}
queue: Deque[str] = deque()
visited: Set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {}
levels: dict[str, int] = {}
queue: deque[str] = deque()
visited: set[str] = set()
pending_and_listeners: dict[str, set[str]] = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
@@ -172,9 +165,8 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
return levels
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
"""
Count the number of outgoing edges for each method in the flow.
def count_outgoing_edges(flow: Any) -> dict[str, int]:
"""Count the number of outgoing edges for each method in the flow.
Parameters
----------
@@ -185,6 +177,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
-------
Dict[str, int]
Dictionary mapping method names to their outgoing edge count.
"""
counts = {}
for method_name in flow._methods:
@@ -197,9 +190,8 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
return counts
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
"""
Build a dictionary mapping each node to its ancestor nodes.
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
"""Build a dictionary mapping each node to its ancestor nodes.
Parameters
----------
@@ -210,9 +202,10 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
-------
Dict[str, Set[str]]
Dictionary mapping each node to a set of its ancestor nodes.
"""
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
visited: Set[str] = set()
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
visited: set[str] = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
@@ -220,10 +213,9 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
def dfs_ancestors(
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any,
) -> None:
"""
Perform depth-first search to build ancestor relationships.
"""Perform depth-first search to build ancestor relationships.
Parameters
----------
@@ -240,6 +232,7 @@ def dfs_ancestors(
-----
This function modifies the ancestors dictionary in-place to build
the complete ancestor graph.
"""
if node in visited:
return
@@ -265,10 +258,9 @@ def dfs_ancestors(
def is_ancestor(
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]],
) -> bool:
"""
Check if one node is an ancestor of another.
"""Check if one node is an ancestor of another.
Parameters
----------
@@ -283,13 +275,13 @@ def is_ancestor(
-------
bool
True if ancestor_candidate is an ancestor of node, False otherwise.
"""
return ancestor_candidate in ancestors.get(node, set())
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
"""
Build a dictionary mapping parent nodes to their children.
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
"""Build a dictionary mapping parent nodes to their children.
Parameters
----------
@@ -306,8 +298,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
- Maps listeners to their trigger methods
- Maps router methods to their paths and listeners
- Children lists are sorted for consistent ordering
"""
parent_children: Dict[str, List[str]] = {}
parent_children: dict[str, list[str]] = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
@@ -332,10 +325,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
def get_child_index(
parent: str, child: str, parent_children: Dict[str, List[str]]
parent: str, child: str, parent_children: dict[str, list[str]],
) -> int:
"""
Get the index of a child node in its parent's sorted children list.
"""Get the index of a child node in its parent's sorted children list.
Parameters
----------
@@ -350,27 +342,25 @@ def get_child_index(
-------
int
Zero-based index of the child in its parent's sorted children list.
"""
children = parent_children.get(parent, [])
children.sort()
return children.index(child)
def process_router_paths(flow, current, current_level, levels, queue):
"""
Handle the router connections for the current node.
"""
def process_router_paths(flow, current, current_level, levels, queue) -> None:
"""Handle the router connections for the current node."""
if current in flow._routers:
paths = flow._router_paths.get(current, [])
for path in paths:
for listener_name, (
condition_type,
_condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
queue.append(listener_name)
if path in trigger_methods and (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
queue.append(listener_name)

View File

@@ -1,23 +1,23 @@
"""
Utilities for creating visual representations of flow structures.
"""Utilities for creating visual representations of flow structures.
This module provides functions for generating network visualizations of flows,
including node placement, edge creation, and visual styling. It handles the
conversion of flow structures into visual network graphs with appropriate
styling and layout.
Example
Example:
-------
>>> flow = Flow()
>>> net = Network(directed=True)
>>> node_positions = compute_positions(flow, node_levels)
>>> add_nodes_to_network(net, flow, node_positions, node_styles)
>>> add_edges(net, flow, node_positions, colors)
"""
import ast
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any
from .utils import (
build_ancestor_dict,
@@ -28,8 +28,7 @@ from .utils import (
def method_calls_crew(method: Any) -> bool:
"""
Check if the method contains a call to `.crew()`.
"""Check if the method contains a call to `.crew()`.
Parameters
----------
@@ -45,21 +44,22 @@ def method_calls_crew(method: Any) -> bool:
-----
Uses AST analysis to detect method calls, specifically looking for
attribute access of 'crew'.
"""
try:
source = inspect.getsource(method)
source = inspect.cleandoc(source)
tree = ast.parse(source)
except Exception as e:
print(f"Could not parse method {method.__name__}: {e}")
except Exception:
return False
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls."""
def __init__(self):
def __init__(self) -> None:
self.found = False
def visit_Call(self, node):
def visit_Call(self, node) -> None:
if isinstance(node.func, ast.Attribute):
if node.func.attr == "crew":
self.found = True
@@ -73,11 +73,10 @@ def method_calls_crew(method: Any) -> bool:
def add_nodes_to_network(
net: Any,
flow: Any,
node_positions: Dict[str, Tuple[float, float]],
node_styles: Dict[str, Dict[str, Any]]
node_positions: dict[str, tuple[float, float]],
node_styles: dict[str, dict[str, Any]],
) -> None:
"""
Add nodes to the network visualization with appropriate styling.
"""Add nodes to the network visualization with appropriate styling.
Parameters
----------
@@ -97,6 +96,7 @@ def add_nodes_to_network(
- Router methods
- Crew methods
- Regular methods
"""
def human_friendly_label(method_name):
return method_name.replace("_", " ").title()
@@ -123,7 +123,7 @@ def add_nodes_to_network(
"multi": "html",
"color": node_style.get("font", {}).get("color", "#FFFFFF"),
},
}
},
)
net.add_node(
@@ -138,12 +138,11 @@ def add_nodes_to_network(
def compute_positions(
flow: Any,
node_levels: Dict[str, int],
node_levels: dict[str, int],
y_spacing: float = 150,
x_spacing: float = 150
) -> Dict[str, Tuple[float, float]]:
"""
Compute the (x, y) positions for each node in the flow graph.
x_spacing: float = 150,
) -> dict[str, tuple[float, float]]:
"""Compute the (x, y) positions for each node in the flow graph.
Parameters
----------
@@ -160,9 +159,10 @@ def compute_positions(
-------
Dict[str, Tuple[float, float]]
Dictionary mapping node names to their (x, y) coordinates.
"""
level_nodes: Dict[int, List[str]] = {}
node_positions: Dict[str, Tuple[float, float]] = {}
level_nodes: dict[int, list[str]] = {}
node_positions: dict[str, tuple[float, float]] = {}
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
@@ -180,10 +180,10 @@ def compute_positions(
def add_edges(
net: Any,
flow: Any,
node_positions: Dict[str, Tuple[float, float]],
colors: Dict[str, str]
node_positions: dict[str, tuple[float, float]],
colors: dict[str, str],
) -> None:
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
"""
Add edges to the network visualization with appropriate styling.
@@ -245,7 +245,7 @@ def add_edges(
"color": edge_color,
"width": 2,
"arrows": "to",
"dashes": True if is_router_edge or is_and_condition else False,
"dashes": bool(is_router_edge or is_and_condition),
"smooth": edge_smooth,
}
@@ -261,9 +261,7 @@ def add_edges(
# If it's a known router edge and the method is known, don't warn.
# This means the path is legitimate, just not reflected as nodes here.
if not (is_router_edge and method_known):
print(
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
)
pass
# Edges for router return paths
for router_method_name, paths in flow._router_paths.items():
@@ -278,7 +276,7 @@ def add_edges(
and listener_name in node_positions
):
is_cycle_edge = is_ancestor(
router_method_name, listener_name, ancestors
router_method_name, listener_name, ancestors,
)
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
@@ -293,7 +291,7 @@ def add_edges(
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(
router_method_name, listener_name, parent_children
router_method_name, listener_name, parent_children,
)
edge_smooth = {
"type": smooth_type,
@@ -316,6 +314,4 @@ def add_edges(
# Same check here: known router edge and known method?
method_known = listener_name in flow._methods
if not method_known:
print(
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
)
pass

View File

@@ -1,55 +1,48 @@
from abc import ABC, abstractmethod
from typing import List
import numpy as np
class BaseEmbedder(ABC):
"""
Abstract base class for text embedding models
"""
"""Abstract base class for text embedding models."""
@abstractmethod
def embed_chunks(self, chunks: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of text chunks
def embed_chunks(self, chunks: list[str]) -> np.ndarray:
"""Generate embeddings for a list of text chunks.
Args:
chunks: List of text chunks to embed
Returns:
Array of embeddings
"""
pass
@abstractmethod
def embed_texts(self, texts: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of texts
def embed_texts(self, texts: list[str]) -> np.ndarray:
"""Generate embeddings for a list of texts.
Args:
texts: List of texts to embed
Returns:
Array of embeddings
"""
pass
@abstractmethod
def embed_text(self, text: str) -> np.ndarray:
"""
Generate embedding for a single text
"""Generate embedding for a single text.
Args:
text: Text to embed
Returns:
Embedding array
"""
pass
@property
@abstractmethod
def dimension(self) -> int:
"""Get the dimension of the embeddings"""
pass
"""Get the dimension of the embeddings."""

View File

@@ -1,5 +1,4 @@
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
@@ -19,75 +18,74 @@ except ImportError:
class FastEmbed(BaseEmbedder):
"""
A wrapper class for text embedding models using FastEmbed
"""
"""A wrapper class for text embedding models using FastEmbed."""
def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[Union[str, Path]] = None,
):
"""
Initialize the embedding model
cache_dir: str | Path | None = None,
) -> None:
"""Initialize the embedding model.
Args:
model_name: Name of the model to use
cache_dir: Directory to cache the model
gpu: Whether to use GPU acceleration
"""
if not FASTEMBED_AVAILABLE:
raise ImportError(
msg = (
"FastEmbed is not installed. Please install it with: "
"uv pip install fastembed or uv pip install fastembed-gpu for GPU support"
)
raise ImportError(
msg,
)
self.model = TextEmbedding(
model_name=model_name,
cache_dir=str(cache_dir) if cache_dir else None,
)
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
"""
Generate embeddings for a list of text chunks
def embed_chunks(self, chunks: list[str]) -> list[np.ndarray]:
"""Generate embeddings for a list of text chunks.
Args:
chunks: List of text chunks to embed
Returns:
List of embeddings
"""
embeddings = list(self.model.embed(chunks))
return embeddings
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
"""
Generate embeddings for a list of texts
return list(self.model.embed(chunks))
def embed_texts(self, texts: list[str]) -> list[np.ndarray]:
"""Generate embeddings for a list of texts.
Args:
texts: List of texts to embed
Returns:
List of embeddings
"""
embeddings = list(self.model.embed(texts))
return embeddings
return list(self.model.embed(texts))
def embed_text(self, text: str) -> np.ndarray:
"""
Generate embedding for a single text
"""Generate embedding for a single text.
Args:
text: Text to embed
Returns:
Embedding array
"""
return self.embed_texts([text])[0]
@property
def dimension(self) -> int:
"""Get the dimension of the embeddings"""
"""Get the dimension of the embeddings."""
# Generate a test embedding to get dimensions
test_embed = self.embed_text("test")
return len(test_embed)

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
@@ -10,68 +10,70 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
"""Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
embedder: Optional[Dict[str, Any]] = None.
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
collection_name: str | None = None
def __init__(
self,
collection_name: str,
sources: List[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None,
sources: list[BaseKnowledgeSource],
embedder: dict[str, Any] | None = None,
storage: KnowledgeStorage | None = None,
**data,
):
) -> None:
super().__init__(**data)
if storage:
self.storage = storage
else:
self.storage = KnowledgeStorage(
embedder=embedder, collection_name=collection_name
embedder=embedder, collection_name=collection_name,
)
self.sources = sources
self.storage.initialize_knowledge_storage()
def query(
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
) -> List[Dict[str, Any]]:
"""
Query across all knowledge sources to find the most relevant information.
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35,
) -> list[dict[str, Any]]:
"""Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
Raises:
ValueError: If storage is not initialized.
"""
if self.storage is None:
raise ValueError("Storage is not initialized.")
msg = "Storage is not initialized."
raise ValueError(msg)
results = self.storage.search(
return self.storage.search(
query,
limit=results_limit,
score_threshold=score_threshold,
)
return results
def add_sources(self):
def add_sources(self) -> None:
try:
for source in self.sources:
source.storage = self.storage
source.add()
except Exception as e:
raise e
except Exception:
raise
def reset(self) -> None:
if self.storage:
self.storage.reset()
else:
raise ValueError("Storage is not initialized.")
msg = "Storage is not initialized."
raise ValueError(msg)

View File

@@ -7,6 +7,7 @@ class KnowledgeConfig(BaseModel):
Args:
results_limit (int): The number of relevant documents to return.
score_threshold (float): The minimum score for a document to be considered relevant.
"""
results_limit: int = Field(default=3, description="The number of results to return")

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union
from pydantic import Field, field_validator
@@ -14,43 +13,43 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files."""
_logger: Logger = Logger(verbose=True)
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
file_path: Path | list[Path] | str | list[str] | None = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default_factory=list, description="The path to the file"
file_paths: Path | list[Path] | str | list[str] | None = Field(
default_factory=list, description="The path to the file",
)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: Optional[KnowledgeStorage] = Field(default=None)
safe_file_paths: List[Path] = Field(default_factory=list)
content: dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage | None = Field(default=None)
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info):
def validate_file_path(self, v, info):
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
v is None
and info.data.get(
"file_path" if info.field_name == "file_paths" else "file_paths"
"file_path" if info.field_name == "file_paths" else "file_paths",
)
is None
):
raise ValueError("Either file_path or file_paths must be provided")
msg = "Either file_path or file_paths must be provided"
raise ValueError(msg)
return v
def model_post_init(self, _):
def model_post_init(self, _) -> None:
"""Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths()
self.validate_content()
self.content = self.load_content()
@abstractmethod
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
pass
def validate_content(self):
def validate_content(self) -> None:
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -59,7 +58,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
color="red",
)
raise FileNotFoundError(f"File not found: {path}")
msg = f"File not found: {path}"
raise FileNotFoundError(msg)
if not path.is_file():
self._logger.log(
"error",
@@ -67,20 +67,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
color="red",
)
def _save_documents(self):
def _save_documents(self) -> None:
"""Save the documents to the storage."""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
msg = "No storage found to save documents."
raise ValueError(msg)
def convert_to_path(self, path: Union[Path, str]) -> Path:
def convert_to_path(self, path: Path | str) -> Path:
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
def _process_file_paths(self) -> List[Path]:
def _process_file_paths(self) -> list[Path]:
"""Convert file_path to a list of Path objects."""
if hasattr(self, "file_path") and self.file_path is not None:
self._logger.log(
"warning",
@@ -90,10 +90,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
self.file_paths = self.file_path
if self.file_paths is None:
raise ValueError("Your source must be provided with a file_paths: []")
msg = "Your source must be provided with a file_paths: []"
raise ValueError(msg)
# Convert single path to list
path_list: List[Union[Path, str]] = (
path_list: list[Path | str] = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))
else list(self.file_paths)
@@ -102,8 +103,9 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
)
if not path_list:
msg = "file_path/file_paths must be a Path, str, or a list of these types"
raise ValueError(
"file_path/file_paths must be a Path, str, or a list of these types"
msg,
)
return [self.convert_to_path(path) for path in path_list]

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict, Field
@@ -12,41 +12,39 @@ class BaseKnowledgeSource(BaseModel, ABC):
chunk_size: int = 4000
chunk_overlap: int = 200
chunks: List[str] = Field(default_factory=list)
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
chunks: list[str] = Field(default_factory=list)
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: Optional[str] = Field(default=None)
storage: KnowledgeStorage | None = Field(default=None)
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: str | None = Field(default=None)
@abstractmethod
def validate_content(self) -> Any:
"""Load and preprocess content from the source."""
pass
@abstractmethod
def add(self) -> None:
"""Process content, chunk it, compute embeddings, and save them."""
pass
def get_embeddings(self) -> List[np.ndarray]:
def get_embeddings(self) -> list[np.ndarray]:
"""Return the list of embeddings for the chunks."""
return self.chunk_embeddings
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]
def _save_documents(self):
"""
Save the documents to the storage.
def _save_documents(self) -> None:
"""Save the documents to the storage.
This method should be called after the chunks and embeddings are generated.
"""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
msg = "No storage found to save documents."
raise ValueError(msg)

View File

@@ -1,5 +1,6 @@
from collections.abc import Iterator
from pathlib import Path
from typing import Iterator, List, Optional, Union
from typing import TYPE_CHECKING
from urllib.parse import urlparse
try:
@@ -7,7 +8,6 @@ try:
from docling.document_converter import DocumentConverter
from docling.exceptions import ConversionError
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.types.doc.document import DoclingDocument
DOCLING_AVAILABLE = True
except ImportError:
@@ -19,27 +19,33 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
if TYPE_CHECKING:
from docling_core.types.doc.document import DoclingDocument
class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
if not DOCLING_AVAILABLE:
raise ImportError(
msg = (
"The docling package is required to use CrewDoclingSource. "
"Please install it using: uv add docling"
)
raise ImportError(
msg,
)
super().__init__(*args, **kwargs)
_logger: Logger = Logger(verbose=True)
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
file_paths: List[Union[Path, str]] = Field(default_factory=list)
chunks: List[str] = Field(default_factory=list)
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
content: List["DoclingDocument"] = Field(default_factory=list)
file_path: list[Path | str] | None = Field(default=None)
file_paths: list[Path | str] = Field(default_factory=list)
chunks: list[str] = Field(default_factory=list)
safe_file_paths: list[Path | str] = Field(default_factory=list)
content: list["DoclingDocument"] = Field(default_factory=list)
document_converter: "DocumentConverter" = Field(
default_factory=lambda: DocumentConverter(
allowed_formats=[
@@ -51,8 +57,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
InputFormat.IMAGE,
InputFormat.XLSX,
InputFormat.PPTX,
]
)
],
),
)
def model_post_init(self, _) -> None:
@@ -66,7 +72,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.safe_file_paths = self.validate_content()
self.content = self._load_content()
def _load_content(self) -> List["DoclingDocument"]:
def _load_content(self) -> list["DoclingDocument"]:
try:
return self._convert_source_to_docling_documents()
except ConversionError as e:
@@ -75,10 +81,10 @@ class CrewDoclingSource(BaseKnowledgeSource):
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
"red",
)
raise e
raise
except Exception as e:
self._logger.log("error", f"Error loading content: {e}")
raise e
raise
def add(self) -> None:
if self.content is None:
@@ -88,7 +94,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.chunks.extend(list(new_chunks_iterable))
self._save_documents()
def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]:
def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter]
@@ -97,8 +103,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
for chunk in chunker.chunk(doc):
yield chunk.text
def validate_content(self) -> List[Union[Path, str]]:
processed_paths: List[Union[Path, str]] = []
def validate_content(self) -> list[Path | str]:
processed_paths: list[Path | str] = []
for path in self.file_paths:
if isinstance(path, str):
if path.startswith(("http://", "https://")):
@@ -106,15 +112,18 @@ class CrewDoclingSource(BaseKnowledgeSource):
if self._validate_url(path):
processed_paths.append(path)
else:
raise ValueError(f"Invalid URL format: {path}")
msg = f"Invalid URL format: {path}"
raise ValueError(msg)
except Exception as e:
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
msg = f"Invalid URL: {path}. Error: {e!s}"
raise ValueError(msg)
else:
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
if local_path.exists():
processed_paths.append(local_path)
else:
raise FileNotFoundError(f"File not found: {local_path}")
msg = f"File not found: {local_path}"
raise FileNotFoundError(msg)
else:
# this is an instance of Path
processed_paths.append(path)
@@ -128,7 +137,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
result.scheme in ("http", "https"),
result.netloc,
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
]
],
)
except Exception:
return False

View File

@@ -1,6 +1,5 @@
import csv
from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -8,11 +7,11 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class CSVKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries CSV file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess CSV file content."""
content_dict = {}
for file_path in self.safe_file_paths:
with open(file_path, "r", encoding="utf-8") as csvfile:
with open(file_path, encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
content = ""
for row in reader:
@@ -21,8 +20,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
return content_dict
def add(self) -> None:
"""
Add CSV file content to the knowledge source, chunk it, compute embeddings,
"""Add CSV file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
content_str = (
@@ -32,7 +30,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,6 +1,4 @@
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Union
from urllib.parse import urlparse
from pydantic import Field, field_validator
@@ -16,34 +14,34 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
_logger: Logger = Logger(verbose=True)
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
file_path: Path | list[Path] | str | list[str] | None = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default_factory=list, description="The path to the file"
file_paths: Path | list[Path] | str | list[str] | None = Field(
default_factory=list, description="The path to the file",
)
chunks: List[str] = Field(default_factory=list)
content: Dict[Path, Dict[str, str]] = Field(default_factory=dict)
safe_file_paths: List[Path] = Field(default_factory=list)
chunks: list[str] = Field(default_factory=list)
content: dict[Path, dict[str, str]] = Field(default_factory=dict)
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info):
def validate_file_path(self, v, info):
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
v is None
and info.data.get(
"file_path" if info.field_name == "file_paths" else "file_paths"
"file_path" if info.field_name == "file_paths" else "file_paths",
)
is None
):
raise ValueError("Either file_path or file_paths must be provided")
msg = "Either file_path or file_paths must be provided"
raise ValueError(msg)
return v
def _process_file_paths(self) -> List[Path]:
def _process_file_paths(self) -> list[Path]:
"""Convert file_path to a list of Path objects."""
if hasattr(self, "file_path") and self.file_path is not None:
self._logger.log(
"warning",
@@ -53,10 +51,11 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.file_paths = self.file_path
if self.file_paths is None:
raise ValueError("Your source must be provided with a file_paths: []")
msg = "Your source must be provided with a file_paths: []"
raise ValueError(msg)
# Convert single path to list
path_list: List[Union[Path, str]] = (
path_list: list[Path | str] = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))
else list(self.file_paths)
@@ -65,13 +64,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
)
if not path_list:
msg = "file_path/file_paths must be a Path, str, or a list of these types"
raise ValueError(
"file_path/file_paths must be a Path, str, or a list of these types"
msg,
)
return [self.convert_to_path(path) for path in path_list]
def validate_content(self):
def validate_content(self) -> None:
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -80,7 +80,8 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
color="red",
)
raise FileNotFoundError(f"File not found: {path}")
msg = f"File not found: {path}"
raise FileNotFoundError(msg)
if not path.is_file():
self._logger.log(
"error",
@@ -100,7 +101,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.validate_content()
self.content = self._load_content()
def _load_content(self) -> Dict[Path, Dict[str, str]]:
def _load_content(self) -> dict[Path, dict[str, str]]:
"""Load and preprocess Excel file content from multiple sheets.
Each sheet's content is converted to CSV format and stored.
@@ -111,6 +112,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
Raises:
ImportError: If required dependencies are missing.
FileNotFoundError: If the specified Excel file cannot be opened.
"""
pd = self._import_dependencies()
content_dict = {}
@@ -119,14 +121,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
with pd.ExcelFile(file_path) as xl:
sheet_dict = {
str(sheet_name): str(
pd.read_excel(xl, sheet_name).to_csv(index=False)
pd.read_excel(xl, sheet_name).to_csv(index=False),
)
for sheet_name in xl.sheet_names
}
content_dict[file_path] = sheet_dict
return content_dict
def convert_to_path(self, path: Union[Path, str]) -> Path:
def convert_to_path(self, path: Path | str) -> Path:
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
@@ -138,13 +140,13 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
return pd
except ImportError as e:
missing_package = str(e).split()[-1]
msg = f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
raise ImportError(
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
msg,
)
def add(self) -> None:
"""
Add Excel file content to the knowledge source, chunk it, compute embeddings,
"""Add Excel file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
# Convert dictionary values to a single string if content is a dictionary
@@ -161,7 +163,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Any, Dict, List
from typing import Any
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -8,12 +8,12 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class JSONKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries JSON file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess JSON file content."""
content: Dict[Path, str] = {}
content: dict[Path, str] = {}
for path in self.safe_file_paths:
path = self.convert_to_path(path)
with open(path, "r", encoding="utf-8") as json_file:
with open(path, encoding="utf-8") as json_file:
data = json.load(json_file)
content[path] = self._json_to_text(data)
return content
@@ -29,12 +29,11 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
for item in data:
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
else:
text += f"{str(data)}"
text += f"{data!s}"
return text
def add(self) -> None:
"""
Add JSON file content to the knowledge source, chunk it, compute embeddings,
"""Add JSON file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
content_str = (
@@ -44,7 +43,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class PDFKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries PDF file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess PDF file content."""
pdfplumber = self._import_pdfplumber()
@@ -31,21 +30,21 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
return pdfplumber
except ImportError:
msg = "pdfplumber is not installed. Please install it with: pip install pdfplumber"
raise ImportError(
"pdfplumber is not installed. Please install it with: pip install pdfplumber"
msg,
)
def add(self) -> None:
"""
Add PDF file content to the knowledge source, chunk it, compute embeddings,
"""Add PDF file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
for _, text in self.content.items():
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,4 +1,3 @@
from typing import List, Optional
from pydantic import Field
@@ -9,16 +8,17 @@ class StringKnowledgeSource(BaseKnowledgeSource):
"""A knowledge source that stores and queries plain text content using embeddings."""
content: str = Field(...)
collection_name: Optional[str] = Field(default=None)
collection_name: str | None = Field(default=None)
def model_post_init(self, _):
def model_post_init(self, _) -> None:
"""Post-initialization method to validate content."""
self.validate_content()
def validate_content(self):
def validate_content(self) -> None:
"""Validate string content."""
if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content")
msg = "StringKnowledgeSource only accepts string content"
raise ValueError(msg)
def add(self) -> None:
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them."""
@@ -26,7 +26,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -7,26 +6,25 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries text file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess text file content."""
content = {}
for path in self.safe_file_paths:
path = self.convert_to_path(path)
with open(path, "r", encoding="utf-8") as f:
with open(path, encoding="utf-8") as f:
content[path] = f.read()
return content
def add(self) -> None:
"""
Add text file content to the knowledge source, chunk it, compute embeddings,
"""Add text file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
for _, text in self.content.items():
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
class BaseKnowledgeStorage(ABC):
@@ -8,22 +8,19 @@ class BaseKnowledgeStorage(ABC):
@abstractmethod
def search(
self,
query: List[str],
query: list[str],
limit: int = 3,
filter: Optional[dict] = None,
filter: dict | None = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Search for documents in the knowledge base."""
pass
@abstractmethod
def save(
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
self, documents: list[str], metadata: dict[str, Any] | list[dict[str, Any]],
) -> None:
"""Save documents to the knowledge base."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the knowledge base."""
pass

View File

@@ -4,12 +4,11 @@ import io
import logging
import os
import shutil
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any
import chromadb
import chromadb.errors
from chromadb.api import ClientAPI
from chromadb.api.types import OneOrMany
from chromadb.config import Settings
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
@@ -19,6 +18,9 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path
if TYPE_CHECKING:
from chromadb.api.types import OneOrMany
@contextlib.contextmanager
def suppress_logging(
@@ -38,30 +40,29 @@ def suppress_logging(
class KnowledgeStorage(BaseKnowledgeStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
"""Extends Storage to handle embeddings for memory entries, improving
search efficiency.
"""
collection: Optional[chromadb.Collection] = None
collection_name: Optional[str] = "knowledge"
app: Optional[ClientAPI] = None
collection: chromadb.Collection | None = None
collection_name: str | None = "knowledge"
app: ClientAPI | None = None
def __init__(
self,
embedder: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None,
):
embedder: dict[str, Any] | None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
self._set_embedder_config(embedder)
def search(
self,
query: List[str],
query: list[str],
limit: int = 3,
filter: Optional[dict] = None,
filter: dict | None = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
with suppress_logging():
if self.collection:
fetched = self.collection.query(
@@ -80,10 +81,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
if result["score"] >= score_threshold:
results.append(result)
return results
else:
raise Exception("Collection not initialized")
msg = "Collection not initialized"
raise Exception(msg)
def initialize_knowledge_storage(self):
def initialize_knowledge_storage(self) -> None:
base_path = os.path.join(db_storage_path(), "knowledge")
chroma_client = chromadb.PersistentClient(
path=base_path,
@@ -104,11 +105,13 @@ class KnowledgeStorage(BaseKnowledgeStorage):
embedding_function=self.embedder,
)
else:
raise Exception("Vector Database Client not initialized")
msg = "Vector Database Client not initialized"
raise Exception(msg)
except Exception:
raise Exception("Failed to create or get collection")
msg = "Failed to create or get collection"
raise Exception(msg)
def reset(self):
def reset(self) -> None:
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
if not self.app:
self.app = chromadb.PersistentClient(
@@ -123,11 +126,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def save(
self,
documents: List[str],
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
):
documents: list[str],
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
) -> None:
if not self.collection:
raise Exception("Collection not initialized")
msg = "Collection not initialized"
raise Exception(msg)
try:
# Create a dictionary to store unique documents
@@ -156,7 +160,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
filtered_ids.append(doc_id)
# If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
final_metadata: OneOrMany[chromadb.Metadata] | None = (
None if all(m is None for m in filtered_metadata) else filtered_metadata
)
@@ -171,10 +175,13 @@ class KnowledgeStorage(BaseKnowledgeStorage):
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
msg = (
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
)
raise ValueError(
msg,
) from e
except Exception as e:
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
@@ -186,15 +193,16 @@ class KnowledgeStorage(BaseKnowledgeStorage):
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
)
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
def _set_embedder_config(self, embedder: dict[str, Any] | None = None) -> None:
"""Set the embedding configuration for the knowledge storage.
Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function.
"""
self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder)

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict, List
from typing import Any
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
def extract_knowledge_context(knowledge_snippets: list[dict[str, Any]]) -> str:
"""Extract knowledge from the task prompt."""
valid_snippets = [
result["context"]

View File

@@ -1,7 +1,7 @@
import asyncio
import uuid
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
from collections.abc import Callable
from typing import Any, cast
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
@@ -35,7 +35,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
show_agent_logs,
)
from crewai.utilities.converter import convert_to_model, generate_model_description
from crewai.utilities.converter import generate_model_description
from crewai.utilities.events.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
@@ -60,15 +60,15 @@ class LiteAgentOutput(BaseModel):
model_config = {"arbitrary_types_allowed": True}
raw: str = Field(description="Raw output of the agent", default="")
pydantic: Optional[BaseModel] = Field(
description="Pydantic output of the agent", default=None
pydantic: BaseModel | None = Field(
description="Pydantic output of the agent", default=None,
)
agent_role: str = Field(description="Role of the agent that produced this output")
usage_metrics: Optional[Dict[str, Any]] = Field(
description="Token usage metrics for this execution", default=None
usage_metrics: dict[str, Any] | None = Field(
description="Token usage metrics for this execution", default=None,
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert pydantic_output to a dictionary."""
if self.pydantic:
return self.pydantic.model_dump()
@@ -82,8 +82,7 @@ class LiteAgentOutput(BaseModel):
class LiteAgent(FlowTrackable, BaseModel):
"""
A lightweight agent that can process messages and use tools.
"""A lightweight agent that can process messages and use tools.
This agent is simpler than the full Agent class, focusing on direct execution
rather than task delegation. It's designed to be used for simple interactions
@@ -99,6 +98,7 @@ class LiteAgent(FlowTrackable, BaseModel):
max_iterations: Maximum number of iterations for tool usage.
max_execution_time: Maximum execution time in seconds.
response_format: Optional Pydantic model for structured output.
"""
model_config = {"arbitrary_types_allowed": True}
@@ -107,19 +107,19 @@ class LiteAgent(FlowTrackable, BaseModel):
role: str = Field(description="Role of the agent")
goal: str = Field(description="Goal of the agent")
backstory: str = Field(description="Backstory of the agent")
llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
default=None, description="Language model that will run the agent"
llm: str | InstanceOf[LLM] | Any | None = Field(
default=None, description="Language model that will run the agent",
)
tools: List[BaseTool] = Field(
default_factory=list, description="Tools at agent's disposal"
tools: list[BaseTool] = Field(
default_factory=list, description="Tools at agent's disposal",
)
# Execution Control Properties
max_iterations: int = Field(
default=15, description="Maximum number of iterations for tool usage"
default=15, description="Maximum number of iterations for tool usage",
)
max_execution_time: Optional[int] = Field(
default=None, description="Maximum execution time in seconds"
max_execution_time: int | None = Field(
default=None, description="Maximum execution time in seconds",
)
respect_context_window: bool = Field(
default=True,
@@ -129,38 +129,38 @@ class LiteAgent(FlowTrackable, BaseModel):
default=True,
description="Whether to use stop words to prevent the LLM from using tools",
)
request_within_rpm_limit: Optional[Callable[[], bool]] = Field(
request_within_rpm_limit: Callable[[], bool] | None = Field(
default=None,
description="Callback to check if the request is within the RPM limit",
)
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
# Output and Formatting Properties
response_format: Optional[Type[BaseModel]] = Field(
default=None, description="Pydantic model for structured output"
response_format: type[BaseModel] | None = Field(
default=None, description="Pydantic model for structured output",
)
verbose: bool = Field(
default=False, description="Whether to print execution details"
default=False, description="Whether to print execution details",
)
callbacks: List[Callable] = Field(
default=[], description="Callbacks to be used for the agent"
callbacks: list[Callable] = Field(
default=[], description="Callbacks to be used for the agent",
)
# State and Results
tools_results: List[Dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent.",
)
# Reference of Agent
original_agent: Optional[BaseAgent] = Field(
default=None, description="Reference to the agent that created this LiteAgent"
original_agent: BaseAgent | None = Field(
default=None, description="Reference to the agent that created this LiteAgent",
)
# Private Attributes
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
_iterations: int = PrivateAttr(default=0)
_printer: Printer = PrivateAttr(default_factory=Printer)
@@ -169,7 +169,8 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Set up the LLM and other components after initialization."""
self.llm = create_llm(self.llm)
if not isinstance(self.llm, LLM):
raise ValueError("Unable to create LLM instance")
msg = "Unable to create LLM instance"
raise ValueError(msg)
# Initialize callbacks
token_callback = TokenCalcHandler(token_cost_process=self._token_process)
@@ -194,9 +195,8 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Return the original role for compatibility with tool interfaces."""
return self.role
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
"""
Execute the agent with the given messages.
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
"""Execute the agent with the given messages.
Args:
messages: Either a string query or a list of message dictionaries.
@@ -205,6 +205,7 @@ class LiteAgent(FlowTrackable, BaseModel):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
# Create agent info for event emission
agent_info = {
@@ -235,18 +236,18 @@ class LiteAgent(FlowTrackable, BaseModel):
# Execute the agent using invoke loop
agent_finish = self._invoke_loop()
formatted_result: Optional[BaseModel] = None
formatted_result: BaseModel | None = None
if self.response_format:
try:
# Cast to BaseModel to ensure type safety
result = self.response_format.model_validate_json(
agent_finish.output
agent_finish.output,
)
if isinstance(result, BaseModel):
formatted_result = result
except Exception as e:
self._printer.print(
content=f"Failed to parse output into response format: {str(e)}",
content=f"Failed to parse output into response format: {e!s}",
color="yellow",
)
@@ -286,13 +287,12 @@ class LiteAgent(FlowTrackable, BaseModel):
error=str(e),
),
)
raise e
raise
async def kickoff_async(
self, messages: Union[str, List[Dict[str, str]]]
self, messages: str | list[dict[str, str]],
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages.
"""Execute the agent asynchronously with the given messages.
Args:
messages: Either a string query or a list of message dictionaries.
@@ -301,6 +301,7 @@ class LiteAgent(FlowTrackable, BaseModel):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
return await asyncio.to_thread(self.kickoff, messages)
@@ -319,7 +320,7 @@ class LiteAgent(FlowTrackable, BaseModel):
else:
# Use the prompt template for agents without tools
base_prompt = self.i18n.slice(
"lite_agent_system_prompt_without_tools"
"lite_agent_system_prompt_without_tools",
).format(
role=self.role,
backstory=self.backstory,
@@ -330,14 +331,14 @@ class LiteAgent(FlowTrackable, BaseModel):
if self.response_format:
schema = generate_model_description(self.response_format)
base_prompt += self.i18n.slice("lite_agent_response_format").format(
response_format=schema
response_format=schema,
)
return base_prompt
def _format_messages(
self, messages: Union[str, List[Dict[str, str]]]
) -> List[Dict[str, str]]:
self, messages: str | list[dict[str, str]],
) -> list[dict[str, str]]:
"""Format messages for the LLM."""
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
@@ -353,11 +354,11 @@ class LiteAgent(FlowTrackable, BaseModel):
return formatted_messages
def _invoke_loop(self) -> AgentFinish:
"""
Run the agent's thought process until it reaches a conclusion or max iterations.
"""Run the agent's thought process until it reaches a conclusion or max iterations.
Returns:
AgentFinish: The final result of the agent execution.
"""
# Execute the agent loop
formatted_answer = None
@@ -369,7 +370,7 @@ class LiteAgent(FlowTrackable, BaseModel):
printer=self._printer,
i18n=self.i18n,
messages=self._messages,
llm=cast(LLM, self.llm),
llm=cast("LLM", self.llm),
callbacks=self._callbacks,
)
@@ -387,7 +388,7 @@ class LiteAgent(FlowTrackable, BaseModel):
try:
answer = get_llm_response(
llm=cast(LLM, self.llm),
llm=cast("LLM", self.llm),
messages=self._messages,
callbacks=self._callbacks,
printer=self._printer,
@@ -407,7 +408,7 @@ class LiteAgent(FlowTrackable, BaseModel):
self,
event=LLMCallFailedEvent(error=str(e)),
)
raise e
raise
formatted_answer = process_llm_response(answer, self.use_stop_words)
@@ -421,8 +422,8 @@ class LiteAgent(FlowTrackable, BaseModel):
agent_role=self.role,
agent=self.original_agent,
)
except Exception as e:
raise e
except Exception:
raise
formatted_answer = handle_agent_action_core(
formatted_answer=formatted_answer,
@@ -443,20 +444,19 @@ class LiteAgent(FlowTrackable, BaseModel):
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
raise e
raise
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self._messages,
llm=cast(LLM, self.llm),
llm=cast("LLM", self.llm),
callbacks=self._callbacks,
i18n=self.i18n,
)
continue
else:
handle_unknown_error(self._printer, e)
raise e
handle_unknown_error(self._printer, e)
raise
finally:
self._iterations += 1
@@ -465,7 +465,7 @@ class LiteAgent(FlowTrackable, BaseModel):
self._show_logs(formatted_answer)
return formatted_answer
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
"""Show logs for the agent's execution."""
show_agent_logs(
printer=self._printer,

View File

@@ -6,17 +6,10 @@ import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from types import SimpleNamespace
from typing import (
Any,
DefaultDict,
Dict,
List,
Literal,
Optional,
Type,
TypedDict,
Union,
cast,
)
@@ -31,7 +24,6 @@ from crewai.utilities.events.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
)
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
@@ -55,7 +47,7 @@ load_dotenv()
class FilteredStream:
def __init__(self, original_stream):
def __init__(self, original_stream) -> None:
self._original_stream = original_stream
self._lock = threading.Lock()
@@ -210,7 +202,7 @@ def suppress_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
warnings.filterwarnings(
"ignore", message="open_text is deprecated*", category=DeprecationWarning
"ignore", message="open_text is deprecated*", category=DeprecationWarning,
)
# Redirect stdout and stderr
@@ -226,14 +218,14 @@ def suppress_warnings():
class Delta(TypedDict):
content: Optional[str]
role: Optional[str]
content: str | None
role: str | None
class StreamingChoices(TypedDict):
delta: Delta
index: int
finish_reason: Optional[str]
finish_reason: str | None
class FunctionArgs(BaseModel):
@@ -249,29 +241,31 @@ class LLM(BaseLLM):
def __init__(
self,
model: str,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
timeout: float | None = None,
temperature: float | None = None,
top_p: float | None = None,
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
logprobs: int | None = None,
top_logprobs: int | None = None,
base_url: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
**kwargs,
):
) -> None:
if callbacks is None:
callbacks = []
self.model = model
self.timeout = timeout
self.temperature = temperature
@@ -301,7 +295,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str]
if stop is None:
self.stop: List[str] = []
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
@@ -318,15 +312,16 @@ class LLM(BaseLLM):
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for the completion call.
Args:
@@ -337,6 +332,7 @@ class LLM(BaseLLM):
Returns:
Dict[str, Any]: Parameters for the completion call
"""
# --- 1) Format messages according to provider requirements
if isinstance(messages, str):
@@ -375,9 +371,9 @@ class LLM(BaseLLM):
def _handle_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str:
"""Handle a streaming response from the LLM.
@@ -391,6 +387,7 @@ class LLM(BaseLLM):
Raises:
Exception: If no content is received from the streaming response
"""
# --- 1) Initialize response tracking
full_response = ""
@@ -399,8 +396,8 @@ class LLM(BaseLLM):
usage_info = None
tool_calls = None
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs,
)
# --- 2) Make sure stream is set to True and include usage metrics
@@ -424,16 +421,16 @@ class LLM(BaseLLM):
choices = chunk["choices"]
elif hasattr(chunk, "choices"):
# Check if choices is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "choices"), type):
choices = getattr(chunk, "choices")
if not isinstance(chunk.choices, type):
choices = chunk.choices
# Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"]
elif hasattr(chunk, "usage"):
# Check if usage is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "usage"), type):
usage_info = getattr(chunk, "usage")
if not isinstance(chunk.usage, type):
usage_info = chunk.usage
if choices and len(choices) > 0:
choice = choices[0]
@@ -443,7 +440,7 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "delta" in choice:
delta = choice["delta"]
elif hasattr(choice, "delta"):
delta = getattr(choice, "delta")
delta = choice.delta
# Extract content from delta
if delta:
@@ -453,7 +450,7 @@ class LLM(BaseLLM):
chunk_content = delta["content"]
# Handle object format
elif hasattr(delta, "content"):
chunk_content = getattr(delta, "content")
chunk_content = delta.content
# Handle case where content might be None or empty
if chunk_content is None and isinstance(delta, dict):
@@ -491,21 +488,21 @@ class LLM(BaseLLM):
# --- 4) Fallback to non-streaming if no content received
if not full_response.strip() and chunk_count == 0:
logging.warning(
"No chunks received in streaming response, falling back to non-streaming"
"No chunks received in streaming response, falling back to non-streaming",
)
non_streaming_params = params.copy()
non_streaming_params["stream"] = False
non_streaming_params.pop(
"stream_options", None
"stream_options", None,
) # Remove stream_options for non-streaming call
return self._handle_non_streaming_response(
non_streaming_params, callbacks, available_functions
non_streaming_params, callbacks, available_functions,
)
# --- 5) Handle empty response with chunks
if not full_response.strip() and chunk_count > 0:
logging.warning(
f"Received {chunk_count} chunks but no content was extracted"
f"Received {chunk_count} chunks but no content was extracted",
)
if last_chunk is not None:
try:
@@ -514,8 +511,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -525,30 +522,31 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
content = None
if isinstance(message, dict) and "content" in message:
content = message["content"]
elif hasattr(message, "content"):
content = getattr(message, "content")
content = message.content
if content:
full_response = content
logging.info(
f"Extracted content from last chunk message: {full_response}"
f"Extracted content from last chunk message: {full_response}",
)
except Exception as e:
logging.debug(f"Error extracting content from last chunk: {e}")
logging.debug(
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}"
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}",
)
# --- 6) If still empty, raise an error instead of using a default response
if not full_response.strip() and len(accumulated_tool_args) == 0:
msg = "No content received from streaming response. Received empty chunks or failed to extract content."
raise Exception(
"No content received from streaming response. Received empty chunks or failed to extract content."
msg,
)
# --- 7) Check for tool calls in the final response
@@ -559,8 +557,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -569,13 +567,13 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"):
tool_calls = getattr(message, "tool_calls")
tool_calls = message.tool_calls
except Exception as e:
logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly
@@ -605,9 +603,9 @@ class LLM(BaseLLM):
# decide whether to summarize the content or abort based on the respect_context_window flag.
raise LLMContextLengthExceededException(str(e))
except Exception as e:
logging.error(f"Error in streaming response: {str(e)}")
logging.exception(f"Error in streaming response: {e!s}")
if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}")
logging.warning(f"Returning partial response despite error: {e!s}")
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
return full_response
@@ -617,13 +615,14 @@ class LLM(BaseLLM):
self,
event=LLMCallFailedEvent(error=str(e)),
)
raise Exception(f"Failed to get streaming response: {str(e)}")
msg = f"Failed to get streaming response: {e!s}"
raise Exception(msg)
def _handle_streaming_tool_calls(
self,
tool_calls: List[ChatCompletionDeltaToolCall],
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
available_functions: Optional[Dict[str, Any]] = None,
tool_calls: list[ChatCompletionDeltaToolCall],
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: dict[str, Any] | None = None,
) -> None | str:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -662,9 +661,9 @@ class LLM(BaseLLM):
def _handle_streaming_callbacks(
self,
callbacks: Optional[List[Any]],
usage_info: Optional[Dict[str, Any]],
last_chunk: Optional[Any],
callbacks: list[Any] | None,
usage_info: dict[str, Any] | None,
last_chunk: Any | None,
) -> None:
"""Handle callbacks with usage info for streaming responses.
@@ -672,6 +671,7 @@ class LLM(BaseLLM):
callbacks: Optional list of callback functions
usage_info: Usage information collected during streaming
last_chunk: The last chunk received from the streaming response
"""
if callbacks and len(callbacks) > 0:
for callback in callbacks:
@@ -688,9 +688,9 @@ class LLM(BaseLLM):
usage_info = last_chunk["usage"]
elif hasattr(last_chunk, "usage"):
if not isinstance(
getattr(last_chunk, "usage"), type
last_chunk.usage, type,
):
usage_info = getattr(last_chunk, "usage")
usage_info = last_chunk.usage
except Exception as e:
logging.debug(f"Error extracting usage info: {e}")
@@ -704,9 +704,9 @@ class LLM(BaseLLM):
def _handle_non_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str:
"""Handle a non-streaming response from the LLM.
@@ -717,6 +717,7 @@ class LLM(BaseLLM):
Returns:
str: The response text
"""
# --- 1) Make the completion call
try:
@@ -731,7 +732,7 @@ class LLM(BaseLLM):
raise LLMContextLengthExceededException(str(e))
# --- 2) Extract response message and content
response_message = cast(Choices, cast(ModelResponse, response).choices)[
response_message = cast("Choices", cast("ModelResponse", response).choices)[
0
].message
text_response = response_message.content or ""
@@ -768,9 +769,9 @@ class LLM(BaseLLM):
def _handle_tool_call(
self,
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
tool_calls: list[Any],
available_functions: dict[str, Any] | None = None,
) -> str | None:
"""Handle a tool call from the LLM.
Args:
@@ -779,6 +780,7 @@ class LLM(BaseLLM):
Returns:
Optional[str]: The result of the tool call, or None if no tool call was made
"""
# --- 1) Validate tool calls and available functions
if not tool_calls or not available_functions:
@@ -805,23 +807,23 @@ class LLM(BaseLLM):
except Exception as e:
# --- 3.4) Handle execution errors
fn = available_functions.get(
function_name, lambda: None
function_name, lambda: None,
) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}")
logging.exception(f"Error executing function '{function_name}': {e}")
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
)
return None
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""High-level LLM call method.
Args:
@@ -844,6 +846,7 @@ class LLM(BaseLLM):
TypeError: If messages format is invalid
ValueError: If response format is not supported
LLMContextLengthExceededException: If input exceeds model's context limit
"""
# --- 1) Emit call started event
assert hasattr(crewai_event_bus, "emit")
@@ -882,12 +885,11 @@ class LLM(BaseLLM):
# --- 7) Make the completion call and handle response
if self.stream:
return self._handle_streaming_response(
params, callbacks, available_functions
)
else:
return self._handle_non_streaming_response(
params, callbacks, available_functions
params, callbacks, available_functions,
)
return self._handle_non_streaming_response(
params, callbacks, available_functions,
)
except LLMContextLengthExceededException:
# Re-raise LLMContextLengthExceededException as it should be handled
@@ -900,15 +902,16 @@ class LLM(BaseLLM):
self,
event=LLMCallFailedEvent(error=str(e)),
)
logging.error(f"LiteLLM call failed: {str(e)}")
logging.exception(f"LiteLLM call failed: {e!s}")
raise
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType):
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType) -> None:
"""Handle the events for the LLM call.
Args:
response (str): The response from the LLM call.
call_type (str): The type of call, either "tool_call" or "llm_call".
"""
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
@@ -917,8 +920,8 @@ class LLM(BaseLLM):
)
def _format_messages_for_provider(
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
self, messages: list[dict[str, str]],
) -> list[dict[str, str]]:
"""Format messages according to provider requirements.
Args:
@@ -931,15 +934,18 @@ class LLM(BaseLLM):
Raises:
TypeError: If messages is None or contains invalid message format.
"""
if messages is None:
raise TypeError("Messages cannot be None")
msg = "Messages cannot be None"
raise TypeError(msg)
# Validate message format first
for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
msg = "Invalid message format. Each message must be a dict with 'role' and 'content' keys"
raise TypeError(
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
msg,
)
# Handle O1 models specially
@@ -949,7 +955,7 @@ class LLM(BaseLLM):
# Convert system messages to assistant messages
if msg["role"] == "system":
formatted_messages.append(
{"role": "assistant", "content": msg["content"]}
{"role": "assistant", "content": msg["content"]},
)
else:
formatted_messages.append(msg)
@@ -977,9 +983,8 @@ class LLM(BaseLLM):
return messages
def _get_custom_llm_provider(self) -> Optional[str]:
"""
Derives the custom_llm_provider from the model string.
def _get_custom_llm_provider(self) -> str | None:
"""Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
- If the model is "gemini/gemini-1.5-pro", returns "gemini".
- If there is no '/', defaults to "openai".
@@ -989,8 +994,7 @@ class LLM(BaseLLM):
return None
def _validate_call_params(self) -> None:
"""
Validate parameters before making a call. Currently this only checks if
"""Validate parameters before making a call. Currently this only checks if
a response_format is provided and whether the model supports it.
The custom_llm_provider is dynamically determined from the model:
- E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter"
@@ -1002,19 +1006,22 @@ class LLM(BaseLLM):
model=self.model,
custom_llm_provider=provider,
):
raise ValueError(
msg = (
f"The model {self.model} does not support response_format for provider '{provider}'. "
"Please remove response_format or use a supported model."
)
raise ValueError(
msg,
)
def supports_function_calling(self) -> bool:
try:
provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling(
self.model, custom_llm_provider=provider
self.model, custom_llm_provider=provider,
)
except Exception as e:
logging.error(f"Failed to check function calling support: {str(e)}")
logging.exception(f"Failed to check function calling support: {e!s}")
return False
def supports_stop_words(self) -> bool:
@@ -1022,16 +1029,16 @@ class LLM(BaseLLM):
params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
logging.exception(f"Failed to get supported params: {e!s}")
return False
def get_context_window_size(self) -> int:
"""
Returns the context window size, using 75% of the maximum to avoid
"""Returns the context window size, using 75% of the maximum to avoid
cutting off messages mid-thread.
Raises:
ValueError: If a model's context window size is outside valid bounds (1024-2097152)
"""
if self.context_window_size != 0:
return self.context_window_size
@@ -1042,21 +1049,21 @@ class LLM(BaseLLM):
# Validate all context window sizes
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < MIN_CONTEXT or value > MAX_CONTEXT:
msg = f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
raise ValueError(
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
msg,
)
self.context_window_size = int(
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO,
)
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if self.model.startswith(key):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
"""
Attempt to keep a single set of callbacks in litellm by removing old
def set_callbacks(self, callbacks: list[Any]) -> None:
"""Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
"""
with suppress_warnings():
@@ -1071,9 +1078,8 @@ class LLM(BaseLLM):
litellm.callbacks = callbacks
def set_env_callbacks(self):
"""
Sets the success and failure callbacks for the LiteLLM library from environment variables.
def set_env_callbacks(self) -> None:
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
environment variables, which should contain comma-separated lists of callback names.
@@ -1089,6 +1095,7 @@ class LLM(BaseLLM):
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
`litellm.failure_callback` to ["langfuse"].
"""
with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any
class BaseLLM(ABC):
@@ -17,17 +17,18 @@ class BaseLLM(ABC):
Attributes:
stop (list): A list of stop sequences that the LLM should use to stop generation.
This is used by the CrewAgentExecutor and other components.
"""
model: str
temperature: Optional[float] = None
stop: Optional[List[str]] = None
temperature: float | None = None
stop: list[str] | None = None
def __init__(
self,
model: str,
temperature: Optional[float] = None,
):
temperature: float | None = None,
) -> None:
"""Initialize the BaseLLM with default attributes.
This constructor sets default values for attributes that are expected
@@ -43,11 +44,11 @@ class BaseLLM(ABC):
@abstractmethod
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""Call the LLM with the given messages.
Args:
@@ -70,14 +71,15 @@ class BaseLLM(ABC):
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
"""
pass
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
bool: True if the LLM supports stop words, False otherwise.
"""
return True # Default implementation assumes support for stop words
@@ -86,6 +88,7 @@ class BaseLLM(ABC):
Returns:
int: The number of tokens/characters the model can handle.
"""
# Default implementation - subclasses should override with model-specific values
return 4096

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any
import aisuite as ai
@@ -6,17 +6,17 @@ from crewai.llms.base_llm import BaseLLM
class AISuiteLLM(BaseLLM):
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
def __init__(self, model: str, temperature: float | None = None, **kwargs) -> None:
super().__init__(model, temperature, **kwargs)
self.client = ai.Client()
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
completion_params = self._prepare_completion_params(messages, tools)
response = self.client.chat.completions.create(**completion_params)
@@ -24,9 +24,9 @@ class AISuiteLLM(BaseLLM):
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
return {
"model": self.model,
"messages": messages,

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any
from crewai.memory import (
EntityMemory,
@@ -12,13 +12,13 @@ from crewai.memory import (
class ContextualMemory:
def __init__(
self,
memory_config: Optional[Dict[str, Any]],
memory_config: dict[str, Any] | None,
stm: ShortTermMemory,
ltm: LongTermMemory,
em: EntityMemory,
um: UserMemory,
exm: ExternalMemory,
):
) -> None:
if memory_config is not None:
self.memory_provider = memory_config.get("provider")
else:
@@ -30,8 +30,7 @@ class ContextualMemory:
self.exm = exm
def build_context_for_task(self, task, context) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
"""Automatically builds a minimal, highly relevant set of contextual information
for a given task.
"""
query = f"{task.description} {context}".strip()
@@ -49,11 +48,9 @@ class ContextualMemory:
return "\n".join(filter(None, context))
def _fetch_stm_context(self, query) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
"""Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points.
"""
if self.stm is None:
return ""
@@ -62,16 +59,14 @@ class ContextualMemory:
[
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
for result in stm_results
]
],
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]:
"""
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
def _fetch_ltm_context(self, task) -> str | None:
"""Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
"""
if self.ltm is None:
return ""
@@ -90,8 +85,7 @@ class ContextualMemory:
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:
"""
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
"""Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points.
"""
if self.em is None:
@@ -102,19 +96,20 @@ class ContextualMemory:
[
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
for result in em_results
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
], # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
)
return f"Entities:\n{formatted_results}" if em_results else ""
def _fetch_user_context(self, query: str) -> str:
"""
Fetches and formats relevant user information from User Memory.
"""Fetches and formats relevant user information from User Memory.
Args:
query (str): The search query to find relevant user memories.
Returns:
str: Formatted user memories as bullet points, or an empty string if none found.
"""
"""
if self.um is None:
return ""
@@ -128,12 +123,14 @@ class ContextualMemory:
return f"User memories/preferences:\n{formatted_memories}"
def _fetch_external_context(self, query: str) -> str:
"""
Fetches and formats relevant information from External Memory.
"""Fetches and formats relevant information from External Memory.
Args:
query (str): The search query to find relevant information.
Returns:
str: Formatted information as bullet points, or an empty string if none found.
"""
if self.exm is None:
return ""

View File

@@ -1,4 +1,3 @@
from typing import Optional
from pydantic import PrivateAttr
@@ -8,15 +7,14 @@ from crewai.memory.storage.rag_storage import RAGStorage
class EntityMemory(Memory):
"""
EntityMemory class for managing structured information about entities
"""EntityMemory class for managing structured information about entities
and their relationships using SQLite storage.
Inherits from the Memory class.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider")
else:
@@ -26,8 +24,9 @@ class EntityMemory(Memory):
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
msg,
)
storage = Mem0Storage(type="entities", crew=crew)
else:
@@ -63,4 +62,5 @@ class EntityMemory(Memory):
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the entity memory: {e}")
msg = f"An error occurred while resetting the entity memory: {e}"
raise Exception(msg)

View File

@@ -5,7 +5,7 @@ class EntityMemoryItem:
type: str,
description: str,
relationships: str,
):
) -> None:
self.name = name
self.type = type
self.description = description

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
@@ -9,41 +9,44 @@ if TYPE_CHECKING:
class ExternalMemory(Memory):
def __init__(self, storage: Optional[Storage] = None, **data: Any):
def __init__(self, storage: Storage | None = None, **data: Any) -> None:
super().__init__(storage=storage, **data)
@staticmethod
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
@staticmethod
def external_supported_storages() -> Dict[str, Any]:
def external_supported_storages() -> dict[str, Any]:
return {
"mem0": ExternalMemory._configure_mem0,
}
@staticmethod
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
if not embedder_config:
raise ValueError("embedder_config is required")
msg = "embedder_config is required"
raise ValueError(msg)
if "provider" not in embedder_config:
raise ValueError("embedder_config must include a 'provider' key")
msg = "embedder_config must include a 'provider' key"
raise ValueError(msg)
provider = embedder_config["provider"]
supported_storages = ExternalMemory.external_supported_storages()
if provider not in supported_storages:
raise ValueError(f"Provider {provider} not supported")
msg = f"Provider {provider} not supported"
raise ValueError(msg)
return supported_storages[provider](crew, embedder_config.get("config", {}))
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
"""Saves a value into the external storage."""
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)

View File

@@ -1,13 +1,13 @@
from typing import Any, Dict, Optional
from typing import Any
class ExternalMemoryItem:
def __init__(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
):
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
self.value = value
self.metadata = metadata
self.agent = agent

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@@ -6,15 +6,14 @@ from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
class LongTermMemory(Memory):
"""
LongTermMemory class for managing cross runs data related to overall crew's
"""LongTermMemory class for managing cross runs data related to overall crew's
execution and performance.
Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with
LongTermMemoryItem instances.
"""
def __init__(self, storage=None, path=None):
def __init__(self, storage=None, path=None) -> None:
if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
@@ -29,7 +28,7 @@ class LongTermMemory(Memory):
datetime=item.datetime,
)
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
def search(self, task: str, latest_n: int = 3) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union
from typing import Any
class LongTermMemoryItem:
@@ -8,9 +8,9 @@ class LongTermMemoryItem:
task: str,
expected_output: str,
datetime: str,
quality: Optional[Union[int, float]] = None,
metadata: Optional[Dict[str, Any]] = None,
):
quality: float | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
self.task = task
self.agent = agent
self.quality = quality

View File

@@ -1,26 +1,24 @@
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel
class Memory(BaseModel):
"""
Base class for memory, now supporting agent tags and generic metadata.
"""
"""Base class for memory, now supporting agent tags and generic metadata."""
embedder_config: Optional[Dict[str, Any]] = None
crew: Optional[Any] = None
embedder_config: dict[str, Any] | None = None
crew: Any | None = None
storage: Any
def __init__(self, storage: Any, **data: Any):
def __init__(self, storage: Any, **data: Any) -> None:
super().__init__(storage=storage, **data)
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
metadata = metadata or {}
if agent:
@@ -33,9 +31,9 @@ class Memory(BaseModel):
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
query=query, limit=limit, score_threshold=score_threshold,
)
def set_crew(self, crew: Any) -> "Memory":

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any
from pydantic import PrivateAttr
@@ -8,17 +8,16 @@ from crewai.memory.storage.rag_storage import RAGStorage
class ShortTermMemory(Memory):
"""
ShortTermMemory class for managing transient data related to immediate tasks
"""ShortTermMemory class for managing transient data related to immediate tasks
and interactions.
Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with
MemoryItem instances.
"""
_memory_provider: Optional[str] = PrivateAttr()
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider")
else:
@@ -28,8 +27,9 @@ class ShortTermMemory(Memory):
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
msg,
)
storage = Mem0Storage(type="short_term", crew=crew)
else:
@@ -49,8 +49,8 @@ class ShortTermMemory(Memory):
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
if self._memory_provider == "mem0":
@@ -65,13 +65,14 @@ class ShortTermMemory(Memory):
score_threshold: float = 0.35,
):
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
query=query, limit=limit, score_threshold=score_threshold,
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
msg = f"An error occurred while resetting the short-term memory: {e}"
raise Exception(
f"An error occurred while resetting the short-term memory: {e}"
msg,
)

View File

@@ -1,13 +1,13 @@
from typing import Any, Dict, Optional
from typing import Any
class ShortTermMemoryItem:
def __init__(
self,
data: Any,
agent: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
agent: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
self.data = data
self.agent = agent
self.metadata = metadata if metadata is not None else {}

View File

@@ -1,11 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
class BaseRAGStorage(ABC):
"""
Base class for RAG-based Storage implementations.
"""
"""Base class for RAG-based Storage implementations."""
app: Any | None = None
@@ -13,9 +11,9 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: Optional[Dict[str, Any]] = None,
embedder_config: dict[str, Any] | None = None,
crew: Any = None,
):
) -> None:
self.type = type
self.allow_reset = allow_reset
self.embedder_config = embedder_config
@@ -25,52 +23,44 @@ class BaseRAGStorage(ABC):
def _initialize_agents(self) -> str:
if self.crew:
return "_".join(
[self._sanitize_role(agent.role) for agent in self.crew.agents]
[self._sanitize_role(agent.role) for agent in self.crew.agents],
)
return ""
@abstractmethod
def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names."""
pass
@abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value with metadata to the storage."""
pass
@abstractmethod
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
filter: dict | None = None,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
"""Search for entries in the storage."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the storage."""
pass
@abstractmethod
def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None
self, text: str, metadata: dict[str, Any] | None = None,
) -> Any:
"""Generate an embedding for the given text and metadata."""
pass
@abstractmethod
def _initialize_app(self):
"""Initialize the vector db."""
pass
def setup_config(self, config: Dict[str, Any]):
def setup_config(self, config: dict[str, Any]) -> None:
"""Setup the config of the storage."""
pass
def initialize_client(self):
"""Initialize the client of the storage. This should setup the app and the db collection"""
pass
def initialize_client(self) -> None:
"""Initialize the client of the storage. This should setup the app and the db collection."""

View File

@@ -1,15 +1,15 @@
from typing import Any, Dict, List
from typing import Any
class Storage:
"""Abstract base class defining the storage interface"""
"""Abstract base class defining the storage interface."""
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
pass
def search(
self, query: str, limit: int, score_threshold: float
) -> Dict[str, Any] | List[Any]:
self, query: str, limit: int, score_threshold: float,
) -> dict[str, Any] | list[Any]:
return {}
def reset(self) -> None:

View File

@@ -2,7 +2,7 @@ import json
import logging
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
from crewai.task import Task
from crewai.utilities import Printer
@@ -14,12 +14,10 @@ logger = logging.getLogger(__name__)
class KickoffTaskOutputsSQLiteStorage:
"""
An updated SQLite storage class for kickoff task outputs storage.
"""
"""An updated SQLite storage class for kickoff task outputs storage."""
def __init__(
self, db_path: Optional[str] = None
self, db_path: str | None = None,
) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
@@ -37,6 +35,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If database initialization fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -52,22 +51,22 @@ class KickoffTaskOutputsSQLiteStorage:
was_replayed BOOLEAN,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
""",
)
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)
def add(
self,
task: Task,
output: Dict[str, Any],
output: dict[str, Any],
task_index: int,
was_replayed: bool = False,
inputs: Dict[str, Any] = {},
inputs: dict[str, Any] | None = None,
) -> None:
"""Add a new task output record to the database.
@@ -80,7 +79,10 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If saving the task output fails due to SQLite errors.
"""
if inputs is None:
inputs = {}
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute("BEGIN TRANSACTION")
@@ -103,7 +105,7 @@ class KickoffTaskOutputsSQLiteStorage:
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)
def update(
@@ -123,6 +125,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If updating the task output fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -136,7 +139,7 @@ class KickoffTaskOutputsSQLiteStorage:
values.append(
json.dumps(value, cls=CrewJSONEncoder)
if isinstance(value, dict)
else value
else value,
)
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
@@ -149,10 +152,10 @@ class KickoffTaskOutputsSQLiteStorage:
logger.warning(f"No row found with task_index {task_index}. No update performed.")
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)
def load(self) -> List[Dict[str, Any]]:
def load(self) -> list[dict[str, Any]]:
"""Load all task output records from the database.
Returns:
@@ -162,6 +165,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -190,7 +194,7 @@ class KickoffTaskOutputsSQLiteStorage:
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)
def delete_all(self) -> None:
@@ -201,6 +205,7 @@ class KickoffTaskOutputsSQLiteStorage:
Raises:
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -210,5 +215,5 @@ class KickoffTaskOutputsSQLiteStorage:
conn.commit()
except sqlite3.Error as e:
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
logger.error(error_msg)
logger.exception(error_msg)
raise DatabaseOperationError(error_msg, e)

View File

@@ -1,19 +1,17 @@
import json
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any
from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
class LTMSQLiteStorage:
"""
An updated SQLite storage class for LTM data storage.
"""
"""An updated SQLite storage class for LTM data storage."""
def __init__(
self, db_path: Optional[str] = None
self, db_path: str | None = None,
) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
@@ -24,10 +22,8 @@ class LTMSQLiteStorage:
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._initialize_db()
def _initialize_db(self):
"""
Initializes the SQLite database and creates LTM table
"""
def _initialize_db(self) -> None:
"""Initializes the SQLite database and creates LTM table."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
@@ -40,7 +36,7 @@ class LTMSQLiteStorage:
datetime TEXT,
score REAL
)
"""
""",
)
conn.commit()
@@ -53,9 +49,9 @@ class LTMSQLiteStorage:
def save(
self,
task_description: str,
metadata: Dict[str, Any],
metadata: dict[str, Any],
datetime: str,
score: Union[int, float],
score: float,
) -> None:
"""Saves data to the LTM table with error handling."""
try:
@@ -76,8 +72,8 @@ class LTMSQLiteStorage:
)
def load(
self, task_description: str, latest_n: int
) -> Optional[List[Dict[str, Any]]]:
self, task_description: str, latest_n: int,
) -> list[dict[str, Any]] | None:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -125,4 +121,3 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return None

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List
from typing import Any
from mem0 import Memory, MemoryClient
@@ -7,17 +7,15 @@ from crewai.memory.storage.interface import Storage
class Mem0Storage(Storage):
"""
Extends Storage to handle embedding and searching across entities using Mem0.
"""
"""Extends Storage to handle embedding and searching across entities using Mem0."""
def __init__(self, type, crew=None, config=None):
def __init__(self, type, crew=None, config=None) -> None:
super().__init__()
supported_types = ["user", "short_term", "long_term", "entities", "external"]
if type not in supported_types:
raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: "
+ ", ".join(supported_types)
+ ", ".join(supported_types),
)
self.memory_type = type
@@ -29,7 +27,8 @@ class Mem0Storage(Storage):
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
user_id = self._get_user_id()
if type == "user" and not user_id:
raise ValueError("User ID is required for user memory type")
msg = "User ID is required for user memory type"
raise ValueError(msg)
# API key in memory config overrides the environment variable
config = self._get_config()
@@ -42,23 +41,20 @@ class Mem0Storage(Storage):
if mem0_api_key:
if mem0_org_id and mem0_project_id:
self.memory = MemoryClient(
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id,
)
else:
self.memory = MemoryClient(api_key=mem0_api_key)
elif mem0_local_config and len(mem0_local_config):
self.memory = Memory.from_config(mem0_local_config)
else:
if mem0_local_config and len(mem0_local_config):
self.memory = Memory.from_config(mem0_local_config)
else:
self.memory = Memory()
self.memory = Memory()
def _sanitize_role(self, role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
"""Sanitizes agent roles to ensure valid directory names."""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
user_id = self._get_user_id()
agent_name = self._get_agent_name()
params = None
@@ -97,7 +93,7 @@ class Mem0Storage(Storage):
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
params = {"query": query, "limit": limit, "output_format": "v1.1"}
if user_id := self._get_user_id():
params["user_id"] = user_id
@@ -120,7 +116,7 @@ class Mem0Storage(Storage):
# automatically when the crew is created.
if isinstance(self.memory, Memory):
del params["metadata"], params["output_format"]
results = self.memory.search(**params)
return [r for r in results["results"] if r["score"] >= score_threshold]
@@ -133,12 +129,11 @@ class Mem0Storage(Storage):
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return agents
return "_".join(agents)
def _get_config(self) -> Dict[str, Any]:
def _get_config(self) -> dict[str, Any]:
return self.config or getattr(self, "memory_config", {}).get("config", {}) or {}
def reset(self):
def reset(self) -> None:
if self.memory:
self.memory.reset()

View File

@@ -4,7 +4,7 @@ import logging
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional
from typing import Any
from chromadb.api import ClientAPI
@@ -32,16 +32,15 @@ def suppress_logging(
class RAGStorage(BaseRAGStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
"""Extends Storage to handle embeddings for memory entries, improving
search efficiency.
"""
app: ClientAPI | None = None
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
):
self, type, allow_reset=True, embedder_config=None, crew=None, path=None,
) -> None:
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
@@ -55,11 +54,11 @@ class RAGStorage(BaseRAGStorage):
self.path = path
self._initialize_app()
def _set_embedder_config(self):
def _set_embedder_config(self) -> None:
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
def _initialize_app(self) -> None:
import chromadb
from chromadb.config import Settings
@@ -73,48 +72,44 @@ class RAGStorage(BaseRAGStorage):
try:
self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config
name=self.type, embedding_function=self.embedder_config,
)
except Exception:
self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config
name=self.type, embedding_function=self.embedder_config,
)
def _sanitize_role(self, role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
"""Sanitizes agent roles to ensure valid directory names."""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _build_storage_file_name(self, type: str, file_name: str) -> str:
"""
Ensures file name does not exceed max allowed by OS
"""
"""Ensures file name does not exceed max allowed by OS."""
base_path = f"{db_storage_path()}/{type}"
if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning(
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters.",
)
file_name = file_name[:MAX_FILE_NAME_LENGTH]
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
try:
self._generate_embedding(value, metadata)
except Exception as e:
logging.error(f"Error during {self.type} save: {str(e)}")
logging.exception(f"Error during {self.type} save: {e!s}")
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
filter: dict | None = None,
score_threshold: float = 0.35,
) -> List[Any]:
) -> list[Any]:
if not hasattr(self, "app"):
self._initialize_app()
@@ -135,10 +130,10 @@ class RAGStorage(BaseRAGStorage):
return results
except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}")
logging.exception(f"Error during {self.type} search: {e!s}")
return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
@@ -160,8 +155,9 @@ class RAGStorage(BaseRAGStorage):
# Ignore this specific error
pass
else:
msg = f"An error occurred while resetting the {self.type} memory: {e}"
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
msg,
)
def _create_default_embedding_function(self):
@@ -170,5 +166,5 @@ class RAGStorage(BaseRAGStorage):
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
)

View File

@@ -1,18 +1,17 @@
import warnings
from typing import Any, Dict, Optional
from typing import Any
from crewai.memory.memory import Memory
class UserMemory(Memory):
"""
UserMemory class for handling user memory storage and retrieval.
"""UserMemory class for handling user memory storage and retrieval.
Inherits from the Memory class and utilizes an instance of a class that
adheres to the Storage for data storage, specifically working with
MemoryItem instances.
"""
def __init__(self, crew=None):
def __init__(self, crew=None) -> None:
warnings.warn(
"UserMemory is deprecated and will be removed in a future version. "
"Please use ExternalMemory instead.",
@@ -22,8 +21,9 @@ class UserMemory(Memory):
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
msg,
)
storage = Mem0Storage(type="user", crew=crew)
super().__init__(storage)
@@ -31,8 +31,8 @@ class UserMemory(Memory):
def save(
self,
value,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
) -> None:
# TODO: Change this function since we want to take care of the case where we save memories for the usr
data = f"Remember the details about the user: {value}"
@@ -44,15 +44,15 @@ class UserMemory(Memory):
limit: int = 3,
score_threshold: float = 0.35,
):
results = self.storage.search(
return self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
)
return results
def reset(self) -> None:
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the user memory: {e}")
msg = f"An error occurred while resetting the user memory: {e}"
raise Exception(msg)

View File

@@ -1,8 +1,8 @@
from typing import Any, Dict, Optional
from typing import Any
class UserMemoryItem:
def __init__(self, data: Any, user: str, metadata: Optional[Dict[str, Any]] = None):
def __init__(self, data: Any, user: str, metadata: dict[str, Any] | None = None) -> None:
self.data = data
self.user = user
self.metadata = metadata if metadata is not None else {}

View File

@@ -2,9 +2,7 @@ from enum import Enum
class Process(str, Enum):
"""
Class representing the different processes that can be used to tackle tasks
"""
"""Class representing the different processes that can be used to tackle tasks."""
sequential = "sequential"
hierarchical = "hierarchical"

View File

@@ -1,5 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import Callable
from crewai import Crew
from crewai.project.utils import memoize
@@ -36,15 +36,13 @@ def task(func):
def agent(func):
"""Marks a method as a crew agent."""
func.is_agent = True
func = memoize(func)
return func
return memoize(func)
def llm(func):
"""Marks a method as an LLM provider."""
func.is_llm = True
func = memoize(func)
return func
return memoize(func)
def output_json(cls):
@@ -91,7 +89,7 @@ def crew(func) -> Callable[..., Crew]:
agents = self._original_agents.items()
# Instantiate tasks in order
for task_name, task_method in tasks:
for _task_name, task_method in tasks:
task_instance = task_method(self)
instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None)
@@ -100,7 +98,7 @@ def crew(func) -> Callable[..., Crew]:
agent_roles.add(agent_instance.role)
# Instantiate agents not included by tasks
for agent_name, agent_method in agents:
for _agent_name, agent_method in agents:
agent_instance = agent_method(self)
if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
@@ -117,9 +115,9 @@ def crew(func) -> Callable[..., Crew]:
return wrapper
for _, callback in self._before_kickoff.items():
for callback in self._before_kickoff.values():
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
for _, callback in self._after_kickoff.items():
for callback in self._after_kickoff.values():
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
return crew

View File

@@ -1,7 +1,8 @@
import inspect
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Dict, TypeVar, cast
from typing import Any, TypeVar, cast
import yaml
from dotenv import load_dotenv
@@ -23,11 +24,11 @@ def CrewBase(cls: T) -> T:
base_directory = Path(inspect.getfile(cls)).parent
original_agents_config_path = getattr(
cls, "agents_config", "config/agents.yaml"
cls, "agents_config", "config/agents.yaml",
)
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.load_configurations()
self.map_all_agent_variables()
@@ -49,22 +50,22 @@ def CrewBase(cls: T) -> T:
}
# Store specific function types
self._original_tasks = self._filter_functions(
self._original_functions, "is_task"
self._original_functions, "is_task",
)
self._original_agents = self._filter_functions(
self._original_functions, "is_agent"
self._original_functions, "is_agent",
)
self._before_kickoff = self._filter_functions(
self._original_functions, "is_before_kickoff"
self._original_functions, "is_before_kickoff",
)
self._after_kickoff = self._filter_functions(
self._original_functions, "is_after_kickoff"
self._original_functions, "is_after_kickoff",
)
self._kickoff = self._filter_functions(
self._original_functions, "is_kickoff"
self._original_functions, "is_kickoff",
)
def load_configurations(self):
def load_configurations(self) -> None:
"""Load agent and task configurations from YAML files."""
if isinstance(self.original_agents_config_path, str):
agents_config_path = (
@@ -75,12 +76,12 @@ def CrewBase(cls: T) -> T:
except FileNotFoundError:
logging.warning(
f"Agent config file not found at {agents_config_path}. "
"Proceeding with empty agent configurations."
"Proceeding with empty agent configurations.",
)
self.agents_config = {}
else:
logging.warning(
"No agent configuration path provided. Proceeding with empty agent configurations."
"No agent configuration path provided. Proceeding with empty agent configurations.",
)
self.agents_config = {}
@@ -93,22 +94,21 @@ def CrewBase(cls: T) -> T:
except FileNotFoundError:
logging.warning(
f"Task config file not found at {tasks_config_path}. "
"Proceeding with empty task configurations."
"Proceeding with empty task configurations.",
)
self.tasks_config = {}
else:
logging.warning(
"No task configuration path provided. Proceeding with empty task configurations."
"No task configuration path provided. Proceeding with empty task configurations.",
)
self.tasks_config = {}
@staticmethod
def load_yaml(config_path: Path):
try:
with open(config_path, "r", encoding="utf-8") as file:
with open(config_path, encoding="utf-8") as file:
return yaml.safe_load(file)
except FileNotFoundError:
print(f"File not found: {config_path}")
raise
def _get_all_functions(self):
@@ -119,8 +119,8 @@ def CrewBase(cls: T) -> T:
}
def _filter_functions(
self, functions: Dict[str, Callable], attribute: str
) -> Dict[str, Callable]:
self, functions: dict[str, Callable], attribute: str,
) -> dict[str, Callable]:
return {
name: func
for name, func in functions.items()
@@ -132,7 +132,7 @@ def CrewBase(cls: T) -> T:
llms = self._filter_functions(all_functions, "is_llm")
tool_functions = self._filter_functions(all_functions, "is_tool")
cache_handler_functions = self._filter_functions(
all_functions, "is_cache_handler"
all_functions, "is_cache_handler",
)
callbacks = self._filter_functions(all_functions, "is_callback")
@@ -149,11 +149,11 @@ def CrewBase(cls: T) -> T:
def _map_agent_variables(
self,
agent_name: str,
agent_info: Dict[str, Any],
llms: Dict[str, Callable],
tool_functions: Dict[str, Callable],
cache_handler_functions: Dict[str, Callable],
callbacks: Dict[str, Callable],
agent_info: dict[str, Any],
llms: dict[str, Callable],
tool_functions: dict[str, Callable],
cache_handler_functions: dict[str, Callable],
callbacks: dict[str, Callable],
) -> None:
if llm := agent_info.get("llm"):
try:
@@ -187,12 +187,12 @@ def CrewBase(cls: T) -> T:
agents = self._filter_functions(all_functions, "is_agent")
tasks = self._filter_functions(all_functions, "is_task")
output_json_functions = self._filter_functions(
all_functions, "is_output_json"
all_functions, "is_output_json",
)
tool_functions = self._filter_functions(all_functions, "is_tool")
callback_functions = self._filter_functions(all_functions, "is_callback")
output_pydantic_functions = self._filter_functions(
all_functions, "is_output_pydantic"
all_functions, "is_output_pydantic",
)
for task_name, task_info in self.tasks_config.items():
@@ -210,13 +210,13 @@ def CrewBase(cls: T) -> T:
def _map_task_variables(
self,
task_name: str,
task_info: Dict[str, Any],
agents: Dict[str, Callable],
tasks: Dict[str, Callable],
output_json_functions: Dict[str, Callable],
tool_functions: Dict[str, Callable],
callback_functions: Dict[str, Callable],
output_pydantic_functions: Dict[str, Callable],
task_info: dict[str, Any],
agents: dict[str, Callable],
tasks: dict[str, Callable],
output_json_functions: dict[str, Callable],
tool_functions: dict[str, Callable],
callback_functions: dict[str, Callable],
output_pydantic_functions: dict[str, Callable],
) -> None:
if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [
@@ -253,4 +253,4 @@ def CrewBase(cls: T) -> T:
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
return cast(T, WrappedClass)
return cast("T", WrappedClass)

View File

@@ -1,5 +1,4 @@
"""
Fingerprint Module
"""Fingerprint Module.
This module provides functionality for generating and validating unique identifiers
for CrewAI agents. These identifiers are used for tracking, auditing, and security.
@@ -7,14 +6,13 @@ for CrewAI agents. These identifiers are used for tracking, auditing, and securi
import uuid
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator
class Fingerprint(BaseModel):
"""
A class for generating and managing unique identifiers for agents.
"""A class for generating and managing unique identifiers for agents.
Each agent has dual identifiers:
- Human-readable ID: For debugging and reference (derived from role if not specified)
@@ -24,48 +22,54 @@ class Fingerprint(BaseModel):
uuid_str (str): String representation of the UUID for this fingerprint, auto-generated
created_at (datetime): When this fingerprint was created, auto-generated
metadata (Dict[str, Any]): Additional metadata associated with this fingerprint
"""
uuid_str: str = Field(default_factory=lambda: str(uuid.uuid4()), description="String representation of the UUID")
created_at: datetime = Field(default_factory=datetime.now, description="When this fingerprint was created")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this fingerprint")
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this fingerprint")
model_config = ConfigDict(arbitrary_types_allowed=True)
@field_validator('metadata')
@field_validator("metadata")
@classmethod
def validate_metadata(cls, v):
"""Validate that metadata is a dictionary with string keys and valid values."""
if not isinstance(v, dict):
raise ValueError("Metadata must be a dictionary")
msg = "Metadata must be a dictionary"
raise ValueError(msg)
# Validate that all keys are strings
for key, value in v.items():
if not isinstance(key, str):
raise ValueError(f"Metadata keys must be strings, got {type(key)}")
msg = f"Metadata keys must be strings, got {type(key)}"
raise ValueError(msg)
# Validate nested dictionaries (prevent deeply nested structures)
if isinstance(value, dict):
# Check for nested dictionaries (limit depth to 1)
for nested_key, nested_value in value.items():
if not isinstance(nested_key, str):
raise ValueError(f"Nested metadata keys must be strings, got {type(nested_key)}")
msg = f"Nested metadata keys must be strings, got {type(nested_key)}"
raise ValueError(msg)
if isinstance(nested_value, dict):
raise ValueError("Metadata can only be nested one level deep")
msg = "Metadata can only be nested one level deep"
raise ValueError(msg)
# Check for maximum metadata size (prevent DoS)
if len(str(v)) > 10000: # Limit metadata size to 10KB
raise ValueError("Metadata size exceeds maximum allowed (10KB)")
msg = "Metadata size exceeds maximum allowed (10KB)"
raise ValueError(msg)
return v
def __init__(self, **data):
def __init__(self, **data) -> None:
"""Initialize a Fingerprint with auto-generated uuid_str and created_at."""
# Remove uuid_str and created_at from data to ensure they're auto-generated
if 'uuid_str' in data:
data.pop('uuid_str')
if 'created_at' in data:
data.pop('created_at')
if "uuid_str" in data:
data.pop("uuid_str")
if "created_at" in data:
data.pop("created_at")
# Call the parent constructor with the modified data
super().__init__(**data)
@@ -77,32 +81,33 @@ class Fingerprint(BaseModel):
@classmethod
def _generate_uuid(cls, seed: str) -> str:
"""
Generate a deterministic UUID based on a seed string.
"""Generate a deterministic UUID based on a seed string.
Args:
seed (str): The seed string to use for UUID generation
Returns:
str: A string representation of the UUID consistently generated from the seed
"""
if not isinstance(seed, str):
raise ValueError("Seed must be a string")
msg = "Seed must be a string"
raise ValueError(msg)
if not seed.strip():
raise ValueError("Seed cannot be empty or whitespace")
msg = "Seed cannot be empty or whitespace"
raise ValueError(msg)
# Create a deterministic UUID using v5 (SHA-1)
# Custom namespace for CrewAI to enhance security
# Using a unique namespace specific to CrewAI to reduce collision risks
CREW_AI_NAMESPACE = uuid.UUID('f47ac10b-58cc-4372-a567-0e02b2c3d479')
CREW_AI_NAMESPACE = uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
return str(uuid.uuid5(CREW_AI_NAMESPACE, seed))
@classmethod
def generate(cls, seed: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> 'Fingerprint':
"""
Static factory method to create a new Fingerprint.
def generate(cls, seed: str | None = None, metadata: dict[str, Any] | None = None) -> "Fingerprint":
"""Static factory method to create a new Fingerprint.
Args:
seed (Optional[str]): A string to use as seed for the UUID generation.
@@ -111,11 +116,12 @@ class Fingerprint(BaseModel):
Returns:
Fingerprint: A new Fingerprint instance
"""
fingerprint = cls(metadata=metadata or {})
if seed:
# For seed-based generation, we need to manually set the uuid_str after creation
object.__setattr__(fingerprint, 'uuid_str', cls._generate_uuid(seed))
object.__setattr__(fingerprint, "uuid_str", cls._generate_uuid(seed))
return fingerprint
def __str__(self) -> str:
@@ -132,29 +138,29 @@ class Fingerprint(BaseModel):
"""Hash of the fingerprint (based on UUID)."""
return hash(self.uuid_str)
def to_dict(self) -> Dict[str, Any]:
"""
Convert the fingerprint to a dictionary representation.
def to_dict(self) -> dict[str, Any]:
"""Convert the fingerprint to a dictionary representation.
Returns:
Dict[str, Any]: Dictionary representation of the fingerprint
"""
return {
"uuid_str": self.uuid_str,
"created_at": self.created_at.isoformat(),
"metadata": self.metadata
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Fingerprint':
"""
Create a Fingerprint from a dictionary representation.
def from_dict(cls, data: dict[str, Any]) -> "Fingerprint":
"""Create a Fingerprint from a dictionary representation.
Args:
data (Dict[str, Any]): Dictionary representation of a fingerprint
Returns:
Fingerprint: A new Fingerprint instance
"""
if not data:
return cls()
@@ -163,8 +169,8 @@ class Fingerprint(BaseModel):
# For consistency with existing stored fingerprints, we need to manually set these
if "uuid_str" in data:
object.__setattr__(fingerprint, 'uuid_str', data["uuid_str"])
object.__setattr__(fingerprint, "uuid_str", data["uuid_str"])
if "created_at" in data and isinstance(data["created_at"], str):
object.__setattr__(fingerprint, 'created_at', datetime.fromisoformat(data["created_at"]))
object.__setattr__(fingerprint, "created_at", datetime.fromisoformat(data["created_at"]))
return fingerprint

Some files were not shown because too many files have changed in this diff Show More