mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-30 19:28:29 +00:00
Compare commits
7 Commits
devin/1748
...
devin/1747
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90980e8190 | ||
|
|
f6cdfc1099 | ||
|
|
39c4ed33bb | ||
|
|
1860026d61 | ||
|
|
46621113af | ||
|
|
ad1ea46bbb | ||
|
|
807dfe0558 |
47
.ruff.toml
47
.ruff.toml
@@ -2,3 +2,50 @@ exclude = [
|
||||
"templates",
|
||||
"__init__.py",
|
||||
]
|
||||
|
||||
[lint]
|
||||
select = ["ALL"]
|
||||
ignore = [
|
||||
"D100", # Missing docstring in public module
|
||||
"D101", # Missing docstring in public class
|
||||
"D102", # Missing docstring in public method
|
||||
"D103", # Missing docstring in public function
|
||||
"D104", # Missing docstring in public package
|
||||
"D105", # Missing docstring in magic method
|
||||
"D106", # Missing docstring in public nested class
|
||||
"D107", # Missing docstring in __init__
|
||||
"D205", # 1 blank line required between summary line and description
|
||||
"ANN001", # Missing type annotation for function argument
|
||||
"ANN002", # Missing type annotation for *args
|
||||
"ANN003", # Missing type annotation for **kwargs
|
||||
"ANN201", # Missing return type annotation for public function
|
||||
"ANN202", # Missing return type annotation for private function
|
||||
"ANN204", # Missing return type annotation for special method
|
||||
"ANN205", # Missing return type annotation for staticmethod
|
||||
"ANN206", # Missing return type annotation for classmethod
|
||||
"E501", # Line too long
|
||||
"PT011", # pytest.raises() without match parameter
|
||||
"PT012", # pytest.raises() block should contain a single simple statement
|
||||
"SIM117", # Use a single `with` statement with multiple contexts
|
||||
"PLR2004", # Magic value used in comparison
|
||||
"B017", # Do not assert blind exception
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"tests/*" = [
|
||||
"S101", # Allow assert in tests
|
||||
"SLF001", # Allow private member access in tests
|
||||
"DTZ001", # Allow datetime without tzinfo in tests
|
||||
"PTH107", # Allow os.remove instead of Path.unlink in tests
|
||||
"PTH118", # Allow os.path.join() in tests
|
||||
"PTH120", # Allow os.path.dirname() in tests
|
||||
"PTH123", # Allow open() instead of Path.open() in tests
|
||||
"PTH202", # Allow os.path.getsize in tests
|
||||
"PT012", # Allow multiple statements in pytest.raises() block in tests
|
||||
"SIM117", # Allow nested with statements in tests
|
||||
"PLR2004", # Allow magic values in tests
|
||||
"B017", # Allow asserting blind exceptions in tests
|
||||
]
|
||||
|
||||
[lint.isort]
|
||||
known-first-party = ["crewai"]
|
||||
|
||||
@@ -94,8 +94,10 @@ crewai = "crewai.cli.cli:crewai"
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
disable_error_code = 'import-untyped'
|
||||
disable_error_code = 'import-untyped,union-attr'
|
||||
exclude = ["cli/templates"]
|
||||
implicit_optional = true
|
||||
strict_optional = false
|
||||
|
||||
[tool.bandit]
|
||||
exclude_dirs = ["src/crewai/cli/templates"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Type, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -67,40 +68,41 @@ class Agent(BaseAgent):
|
||||
step_callback: Callback to be executed after each step of the agent execution.
|
||||
knowledge_sources: Knowledge sources for the agent.
|
||||
embedder: Embedder configuration for the agent.
|
||||
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
max_execution_time: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
)
|
||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
step_callback: Optional[Any] = Field(
|
||||
step_callback: Any | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
use_system_prompt: Optional[bool] = Field(
|
||||
use_system_prompt: bool | None = Field(
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||
description="Language model that will run the agent.", default=None,
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
description="Language model that will run the agent.", default=None,
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
default=None, description="System format for the agent."
|
||||
system_template: str | None = Field(
|
||||
default=None, description="System format for the agent.",
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
default=None, description="Prompt format for the agent."
|
||||
prompt_template: str | None = Field(
|
||||
default=None, description="Prompt format for the agent.",
|
||||
)
|
||||
response_template: Optional[str] = Field(
|
||||
default=None, description="Response format for the agent."
|
||||
response_template: str | None = Field(
|
||||
default=None, description="Response format for the agent.",
|
||||
)
|
||||
allow_code_execution: Optional[bool] = Field(
|
||||
default=False, description="Enable code execution for the agent."
|
||||
allow_code_execution: bool | None = Field(
|
||||
default=False, description="Enable code execution for the agent.",
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
default=True,
|
||||
@@ -118,19 +120,19 @@ class Agent(BaseAgent):
|
||||
default="safe",
|
||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||
)
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
embedder: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
agent_knowledge_context: Optional[str] = Field(
|
||||
agent_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the agent.",
|
||||
)
|
||||
crew_knowledge_context: Optional[str] = Field(
|
||||
crew_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the crew.",
|
||||
)
|
||||
knowledge_search_query: Optional[str] = Field(
|
||||
knowledge_search_query: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||
)
|
||||
@@ -141,7 +143,7 @@ class Agent(BaseAgent):
|
||||
|
||||
self.llm = create_llm(self.llm)
|
||||
if self.function_calling_llm and not isinstance(
|
||||
self.function_calling_llm, BaseLLM
|
||||
self.function_calling_llm, BaseLLM,
|
||||
):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
|
||||
@@ -153,12 +155,12 @@ class Agent(BaseAgent):
|
||||
|
||||
return self
|
||||
|
||||
def _setup_agent_executor(self):
|
||||
def _setup_agent_executor(self) -> None:
|
||||
if not self.cache_handler:
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
@@ -174,7 +176,8 @@ class Agent(BaseAgent):
|
||||
storage=self.knowledge_storage or None,
|
||||
)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
||||
msg = f"Invalid Knowledge Configuration: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def _is_any_available_memory(self) -> bool:
|
||||
"""Check if any memory is available."""
|
||||
@@ -196,8 +199,8 @@ class Agent(BaseAgent):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
@@ -213,6 +216,7 @@ class Agent(BaseAgent):
|
||||
TimeoutError: If execution exceeds the maximum execution time.
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
|
||||
"""
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
|
||||
@@ -228,18 +232,18 @@ class Agent(BaseAgent):
|
||||
# schema = json.dumps(task.output_json, indent=2)
|
||||
schema = generate_model_description(task.output_json)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
"formatted_task_instructions",
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema = generate_model_description(task.output_pydantic)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
"formatted_task_instructions",
|
||||
).format(output_format=schema)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
task=task_prompt, context=context,
|
||||
)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
@@ -267,25 +271,25 @@ class Agent(BaseAgent):
|
||||
)
|
||||
try:
|
||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
||||
task_prompt
|
||||
task_prompt,
|
||||
)
|
||||
if self.knowledge_search_query:
|
||||
agent_knowledge_snippets = self.knowledge.query(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
[self.knowledge_search_query], **knowledge_config,
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
self.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
agent_knowledge_snippets,
|
||||
)
|
||||
if self.agent_knowledge_context:
|
||||
task_prompt += self.agent_knowledge_context
|
||||
if self.crew:
|
||||
knowledge_snippets = self.crew.query_knowledge(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
[self.knowledge_search_query], **knowledge_config,
|
||||
)
|
||||
if knowledge_snippets:
|
||||
self.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
knowledge_snippets,
|
||||
)
|
||||
if self.crew_knowledge_context:
|
||||
task_prompt += self.crew_knowledge_context
|
||||
@@ -342,11 +346,12 @@ class Agent(BaseAgent):
|
||||
not isinstance(self.max_execution_time, int)
|
||||
or self.max_execution_time <= 0
|
||||
):
|
||||
msg = "Max Execution time must be a positive integer greater than zero"
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
msg,
|
||||
)
|
||||
result = self._execute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
task_prompt, task, self.max_execution_time,
|
||||
)
|
||||
else:
|
||||
result = self._execute_without_timeout(task_prompt, task)
|
||||
@@ -361,7 +366,7 @@ class Agent(BaseAgent):
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
@@ -373,7 +378,7 @@ class Agent(BaseAgent):
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
@@ -384,7 +389,7 @@ class Agent(BaseAgent):
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
result = self.execute_task(task, context, tools)
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
@@ -416,24 +421,27 @@ class Agent(BaseAgent):
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the timeout.
|
||||
RuntimeError: If execution fails for other reasons.
|
||||
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
self._execute_without_timeout, task_prompt=task_prompt, task=task
|
||||
self._execute_without_timeout, task_prompt=task_prompt, task=task,
|
||||
)
|
||||
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
future.cancel()
|
||||
msg = f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
||||
msg,
|
||||
)
|
||||
except Exception as e:
|
||||
future.cancel()
|
||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
||||
msg = f"Task execution failed: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||
"""Execute a task without a timeout.
|
||||
@@ -444,6 +452,7 @@ class Agent(BaseAgent):
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
|
||||
"""
|
||||
return self.agent_executor.invoke(
|
||||
{
|
||||
@@ -451,18 +460,19 @@ class Agent(BaseAgent):
|
||||
"tool_names": self.agent_executor.tools_names,
|
||||
"tools": self.agent_executor.tools_description,
|
||||
"ask_for_human_input": task.human_input,
|
||||
}
|
||||
},
|
||||
)["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
self, tools: list[BaseTool] | None = None, task=None,
|
||||
) -> None:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the CrewAgentExecutor class.
|
||||
|
||||
"""
|
||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
prompt = Prompts(
|
||||
@@ -479,7 +489,7 @@ class Agent(BaseAgent):
|
||||
|
||||
if self.response_template:
|
||||
stop_words.append(
|
||||
self.response_template.split("{{ .Response }}")[1].strip()
|
||||
self.response_template.split("{{ .Response }}")[1].strip(),
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
@@ -504,10 +514,9 @@ class Agent(BaseAgent):
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]):
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
return agent_tools.tools()
|
||||
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
@@ -523,7 +532,7 @@ class Agent(BaseAgent):
|
||||
return [CodeInterpreterTool(unsafe_mode=unsafe_mode)]
|
||||
except ModuleNotFoundError:
|
||||
self._logger.log(
|
||||
"info", "Coding tools not available. Install crewai_tools. "
|
||||
"info", "Coding tools not available. Install crewai_tools. ",
|
||||
)
|
||||
|
||||
def get_output_converter(self, llm, text, model, instructions):
|
||||
@@ -555,7 +564,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
@@ -565,48 +574,48 @@ class Agent(BaseAgent):
|
||||
search: This tool is used for search
|
||||
calculator: This tool is used for math
|
||||
"""
|
||||
description = "\n".join(
|
||||
return "\n".join(
|
||||
[
|
||||
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
|
||||
for tool in tools
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
return description
|
||||
|
||||
def _validate_docker_installation(self) -> None:
|
||||
"""Check if Docker is installed and running."""
|
||||
if not shutil.which("docker"):
|
||||
msg = f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}"
|
||||
raise RuntimeError(
|
||||
f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}"
|
||||
msg,
|
||||
)
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "info"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
capture_output=True,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
msg = f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
raise RuntimeError(
|
||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
msg,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
|
||||
@property
|
||||
def fingerprint(self) -> Fingerprint:
|
||||
"""
|
||||
Get the agent's fingerprint.
|
||||
"""Get the agent's fingerprint.
|
||||
|
||||
Returns:
|
||||
Fingerprint: The agent's fingerprint
|
||||
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def set_fingerprint(self, fingerprint: Fingerprint):
|
||||
def set_fingerprint(self, fingerprint: Fingerprint) -> None:
|
||||
self.security_config.fingerprint = fingerprint
|
||||
|
||||
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
|
||||
@@ -619,7 +628,7 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
query = self.i18n.slice("knowledge_search_query").format(
|
||||
task_prompt=task_prompt
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
|
||||
if not isinstance(self.llm, BaseLLM):
|
||||
@@ -644,7 +653,7 @@ class Agent(BaseAgent):
|
||||
"content": rewriter_prompt,
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
],
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -666,11 +675,10 @@ class Agent(BaseAgent):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
"""Execute the agent with the given messages using a LiteAgent instance.
|
||||
|
||||
This method is useful when you want to use the Agent configuration but
|
||||
with the simpler and more direct execution flow of LiteAgent.
|
||||
@@ -683,6 +691,7 @@ class Agent(BaseAgent):
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
|
||||
"""
|
||||
lite_agent = LiteAgent(
|
||||
role=self.role,
|
||||
@@ -703,11 +712,10 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
"""Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
|
||||
This is the async version of the kickoff method.
|
||||
|
||||
@@ -719,6 +727,7 @@ class Agent(BaseAgent):
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
|
||||
"""
|
||||
lite_agent = LiteAgent(
|
||||
role=self.role,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
@@ -16,27 +16,27 @@ class BaseAgentAdapter(BaseAgent, ABC):
|
||||
"""
|
||||
|
||||
adapted_structured_output: bool = False
|
||||
_agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
_agent_config: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any):
|
||||
def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(adapted_agent=True, **kwargs)
|
||||
self._agent_config = agent_config
|
||||
|
||||
@abstractmethod
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure and adapt tools for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
tools: Optional list of BaseTool instances to be configured
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def configure_structured_output(self, structured_output: Any) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
structured_output: The structured output to be configured
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -8,7 +8,7 @@ class BaseConverterAdapter(ABC):
|
||||
converter adapters must implement for converting structured output.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
def __init__(self, agent_adapter) -> None:
|
||||
self.agent_adapter = agent_adapter
|
||||
|
||||
@abstractmethod
|
||||
@@ -16,14 +16,11 @@ class BaseConverterAdapter(ABC):
|
||||
"""Configure agents to return structured output.
|
||||
Must support json and pydantic output.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enhance_system_prompt(self, base_prompt: str) -> str:
|
||||
"""Enhance the system prompt with structured output instructions."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""Post-process the result to ensure it matches the expected format: string."""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
@@ -12,23 +12,23 @@ class BaseToolAdapter(ABC):
|
||||
different frameworks and platforms.
|
||||
"""
|
||||
|
||||
original_tools: List[BaseTool]
|
||||
converted_tools: List[Any]
|
||||
original_tools: list[BaseTool]
|
||||
converted_tools: list[Any]
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
|
||||
@abstractmethod
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure and convert tools for the specific implementation.
|
||||
|
||||
Args:
|
||||
tools: List of BaseTool instances to be configured and converted
|
||||
"""
|
||||
pass
|
||||
|
||||
def tools(self) -> List[Any]:
|
||||
"""
|
||||
|
||||
def tools(self) -> list[Any]:
|
||||
"""Return all converted tools."""
|
||||
return self.converted_tools
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, AsyncIterable, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -52,16 +52,17 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
role: str,
|
||||
goal: str,
|
||||
backstory: str,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
llm: Any = None,
|
||||
max_iterations: int = 10,
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
agent_config: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the LangGraph agent adapter."""
|
||||
if not LANGGRAPH_AVAILABLE:
|
||||
msg = "LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
|
||||
raise ImportError(
|
||||
"LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
|
||||
msg,
|
||||
)
|
||||
super().__init__(
|
||||
role=role,
|
||||
@@ -82,7 +83,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
try:
|
||||
self._memory = MemorySaver()
|
||||
|
||||
converted_tools: List[Any] = self._tool_adapter.tools()
|
||||
converted_tools: list[Any] = self._tool_adapter.tools()
|
||||
if self._agent_config:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
@@ -101,18 +102,18 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
except ImportError as e:
|
||||
self._logger.log(
|
||||
"error", f"Failed to import LangGraph dependencies: {str(e)}"
|
||||
"error", f"Failed to import LangGraph dependencies: {e!s}",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error setting up LangGraph agent: {str(e)}")
|
||||
self._logger.log("error", f"Error setting up LangGraph agent: {e!s}")
|
||||
raise
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build a system prompt for the LangGraph agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -124,8 +125,8 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task using the LangGraph workflow."""
|
||||
self.create_agent_executor(tools)
|
||||
@@ -137,7 +138,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
task=task_prompt, context=context,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -159,7 +160,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
"messages": [
|
||||
("system", self._build_system_prompt()),
|
||||
("user", task_prompt),
|
||||
]
|
||||
],
|
||||
},
|
||||
config,
|
||||
)
|
||||
@@ -180,14 +181,14 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(
|
||||
agent=self, task=task, output=final_answer
|
||||
agent=self, task=task, output=final_answer,
|
||||
),
|
||||
)
|
||||
|
||||
return final_answer
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error executing LangGraph task: {str(e)}")
|
||||
self._logger.log("error", f"Error executing LangGraph task: {e!s}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -198,11 +199,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure the LangGraph agent for execution."""
|
||||
self.configure_tools(tools)
|
||||
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure tools for the LangGraph agent."""
|
||||
if tools:
|
||||
all_tools = list(self.tools or []) + list(tools or [])
|
||||
@@ -210,13 +211,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
available_tools = self._tool_adapter.tools()
|
||||
self._graph.tools = available_tools
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph."""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
|
||||
def get_output_converter(
|
||||
self, llm: Any, text: str, model: Any, instructions: str
|
||||
self, llm: Any, text: str, model: Any, instructions: str,
|
||||
) -> Any:
|
||||
"""Convert output format if needed."""
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
@@ -1,29 +1,25 @@
|
||||
import inspect
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class LangGraphToolAdapter(BaseToolAdapter):
|
||||
"""Adapts CrewAI tools to LangGraph agent tool compatible format"""
|
||||
"""Adapts CrewAI tools to LangGraph agent tool compatible format."""
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
"""
|
||||
Configure and convert CrewAI tools to LangGraph-compatible format.
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure and convert CrewAI tools to LangGraph-compatible format.
|
||||
LangGraph expects tools in langchain_core.tools format.
|
||||
"""
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
|
||||
converted_tools = []
|
||||
if self.original_tools:
|
||||
all_tools = tools + self.original_tools
|
||||
else:
|
||||
all_tools = tools
|
||||
all_tools = tools + self.original_tools if self.original_tools else tools
|
||||
for tool in all_tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
converted_tools.append(tool)
|
||||
@@ -57,5 +53,5 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
|
||||
self.converted_tools = converted_tools
|
||||
|
||||
def tools(self) -> List[Any]:
|
||||
def tools(self) -> list[Any]:
|
||||
return self.converted_tools or []
|
||||
|
||||
@@ -5,10 +5,10 @@ from crewai.utilities.converter import generate_model_description
|
||||
|
||||
|
||||
class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
"""Adapter for handling structured output conversion in LangGraph agents"""
|
||||
"""Adapter for handling structured output conversion in LangGraph agents."""
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
def __init__(self, agent_adapter) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter."""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
@@ -32,7 +32,7 @@ class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
self._system_prompt_appendix = self._generate_system_prompt_appendix()
|
||||
|
||||
def _generate_system_prompt_appendix(self) -> str:
|
||||
"""Generate an appendix for the system prompt to enforce structured output"""
|
||||
"""Generate an appendix for the system prompt to enforce structured output."""
|
||||
if not self._output_format or not self._schema:
|
||||
return ""
|
||||
|
||||
@@ -41,19 +41,19 @@ Important: Your final answer MUST be provided in the following structured format
|
||||
|
||||
{self._schema}
|
||||
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
The output should be raw JSON that exactly matches the specified schema.
|
||||
"""
|
||||
|
||||
def enhance_system_prompt(self, original_prompt: str) -> str:
|
||||
"""Add structured output instructions to the system prompt if needed"""
|
||||
"""Add structured output instructions to the system prompt if needed."""
|
||||
if not self._system_prompt_appendix:
|
||||
return original_prompt
|
||||
|
||||
return f"{original_prompt}\n{self._system_prompt_appendix}"
|
||||
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""Post-process the result to ensure it matches the expected format"""
|
||||
"""Post-process the result to ensure it matches the expected format."""
|
||||
if not self._output_format:
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -29,13 +29,13 @@ except ImportError:
|
||||
|
||||
|
||||
class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
"""Adapter for OpenAI Assistants"""
|
||||
"""Adapter for OpenAI Assistants."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
_openai_agent: "OpenAIAgent" = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
|
||||
_active_thread: Optional[str] = PrivateAttr(default=None)
|
||||
_active_thread: str | None = PrivateAttr(default=None)
|
||||
function_calling_llm: Any = Field(default=None)
|
||||
step_callback: Any = Field(default=None)
|
||||
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
|
||||
@@ -44,35 +44,35 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o-mini",
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
agent_config: Optional[dict] = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
agent_config: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
if not OPENAI_AVAILABLE:
|
||||
msg = "OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
|
||||
raise ImportError(
|
||||
"OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
|
||||
msg,
|
||||
)
|
||||
else:
|
||||
role = kwargs.pop("role", None)
|
||||
goal = kwargs.pop("goal", None)
|
||||
backstory = kwargs.pop("backstory", None)
|
||||
super().__init__(
|
||||
role=role,
|
||||
goal=goal,
|
||||
backstory=backstory,
|
||||
tools=tools,
|
||||
agent_config=agent_config,
|
||||
**kwargs,
|
||||
)
|
||||
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
|
||||
self.llm = model
|
||||
self._converter_adapter = OpenAIConverterAdapter(self)
|
||||
role = kwargs.pop("role", None)
|
||||
goal = kwargs.pop("goal", None)
|
||||
backstory = kwargs.pop("backstory", None)
|
||||
super().__init__(
|
||||
role=role,
|
||||
goal=goal,
|
||||
backstory=backstory,
|
||||
tools=tools,
|
||||
agent_config=agent_config,
|
||||
**kwargs,
|
||||
)
|
||||
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
|
||||
self.llm = model
|
||||
self._converter_adapter = OpenAIConverterAdapter(self)
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build a system prompt for the OpenAI agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -84,10 +84,10 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task using the OpenAI Assistant"""
|
||||
"""Execute a task using the OpenAI Assistant."""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
self.create_agent_executor(tools)
|
||||
|
||||
@@ -98,7 +98,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
task_prompt = task.prompt()
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
task=task_prompt, context=context,
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -114,13 +114,13 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(
|
||||
agent=self, task=task, output=final_answer
|
||||
agent=self, task=task, output=final_answer,
|
||||
),
|
||||
)
|
||||
return final_answer
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error executing OpenAI task: {str(e)}")
|
||||
self._logger.log("error", f"Error executing OpenAI task: {e!s}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -131,9 +131,8 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""
|
||||
Configure the OpenAI agent for execution.
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure the OpenAI agent for execution.
|
||||
While OpenAI handles execution differently through Runner,
|
||||
we can use this method to set up tools and configurations.
|
||||
"""
|
||||
@@ -152,27 +151,27 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
self.agent_executor = Runner
|
||||
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant."""
|
||||
if tools:
|
||||
self._tool_adapter.configure_tools(tools)
|
||||
if self._tool_adapter.converted_tools:
|
||||
self._openai_agent.tools = self._tool_adapter.converted_tools
|
||||
|
||||
def handle_execution_result(self, result: Any) -> str:
|
||||
"""Process OpenAI Assistant execution result converting any structured output to a string"""
|
||||
"""Process OpenAI Assistant execution result converting any structured output to a string."""
|
||||
return self._converter_adapter.post_process_result(result.final_output)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
"""Implement delegation tools support"""
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support."""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
return agent_tools.tools()
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
structured_output: The structured output to be configured
|
||||
|
||||
"""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from agents import FunctionTool, Tool
|
||||
|
||||
@@ -8,42 +8,36 @@ from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
"""Adapter for OpenAI Assistant tools"""
|
||||
"""Adapter for OpenAI Assistant tools."""
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
self.original_tools = tools or []
|
||||
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
if self.original_tools:
|
||||
all_tools = tools + self.original_tools
|
||||
else:
|
||||
all_tools = tools
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure tools for the OpenAI Assistant."""
|
||||
all_tools = tools + self.original_tools if self.original_tools else tools
|
||||
if all_tools:
|
||||
self.converted_tools = self._convert_tools_to_openai_format(all_tools)
|
||||
|
||||
def _convert_tools_to_openai_format(
|
||||
self, tools: Optional[List[BaseTool]]
|
||||
) -> List[Tool]:
|
||||
"""Convert CrewAI tools to OpenAI Assistant tool format"""
|
||||
self, tools: list[BaseTool] | None,
|
||||
) -> list[Tool]:
|
||||
"""Convert CrewAI tools to OpenAI Assistant tool format."""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
def sanitize_tool_name(name: str) -> str:
|
||||
"""Convert tool name to match OpenAI's required pattern"""
|
||||
"""Convert tool name to match OpenAI's required pattern."""
|
||||
import re
|
||||
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
|
||||
return sanitized
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
|
||||
|
||||
def create_tool_wrapper(tool: BaseTool):
|
||||
"""Create a wrapper function that handles the OpenAI function tool interface"""
|
||||
"""Create a wrapper function that handles the OpenAI function tool interface."""
|
||||
|
||||
async def wrapper(context_wrapper: Any, arguments: Any) -> Any:
|
||||
# Get the parameter name from the schema
|
||||
param_name = list(
|
||||
tool.args_schema.model_json_schema()["properties"].keys()
|
||||
)[0]
|
||||
param_name = next(iter(tool.args_schema.model_json_schema()["properties"].keys()))
|
||||
|
||||
# Handle different argument types
|
||||
if isinstance(arguments, dict):
|
||||
|
||||
@@ -7,8 +7,7 @@ from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
"""
|
||||
Adapter for handling structured output conversion in OpenAI agents.
|
||||
"""Adapter for handling structured output conversion in OpenAI agents.
|
||||
|
||||
This adapter enhances the OpenAI agent to handle structured output formats
|
||||
and post-processes the results when needed.
|
||||
@@ -17,21 +16,22 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
_output_format: The expected output format (json, pydantic, or None)
|
||||
_schema: The schema description for the expected output
|
||||
_output_model: The Pydantic model for the output
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
def __init__(self, agent_adapter) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter."""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._output_model = None
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""
|
||||
Configure the structured output for OpenAI agent based on task requirements.
|
||||
"""Configure the structured output for OpenAI agent based on task requirements.
|
||||
|
||||
Args:
|
||||
task: The task containing output format requirements
|
||||
|
||||
"""
|
||||
# Reset configuration
|
||||
self._output_format = None
|
||||
@@ -55,14 +55,14 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
self._output_model = task.output_pydantic
|
||||
|
||||
def enhance_system_prompt(self, base_prompt: str) -> str:
|
||||
"""
|
||||
Enhance the base system prompt with structured output requirements if needed.
|
||||
"""Enhance the base system prompt with structured output requirements if needed.
|
||||
|
||||
Args:
|
||||
base_prompt: The original system prompt
|
||||
|
||||
Returns:
|
||||
Enhanced system prompt with output format instructions if needed
|
||||
|
||||
"""
|
||||
if not self._output_format:
|
||||
return base_prompt
|
||||
@@ -76,8 +76,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
return f"{base_prompt}\n\n{output_schema}"
|
||||
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""
|
||||
Post-process the result to ensure it matches the expected format.
|
||||
"""Post-process the result to ensure it matches the expected format.
|
||||
|
||||
This method attempts to extract valid JSON from the result if necessary.
|
||||
|
||||
@@ -86,6 +85,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
|
||||
Returns:
|
||||
Processed result conforming to the expected output format
|
||||
|
||||
"""
|
||||
if not self._output_format:
|
||||
return result
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -14,6 +15,7 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
@@ -25,7 +27,6 @@ from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
T = TypeVar("T", bound="BaseAgent")
|
||||
@@ -77,30 +78,31 @@ class BaseAgent(ABC, BaseModel):
|
||||
Set the rpm controller for the agent.
|
||||
set_private_attrs() -> "BaseAgent":
|
||||
Set private attributes.
|
||||
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||
_rpm_controller: RPMController | None = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||
_original_role: str | None = PrivateAttr(default=None)
|
||||
_original_goal: str | None = PrivateAttr(default=None)
|
||||
_original_backstory: str | None = PrivateAttr(default=None)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Objective of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True
|
||||
config: dict[str, Any] | None = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True,
|
||||
)
|
||||
cache: bool = Field(
|
||||
default=True, description="Whether the agent should use a cache for tool usage."
|
||||
default=True, description="Whether the agent should use a cache for tool usage.",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Verbose mode for the Agent Execution"
|
||||
default=False, description="Verbose mode for the Agent Execution",
|
||||
)
|
||||
max_rpm: Optional[int] = Field(
|
||||
max_rpm: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of requests per minute for the agent execution to be respected.",
|
||||
)
|
||||
@@ -108,41 +110,41 @@ class BaseAgent(ABC, BaseModel):
|
||||
default=False,
|
||||
description="Enable agent to delegate and ask questions among each other.",
|
||||
)
|
||||
tools: Optional[List[BaseTool]] = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
tools: list[BaseTool] | None = Field(
|
||||
default_factory=list, description="Tools at agents' disposal",
|
||||
)
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
default=25, description="Maximum iterations for an agent to execute a task",
|
||||
)
|
||||
agent_executor: InstanceOf = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
default=None, description="An instance of the CrewAgentExecutor class.",
|
||||
)
|
||||
llm: Any = Field(
|
||||
default=None, description="Language model that will run the agent."
|
||||
default=None, description="Language model that will run the agent.",
|
||||
)
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
cache_handler: InstanceOf[CacheHandler] | None = Field(
|
||||
default=None, description="An instance of the CacheHandler class.",
|
||||
)
|
||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||
default_factory=ToolsHandler,
|
||||
description="An instance of the ToolsHandler class.",
|
||||
)
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent.",
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution."
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution.",
|
||||
)
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
default=None, description="Knowledge for the agent."
|
||||
knowledge: Knowledge | None = Field(
|
||||
default=None, description="Knowledge for the agent.",
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: Optional[Any] = Field(
|
||||
knowledge_storage: Any | None = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
@@ -150,13 +152,13 @@ class BaseAgent(ABC, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the agent, including fingerprinting.",
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
callbacks: list[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent",
|
||||
)
|
||||
adapted_agent: bool = Field(
|
||||
default=False, description="Whether the agent is adapted"
|
||||
default=False, description="Whether the agent is adapted",
|
||||
)
|
||||
knowledge_config: Optional[KnowledgeConfig] = Field(
|
||||
knowledge_config: KnowledgeConfig | None = Field(
|
||||
default=None,
|
||||
description="Knowledge configuration for the agent such as limits and threshold",
|
||||
)
|
||||
@@ -168,7 +170,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@field_validator("tools")
|
||||
@classmethod
|
||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
||||
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
|
||||
"""Validate and process the tools provided to the agent.
|
||||
|
||||
This method ensures that each tool is either an instance of BaseTool
|
||||
@@ -188,11 +190,14 @@ class BaseAgent(ABC, BaseModel):
|
||||
# Tool has the required attributes, create a Tool instance
|
||||
processed_tools.append(Tool.from_langchain(tool))
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Invalid tool type: {type(tool)}. "
|
||||
"Tool must be an instance of BaseTool or "
|
||||
"an object with 'name', 'func', and 'description' attributes."
|
||||
)
|
||||
raise ValueError(
|
||||
msg,
|
||||
)
|
||||
return processed_tools
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -200,15 +205,16 @@ class BaseAgent(ABC, BaseModel):
|
||||
# Validate required fields
|
||||
for field in ["role", "goal", "backstory"]:
|
||||
if getattr(self, field) is None:
|
||||
msg = f"{field} must be provided either directly or through config"
|
||||
raise ValueError(
|
||||
f"{field} must be provided either directly or through config"
|
||||
msg,
|
||||
)
|
||||
|
||||
# Set private attributes
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
max_rpm=self.max_rpm, logger=self._logger,
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
@@ -221,10 +227,11 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
if v:
|
||||
msg = "may_not_set_field"
|
||||
raise PydanticCustomError(
|
||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||
msg, "This field is not to be set by the user.", {},
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -233,7 +240,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
max_rpm=self.max_rpm, logger=self._logger,
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
@@ -252,8 +259,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@@ -262,11 +269,10 @@ class BaseAgent(ABC, BaseModel):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
||||
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
|
||||
"""Set the task tools that init BaseAgenTools class."""
|
||||
pass
|
||||
|
||||
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||
def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||
"""Create a deep copy of the Agent."""
|
||||
exclude = {
|
||||
"id",
|
||||
@@ -309,7 +315,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
copied_agent = type(self)(
|
||||
return type(self)(
|
||||
**copied_data,
|
||||
llm=existing_llm,
|
||||
tools=self.tools,
|
||||
@@ -318,9 +324,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
knowledge_storage=copied_knowledge_storage,
|
||||
)
|
||||
|
||||
return copied_agent
|
||||
|
||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Interpolate inputs into the agent description and backstory."""
|
||||
if self._original_role is None:
|
||||
self._original_role = self.role
|
||||
@@ -331,13 +336,13 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
if inputs:
|
||||
self.role = interpolate_only(
|
||||
input_string=self._original_role, inputs=inputs
|
||||
input_string=self._original_role, inputs=inputs,
|
||||
)
|
||||
self.goal = interpolate_only(
|
||||
input_string=self._original_goal, inputs=inputs
|
||||
input_string=self._original_goal, inputs=inputs,
|
||||
)
|
||||
self.backstory = interpolate_only(
|
||||
input_string=self._original_backstory, inputs=inputs
|
||||
input_string=self._original_backstory, inputs=inputs,
|
||||
)
|
||||
|
||||
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
|
||||
@@ -345,6 +350,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
Args:
|
||||
cache_handler: An instance of the CacheHandler class.
|
||||
|
||||
"""
|
||||
self.tools_handler = ToolsHandler()
|
||||
if self.cache:
|
||||
@@ -357,10 +363,11 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
Args:
|
||||
rpm_controller: An instance of the RPMController class.
|
||||
|
||||
"""
|
||||
if not self._rpm_controller:
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -43,8 +44,7 @@ class CrewAgentExecutorMixin:
|
||||
},
|
||||
agent=self.agent.role,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to short term memory: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _create_external_memory(self, output) -> None:
|
||||
@@ -56,7 +56,7 @@ class CrewAgentExecutorMixin:
|
||||
and hasattr(self.crew, "_external_memory")
|
||||
and self.crew._external_memory
|
||||
):
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
self.crew._external_memory.save(
|
||||
value=output.text,
|
||||
metadata={
|
||||
@@ -64,9 +64,6 @@ class CrewAgentExecutorMixin:
|
||||
},
|
||||
agent=self.agent.role,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to external memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_long_term_memory(self, output) -> None:
|
||||
"""Create and save long-term and entity memory items based on evaluation."""
|
||||
@@ -103,15 +100,13 @@ class CrewAgentExecutorMixin:
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
relationships="\n".join(
|
||||
[f"- {r}" for r in entity.relationships]
|
||||
[f"- {r}" for r in entity.relationships],
|
||||
),
|
||||
)
|
||||
self.crew._entity_memory.save(entity_memory)
|
||||
except AttributeError as e:
|
||||
print(f"Missing attributes for long term memory: {e}")
|
||||
except AttributeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Failed to add to long term memory: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
elif (
|
||||
self.crew
|
||||
@@ -126,7 +121,7 @@ class CrewAgentExecutorMixin:
|
||||
def _ask_human_input(self, final_answer: str) -> str:
|
||||
"""Prompt human input with mode-appropriate messaging."""
|
||||
self._printer.print(
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m",
|
||||
)
|
||||
|
||||
# Training mode prompt (single iteration)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OutputConverter(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for converting task results into structured formats.
|
||||
"""Abstract base class for converting task results into structured formats.
|
||||
|
||||
This class provides a framework for converting unstructured text into
|
||||
either Pydantic models or JSON, tailored for specific agent requirements.
|
||||
@@ -19,6 +18,7 @@ class OutputConverter(BaseModel, ABC):
|
||||
model (Any): The target model for structuring the output.
|
||||
instructions (str): Specific instructions for the conversion process.
|
||||
max_attempts (int): Maximum number of conversion attempts (default: 3).
|
||||
|
||||
"""
|
||||
|
||||
text: str = Field(description="Text to be converted.")
|
||||
@@ -33,9 +33,7 @@ class OutputConverter(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def to_pydantic(self, current_attempt=1) -> BaseModel:
|
||||
"""Convert text to pydantic."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_json(self, current_attempt=1) -> dict:
|
||||
"""Convert text to json."""
|
||||
pass
|
||||
|
||||
8
src/crewai/agents/cache/cache_handler.py
vendored
8
src/crewai/agents/cache/cache_handler.py
vendored
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
@@ -6,10 +6,10 @@ from pydantic import BaseModel, PrivateAttr
|
||||
class CacheHandler(BaseModel):
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
_cache: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def add(self, tool, input, output):
|
||||
def add(self, tool, input, output) -> None:
|
||||
self._cache[f"{tool}-{input}"] = output
|
||||
|
||||
def read(self, tool, input) -> Optional[str]:
|
||||
def read(self, tool, input) -> str | None:
|
||||
return self._cache.get(f"{tool}-{input}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
@@ -10,8 +9,6 @@ from crewai.agents.parser import (
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities import I18N, Printer
|
||||
@@ -34,6 +31,10 @@ from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
_logger: Logger = Logger()
|
||||
@@ -46,18 +47,22 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
agent: BaseAgent,
|
||||
prompt: dict[str, str],
|
||||
max_iter: int,
|
||||
tools: List[CrewStructuredTool],
|
||||
tools: list[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: List[str],
|
||||
stop_words: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Any = None,
|
||||
original_tools: List[Any] = [],
|
||||
original_tools: list[Any] | None = None,
|
||||
function_calling_llm: Any = None,
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Optional[Callable[[], bool]] = None,
|
||||
callbacks: List[Any] = [],
|
||||
):
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
) -> None:
|
||||
if callbacks is None:
|
||||
callbacks = []
|
||||
if original_tools is None:
|
||||
original_tools = []
|
||||
self._i18n: I18N = I18N()
|
||||
self.llm: BaseLLM = llm
|
||||
self.task = task
|
||||
@@ -79,10 +84,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.respect_context_window = respect_context_window
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.ask_for_human_input = False
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
self.messages: list[dict[str, str]] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.tool_name_to_tool_map: Dict[str, Union[CrewStructuredTool, BaseTool]] = {
|
||||
self.tool_name_to_tool_map: dict[str, CrewStructuredTool | BaseTool] = {
|
||||
tool.name: tool for tool in self.tools
|
||||
}
|
||||
existing_stop = self.llm.stop or []
|
||||
@@ -90,11 +95,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
set(
|
||||
existing_stop + self.stop
|
||||
if isinstance(existing_stop, list)
|
||||
else self.stop
|
||||
)
|
||||
else self.stop,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
def invoke(self, inputs: dict[str, str]) -> dict[str, Any]:
|
||||
if "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
|
||||
user_prompt = self._format_prompt(self.prompt.get("user", ""), inputs)
|
||||
@@ -120,9 +125,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
handle_unknown_error(self._printer, e)
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
raise
|
||||
raise
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
@@ -133,8 +137,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""
|
||||
Main loop to invoke the agent's thought process until it reaches a conclusion
|
||||
"""Main loop to invoke the agent's thought process until it reaches a conclusion
|
||||
or the maximum number of iterations is reached.
|
||||
"""
|
||||
formatted_answer = None
|
||||
@@ -170,8 +173,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
):
|
||||
fingerprint_context = {
|
||||
"agent_fingerprint": str(
|
||||
self.agent.security_config.fingerprint
|
||||
)
|
||||
self.agent.security_config.fingerprint,
|
||||
),
|
||||
}
|
||||
|
||||
tool_result = execute_tool_and_check_finality(
|
||||
@@ -187,7 +190,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
)
|
||||
formatted_answer = self._handle_agent_action(
|
||||
formatted_answer, tool_result
|
||||
formatted_answer, tool_result,
|
||||
)
|
||||
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
@@ -205,7 +208,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
raise
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
respect_context_window=self.respect_context_window,
|
||||
@@ -216,9 +219,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
@@ -231,8 +233,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return formatted_answer
|
||||
|
||||
def _handle_agent_action(
|
||||
self, formatted_answer: AgentAction, tool_result: ToolResult
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
self, formatted_answer: AgentAction, tool_result: ToolResult,
|
||||
) -> AgentAction | AgentFinish:
|
||||
"""Handle the AgentAction, execute tools, and process the results."""
|
||||
# Special case for add_image_tool
|
||||
add_image_tool = self._i18n.tools("add_image")
|
||||
@@ -261,24 +263,26 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""Append a message to the message list with the given role."""
|
||||
self.messages.append(format_message_for_llm(text, role=role))
|
||||
|
||||
def _show_start_logs(self):
|
||||
def _show_start_logs(self) -> None:
|
||||
"""Show logs for the start of agent execution."""
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
msg = "Agent cannot be None"
|
||||
raise ValueError(msg)
|
||||
show_agent_logs(
|
||||
printer=self._printer,
|
||||
agent_role=self.agent.role,
|
||||
task_description=(
|
||||
getattr(self.task, "description") if self.task else "Not Found"
|
||||
self.task.description if self.task else "Not Found"
|
||||
),
|
||||
verbose=self.agent.verbose
|
||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
||||
)
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
|
||||
"""Show logs for the agent's execution."""
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
msg = "Agent cannot be None"
|
||||
raise ValueError(msg)
|
||||
show_agent_logs(
|
||||
printer=self._printer,
|
||||
agent_role=self.agent.role,
|
||||
@@ -300,11 +304,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
summary = self.llm.call(
|
||||
[
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("summarizer_system_message"), role="system"
|
||||
self._i18n.slice("summarizer_system_message"), role="system",
|
||||
),
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("summarize_instruction").format(
|
||||
group=group["content"]
|
||||
group=group["content"],
|
||||
),
|
||||
),
|
||||
],
|
||||
@@ -316,12 +320,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
self.messages = [
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("summary").format(merged_summary=merged_summary)
|
||||
)
|
||||
self._i18n.slice("summary").format(merged_summary=merged_summary),
|
||||
),
|
||||
]
|
||||
|
||||
def _handle_crew_training_output(
|
||||
self, result: AgentFinish, human_feedback: Optional[str] = None
|
||||
self, result: AgentFinish, human_feedback: str | None = None,
|
||||
) -> None:
|
||||
"""Handle the process of saving training data."""
|
||||
agent_id = str(self.agent.id) # type: ignore
|
||||
@@ -348,29 +352,27 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"initial_output": result.output,
|
||||
"human_feedback": human_feedback,
|
||||
}
|
||||
# Save improved output
|
||||
elif train_iteration in agent_training_data:
|
||||
agent_training_data[train_iteration]["improved_output"] = result.output
|
||||
else:
|
||||
# Save improved output
|
||||
if train_iteration in agent_training_data:
|
||||
agent_training_data[train_iteration]["improved_output"] = result.output
|
||||
else:
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"No existing training data for agent {agent_id} and iteration "
|
||||
f"{train_iteration}. Cannot save improved output."
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"No existing training data for agent {agent_id} and iteration "
|
||||
f"{train_iteration}. Cannot save improved output."
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
# Update the training data and save
|
||||
training_data[agent_id] = agent_training_data
|
||||
training_handler.save(training_data)
|
||||
|
||||
def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str:
|
||||
def _format_prompt(self, prompt: str, inputs: dict[str, str]) -> str:
|
||||
prompt = prompt.replace("{input}", inputs["input"])
|
||||
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
|
||||
prompt = prompt.replace("{tools}", inputs["tools"])
|
||||
return prompt
|
||||
return prompt.replace("{tools}", inputs["tools"])
|
||||
|
||||
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
||||
"""Handle human feedback with different flows for training vs regular use.
|
||||
@@ -380,6 +382,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
Returns:
|
||||
AgentFinish: The final answer after processing feedback
|
||||
|
||||
"""
|
||||
human_feedback = self._ask_human_input(formatted_answer.output)
|
||||
|
||||
@@ -393,14 +396,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return bool(self.crew and self.crew._train)
|
||||
|
||||
def _handle_training_feedback(
|
||||
self, initial_answer: AgentFinish, feedback: str
|
||||
self, initial_answer: AgentFinish, feedback: str,
|
||||
) -> AgentFinish:
|
||||
"""Process feedback for training scenarios with single iteration."""
|
||||
self._handle_crew_training_output(initial_answer, feedback)
|
||||
self.messages.append(
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback),
|
||||
),
|
||||
)
|
||||
improved_answer = self._invoke_loop()
|
||||
self._handle_crew_training_output(improved_answer)
|
||||
@@ -408,7 +411,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return improved_answer
|
||||
|
||||
def _handle_regular_feedback(
|
||||
self, current_answer: AgentFinish, initial_feedback: str
|
||||
self, current_answer: AgentFinish, initial_feedback: str,
|
||||
) -> AgentFinish:
|
||||
"""Process feedback for regular use with potential multiple iterations."""
|
||||
feedback = initial_feedback
|
||||
@@ -428,8 +431,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""Process a single feedback iteration."""
|
||||
self.messages.append(
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback),
|
||||
),
|
||||
)
|
||||
return self._invoke_loop()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
@@ -18,7 +18,7 @@ class AgentAction:
|
||||
text: str
|
||||
result: str
|
||||
|
||||
def __init__(self, thought: str, tool: str, tool_input: str, text: str):
|
||||
def __init__(self, thought: str, tool: str, tool_input: str, text: str) -> None:
|
||||
self.thought = thought
|
||||
self.tool = tool
|
||||
self.tool_input = tool_input
|
||||
@@ -30,7 +30,7 @@ class AgentFinish:
|
||||
output: str
|
||||
text: str
|
||||
|
||||
def __init__(self, thought: str, output: str, text: str):
|
||||
def __init__(self, thought: str, output: str, text: str) -> None:
|
||||
self.thought = thought
|
||||
self.output = output
|
||||
self.text = text
|
||||
@@ -39,7 +39,7 @@ class AgentFinish:
|
||||
class OutputParserException(Exception):
|
||||
error: str
|
||||
|
||||
def __init__(self, error: str):
|
||||
def __init__(self, error: str) -> None:
|
||||
self.error = error
|
||||
|
||||
|
||||
@@ -67,24 +67,24 @@ class CrewAgentParser:
|
||||
_i18n: I18N = I18N()
|
||||
agent: Any = None
|
||||
|
||||
def __init__(self, agent: Optional[Any] = None):
|
||||
def __init__(self, agent: Any | None = None) -> None:
|
||||
self.agent = agent
|
||||
|
||||
@staticmethod
|
||||
def parse_text(text: str) -> Union[AgentAction, AgentFinish]:
|
||||
"""
|
||||
Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
|
||||
def parse_text(text: str) -> AgentAction | AgentFinish:
|
||||
"""Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
|
||||
|
||||
Args:
|
||||
text: The text to parse.
|
||||
|
||||
Returns:
|
||||
Either an AgentAction or AgentFinish based on the parsed content.
|
||||
|
||||
"""
|
||||
parser = CrewAgentParser()
|
||||
return parser.parse(text)
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
def parse(self, text: str) -> AgentAction | AgentFinish:
|
||||
thought = self._extract_thought(text)
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
regex = (
|
||||
@@ -102,7 +102,7 @@ class CrewAgentParser:
|
||||
final_answer = final_answer[:-3].rstrip()
|
||||
return AgentFinish(thought, final_answer, text)
|
||||
|
||||
elif action_match:
|
||||
if action_match:
|
||||
action = action_match.group(1)
|
||||
clean_action = self._clean_action(action)
|
||||
|
||||
@@ -114,21 +114,21 @@ class CrewAgentParser:
|
||||
return AgentAction(thought, clean_action, safe_tool_input, text)
|
||||
|
||||
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
|
||||
msg = f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}"
|
||||
raise OutputParserException(
|
||||
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}",
|
||||
msg,
|
||||
)
|
||||
elif not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||
if not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL,
|
||||
):
|
||||
raise OutputParserException(
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
)
|
||||
else:
|
||||
format = self._i18n.slice("format_without_tools")
|
||||
error = f"{format}"
|
||||
raise OutputParserException(
|
||||
error,
|
||||
)
|
||||
format = self._i18n.slice("format_without_tools")
|
||||
error = f"{format}"
|
||||
raise OutputParserException(
|
||||
error,
|
||||
)
|
||||
|
||||
def _extract_thought(self, text: str) -> str:
|
||||
thought_index = text.find("\nAction")
|
||||
@@ -138,8 +138,7 @@ class CrewAgentParser:
|
||||
return ""
|
||||
thought = text[:thought_index].strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
thought = thought.replace("```", "").strip()
|
||||
return thought
|
||||
return thought.replace("```", "").strip()
|
||||
|
||||
def _clean_action(self, text: str) -> str:
|
||||
"""Clean action string by removing non-essential formatting characters."""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools.cache_tools.cache_tools import CacheTools
|
||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
|
||||
from ..tools.cache_tools.cache_tools import CacheTools
|
||||
from ..tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
from .cache.cache_handler import CacheHandler
|
||||
|
||||
|
||||
@@ -9,16 +10,16 @@ class ToolsHandler:
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling")
|
||||
cache: Optional[CacheHandler]
|
||||
cache: CacheHandler | None
|
||||
|
||||
def __init__(self, cache: Optional[CacheHandler] = None):
|
||||
def __init__(self, cache: CacheHandler | None = None) -> None:
|
||||
"""Initialize the callback handler."""
|
||||
self.cache = cache
|
||||
self.last_used_tool = {} # type: ignore # BUG?: same as above
|
||||
|
||||
def on_tool_use(
|
||||
self,
|
||||
calling: Union[ToolCalling, InstructorToolCalling],
|
||||
calling: ToolCalling | InstructorToolCalling,
|
||||
output: str,
|
||||
should_cache: bool = True,
|
||||
) -> Any:
|
||||
|
||||
@@ -9,9 +9,9 @@ def add_crew_to_flow(crew_name: str) -> None:
|
||||
"""Add a new crew to the current flow."""
|
||||
# Check if pyproject.toml exists in the current directory
|
||||
if not Path("pyproject.toml").exists():
|
||||
print("This command must be run from the root of a flow project.")
|
||||
msg = "This command must be run from the root of a flow project."
|
||||
raise click.ClickException(
|
||||
"This command must be run from the root of a flow project."
|
||||
msg,
|
||||
)
|
||||
|
||||
# Determine the flow folder based on the current directory
|
||||
@@ -19,8 +19,8 @@ def add_crew_to_flow(crew_name: str) -> None:
|
||||
crews_folder = flow_folder / "src" / flow_folder.name / "crews"
|
||||
|
||||
if not crews_folder.exists():
|
||||
print("Crews folder does not exist in the current flow.")
|
||||
raise click.ClickException("Crews folder does not exist in the current flow.")
|
||||
msg = "Crews folder does not exist in the current flow."
|
||||
raise click.ClickException(msg)
|
||||
|
||||
# Create the crew within the flow's crews directory
|
||||
create_embedded_crew(crew_name, parent_folder=crews_folder)
|
||||
@@ -39,7 +39,7 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
|
||||
|
||||
if crew_folder.exists():
|
||||
if not click.confirm(
|
||||
f"Crew {folder_name} already exists. Do you want to override it?"
|
||||
f"Crew {folder_name} already exists. Do you want to override it?",
|
||||
):
|
||||
click.secho("Operation cancelled.", fg="yellow")
|
||||
return
|
||||
@@ -66,5 +66,5 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
|
||||
copy_template(src_file, dst_file, crew_name, class_name, folder_name)
|
||||
|
||||
click.secho(
|
||||
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True
|
||||
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
import webbrowser
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from rich.console import Console
|
||||
@@ -17,38 +17,37 @@ class AuthenticationCommand:
|
||||
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
|
||||
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
def signup(self) -> None:
|
||||
"""Sign up to CrewAI+"""
|
||||
"""Sign up to CrewAI+."""
|
||||
console.print("Signing Up to CrewAI+ \n", style="bold blue")
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
|
||||
return self._poll_for_token(device_code_data)
|
||||
|
||||
def _get_device_code(self) -> Dict[str, Any]:
|
||||
def _get_device_code(self) -> dict[str, Any]:
|
||||
"""Get the device code to authenticate the user."""
|
||||
|
||||
device_code_payload = {
|
||||
"client_id": AUTH0_CLIENT_ID,
|
||||
"scope": "openid",
|
||||
"audience": AUTH0_AUDIENCE,
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20
|
||||
url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _display_auth_instructions(self, device_code_data: Dict[str, str]) -> None:
|
||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||
"""Display the authentication instructions to the user."""
|
||||
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
|
||||
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
|
||||
"""Poll the server for the token."""
|
||||
token_payload = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
@@ -81,7 +80,7 @@ class AuthenticationCommand:
|
||||
)
|
||||
|
||||
console.print(
|
||||
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n"
|
||||
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n",
|
||||
)
|
||||
return
|
||||
|
||||
@@ -92,5 +91,5 @@ class AuthenticationCommand:
|
||||
attempts += 1
|
||||
|
||||
console.print(
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red",
|
||||
)
|
||||
|
||||
@@ -5,5 +5,5 @@ def get_auth_token() -> str:
|
||||
"""Get the authentication token."""
|
||||
access_token = TokenManager().get_token()
|
||||
if not access_token:
|
||||
raise Exception()
|
||||
raise Exception
|
||||
return access_token
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from auth0.authentication.token_verifier import (
|
||||
AsymmetricSignatureVerifier,
|
||||
@@ -15,8 +14,7 @@ from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
|
||||
|
||||
|
||||
def validate_token(id_token: str) -> None:
|
||||
"""
|
||||
Verify the token and its precedence
|
||||
"""Verify the token and its precedence.
|
||||
|
||||
:param id_token:
|
||||
"""
|
||||
@@ -24,15 +22,14 @@ def validate_token(id_token: str) -> None:
|
||||
issuer = f"https://{AUTH0_DOMAIN}/"
|
||||
signature_verifier = AsymmetricSignatureVerifier(jwks_url)
|
||||
token_verifier = TokenVerifier(
|
||||
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
|
||||
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID,
|
||||
)
|
||||
token_verifier.verify(id_token)
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
"""Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
@@ -41,8 +38,7 @@ class TokenManager:
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
"""Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
@@ -57,8 +53,7 @@ class TokenManager:
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_in: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
"""Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_in: The expiration time of the access token in seconds.
|
||||
@@ -71,9 +66,8 @@ class TokenManager:
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
def get_token(self) -> str | None:
|
||||
"""Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
@@ -89,8 +83,7 @@ class TokenManager:
|
||||
return data["access_token"]
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
"""Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
@@ -112,8 +105,7 @@ class TokenManager:
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
"""Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
@@ -127,9 +119,8 @@ class TokenManager:
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
def read_secure_file(self, filename: str) -> bytes | None:
|
||||
"""Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import os
|
||||
from importlib.metadata import version as get_version
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import click
|
||||
|
||||
@@ -28,7 +26,7 @@ from .update_crew import update_crew
|
||||
|
||||
@click.group()
|
||||
@click.version_option(get_version("crewai"))
|
||||
def crewai():
|
||||
def crewai() -> None:
|
||||
"""Top-level command group for crewai."""
|
||||
|
||||
|
||||
@@ -37,7 +35,7 @@ def crewai():
|
||||
@click.argument("name")
|
||||
@click.option("--provider", type=str, help="The provider to use for the crew")
|
||||
@click.option("--skip_provider", is_flag=True, help="Skip provider validation")
|
||||
def create(type, name, provider, skip_provider=False):
|
||||
def create(type, name, provider, skip_provider=False) -> None:
|
||||
"""Create a new crew, or flow."""
|
||||
if type == "crew":
|
||||
create_crew(name, provider, skip_provider)
|
||||
@@ -49,9 +47,9 @@ def create(type, name, provider, skip_provider=False):
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--tools", is_flag=True, help="Show the installed version of crewai tools"
|
||||
"--tools", is_flag=True, help="Show the installed version of crewai tools",
|
||||
)
|
||||
def version(tools):
|
||||
def version(tools) -> None:
|
||||
"""Show the installed version of crewai."""
|
||||
try:
|
||||
crewai_version = get_version("crewai")
|
||||
@@ -82,7 +80,7 @@ def version(tools):
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a custom file for training",
|
||||
)
|
||||
def train(n_iterations: int, filename: str):
|
||||
def train(n_iterations: int, filename: str) -> None:
|
||||
"""Train the crew."""
|
||||
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations, filename)
|
||||
@@ -96,11 +94,11 @@ def train(n_iterations: int, filename: str):
|
||||
help="Replay the crew from this task ID, including all subsequent tasks.",
|
||||
)
|
||||
def replay(task_id: str) -> None:
|
||||
"""
|
||||
Replay the crew execution from a specific task.
|
||||
"""Replay the crew execution from a specific task.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to replay from.
|
||||
|
||||
"""
|
||||
try:
|
||||
click.echo(f"Replaying the crew from task {task_id}")
|
||||
@@ -111,16 +109,14 @@ def replay(task_id: str) -> None:
|
||||
|
||||
@crewai.command()
|
||||
def log_tasks_outputs() -> None:
|
||||
"""
|
||||
Retrieve your latest crew.kickoff() task outputs.
|
||||
"""
|
||||
"""Retrieve your latest crew.kickoff() task outputs."""
|
||||
try:
|
||||
storage = KickoffTaskOutputsSQLiteStorage()
|
||||
tasks = storage.load()
|
||||
|
||||
if not tasks:
|
||||
click.echo(
|
||||
"No task outputs found. Only crew kickoff task outputs are logged."
|
||||
"No task outputs found. Only crew kickoff task outputs are logged.",
|
||||
)
|
||||
return
|
||||
|
||||
@@ -153,13 +149,11 @@ def reset_memories(
|
||||
kickoff_outputs: bool,
|
||||
all: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.
|
||||
"""
|
||||
"""Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved."""
|
||||
try:
|
||||
if not all and not (long or short or entities or knowledge or kickoff_outputs):
|
||||
click.echo(
|
||||
"Please specify at least one memory type to reset using the appropriate flags."
|
||||
"Please specify at least one memory type to reset using the appropriate flags.",
|
||||
)
|
||||
return
|
||||
reset_memories_command(long, short, entities, knowledge, kickoff_outputs, all)
|
||||
@@ -182,71 +176,69 @@ def reset_memories(
|
||||
default="gpt-4o-mini",
|
||||
help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.",
|
||||
)
|
||||
def test(n_iterations: int, model: str):
|
||||
def test(n_iterations: int, model: str) -> None:
|
||||
"""Test the crew and evaluate the results."""
|
||||
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
|
||||
evaluate_crew(n_iterations, model)
|
||||
|
||||
|
||||
@crewai.command(
|
||||
context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
allow_extra_args=True,
|
||||
)
|
||||
context_settings={
|
||||
"ignore_unknown_options": True,
|
||||
"allow_extra_args": True,
|
||||
},
|
||||
)
|
||||
@click.pass_context
|
||||
def install(context):
|
||||
def install(context) -> None:
|
||||
"""Install the Crew."""
|
||||
install_crew(context.args)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""Run the Crew."""
|
||||
run_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def update():
|
||||
def update() -> None:
|
||||
"""Update the pyproject.toml of the Crew project to use uv."""
|
||||
update_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def signup():
|
||||
def signup() -> None:
|
||||
"""Sign Up/Login to CrewAI+."""
|
||||
AuthenticationCommand().signup()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def login():
|
||||
def login() -> None:
|
||||
"""Sign Up/Login to CrewAI+."""
|
||||
AuthenticationCommand().signup()
|
||||
|
||||
|
||||
# DEPLOY CREWAI+ COMMANDS
|
||||
@crewai.group()
|
||||
def deploy():
|
||||
def deploy() -> None:
|
||||
"""Deploy the Crew CLI group."""
|
||||
pass
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
def tool() -> None:
|
||||
"""Tool Repository related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool):
|
||||
def deploy_create(yes: bool) -> None:
|
||||
"""Create a Crew deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.create_crew(yes)
|
||||
|
||||
|
||||
@deploy.command(name="list")
|
||||
def deploy_list():
|
||||
def deploy_list() -> None:
|
||||
"""List all deployments."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.list_crews()
|
||||
@@ -254,7 +246,7 @@ def deploy_list():
|
||||
|
||||
@deploy.command(name="push")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_push(uuid: Optional[str]):
|
||||
def deploy_push(uuid: str | None) -> None:
|
||||
"""Deploy the Crew."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid)
|
||||
@@ -262,7 +254,7 @@ def deploy_push(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="status")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deply_status(uuid: Optional[str]):
|
||||
def deply_status(uuid: str | None) -> None:
|
||||
"""Get the status of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_status(uuid=uuid)
|
||||
@@ -270,7 +262,7 @@ def deply_status(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="logs")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_logs(uuid: Optional[str]):
|
||||
def deploy_logs(uuid: str | None) -> None:
|
||||
"""Get the logs of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||
@@ -278,7 +270,7 @@ def deploy_logs(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="remove")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_remove(uuid: Optional[str]):
|
||||
def deploy_remove(uuid: str | None) -> None:
|
||||
"""Remove a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
@@ -286,14 +278,14 @@ def deploy_remove(uuid: Optional[str]):
|
||||
|
||||
@tool.command(name="create")
|
||||
@click.argument("handle")
|
||||
def tool_create(handle: str):
|
||||
def tool_create(handle: str) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.create(handle)
|
||||
|
||||
|
||||
@tool.command(name="install")
|
||||
@click.argument("handle")
|
||||
def tool_install(handle: str):
|
||||
def tool_install(handle: str) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.install(handle)
|
||||
@@ -309,27 +301,26 @@ def tool_install(handle: str):
|
||||
)
|
||||
@click.option("--public", "is_public", flag_value=True, default=False)
|
||||
@click.option("--private", "is_public", flag_value=False)
|
||||
def tool_publish(is_public: bool, force: bool):
|
||||
def tool_publish(is_public: bool, force: bool) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.publish(is_public, force)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def flow():
|
||||
def flow() -> None:
|
||||
"""Flow related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@flow.command(name="kickoff")
|
||||
def flow_run():
|
||||
def flow_run() -> None:
|
||||
"""Kickoff the Flow."""
|
||||
click.echo("Running the Flow")
|
||||
kickoff_flow()
|
||||
|
||||
|
||||
@flow.command(name="plot")
|
||||
def flow_plot():
|
||||
def flow_plot() -> None:
|
||||
"""Plot the Flow."""
|
||||
click.echo("Plotting the Flow")
|
||||
plot_flow()
|
||||
@@ -337,20 +328,19 @@ def flow_plot():
|
||||
|
||||
@flow.command(name="add-crew")
|
||||
@click.argument("crew_name")
|
||||
def flow_add_crew(crew_name):
|
||||
def flow_add_crew(crew_name) -> None:
|
||||
"""Add a crew to an existing flow."""
|
||||
click.echo(f"Adding crew {crew_name} to the flow")
|
||||
add_crew_to_flow(crew_name)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def chat():
|
||||
"""
|
||||
Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
def chat() -> None:
|
||||
"""Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
and using the Chat LLM to generate responses.
|
||||
"""
|
||||
click.secho(
|
||||
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
|
||||
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
|
||||
)
|
||||
|
||||
run_chat()
|
||||
|
||||
@@ -10,13 +10,13 @@ console = Console()
|
||||
|
||||
|
||||
class BaseCommand:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
|
||||
|
||||
class PlusAPIMixin:
|
||||
def __init__(self, telemetry):
|
||||
def __init__(self, telemetry) -> None:
|
||||
try:
|
||||
telemetry.set_tracer()
|
||||
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
||||
@@ -30,11 +30,11 @@ class PlusAPIMixin:
|
||||
raise SystemExit
|
||||
|
||||
def _validate_response(self, response: requests.Response) -> None:
|
||||
"""
|
||||
Handle and display error messages from API responses.
|
||||
"""Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
response (requests.Response): The response from the Plus API
|
||||
|
||||
"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
@@ -55,13 +55,13 @@ class PlusAPIMixin:
|
||||
for field, messages in json_response.items():
|
||||
for message in messages:
|
||||
console.print(
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}"
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
if not response.ok:
|
||||
console.print(
|
||||
"Request to Enterprise API failed. Details:", style="bold red"
|
||||
"Request to Enterprise API failed. Details:", style="bold red",
|
||||
)
|
||||
details = (
|
||||
json_response.get("error")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -8,16 +7,16 @@ DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
tool_repository_username: Optional[str] = Field(
|
||||
None, description="Username for interacting with the Tool Repository"
|
||||
tool_repository_username: str | None = Field(
|
||||
None, description="Username for interacting with the Tool Repository",
|
||||
)
|
||||
tool_repository_password: Optional[str] = Field(
|
||||
None, description="Password for interacting with the Tool Repository"
|
||||
tool_repository_password: str | None = Field(
|
||||
None, description="Password for interacting with the Tool Repository",
|
||||
)
|
||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
"""Load Settings from config path"""
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data) -> None:
|
||||
"""Load Settings from config path."""
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_data = {}
|
||||
@@ -32,7 +31,7 @@ class Settings(BaseModel):
|
||||
super().__init__(config_path=config_path, **merged_data)
|
||||
|
||||
def dump(self) -> None:
|
||||
"""Save current settings to settings.json"""
|
||||
"""Save current settings to settings.json."""
|
||||
if self.config_path.is_file():
|
||||
with self.config_path.open("r") as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
@@ -3,31 +3,31 @@ ENV_VARS = {
|
||||
{
|
||||
"prompt": "Enter your OPENAI API key (press Enter to skip)",
|
||||
"key_name": "OPENAI_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"anthropic": [
|
||||
{
|
||||
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
|
||||
"key_name": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"gemini": [
|
||||
{
|
||||
"prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)",
|
||||
"key_name": "GEMINI_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"nvidia_nim": [
|
||||
{
|
||||
"prompt": "Enter your NVIDIA API key (press Enter to skip)",
|
||||
"key_name": "NVIDIA_NIM_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"groq": [
|
||||
{
|
||||
"prompt": "Enter your GROQ API key (press Enter to skip)",
|
||||
"key_name": "GROQ_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"watson": [
|
||||
{
|
||||
@@ -47,7 +47,7 @@ ENV_VARS = {
|
||||
{
|
||||
"default": True,
|
||||
"API_BASE": "http://localhost:11434",
|
||||
}
|
||||
},
|
||||
],
|
||||
"bedrock": [
|
||||
{
|
||||
@@ -101,7 +101,7 @@ ENV_VARS = {
|
||||
{
|
||||
"prompt": "Enter your SambaNovaCloud API key (press Enter to skip)",
|
||||
"key_name": "SAMBANOVA_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ def create_folder_structure(name, parent_folder=None):
|
||||
|
||||
if folder_path.exists():
|
||||
if not click.confirm(
|
||||
f"Folder {folder_name} already exists. Do you want to override it?"
|
||||
f"Folder {folder_name} already exists. Do you want to override it?",
|
||||
):
|
||||
click.secho("Operation cancelled.", fg="yellow")
|
||||
sys.exit(0)
|
||||
@@ -48,7 +48,7 @@ def create_folder_structure(name, parent_folder=None):
|
||||
return folder_path, folder_name, class_name
|
||||
|
||||
|
||||
def copy_template_files(folder_path, name, class_name, parent_folder):
|
||||
def copy_template_files(folder_path, name, class_name, parent_folder) -> None:
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
|
||||
@@ -89,7 +89,7 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
|
||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||
|
||||
|
||||
def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
def create_crew(name, provider=None, skip_provider=False, parent_folder=None) -> None:
|
||||
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
||||
env_vars = load_env_vars(folder_path)
|
||||
if not skip_provider:
|
||||
@@ -109,7 +109,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
|
||||
if existing_provider:
|
||||
if not click.confirm(
|
||||
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
|
||||
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?",
|
||||
):
|
||||
click.secho("Keeping existing provider configuration.", fg="yellow")
|
||||
return
|
||||
@@ -126,11 +126,11 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
if selected_provider: # Valid selection
|
||||
break
|
||||
click.secho(
|
||||
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
||||
"No provider selected. Please try again or press 'q' to exit.", fg="red",
|
||||
)
|
||||
|
||||
# Check if the selected provider has predefined models
|
||||
if selected_provider in MODELS and MODELS[selected_provider]:
|
||||
if MODELS.get(selected_provider):
|
||||
while True:
|
||||
selected_model = select_model(selected_provider, provider_models)
|
||||
if selected_model is None: # User typed 'q'
|
||||
@@ -167,7 +167,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
click.secho("API keys and model saved to .env file", fg="green")
|
||||
else:
|
||||
click.secho(
|
||||
"No API keys provided. Skipping .env file creation.", fg="yellow"
|
||||
"No API keys provided. Skipping .env file creation.", fg="yellow",
|
||||
)
|
||||
|
||||
click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")
|
||||
|
||||
@@ -5,7 +5,7 @@ import click
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
|
||||
def create_flow(name):
|
||||
def create_flow(name) -> None:
|
||||
"""Create a new flow."""
|
||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
@@ -43,12 +43,12 @@ def create_flow(name):
|
||||
"poem_crew",
|
||||
]
|
||||
|
||||
def process_file(src_file, dst_file):
|
||||
def process_file(src_file, dst_file) -> None:
|
||||
if src_file.suffix in [".pyc", ".pyo", ".pyd"]:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(src_file, "r", encoding="utf-8") as file:
|
||||
with open(src_file, encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
except Exception as e:
|
||||
click.secho(f"Error processing file {src_file}: {e}", fg="red")
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -22,10 +22,9 @@ MIN_REQUIRED_VERSION = "0.98.0"
|
||||
|
||||
|
||||
def check_conversational_crews_version(
|
||||
crewai_version: str, pyproject_data: dict
|
||||
crewai_version: str, pyproject_data: dict,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the installed crewAI version supports conversational crews.
|
||||
"""Check if the installed crewAI version supports conversational crews.
|
||||
|
||||
Args:
|
||||
crewai_version: The current version of crewAI.
|
||||
@@ -33,6 +32,7 @@ def check_conversational_crews_version(
|
||||
|
||||
Returns:
|
||||
bool: True if version check passes, False otherwise.
|
||||
|
||||
"""
|
||||
try:
|
||||
if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION):
|
||||
@@ -48,9 +48,8 @@ def check_conversational_crews_version(
|
||||
return True
|
||||
|
||||
|
||||
def run_chat():
|
||||
"""
|
||||
Runs an interactive chat loop using the Crew's chat LLM with function calling.
|
||||
def run_chat() -> None:
|
||||
"""Runs an interactive chat loop using the Crew's chat LLM with function calling.
|
||||
Incorporates crew_name, crew_description, and input fields to build a tool schema.
|
||||
Exits if crew_name or crew_description are missing.
|
||||
"""
|
||||
@@ -84,7 +83,7 @@ def run_chat():
|
||||
|
||||
# Call the LLM to generate the introductory message
|
||||
introductory_message = chat_llm.call(
|
||||
messages=[{"role": "system", "content": system_message}]
|
||||
messages=[{"role": "system", "content": system_message}],
|
||||
)
|
||||
finally:
|
||||
# Stop loading indicator
|
||||
@@ -108,15 +107,13 @@ def run_chat():
|
||||
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
|
||||
|
||||
|
||||
def show_loading(event: threading.Event):
|
||||
def show_loading(event: threading.Event) -> None:
|
||||
"""Display animated loading dots while processing."""
|
||||
while not event.is_set():
|
||||
print(".", end="", flush=True)
|
||||
time.sleep(1)
|
||||
print()
|
||||
|
||||
|
||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
|
||||
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
|
||||
"""Initializes the chat LLM and handles exceptions."""
|
||||
try:
|
||||
return create_llm(crew.chat_llm)
|
||||
@@ -157,7 +154,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
)
|
||||
|
||||
|
||||
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
||||
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
|
||||
"""Creates a wrapper function for running the crew tool with messages."""
|
||||
|
||||
def run_crew_tool_with_messages(**kwargs):
|
||||
@@ -166,7 +163,7 @@ def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
||||
return run_crew_tool_with_messages
|
||||
|
||||
|
||||
def flush_input():
|
||||
def flush_input() -> None:
|
||||
"""Flush any pending input from the user."""
|
||||
if platform.system() == "Windows":
|
||||
# Windows platform
|
||||
@@ -181,7 +178,7 @@ def flush_input():
|
||||
termios.tcflush(sys.stdin, termios.TCIFLUSH)
|
||||
|
||||
|
||||
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions) -> None:
|
||||
"""Main chat loop for interacting with the user."""
|
||||
while True:
|
||||
try:
|
||||
@@ -190,7 +187,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
|
||||
user_input = get_user_input()
|
||||
handle_user_input(
|
||||
user_input, chat_llm, messages, crew_tool_schema, available_functions
|
||||
user_input, chat_llm, messages, crew_tool_schema, available_functions,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@@ -221,9 +218,9 @@ def get_user_input() -> str:
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: LLM,
|
||||
messages: List[Dict[str, str]],
|
||||
crew_tool_schema: Dict[str, Any],
|
||||
available_functions: Dict[str, Any],
|
||||
messages: list[dict[str, str]],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
) -> None:
|
||||
if user_input.strip().lower() == "exit":
|
||||
click.echo("Exiting chat. Goodbye!")
|
||||
@@ -251,8 +248,7 @@ def handle_user_input(
|
||||
|
||||
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
"""
|
||||
Dynamically build a Littellm 'function' schema for the given crew.
|
||||
"""Dynamically build a Littellm 'function' schema for the given crew.
|
||||
|
||||
crew_name: The name of the crew (used for the function 'name').
|
||||
crew_inputs: A ChatInputs object containing crew_description
|
||||
@@ -281,9 +277,8 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
"""
|
||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
|
||||
"""Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew instance to run.
|
||||
@@ -295,6 +290,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
|
||||
Raises:
|
||||
SystemExit: Exits the chat if an error occurs during crew execution.
|
||||
|
||||
"""
|
||||
try:
|
||||
# Serialize 'messages' to JSON string before adding to kwargs
|
||||
@@ -304,9 +300,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
crew_output = crew.kickoff(inputs=kwargs)
|
||||
|
||||
# Convert CrewOutput to a string to send back to the user
|
||||
result = str(crew_output)
|
||||
return str(crew_output)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# Exit the chat and show the error message
|
||||
click.secho("An error occurred while running the crew:", fg="red")
|
||||
@@ -314,12 +309,12 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
"""
|
||||
Loads the crew by importing the crew class from the user's project.
|
||||
def load_crew_and_name() -> tuple[Crew, str]:
|
||||
"""Loads the crew by importing the crew class from the user's project.
|
||||
|
||||
Returns:
|
||||
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
|
||||
|
||||
"""
|
||||
# Get the current working directory
|
||||
cwd = Path.cwd()
|
||||
@@ -327,7 +322,8 @@ def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
# Path to the pyproject.toml file
|
||||
pyproject_path = cwd / "pyproject.toml"
|
||||
if not pyproject_path.exists():
|
||||
raise FileNotFoundError("pyproject.toml not found in the current directory.")
|
||||
msg = "pyproject.toml not found in the current directory."
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Load the pyproject.toml file using 'tomli'
|
||||
with pyproject_path.open("rb") as f:
|
||||
@@ -351,14 +347,16 @@ def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
try:
|
||||
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import crew module {crew_module_name}: {e}")
|
||||
msg = f"Failed to import crew module {crew_module_name}: {e}"
|
||||
raise ImportError(msg)
|
||||
|
||||
# Get the crew class from the module
|
||||
try:
|
||||
crew_class = getattr(crew_module, crew_class_name)
|
||||
except AttributeError:
|
||||
msg = f"Crew class {crew_class_name} not found in module {crew_module_name}"
|
||||
raise AttributeError(
|
||||
f"Crew class {crew_class_name} not found in module {crew_module_name}"
|
||||
msg,
|
||||
)
|
||||
|
||||
# Instantiate the crew
|
||||
@@ -367,8 +365,7 @@ def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
|
||||
|
||||
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
|
||||
"""
|
||||
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||
"""Generates the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew object containing tasks and agents.
|
||||
@@ -377,6 +374,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
||||
|
||||
Returns:
|
||||
ChatInputs: An object containing the crew's name, description, and input fields.
|
||||
|
||||
"""
|
||||
# Extract placeholders from tasks and agents
|
||||
required_inputs = fetch_required_inputs(crew)
|
||||
@@ -391,22 +389,22 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
||||
crew_description = generate_crew_description_with_ai(crew, chat_llm)
|
||||
|
||||
return ChatInputs(
|
||||
crew_name=crew_name, crew_description=crew_description, inputs=input_fields
|
||||
crew_name=crew_name, crew_description=crew_description, inputs=input_fields,
|
||||
)
|
||||
|
||||
|
||||
def fetch_required_inputs(crew: Crew) -> Set[str]:
|
||||
"""
|
||||
Extracts placeholders from the crew's tasks and agents.
|
||||
def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||
"""Extracts placeholders from the crew's tasks and agents.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew object.
|
||||
|
||||
Returns:
|
||||
Set[str]: A set of placeholder names.
|
||||
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: Set[str] = set()
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks
|
||||
for task in crew.tasks:
|
||||
@@ -422,8 +420,7 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
|
||||
|
||||
|
||||
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
|
||||
"""
|
||||
Generates an input description using AI based on the context of the crew.
|
||||
"""Generates an input description using AI based on the context of the crew.
|
||||
|
||||
Args:
|
||||
input_name (str): The name of the input placeholder.
|
||||
@@ -432,6 +429,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
|
||||
Returns:
|
||||
str: A concise description of the input.
|
||||
|
||||
"""
|
||||
# Gather context from tasks and agents where the input is used
|
||||
context_texts = []
|
||||
@@ -444,10 +442,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
):
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description or ""
|
||||
lambda m: m.group(1), task.description or "",
|
||||
)
|
||||
expected_output = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.expected_output or ""
|
||||
lambda m: m.group(1), task.expected_output or "",
|
||||
)
|
||||
context_texts.append(f"Task Description: {task_description}")
|
||||
context_texts.append(f"Expected Output: {expected_output}")
|
||||
@@ -461,7 +459,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
|
||||
agent_backstory = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), agent.backstory or ""
|
||||
lambda m: m.group(1), agent.backstory or "",
|
||||
)
|
||||
context_texts.append(f"Agent Role: {agent_role}")
|
||||
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||
@@ -470,7 +468,8 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
context = "\n".join(context_texts)
|
||||
if not context:
|
||||
# If no context is found for the input, raise an exception as per instruction
|
||||
raise ValueError(f"No context found for input '{input_name}'.")
|
||||
msg = f"No context found for input '{input_name}'."
|
||||
raise ValueError(msg)
|
||||
|
||||
prompt = (
|
||||
f"Based on the following context, write a concise description (15 words or less) of the input '{input_name}'.\n"
|
||||
@@ -479,14 +478,12 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
description = response.strip()
|
||||
return response.strip()
|
||||
|
||||
return description
|
||||
|
||||
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
"""
|
||||
Generates a brief description of the crew using AI.
|
||||
"""Generates a brief description of the crew using AI.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew object.
|
||||
@@ -494,6 +491,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
|
||||
Returns:
|
||||
str: A concise description of the crew's purpose (15 words or less).
|
||||
|
||||
"""
|
||||
# Gather context from tasks and agents
|
||||
context_texts = []
|
||||
@@ -502,10 +500,10 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
for task in crew.tasks:
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description or ""
|
||||
lambda m: m.group(1), task.description or "",
|
||||
)
|
||||
expected_output = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.expected_output or ""
|
||||
lambda m: m.group(1), task.expected_output or "",
|
||||
)
|
||||
context_texts.append(f"Task Description: {task_description}")
|
||||
context_texts.append(f"Expected Output: {expected_output}")
|
||||
@@ -514,7 +512,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
|
||||
agent_backstory = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), agent.backstory or ""
|
||||
lambda m: m.group(1), agent.backstory or "",
|
||||
)
|
||||
context_texts.append(f"Agent Role: {agent_role}")
|
||||
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||
@@ -522,7 +520,8 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
|
||||
context = "\n".join(context_texts)
|
||||
if not context:
|
||||
raise ValueError("No context found for generating crew description.")
|
||||
msg = "No context found for generating crew description."
|
||||
raise ValueError(msg)
|
||||
|
||||
prompt = (
|
||||
"Based on the following context, write a concise, action-oriented description (15 words or less) of the crew's purpose.\n"
|
||||
@@ -531,6 +530,5 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
crew_description = response.strip()
|
||||
return response.strip()
|
||||
|
||||
return crew_description
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
@@ -10,34 +10,27 @@ console = Console()
|
||||
|
||||
|
||||
class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle deployment-related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the DeployCommand with project name and API client.
|
||||
"""
|
||||
"""A class to handle deployment-related operations for CrewAI projects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the DeployCommand with project name and API client."""
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
self.project_name = get_project_name(require=True)
|
||||
|
||||
def _standard_no_param_error_message(self) -> None:
|
||||
"""
|
||||
Display a standard error message when no UUID or project name is available.
|
||||
"""
|
||||
"""Display a standard error message when no UUID or project name is available."""
|
||||
console.print(
|
||||
"No UUID provided, project pyproject.toml not found or with error.",
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
def _display_deployment_info(self, json_response: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Display deployment information.
|
||||
def _display_deployment_info(self, json_response: dict[str, Any]) -> None:
|
||||
"""Display deployment information.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The deployment information to display.
|
||||
|
||||
"""
|
||||
console.print("Deploying the crew...\n", style="bold blue")
|
||||
for key, value in json_response.items():
|
||||
@@ -47,24 +40,24 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
console.print(" or")
|
||||
console.print(f"crewai deploy status --uuid \"{json_response['uuid']}\"")
|
||||
|
||||
def _display_logs(self, log_messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Display log messages.
|
||||
def _display_logs(self, log_messages: list[dict[str, Any]]) -> None:
|
||||
"""Display log messages.
|
||||
|
||||
Args:
|
||||
log_messages (List[Dict[str, Any]]): The log messages to display.
|
||||
|
||||
"""
|
||||
for log_message in log_messages:
|
||||
console.print(
|
||||
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}"
|
||||
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}",
|
||||
)
|
||||
|
||||
def deploy(self, uuid: Optional[str] = None) -> None:
|
||||
"""
|
||||
Deploy a crew using either UUID or project name.
|
||||
def deploy(self, uuid: str | None = None) -> None:
|
||||
"""Deploy a crew using either UUID or project name.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to deploy.
|
||||
|
||||
"""
|
||||
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
|
||||
console.print("Starting deployment...", style="bold blue")
|
||||
@@ -80,9 +73,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._display_deployment_info(response.json())
|
||||
|
||||
def create_crew(self, confirm: bool = False) -> None:
|
||||
"""
|
||||
Create a new crew deployment.
|
||||
"""
|
||||
"""Create a new crew deployment."""
|
||||
self._create_crew_deployment_span = (
|
||||
self._telemetry.create_crew_deployment_span()
|
||||
)
|
||||
@@ -110,29 +101,28 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._display_creation_success(response.json())
|
||||
|
||||
def _confirm_input(
|
||||
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
|
||||
self, env_vars: dict[str, str], remote_repo_url: str, confirm: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Confirm input parameters with the user.
|
||||
"""Confirm input parameters with the user.
|
||||
|
||||
Args:
|
||||
env_vars (Dict[str, str]): Environment variables.
|
||||
remote_repo_url (str): Remote repository URL.
|
||||
confirm (bool): Whether to confirm input.
|
||||
|
||||
"""
|
||||
if not confirm:
|
||||
input(f"Press Enter to continue with the following Env vars: {env_vars}")
|
||||
input(
|
||||
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n"
|
||||
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n",
|
||||
)
|
||||
|
||||
def _create_payload(
|
||||
self,
|
||||
env_vars: Dict[str, str],
|
||||
env_vars: dict[str, str],
|
||||
remote_repo_url: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create the payload for crew creation.
|
||||
) -> dict[str, Any]:
|
||||
"""Create the payload for crew creation.
|
||||
|
||||
Args:
|
||||
remote_repo_url (str): Remote repository URL.
|
||||
@@ -140,25 +130,26 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The payload for crew creation.
|
||||
|
||||
"""
|
||||
return {
|
||||
"deploy": {
|
||||
"name": self.project_name,
|
||||
"repo_clone_url": remote_repo_url,
|
||||
"env": env_vars,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _display_creation_success(self, json_response: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Display success message after crew creation.
|
||||
def _display_creation_success(self, json_response: dict[str, Any]) -> None:
|
||||
"""Display success message after crew creation.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The response containing crew information.
|
||||
|
||||
"""
|
||||
console.print("Deployment created successfully!\n", style="bold green")
|
||||
console.print(
|
||||
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green"
|
||||
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green",
|
||||
)
|
||||
console.print(f"Status: {json_response['status']}", style="bold green")
|
||||
console.print("\nTo (re)deploy the crew, run:")
|
||||
@@ -167,9 +158,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
console.print(f"crewai deploy push --uuid {json_response['uuid']}")
|
||||
|
||||
def list_crews(self) -> None:
|
||||
"""
|
||||
List all available crews.
|
||||
"""
|
||||
"""List all available crews."""
|
||||
console.print("Listing all Crews\n", style="bold blue")
|
||||
|
||||
response = self.plus_api_client.list_crews()
|
||||
@@ -179,31 +168,29 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
else:
|
||||
self._display_no_crews_message()
|
||||
|
||||
def _display_crews(self, crews_data: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Display the list of crews.
|
||||
def _display_crews(self, crews_data: list[dict[str, Any]]) -> None:
|
||||
"""Display the list of crews.
|
||||
|
||||
Args:
|
||||
crews_data (List[Dict[str, Any]]): List of crew data to display.
|
||||
|
||||
"""
|
||||
for crew_data in crews_data:
|
||||
console.print(
|
||||
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]"
|
||||
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]",
|
||||
)
|
||||
|
||||
def _display_no_crews_message(self) -> None:
|
||||
"""
|
||||
Display a message when no crews are available.
|
||||
"""
|
||||
"""Display a message when no crews are available."""
|
||||
console.print("You don't have any Crews yet. Let's create one!", style="yellow")
|
||||
console.print(" crewai create crew <crew_name>", style="green")
|
||||
|
||||
def get_crew_status(self, uuid: Optional[str] = None) -> None:
|
||||
"""
|
||||
Get the status of a crew.
|
||||
def get_crew_status(self, uuid: str | None = None) -> None:
|
||||
"""Get the status of a crew.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to check.
|
||||
|
||||
"""
|
||||
console.print("Fetching deployment status...", style="bold blue")
|
||||
if uuid:
|
||||
@@ -217,23 +204,23 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._validate_response(response)
|
||||
self._display_crew_status(response.json())
|
||||
|
||||
def _display_crew_status(self, status_data: Dict[str, str]) -> None:
|
||||
"""
|
||||
Display the status of a crew.
|
||||
def _display_crew_status(self, status_data: dict[str, str]) -> None:
|
||||
"""Display the status of a crew.
|
||||
|
||||
Args:
|
||||
status_data (Dict[str, str]): The status data to display.
|
||||
|
||||
"""
|
||||
console.print(f"Name:\t {status_data['name']}")
|
||||
console.print(f"Status:\t {status_data['status']}")
|
||||
|
||||
def get_crew_logs(self, uuid: Optional[str], log_type: str = "deployment") -> None:
|
||||
"""
|
||||
Get logs for a crew.
|
||||
def get_crew_logs(self, uuid: str | None, log_type: str = "deployment") -> None:
|
||||
"""Get logs for a crew.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to get logs for.
|
||||
log_type (str): The type of logs to retrieve (default: "deployment").
|
||||
|
||||
"""
|
||||
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
|
||||
console.print(f"Fetching {log_type} logs...", style="bold blue")
|
||||
@@ -249,12 +236,12 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._validate_response(response)
|
||||
self._display_logs(response.json())
|
||||
|
||||
def remove_crew(self, uuid: Optional[str]) -> None:
|
||||
"""
|
||||
Remove a crew deployment.
|
||||
def remove_crew(self, uuid: str | None) -> None:
|
||||
"""Remove a crew deployment.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to remove.
|
||||
|
||||
"""
|
||||
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
|
||||
console.print("Removing deployment...", style="bold blue")
|
||||
@@ -269,9 +256,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
if response.status_code == 204:
|
||||
console.print(
|
||||
f"Crew '{self.project_name}' removed successfully.", style="green"
|
||||
f"Crew '{self.project_name}' removed successfully.", style="green",
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"Failed to remove crew '{self.project_name}'", style="bold red"
|
||||
f"Failed to remove crew '{self.project_name}'", style="bold red",
|
||||
)
|
||||
|
||||
@@ -4,18 +4,19 @@ import click
|
||||
|
||||
|
||||
def evaluate_crew(n_iterations: int, model: str) -> None:
|
||||
"""
|
||||
Test and Evaluate the crew by running a command in the UV environment.
|
||||
"""Test and Evaluate the crew by running a command in the UV environment.
|
||||
|
||||
Args:
|
||||
n_iterations (int): The number of iterations to test the crew.
|
||||
model (str): The model to test the crew with.
|
||||
|
||||
"""
|
||||
command = ["uv", "run", "test", str(n_iterations), model]
|
||||
|
||||
try:
|
||||
if n_iterations <= 0:
|
||||
raise ValueError("The number of iterations must be a positive integer.")
|
||||
msg = "The number of iterations must be a positive integer."
|
||||
raise ValueError(msg)
|
||||
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import subprocess
|
||||
from functools import lru_cache
|
||||
from functools import cache
|
||||
|
||||
|
||||
class Repository:
|
||||
def __init__(self, path="."):
|
||||
def __init__(self, path=".") -> None:
|
||||
self.path = path
|
||||
|
||||
if not self.is_git_installed():
|
||||
raise ValueError("Git is not installed or not found in your PATH.")
|
||||
msg = "Git is not installed or not found in your PATH."
|
||||
raise ValueError(msg)
|
||||
|
||||
if not self.is_git_repo():
|
||||
raise ValueError(f"{self.path} is not a Git repository.")
|
||||
msg = f"{self.path} is not a Git repository."
|
||||
raise ValueError(msg)
|
||||
|
||||
self.fetch()
|
||||
|
||||
@@ -18,7 +20,7 @@ class Repository:
|
||||
"""Check if Git is installed and available in the system."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "--version"], capture_output=True, check=True, text=True
|
||||
["git", "--version"], capture_output=True, check=True, text=True,
|
||||
)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
@@ -36,7 +38,7 @@ class Repository:
|
||||
encoding="utf-8",
|
||||
).strip()
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@cache
|
||||
def is_git_repo(self) -> bool:
|
||||
"""Check if the current directory is a git repository."""
|
||||
try:
|
||||
@@ -62,10 +64,7 @@ class Repository:
|
||||
|
||||
def is_synced(self) -> bool:
|
||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return not (self.has_uncommitted_changes() or self.is_ahead_or_behind())
|
||||
|
||||
def origin_url(self) -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
|
||||
@@ -8,11 +8,9 @@ import click
|
||||
# so if you expect this to support more things you will need to replicate it there
|
||||
# ask @joaomdmoura if you are unsure
|
||||
def install_crew(proxy_options: list[str]) -> None:
|
||||
"""
|
||||
Install the crew by running the UV command to lock and install.
|
||||
"""
|
||||
"""Install the crew by running the UV command to lock and install."""
|
||||
try:
|
||||
command = ["uv", "sync"] + proxy_options
|
||||
command = ["uv", "sync", *proxy_options]
|
||||
subprocess.run(command, check=True, capture_output=False, text=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
|
||||
@@ -4,9 +4,7 @@ import click
|
||||
|
||||
|
||||
def kickoff_flow() -> None:
|
||||
"""
|
||||
Kickoff the flow by running a command in the UV environment.
|
||||
"""
|
||||
"""Kickoff the flow by running a command in the UV environment."""
|
||||
command = ["uv", "run", "kickoff"]
|
||||
|
||||
try:
|
||||
|
||||
@@ -4,9 +4,7 @@ import click
|
||||
|
||||
|
||||
def plot_flow() -> None:
|
||||
"""
|
||||
Plot the flow by running a command in the UV environment.
|
||||
"""
|
||||
"""Plot the flow by running a command in the UV environment."""
|
||||
command = ["uv", "run", "plot"]
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from os import getenv
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
@@ -8,9 +7,7 @@ from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class PlusAPI:
|
||||
"""
|
||||
This class exposes methods for working with the CrewAI+ API.
|
||||
"""
|
||||
"""This class exposes methods for working with the CrewAI+ API."""
|
||||
|
||||
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
||||
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
||||
@@ -42,7 +39,7 @@ class PlusAPI:
|
||||
handle: str,
|
||||
is_public: bool,
|
||||
version: str,
|
||||
description: Optional[str],
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
):
|
||||
params = {
|
||||
@@ -56,7 +53,7 @@ class PlusAPI:
|
||||
|
||||
def deploy_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
|
||||
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy",
|
||||
)
|
||||
|
||||
def deploy_by_uuid(self, uuid: str) -> requests.Response:
|
||||
@@ -64,29 +61,29 @@ class PlusAPI:
|
||||
|
||||
def crew_status_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status",
|
||||
)
|
||||
|
||||
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
|
||||
|
||||
def crew_by_name(
|
||||
self, project_name: str, log_type: str = "deployment"
|
||||
self, project_name: str, log_type: str = "deployment",
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}",
|
||||
)
|
||||
|
||||
def crew_by_uuid(
|
||||
self, uuid: str, log_type: str = "deployment"
|
||||
self, uuid: str, log_type: str = "deployment",
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
|
||||
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}",
|
||||
)
|
||||
|
||||
def delete_crew_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
|
||||
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}",
|
||||
)
|
||||
|
||||
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
|
||||
|
||||
@@ -10,8 +10,7 @@ from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||
|
||||
|
||||
def select_choice(prompt_message, choices):
|
||||
"""
|
||||
Presents a list of choices to the user and prompts them to select one.
|
||||
"""Presents a list of choices to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
- prompt_message (str): The message to display to the user before presenting the choices.
|
||||
@@ -19,11 +18,11 @@ def select_choice(prompt_message, choices):
|
||||
|
||||
Returns:
|
||||
- str: The selected choice from the list, or None if the user chooses to quit.
|
||||
"""
|
||||
|
||||
"""
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
return None
|
||||
click.secho(prompt_message, fg="cyan")
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||
@@ -31,7 +30,7 @@ def select_choice(prompt_message, choices):
|
||||
|
||||
while True:
|
||||
choice = click.prompt(
|
||||
"Enter the number of your choice or 'q' to quit", type=str
|
||||
"Enter the number of your choice or 'q' to quit", type=str,
|
||||
)
|
||||
|
||||
if choice.lower() == "q":
|
||||
@@ -51,8 +50,7 @@ def select_choice(prompt_message, choices):
|
||||
|
||||
|
||||
def select_provider(provider_models):
|
||||
"""
|
||||
Presents a list of providers to the user and prompts them to select one.
|
||||
"""Presents a list of providers to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
- provider_models (dict): A dictionary of provider models.
|
||||
@@ -60,12 +58,13 @@ def select_provider(provider_models):
|
||||
Returns:
|
||||
- str: The selected provider
|
||||
- None: If user explicitly quits
|
||||
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||
|
||||
provider = select_choice(
|
||||
"Select a provider to set up:", predefined_providers + ["other"]
|
||||
"Select a provider to set up:", [*predefined_providers, "other"],
|
||||
)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
@@ -79,8 +78,7 @@ def select_provider(provider_models):
|
||||
|
||||
|
||||
def select_model(provider, provider_models):
|
||||
"""
|
||||
Presents a list of models for a given provider to the user and prompts them to select one.
|
||||
"""Presents a list of models for a given provider to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
- provider (str): The provider for which to select a model.
|
||||
@@ -88,6 +86,7 @@ def select_model(provider, provider_models):
|
||||
|
||||
Returns:
|
||||
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
|
||||
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
|
||||
@@ -100,15 +99,13 @@ def select_model(provider, provider_models):
|
||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||
return None
|
||||
|
||||
selected_model = select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||
return select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models,
|
||||
)
|
||||
return selected_model
|
||||
|
||||
|
||||
def load_provider_data(cache_file, cache_expiry):
|
||||
"""
|
||||
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
|
||||
"""Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
@@ -116,6 +113,7 @@ def load_provider_data(cache_file, cache_expiry):
|
||||
|
||||
Returns:
|
||||
- dict or None: The loaded provider data or None if the operation fails.
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
if (
|
||||
@@ -126,7 +124,7 @@ def load_provider_data(cache_file, cache_expiry):
|
||||
if data:
|
||||
return data
|
||||
click.secho(
|
||||
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
|
||||
"Cache is corrupted. Fetching provider data from the web...", fg="yellow",
|
||||
)
|
||||
else:
|
||||
click.secho(
|
||||
@@ -137,31 +135,31 @@ def load_provider_data(cache_file, cache_expiry):
|
||||
|
||||
|
||||
def read_cache_file(cache_file):
|
||||
"""
|
||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
||||
"""Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
Returns:
|
||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
|
||||
|
||||
"""
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
with open(cache_file) as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_provider_data(cache_file):
|
||||
"""
|
||||
Fetches provider data from a specified URL and caches it to a file.
|
||||
"""Fetches provider data from a specified URL and caches it to a file.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
Returns:
|
||||
- dict or None: The fetched provider data or None if the operation fails.
|
||||
|
||||
"""
|
||||
try:
|
||||
response = requests.get(JSON_URL, stream=True, timeout=60)
|
||||
@@ -178,20 +176,20 @@ def fetch_provider_data(cache_file):
|
||||
|
||||
|
||||
def download_data(response):
|
||||
"""
|
||||
Downloads data from a given HTTP response and returns the JSON content.
|
||||
"""Downloads data from a given HTTP response and returns the JSON content.
|
||||
|
||||
Args:
|
||||
- response (requests.Response): The HTTP response object.
|
||||
|
||||
Returns:
|
||||
- dict: The JSON content of the response.
|
||||
|
||||
"""
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
data_chunks = []
|
||||
with click.progressbar(
|
||||
length=total_size, label="Downloading", show_pos=True
|
||||
length=total_size, label="Downloading", show_pos=True,
|
||||
) as progress_bar:
|
||||
for chunk in response.iter_content(block_size):
|
||||
if chunk:
|
||||
@@ -202,11 +200,11 @@ def download_data(response):
|
||||
|
||||
|
||||
def get_provider_data():
|
||||
"""
|
||||
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
||||
"""Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
||||
|
||||
Returns:
|
||||
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
|
||||
|
||||
"""
|
||||
cache_dir = Path.home() / ".crewai"
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
|
||||
@@ -4,11 +4,11 @@ import click
|
||||
|
||||
|
||||
def replay_task_command(task_id: str) -> None:
|
||||
"""
|
||||
Replay the crew execution from a specific task.
|
||||
"""Replay the crew execution from a specific task.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to replay from.
|
||||
|
||||
"""
|
||||
command = ["uv", "run", "replay", task_id]
|
||||
|
||||
|
||||
@@ -13,8 +13,7 @@ def reset_memories_command(
|
||||
kickoff_outputs,
|
||||
all,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the crew memories.
|
||||
"""Reset the crew memories.
|
||||
|
||||
Args:
|
||||
long (bool): Whether to reset the long-term memory.
|
||||
@@ -23,49 +22,50 @@ def reset_memories_command(
|
||||
kickoff_outputs (bool): Whether to reset the latest kickoff task outputs.
|
||||
all (bool): Whether to reset all memories.
|
||||
knowledge (bool): Whether to reset the knowledge.
|
||||
"""
|
||||
|
||||
"""
|
||||
try:
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge, all]):
|
||||
click.echo(
|
||||
"No memory type specified. Please specify at least one type to reset."
|
||||
"No memory type specified. Please specify at least one type to reset.",
|
||||
)
|
||||
return
|
||||
|
||||
crews = get_crews()
|
||||
if not crews:
|
||||
raise ValueError("No crew found.")
|
||||
msg = "No crew found."
|
||||
raise ValueError(msg)
|
||||
for crew in crews:
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed.",
|
||||
)
|
||||
continue
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset.",
|
||||
)
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset.",
|
||||
)
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset.",
|
||||
)
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset.",
|
||||
)
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset.",
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
@@ -15,8 +14,7 @@ class CrewType(Enum):
|
||||
|
||||
|
||||
def run_crew() -> None:
|
||||
"""
|
||||
Run the crew or flow by running a command in the UV environment.
|
||||
"""Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
@@ -48,11 +46,11 @@ def run_crew() -> None:
|
||||
|
||||
|
||||
def execute_command(crew_type: CrewType) -> None:
|
||||
"""
|
||||
Execute the appropriate command based on crew type.
|
||||
"""Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run
|
||||
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
@@ -67,12 +65,12 @@ def execute_command(crew_type: CrewType) -> None:
|
||||
|
||||
|
||||
def handle_error(error: subprocess.CalledProcessError, crew_type: CrewType) -> None:
|
||||
"""
|
||||
Handle subprocess errors with appropriate messaging.
|
||||
"""Handle subprocess errors with appropriate messaging.
|
||||
|
||||
Args:
|
||||
error: The subprocess error that occurred
|
||||
crew_type: The type of crew that was being run
|
||||
|
||||
"""
|
||||
entity_type = "flow" if crew_type == CrewType.FLOW else "crew"
|
||||
click.echo(f"An error occurred while running the {entity_type}: {error}", err=True)
|
||||
|
||||
@@ -22,15 +22,13 @@ console = Console()
|
||||
|
||||
|
||||
class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle tool repository related operations for CrewAI projects.
|
||||
"""
|
||||
"""A class to handle tool repository related operations for CrewAI projects."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def create(self, handle: str):
|
||||
def create(self, handle: str) -> None:
|
||||
self._ensure_not_in_project()
|
||||
|
||||
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
|
||||
@@ -40,8 +38,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
if project_root.exists():
|
||||
click.secho(f"Folder {folder_name} already exists.", fg="red")
|
||||
raise SystemExit
|
||||
else:
|
||||
os.makedirs(project_root)
|
||||
os.makedirs(project_root)
|
||||
|
||||
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
|
||||
|
||||
@@ -56,12 +53,12 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
self.login()
|
||||
subprocess.run(["git", "init"], check=True)
|
||||
console.print(
|
||||
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]"
|
||||
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]",
|
||||
)
|
||||
finally:
|
||||
os.chdir(old_directory)
|
||||
|
||||
def publish(self, is_public: bool, force: bool = False):
|
||||
def publish(self, is_public: bool, force: bool = False) -> None:
|
||||
if not git.Repository().is_synced() and not force:
|
||||
console.print(
|
||||
"[bold red]Failed to publish tool.[/bold red]\n"
|
||||
@@ -69,9 +66,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"* [bold]Commit[/bold] your changes.\n"
|
||||
"* [bold]Push[/bold] to sync with the remote.\n"
|
||||
"* [bold]Pull[/bold] the latest changes from the remote.\n"
|
||||
"\nOnce your repository is up-to-date, retry publishing the tool."
|
||||
"\nOnce your repository is up-to-date, retry publishing the tool.",
|
||||
)
|
||||
raise SystemExit()
|
||||
raise SystemExit
|
||||
|
||||
project_name = get_project_name(require=True)
|
||||
assert isinstance(project_name, str)
|
||||
@@ -90,7 +87,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
|
||||
tarball_filename = next(
|
||||
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None
|
||||
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None,
|
||||
)
|
||||
if not tarball_filename:
|
||||
console.print(
|
||||
@@ -123,7 +120,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
def install(self, handle: str):
|
||||
def install(self, handle: str) -> None:
|
||||
get_response = self.plus_api_client.get_tool(handle)
|
||||
|
||||
if get_response.status_code == 404:
|
||||
@@ -132,9 +129,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
elif get_response.status_code != 200:
|
||||
if get_response.status_code != 200:
|
||||
console.print(
|
||||
"Failed to get tool details. Please try again later.", style="bold red"
|
||||
"Failed to get tool details. Please try again later.", style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -142,7 +139,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
console.print(f"Successfully installed {handle}", style="bold green")
|
||||
|
||||
def login(self):
|
||||
def login(self) -> None:
|
||||
login_response = self.plus_api_client.login_to_tool_repository()
|
||||
|
||||
if login_response.status_code != 200:
|
||||
@@ -164,10 +161,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
settings.dump()
|
||||
|
||||
console.print(
|
||||
"Successfully authenticated to the tool repository.", style="bold green"
|
||||
"Successfully authenticated to the tool repository.", style="bold green",
|
||||
)
|
||||
|
||||
def _add_package(self, tool_details):
|
||||
def _add_package(self, tool_details) -> None:
|
||||
tool_handle = tool_details["handle"]
|
||||
repository_handle = tool_details["repository"]["handle"]
|
||||
repository_url = tool_details["repository"]["url"]
|
||||
@@ -192,16 +189,16 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
click.echo(add_package_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
|
||||
def _ensure_not_in_project(self):
|
||||
def _ensure_not_in_project(self) -> None:
|
||||
if os.path.isfile("./pyproject.toml"):
|
||||
console.print(
|
||||
"[bold red]Oops! It looks like you're inside a project.[/bold red]"
|
||||
"[bold red]Oops! It looks like you're inside a project.[/bold red]",
|
||||
)
|
||||
console.print(
|
||||
"You can't create a new tool while inside an existing project."
|
||||
"You can't create a new tool while inside an existing project.",
|
||||
)
|
||||
console.print(
|
||||
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again."
|
||||
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again.",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -211,10 +208,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
env = os.environ.copy()
|
||||
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
|
||||
settings.tool_repository_username or ""
|
||||
settings.tool_repository_username or "",
|
||||
)
|
||||
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
|
||||
settings.tool_repository_password or ""
|
||||
settings.tool_repository_password or "",
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
@@ -4,20 +4,22 @@ import click
|
||||
|
||||
|
||||
def train_crew(n_iterations: int, filename: str) -> None:
|
||||
"""
|
||||
Train the crew by running a command in the UV environment.
|
||||
"""Train the crew by running a command in the UV environment.
|
||||
|
||||
Args:
|
||||
n_iterations (int): The number of iterations to train the crew.
|
||||
|
||||
"""
|
||||
command = ["uv", "run", "train", str(n_iterations), filename]
|
||||
|
||||
try:
|
||||
if n_iterations <= 0:
|
||||
raise ValueError("The number of iterations must be a positive integer.")
|
||||
msg = "The number of iterations must be a positive integer."
|
||||
raise ValueError(msg)
|
||||
|
||||
if not filename.endswith(".pkl"):
|
||||
raise ValueError("The filename must not end with .pkl")
|
||||
msg = "The filename must not end with .pkl"
|
||||
raise ValueError(msg)
|
||||
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
|
||||
@@ -11,9 +11,8 @@ def update_crew() -> None:
|
||||
migrate_pyproject("pyproject.toml", "pyproject.toml")
|
||||
|
||||
|
||||
def migrate_pyproject(input_file, output_file):
|
||||
"""
|
||||
Migrate the pyproject.toml to the new format.
|
||||
def migrate_pyproject(input_file, output_file) -> None:
|
||||
"""Migrate the pyproject.toml to the new format.
|
||||
|
||||
This function is used to migrate the pyproject.toml to the new format.
|
||||
And it will be used to migrate the pyproject.toml to the new format when uv is used.
|
||||
@@ -81,7 +80,7 @@ def migrate_pyproject(input_file, output_file):
|
||||
# Extract the module name from any existing script
|
||||
existing_scripts = new_pyproject["project"]["scripts"]
|
||||
module_name = next(
|
||||
(value.split(".")[0] for value in existing_scripts.values() if "." in value)
|
||||
value.split(".")[0] for value in existing_scripts.values() if "." in value
|
||||
)
|
||||
|
||||
new_pyproject["project"]["scripts"]["run_crew"] = f"{module_name}.main:run"
|
||||
@@ -93,22 +92,19 @@ def migrate_pyproject(input_file, output_file):
|
||||
# Backup the old pyproject.toml
|
||||
backup_file = "pyproject-old.toml"
|
||||
shutil.copy2(input_file, backup_file)
|
||||
print(f"Original pyproject.toml backed up as {backup_file}")
|
||||
|
||||
# Rename the poetry.lock file
|
||||
lock_file = "poetry.lock"
|
||||
lock_backup = "poetry-old.lock"
|
||||
if os.path.exists(lock_file):
|
||||
os.rename(lock_file, lock_backup)
|
||||
print(f"Original poetry.lock renamed to {lock_backup}")
|
||||
else:
|
||||
print("No poetry.lock file found to rename.")
|
||||
pass
|
||||
|
||||
# Write the new pyproject.toml
|
||||
with open(output_file, "wb") as f:
|
||||
tomli_w.dump(new_pyproject, f)
|
||||
|
||||
print(f"Migration complete. New pyproject.toml written to {output_file}")
|
||||
|
||||
|
||||
def parse_version(version: str) -> str:
|
||||
|
||||
@@ -3,7 +3,7 @@ import shutil
|
||||
import sys
|
||||
from functools import reduce
|
||||
from inspect import isfunction, ismethod
|
||||
from typing import Any, Dict, List, get_type_hints
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -19,9 +19,9 @@ if sys.version_info >= (3, 11):
|
||||
console = Console()
|
||||
|
||||
|
||||
def copy_template(src, dst, name, class_name, folder_name):
|
||||
def copy_template(src, dst, name, class_name, folder_name) -> None:
|
||||
"""Copy a file from src to dst."""
|
||||
with open(src, "r") as file:
|
||||
with open(src) as file:
|
||||
content = file.read()
|
||||
|
||||
# Interpolate the content
|
||||
@@ -39,8 +39,7 @@ def copy_template(src, dst, name, class_name, folder_name):
|
||||
def read_toml(file_path: str = "pyproject.toml"):
|
||||
"""Read the content of a TOML file and return it as a dictionary."""
|
||||
with open(file_path, "rb") as f:
|
||||
toml_dict = tomli.load(f)
|
||||
return toml_dict
|
||||
return tomli.load(f)
|
||||
|
||||
|
||||
def parse_toml(content):
|
||||
@@ -50,59 +49,56 @@ def parse_toml(content):
|
||||
|
||||
|
||||
def get_project_name(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False,
|
||||
) -> str | None:
|
||||
"""Get the project name from the pyproject.toml file."""
|
||||
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
|
||||
|
||||
|
||||
def get_project_version(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False,
|
||||
) -> str | None:
|
||||
"""Get the project version from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["project", "version"], require=require
|
||||
pyproject_path, ["project", "version"], require=require,
|
||||
)
|
||||
|
||||
|
||||
def get_project_description(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False,
|
||||
) -> str | None:
|
||||
"""Get the project description from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["project", "description"], require=require
|
||||
pyproject_path, ["project", "description"], require=require,
|
||||
)
|
||||
|
||||
|
||||
def _get_project_attribute(
|
||||
pyproject_path: str, keys: List[str], require: bool
|
||||
pyproject_path: str, keys: list[str], require: bool,
|
||||
) -> Any | None:
|
||||
"""Get an attribute from the pyproject.toml file."""
|
||||
attribute = None
|
||||
|
||||
try:
|
||||
with open(pyproject_path, "r") as f:
|
||||
with open(pyproject_path) as f:
|
||||
pyproject_content = parse_toml(f.read())
|
||||
|
||||
dependencies = (
|
||||
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
|
||||
)
|
||||
if not any(True for dep in dependencies if "crewai" in dep):
|
||||
raise Exception("crewai is not in the dependencies.")
|
||||
msg = "crewai is not in the dependencies."
|
||||
raise Exception(msg)
|
||||
|
||||
attribute = _get_nested_value(pyproject_content, keys)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {pyproject_path} not found.")
|
||||
pass
|
||||
except KeyError:
|
||||
print(f"Error: {pyproject_path} is not a valid pyproject.toml file.")
|
||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
||||
print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file."
|
||||
if sys.version_info >= (3, 11)
|
||||
else f"Error reading the pyproject.toml file: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error reading the pyproject.toml file: {e}")
|
||||
pass
|
||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception: # type: ignore
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if require and not attribute:
|
||||
console.print(
|
||||
@@ -114,7 +110,7 @@ def _get_project_attribute(
|
||||
return attribute
|
||||
|
||||
|
||||
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
||||
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
@@ -122,7 +118,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
# Read the .env file
|
||||
with open(env_file_path, "r") as f:
|
||||
with open(env_file_path) as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Parse the .env file content to a dictionary
|
||||
@@ -135,14 +131,14 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
return env_dict
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {env_file_path} not found.")
|
||||
except Exception as e:
|
||||
print(f"Error reading the .env file: {e}")
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def tree_copy(source, destination):
|
||||
def tree_copy(source, destination) -> None:
|
||||
"""Copies the entire directory structure from the source to the destination."""
|
||||
for item in os.listdir(source):
|
||||
source_item = os.path.join(source, item)
|
||||
@@ -153,7 +149,7 @@ def tree_copy(source, destination):
|
||||
shutil.copy2(source_item, destination_item)
|
||||
|
||||
|
||||
def tree_find_and_replace(directory, find, replace):
|
||||
def tree_find_and_replace(directory, find, replace) -> None:
|
||||
"""Recursively searches through a directory, replacing a target string in
|
||||
both file contents and filenames with a specified replacement string.
|
||||
"""
|
||||
@@ -161,7 +157,7 @@ def tree_find_and_replace(directory, find, replace):
|
||||
for filename in files:
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
with open(filepath, "r") as file:
|
||||
with open(filepath) as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
@@ -180,19 +176,19 @@ def tree_find_and_replace(directory, find, replace):
|
||||
|
||||
|
||||
def load_env_vars(folder_path):
|
||||
"""
|
||||
Loads environment variables from a .env file in the specified folder path.
|
||||
"""Loads environment variables from a .env file in the specified folder path.
|
||||
|
||||
Args:
|
||||
- folder_path (Path): The path to the folder containing the .env file.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary of environment variables.
|
||||
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
env_vars = {}
|
||||
if env_file_path.exists():
|
||||
with open(env_file_path, "r") as file:
|
||||
with open(env_file_path) as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
@@ -201,8 +197,7 @@ def load_env_vars(folder_path):
|
||||
|
||||
|
||||
def update_env_vars(env_vars, provider, model):
|
||||
"""
|
||||
Updates environment variables with the API key for the selected provider and model.
|
||||
"""Updates environment variables with the API key for the selected provider and model.
|
||||
|
||||
Args:
|
||||
- env_vars (dict): Environment variables dictionary.
|
||||
@@ -211,6 +206,7 @@ def update_env_vars(env_vars, provider, model):
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
"""
|
||||
api_key_var = ENV_VARS.get(
|
||||
provider,
|
||||
@@ -218,14 +214,14 @@ def update_env_vars(env_vars, provider, model):
|
||||
click.prompt(
|
||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||
type=str,
|
||||
)
|
||||
),
|
||||
],
|
||||
)[0]
|
||||
|
||||
if api_key_var not in env_vars:
|
||||
try:
|
||||
env_vars[api_key_var] = click.prompt(
|
||||
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
|
||||
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True,
|
||||
)
|
||||
except click.exceptions.Abort:
|
||||
click.secho("Operation aborted by the user.", fg="red")
|
||||
@@ -238,13 +234,13 @@ def update_env_vars(env_vars, provider, model):
|
||||
return env_vars
|
||||
|
||||
|
||||
def write_env_file(folder_path, env_vars):
|
||||
"""
|
||||
Writes environment variables to a .env file in the specified folder.
|
||||
def write_env_file(folder_path, env_vars) -> None:
|
||||
"""Writes environment variables to a .env file in the specified folder.
|
||||
|
||||
Args:
|
||||
- folder_path (Path): The path to the folder where the .env file will be written.
|
||||
- env_vars (dict): A dictionary of environment variables to write.
|
||||
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
with open(env_file_path, "w") as file:
|
||||
@@ -263,7 +259,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
crew_os_path = os.path.join(root, crew_path)
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"crew_module", crew_os_path
|
||||
"crew_module", crew_os_path,
|
||||
)
|
||||
if not spec or not spec.loader:
|
||||
continue
|
||||
@@ -277,19 +273,16 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
|
||||
try:
|
||||
crew_instances.extend(fetch_crews(module_attr))
|
||||
except Exception as e:
|
||||
print(f"Error processing attribute {attr_name}: {e}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
except Exception as exec_error:
|
||||
print(f"Error executing module: {exec_error}")
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print(f"Traceback: {traceback.format_exc()}")
|
||||
except (ImportError, AttributeError) as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Error importing crew from {crew_path}: {str(e)}",
|
||||
f"Error importing crew from {crew_path}: {e!s}",
|
||||
style="bold red",
|
||||
)
|
||||
continue
|
||||
@@ -303,7 +296,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
except Exception as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||
f"Unexpected error while loading crew: {e!s}", style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
return crew_instances
|
||||
@@ -317,13 +310,12 @@ def get_crew_instance(module_attr) -> Crew | None:
|
||||
):
|
||||
return module_attr().crew()
|
||||
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
|
||||
module_attr
|
||||
module_attr,
|
||||
).get("return") is Crew:
|
||||
return module_attr()
|
||||
elif isinstance(module_attr, Crew):
|
||||
if isinstance(module_attr, Crew):
|
||||
return module_attr
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def fetch_crews(module_attr) -> list[Crew]:
|
||||
|
||||
@@ -2,5 +2,5 @@ import importlib.metadata
|
||||
|
||||
|
||||
def get_crewai_version() -> str:
|
||||
"""Get the version number of CrewAI running the CLI"""
|
||||
"""Get the version number of CrewAI running the CLI."""
|
||||
return importlib.metadata.version("crewai")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -12,27 +12,28 @@ class CrewOutput(BaseModel):
|
||||
"""Class that represents the result of a crew."""
|
||||
|
||||
raw: str = Field(description="Raw output of crew", default="")
|
||||
pydantic: Optional[BaseModel] = Field(
|
||||
description="Pydantic output of Crew", default=None
|
||||
pydantic: BaseModel | None = Field(
|
||||
description="Pydantic output of Crew", default=None,
|
||||
)
|
||||
json_dict: Optional[Dict[str, Any]] = Field(
|
||||
description="JSON dict output of Crew", default=None
|
||||
json_dict: dict[str, Any] | None = Field(
|
||||
description="JSON dict output of Crew", default=None,
|
||||
)
|
||||
tasks_output: list[TaskOutput] = Field(
|
||||
description="Output of each task", default=[]
|
||||
description="Output of each task", default=[],
|
||||
)
|
||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def json(self) -> str | None:
|
||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
msg = "No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
raise ValueError(
|
||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
msg,
|
||||
)
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
output_dict = {}
|
||||
if self.json_dict:
|
||||
@@ -44,12 +45,12 @@ class CrewOutput(BaseModel):
|
||||
def __getitem__(self, key):
|
||||
if self.pydantic and hasattr(self.pydantic, key):
|
||||
return getattr(self.pydantic, key)
|
||||
elif self.json_dict and key in self.json_dict:
|
||||
if self.json_dict and key in self.json_dict:
|
||||
return self.json_dict[key]
|
||||
else:
|
||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||
msg = f"Key '{key}' not found in CrewOutput."
|
||||
raise KeyError(msg)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.pydantic:
|
||||
return str(self.pydantic)
|
||||
if self.json_dict:
|
||||
|
||||
@@ -2,17 +2,11 @@ import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import uuid4
|
||||
@@ -48,14 +42,14 @@ class FlowState(BaseModel):
|
||||
|
||||
# Type variables with explicit bounds
|
||||
T = TypeVar(
|
||||
"T", bound=Union[Dict[str, Any], BaseModel]
|
||||
"T", bound=dict[str, Any] | BaseModel,
|
||||
) # Generic flow state type parameter
|
||||
StateT = TypeVar(
|
||||
"StateT", bound=Union[Dict[str, Any], BaseModel]
|
||||
"StateT", bound=dict[str, Any] | BaseModel,
|
||||
) # State validation type parameter
|
||||
|
||||
|
||||
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||
def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
|
||||
"""Ensure state matches expected type with proper validation.
|
||||
|
||||
Args:
|
||||
@@ -68,6 +62,7 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||
Raises:
|
||||
TypeError: If state doesn't match expected type
|
||||
ValueError: If state validation fails
|
||||
|
||||
"""
|
||||
"""Ensure state matches expected type with proper validation.
|
||||
|
||||
@@ -84,20 +79,22 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||
"""
|
||||
if expected_type is dict:
|
||||
if not isinstance(state, dict):
|
||||
raise TypeError(f"Expected dict, got {type(state).__name__}")
|
||||
return cast(StateT, state)
|
||||
msg = f"Expected dict, got {type(state).__name__}"
|
||||
raise TypeError(msg)
|
||||
return cast("StateT", state)
|
||||
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
|
||||
if not isinstance(state, expected_type):
|
||||
msg = f"Expected {expected_type.__name__}, got {type(state).__name__}"
|
||||
raise TypeError(
|
||||
f"Expected {expected_type.__name__}, got {type(state).__name__}"
|
||||
msg,
|
||||
)
|
||||
return cast(StateT, state)
|
||||
raise TypeError(f"Invalid expected_type: {expected_type}")
|
||||
return cast("StateT", state)
|
||||
msg = f"Invalid expected_type: {expected_type}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
"""
|
||||
Marks a method as a flow's starting point.
|
||||
def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
"""Marks a method as a flow's starting point.
|
||||
|
||||
This decorator designates a method as an entry point for the flow execution.
|
||||
It can optionally specify conditions that trigger the start based on other
|
||||
@@ -135,6 +132,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
>>> @start(and_("method1", "method2")) # Start after multiple methods
|
||||
>>> def complex_start(self):
|
||||
... pass
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@@ -154,17 +152,17 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
else:
|
||||
msg = "Condition must be a method, string, or a result of or_() or and_()"
|
||||
raise ValueError(
|
||||
"Condition must be a method, string, or a result of or_() or and_()"
|
||||
msg,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""
|
||||
Creates a listener that executes when specified conditions are met.
|
||||
def listen(condition: str | dict | Callable) -> Callable:
|
||||
"""Creates a listener that executes when specified conditions are met.
|
||||
|
||||
This decorator sets up a method to execute in response to other method
|
||||
executions in the flow. It supports both simple and complex triggering
|
||||
@@ -197,6 +195,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
>>> @listen(or_("success", "failure")) # Listen to multiple methods
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@@ -214,17 +213,17 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
else:
|
||||
msg = "Condition must be a method, string, or a result of or_() or and_()"
|
||||
raise ValueError(
|
||||
"Condition must be a method, string, or a result of or_() or and_()"
|
||||
msg,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""
|
||||
Creates a routing method that directs flow execution based on conditions.
|
||||
def router(condition: str | dict | Callable) -> Callable:
|
||||
"""Creates a routing method that directs flow execution based on conditions.
|
||||
|
||||
This decorator marks a method as a router, which can dynamically determine
|
||||
the next steps in the flow based on its return value. Routers are triggered
|
||||
@@ -262,6 +261,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
... if all([self.state.valid, self.state.processed]):
|
||||
... return CONTINUE
|
||||
... return STOP
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@@ -280,17 +280,17 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
else:
|
||||
msg = "Condition must be a method, string, or a result of or_() or and_()"
|
||||
raise ValueError(
|
||||
"Condition must be a method, string, or a result of or_() or and_()"
|
||||
msg,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with OR logic for flow control.
|
||||
def or_(*conditions: str | dict | Callable) -> dict:
|
||||
"""Combines multiple conditions with OR logic for flow control.
|
||||
|
||||
Creates a condition that is satisfied when any of the specified conditions
|
||||
are met. This is used with @start, @listen, or @router decorators to create
|
||||
@@ -320,6 +320,7 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
>>> @listen(or_("success", "timeout"))
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
|
||||
"""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
@@ -330,13 +331,13 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
elif callable(condition):
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in or_()")
|
||||
msg = "Invalid condition in or_()"
|
||||
raise ValueError(msg)
|
||||
return {"type": "OR", "methods": methods}
|
||||
|
||||
|
||||
def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with AND logic for flow control.
|
||||
def and_(*conditions: str | dict | Callable) -> dict:
|
||||
"""Combines multiple conditions with AND logic for flow control.
|
||||
|
||||
Creates a condition that is satisfied only when all specified conditions
|
||||
are met. This is used with @start, @listen, or @router decorators to create
|
||||
@@ -366,6 +367,7 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
>>> @listen(and_("validated", "processed"))
|
||||
>>> def handle_complete_data(self):
|
||||
... pass
|
||||
|
||||
"""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
@@ -376,7 +378,8 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
elif callable(condition):
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in and_()")
|
||||
msg = "Invalid condition in and_()"
|
||||
raise ValueError(msg)
|
||||
return {"type": "AND", "methods": methods}
|
||||
|
||||
|
||||
@@ -416,10 +419,10 @@ class FlowMeta(type):
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
|
||||
setattr(cls, "_start_methods", start_methods)
|
||||
setattr(cls, "_listeners", listeners)
|
||||
setattr(cls, "_routers", routers)
|
||||
setattr(cls, "_router_paths", router_paths)
|
||||
cls._start_methods = start_methods
|
||||
cls._listeners = listeners
|
||||
cls._routers = routers
|
||||
cls._router_paths = router_paths
|
||||
|
||||
return cls
|
||||
|
||||
@@ -427,17 +430,18 @@ class FlowMeta(type):
|
||||
class Flow(Generic[T], metaclass=FlowMeta):
|
||||
"""Base class for all flows.
|
||||
|
||||
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
|
||||
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel.
|
||||
"""
|
||||
|
||||
_printer = Printer()
|
||||
|
||||
_start_methods: List[str] = []
|
||||
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
||||
_routers: Set[str] = set()
|
||||
_router_paths: Dict[str, List[str]] = {}
|
||||
initial_state: Union[Type[T], T, None] = None
|
||||
_start_methods: list[str] = []
|
||||
_listeners: dict[str, tuple[str, list[str]]] = {}
|
||||
_routers: set[str] = set()
|
||||
_router_paths: dict[str, list[str]] = {}
|
||||
initial_state: type[T] | T | None = None
|
||||
|
||||
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
|
||||
def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
_initial_state_T = item # type: ignore
|
||||
|
||||
@@ -446,7 +450,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persistence: Optional[FlowPersistence] = None,
|
||||
persistence: FlowPersistence | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new Flow instance.
|
||||
@@ -454,13 +458,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Args:
|
||||
persistence: Optional persistence backend for storing flow states
|
||||
**kwargs: Additional state values to initialize or override
|
||||
|
||||
"""
|
||||
# Initialize basic instance attributes
|
||||
self._methods: Dict[str, Callable] = {}
|
||||
self._method_execution_counts: Dict[str, int] = {}
|
||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
self._methods: dict[str, Callable] = {}
|
||||
self._method_execution_counts: dict[str, int] = {}
|
||||
self._pending_and_listeners: dict[str, set[str]] = {}
|
||||
self._method_outputs: list[Any] = [] # List to store all method outputs
|
||||
self._persistence: FlowPersistence | None = persistence
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
@@ -502,58 +507,61 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Raises:
|
||||
ValueError: If structured state model lacks 'id' field
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
|
||||
"""
|
||||
# Handle case where initial_state is None but we have a type parameter
|
||||
if self.initial_state is None and hasattr(self, "_initial_state_T"):
|
||||
state_type = getattr(self, "_initial_state_T")
|
||||
state_type = self._initial_state_T
|
||||
if isinstance(state_type, type):
|
||||
if issubclass(state_type, FlowState):
|
||||
# Create instance without id, then set it
|
||||
instance = state_type()
|
||||
if not hasattr(instance, "id"):
|
||||
setattr(instance, "id", str(uuid4()))
|
||||
return cast(T, instance)
|
||||
elif issubclass(state_type, BaseModel):
|
||||
instance.id = str(uuid4())
|
||||
return cast("T", instance)
|
||||
if issubclass(state_type, BaseModel):
|
||||
# Create a new type that includes the ID field
|
||||
class StateWithId(state_type, FlowState): # type: ignore
|
||||
pass
|
||||
|
||||
instance = StateWithId()
|
||||
if not hasattr(instance, "id"):
|
||||
setattr(instance, "id", str(uuid4()))
|
||||
return cast(T, instance)
|
||||
elif state_type is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
instance.id = str(uuid4())
|
||||
return cast("T", instance)
|
||||
if state_type is dict:
|
||||
return cast("T", {"id": str(uuid4())})
|
||||
|
||||
# Handle case where no initial state is provided
|
||||
if self.initial_state is None:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
return cast("T", {"id": str(uuid4())})
|
||||
|
||||
# Handle case where initial_state is a type (class)
|
||||
if isinstance(self.initial_state, type):
|
||||
if issubclass(self.initial_state, FlowState):
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
elif issubclass(self.initial_state, BaseModel):
|
||||
return cast("T", self.initial_state()) # Uses model defaults
|
||||
if issubclass(self.initial_state, BaseModel):
|
||||
# Validate that the model has an id field
|
||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||
if not model_fields or "id" not in model_fields:
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
elif self.initial_state is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
msg = "Flow state model must have an 'id' field"
|
||||
raise ValueError(msg)
|
||||
return cast("T", self.initial_state()) # Uses model defaults
|
||||
if self.initial_state is dict:
|
||||
return cast("T", {"id": str(uuid4())})
|
||||
|
||||
# Handle dictionary instance case
|
||||
if isinstance(self.initial_state, dict):
|
||||
new_state = dict(self.initial_state) # Copy to avoid mutations
|
||||
if "id" not in new_state:
|
||||
new_state["id"] = str(uuid4())
|
||||
return cast(T, new_state)
|
||||
return cast("T", new_state)
|
||||
|
||||
# Handle BaseModel instance case
|
||||
if isinstance(self.initial_state, BaseModel):
|
||||
model = cast(BaseModel, self.initial_state)
|
||||
model = cast("BaseModel", self.initial_state)
|
||||
if not hasattr(model, "id"):
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
msg = "Flow state model must have an 'id' field"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create new instance with same values to avoid mutations
|
||||
if hasattr(model, "model_dump"):
|
||||
@@ -570,9 +578,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
# Create new instance of the same class
|
||||
model_class = type(model)
|
||||
return cast(T, model_class(**state_dict))
|
||||
return cast("T", model_class(**state_dict))
|
||||
msg = f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
|
||||
raise TypeError(
|
||||
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
|
||||
msg,
|
||||
)
|
||||
|
||||
def _copy_state(self) -> T:
|
||||
@@ -583,7 +592,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def method_outputs(self) -> List[Any]:
|
||||
def method_outputs(self) -> list[Any]:
|
||||
"""Returns the list of all outputs from executed methods."""
|
||||
return self._method_outputs
|
||||
|
||||
@@ -607,6 +616,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
flow = MyFlow()
|
||||
print(f"Current flow ID: {flow.flow_id}") # Safely get flow ID
|
||||
```
|
||||
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "_state"):
|
||||
@@ -614,13 +624,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
return str(self._state.get("id", ""))
|
||||
elif isinstance(self._state, BaseModel):
|
||||
if isinstance(self._state, BaseModel):
|
||||
return str(getattr(self._state, "id", ""))
|
||||
return ""
|
||||
except (AttributeError, TypeError):
|
||||
return "" # Safely handle any unexpected attribute access issues
|
||||
|
||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||
def _initialize_state(self, inputs: dict[str, Any]) -> None:
|
||||
"""Initialize or update flow state with new inputs.
|
||||
|
||||
Args:
|
||||
@@ -629,6 +639,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Raises:
|
||||
ValueError: If validation fails for structured state
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
|
||||
"""
|
||||
if isinstance(self._state, dict):
|
||||
# For dict states, preserve existing fields unless overridden
|
||||
@@ -644,7 +655,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
elif isinstance(self._state, BaseModel):
|
||||
# For BaseModel states, preserve existing fields unless overridden
|
||||
try:
|
||||
model = cast(BaseModel, self._state)
|
||||
model = cast("BaseModel", self._state)
|
||||
# Get current state as dict
|
||||
if hasattr(model, "model_dump"):
|
||||
current_state = model.model_dump()
|
||||
@@ -662,19 +673,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
model_class = type(model)
|
||||
if hasattr(model_class, "model_validate"):
|
||||
# Pydantic v2
|
||||
self._state = cast(T, model_class.model_validate(new_state))
|
||||
self._state = cast("T", model_class.model_validate(new_state))
|
||||
elif hasattr(model_class, "parse_obj"):
|
||||
# Pydantic v1
|
||||
self._state = cast(T, model_class.parse_obj(new_state))
|
||||
self._state = cast("T", model_class.parse_obj(new_state))
|
||||
else:
|
||||
# Fallback for other BaseModel implementations
|
||||
self._state = cast(T, model_class(**new_state))
|
||||
self._state = cast("T", model_class(**new_state))
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
||||
msg = f"Invalid inputs for structured state: {e}"
|
||||
raise ValueError(msg) from e
|
||||
else:
|
||||
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
||||
msg = "State must be a BaseModel instance or a dictionary."
|
||||
raise TypeError(msg)
|
||||
|
||||
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
|
||||
def _restore_state(self, stored_state: dict[str, Any]) -> None:
|
||||
"""Restore flow state from persistence.
|
||||
|
||||
Args:
|
||||
@@ -683,11 +696,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Raises:
|
||||
ValueError: If validation fails for structured state
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
|
||||
"""
|
||||
# When restoring from persistence, use the stored ID
|
||||
stored_id = stored_state.get("id")
|
||||
if not stored_id:
|
||||
raise ValueError("Stored state must have an 'id' field")
|
||||
msg = "Stored state must have an 'id' field"
|
||||
raise ValueError(msg)
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
# For dict states, update all fields from stored state
|
||||
@@ -695,22 +710,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._state.update(stored_state)
|
||||
elif isinstance(self._state, BaseModel):
|
||||
# For BaseModel states, create new instance with stored values
|
||||
model = cast(BaseModel, self._state)
|
||||
model = cast("BaseModel", self._state)
|
||||
if hasattr(model, "model_validate"):
|
||||
# Pydantic v2
|
||||
self._state = cast(T, type(model).model_validate(stored_state))
|
||||
self._state = cast("T", type(model).model_validate(stored_state))
|
||||
elif hasattr(model, "parse_obj"):
|
||||
# Pydantic v1
|
||||
self._state = cast(T, type(model).parse_obj(stored_state))
|
||||
self._state = cast("T", type(model).parse_obj(stored_state))
|
||||
else:
|
||||
# Fallback for other BaseModel implementations
|
||||
self._state = cast(T, type(model)(**stored_state))
|
||||
self._state = cast("T", type(model)(**stored_state))
|
||||
else:
|
||||
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
|
||||
msg = f"State must be dict or BaseModel, got {type(self._state)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Start the flow execution in a synchronous context.
|
||||
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
"""Start the flow execution in a synchronous context.
|
||||
|
||||
This method wraps kickoff_async so that all state initialization and event
|
||||
emission is handled in the asynchronous method.
|
||||
@@ -721,9 +736,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
return asyncio.run(run_flow())
|
||||
|
||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Start the flow execution asynchronously.
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
"""Start the flow execution asynchronously.
|
||||
|
||||
This method performs state restoration (if an 'id' is provided and persistence is available)
|
||||
and updates the flow state with any additional inputs. It then emits the FlowStartedEvent,
|
||||
@@ -735,6 +749,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
|
||||
"""
|
||||
if inputs:
|
||||
# Override the id in the state if it exists in inputs
|
||||
@@ -742,7 +757,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(self._state, dict):
|
||||
self._state["id"] = inputs["id"]
|
||||
elif isinstance(self._state, BaseModel):
|
||||
setattr(self._state, "id", inputs["id"])
|
||||
self._state.id = inputs["id"]
|
||||
|
||||
# If persistence is enabled, attempt to restore the stored state using the provided id.
|
||||
if "id" in inputs and self._persistence is not None:
|
||||
@@ -756,7 +771,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._restore_state(stored_state)
|
||||
else:
|
||||
self._log_flow_event(
|
||||
f"No flow state found for UUID: {restore_uuid}", color="red"
|
||||
f"No flow state found for UUID: {restore_uuid}", color="red",
|
||||
)
|
||||
|
||||
# Update state with any additional inputs (ignoring the 'id' key)
|
||||
@@ -774,7 +789,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
),
|
||||
)
|
||||
self._log_flow_event(
|
||||
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
|
||||
f"Flow started with ID: {self.flow_id}", color="bold_magenta",
|
||||
)
|
||||
|
||||
if inputs is not None and "id" not in inputs:
|
||||
@@ -800,8 +815,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return final_output
|
||||
|
||||
async def _execute_start_method(self, start_method_name: str) -> None:
|
||||
"""
|
||||
Executes a flow's start method and its triggered listeners.
|
||||
"""Executes a flow's start method and its triggered listeners.
|
||||
|
||||
This internal method handles the execution of methods marked with @start
|
||||
decorator and manages the subsequent chain of listener executions.
|
||||
@@ -816,14 +830,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
- Executes the start method and captures its result
|
||||
- Triggers execution of any listeners waiting on this start method
|
||||
- Part of the flow's initialization sequence
|
||||
|
||||
"""
|
||||
result = await self._execute_method(
|
||||
start_method_name, self._methods[start_method_name]
|
||||
start_method_name, self._methods[start_method_name],
|
||||
)
|
||||
await self._execute_listeners(start_method_name, result)
|
||||
|
||||
async def _execute_method(
|
||||
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
|
||||
self, method_name: str, method: Callable, *args: Any, **kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
|
||||
@@ -873,11 +888,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
error=e,
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
|
||||
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
||||
"""
|
||||
Executes all listeners and routers triggered by a method completion.
|
||||
"""Executes all listeners and routers triggered by a method completion.
|
||||
|
||||
This internal method manages the execution flow by:
|
||||
1. First executing all triggered routers sequentially
|
||||
@@ -897,6 +911,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
- Each router's result becomes a new trigger_method
|
||||
- Normal listeners are executed in parallel for efficiency
|
||||
- Listeners can receive the trigger method's result as a parameter
|
||||
|
||||
"""
|
||||
# First, handle routers repeatedly until no router triggers anymore
|
||||
router_results = []
|
||||
@@ -904,7 +919,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
while True:
|
||||
routers_triggered = self._find_triggered_methods(
|
||||
current_trigger, router_only=True
|
||||
current_trigger, router_only=True,
|
||||
)
|
||||
if not routers_triggered:
|
||||
break
|
||||
@@ -920,12 +935,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
# Now execute normal listeners for all router results and the original trigger
|
||||
all_triggers = [trigger_method] + router_results
|
||||
all_triggers = [trigger_method, *router_results]
|
||||
|
||||
for current_trigger in all_triggers:
|
||||
if current_trigger: # Skip None results
|
||||
listeners_triggered = self._find_triggered_methods(
|
||||
current_trigger, router_only=False
|
||||
current_trigger, router_only=False,
|
||||
)
|
||||
if listeners_triggered:
|
||||
tasks = [
|
||||
@@ -935,10 +950,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> List[str]:
|
||||
"""
|
||||
Finds all methods that should be triggered based on conditions.
|
||||
self, trigger_method: str, router_only: bool,
|
||||
) -> list[str]:
|
||||
"""Finds all methods that should be triggered based on conditions.
|
||||
|
||||
This internal method evaluates both OR and AND conditions to determine
|
||||
which methods should be executed next in the flow.
|
||||
@@ -963,6 +977,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
* AND: Triggers only when all conditions are met
|
||||
- Maintains state for AND conditions using _pending_and_listeners
|
||||
- Separates router and normal listener evaluation
|
||||
|
||||
"""
|
||||
triggered = []
|
||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||
@@ -992,8 +1007,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return triggered
|
||||
|
||||
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
||||
"""
|
||||
Executes a single listener method with proper event handling.
|
||||
"""Executes a single listener method with proper event handling.
|
||||
|
||||
This internal method manages the execution of an individual listener,
|
||||
including parameter inspection, event emission, and error handling.
|
||||
@@ -1018,6 +1032,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
-------------
|
||||
Catches and logs any exceptions during execution, preventing
|
||||
individual listener failures from breaking the entire flow.
|
||||
|
||||
"""
|
||||
try:
|
||||
method = self._methods[listener_name]
|
||||
@@ -1028,7 +1043,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
if method_params:
|
||||
listener_result = await self._execute_method(
|
||||
listener_name, method, result
|
||||
listener_name, method, result,
|
||||
)
|
||||
else:
|
||||
listener_result = await self._execute_method(listener_name, method)
|
||||
@@ -1036,17 +1051,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Execute listeners (and possibly routers) of this listener
|
||||
await self._execute_listeners(listener_name, listener_result)
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _log_flow_event(
|
||||
self, message: str, color: str = "yellow", level: str = "info"
|
||||
self, message: str, color: str = "yellow", level: str = "info",
|
||||
) -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
@@ -1064,6 +1076,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Note:
|
||||
This method uses the Printer utility for colored console output
|
||||
and the standard logging module for log level support.
|
||||
|
||||
"""
|
||||
self._printer.print(message, color=color)
|
||||
if level == "info":
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||
|
||||
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
|
||||
inspecting the call stack.
|
||||
"""
|
||||
|
||||
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
||||
parent_flow: InstanceOf[Flow] | None = Field(
|
||||
default=None,
|
||||
description="The parent flow of the instance, if it was created inside a flow.",
|
||||
)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
# flow_visualizer.py
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pyvis.network import Network
|
||||
|
||||
from crewai.flow.config import COLORS, NODE_STYLES
|
||||
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
||||
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import safe_path_join
|
||||
from crewai.flow.utils import calculate_node_levels
|
||||
from crewai.flow.visualization_utils import (
|
||||
add_edges,
|
||||
@@ -20,9 +19,8 @@ from crewai.flow.visualization_utils import (
|
||||
class FlowPlot:
|
||||
"""Handles the creation and rendering of flow visualization diagrams."""
|
||||
|
||||
def __init__(self, flow):
|
||||
"""
|
||||
Initialize FlowPlot with a flow object.
|
||||
def __init__(self, flow) -> None:
|
||||
"""Initialize FlowPlot with a flow object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -33,21 +31,24 @@ class FlowPlot:
|
||||
------
|
||||
ValueError
|
||||
If flow object is invalid or missing required attributes.
|
||||
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
||||
if not hasattr(flow, '_start_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
|
||||
|
||||
if not hasattr(flow, "_methods"):
|
||||
msg = "Invalid flow object: missing '_methods' attribute"
|
||||
raise ValueError(msg)
|
||||
if not hasattr(flow, "_listeners"):
|
||||
msg = "Invalid flow object: missing '_listeners' attribute"
|
||||
raise ValueError(msg)
|
||||
if not hasattr(flow, "_start_methods"):
|
||||
msg = "Invalid flow object: missing '_start_methods' attribute"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.flow = flow
|
||||
self.colors = COLORS
|
||||
self.node_styles = NODE_STYLES
|
||||
|
||||
def plot(self, filename):
|
||||
"""
|
||||
Generate and save an HTML visualization of the flow.
|
||||
def plot(self, filename) -> None:
|
||||
"""Generate and save an HTML visualization of the flow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -62,10 +63,12 @@ class FlowPlot:
|
||||
If file operations fail or visualization cannot be generated.
|
||||
RuntimeError
|
||||
If network visualization generation fails.
|
||||
|
||||
"""
|
||||
if not filename or not isinstance(filename, str):
|
||||
raise ValueError("Filename must be a non-empty string")
|
||||
|
||||
msg = "Filename must be a non-empty string"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Initialize network
|
||||
net = Network(
|
||||
@@ -89,58 +92,63 @@ class FlowPlot:
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
# Calculate levels for nodes
|
||||
try:
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to calculate node levels: {str(e)}")
|
||||
msg = f"Failed to calculate node levels: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Compute positions
|
||||
try:
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to compute node positions: {str(e)}")
|
||||
msg = f"Failed to compute node positions: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Add nodes to the network
|
||||
try:
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
||||
msg = f"Failed to add nodes to network: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Add edges to the network
|
||||
try:
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
|
||||
msg = f"Failed to add edges to network: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Generate HTML
|
||||
try:
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
|
||||
msg = f"Failed to generate network visualization: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Save the final HTML content to the file
|
||||
try:
|
||||
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {filename}.html")
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
|
||||
except OSError as e:
|
||||
msg = f"Failed to save flow visualization to {filename}.html: {e!s}"
|
||||
raise OSError(msg)
|
||||
|
||||
except (ValueError, RuntimeError, IOError) as e:
|
||||
raise e
|
||||
except (OSError, ValueError, RuntimeError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
|
||||
msg = f"Unexpected error during flow visualization: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
finally:
|
||||
self._cleanup_pyvis_lib()
|
||||
|
||||
def _generate_final_html(self, network_html):
|
||||
"""
|
||||
Generate the final HTML content with network visualization and legend.
|
||||
"""Generate the final HTML content with network visualization and legend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -158,9 +166,11 @@ class FlowPlot:
|
||||
If template or logo files cannot be accessed.
|
||||
ValueError
|
||||
If network_html is invalid.
|
||||
|
||||
"""
|
||||
if not network_html:
|
||||
raise ValueError("Invalid network HTML content")
|
||||
msg = "Invalid network HTML content"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Extract just the body content from the generated HTML
|
||||
@@ -169,9 +179,11 @@ class FlowPlot:
|
||||
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
|
||||
|
||||
if not os.path.exists(template_path):
|
||||
raise IOError(f"Template file not found: {template_path}")
|
||||
msg = f"Template file not found: {template_path}"
|
||||
raise OSError(msg)
|
||||
if not os.path.exists(logo_path):
|
||||
raise IOError(f"Logo file not found: {logo_path}")
|
||||
msg = f"Logo file not found: {logo_path}"
|
||||
raise OSError(msg)
|
||||
|
||||
html_handler = HTMLTemplateHandler(template_path, logo_path)
|
||||
network_body = html_handler.extract_body_content(network_html)
|
||||
@@ -179,16 +191,15 @@ class FlowPlot:
|
||||
# Generate the legend items HTML
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
network_body, legend_items_html
|
||||
return html_handler.generate_final_html(
|
||||
network_body, legend_items_html,
|
||||
)
|
||||
return final_html_content
|
||||
except Exception as e:
|
||||
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
|
||||
msg = f"Failed to generate visualization HTML: {e!s}"
|
||||
raise OSError(msg)
|
||||
|
||||
def _cleanup_pyvis_lib(self):
|
||||
"""
|
||||
Clean up the generated lib folder from pyvis.
|
||||
def _cleanup_pyvis_lib(self) -> None:
|
||||
"""Clean up the generated lib folder from pyvis.
|
||||
|
||||
This method safely removes the temporary lib directory created by pyvis
|
||||
during network visualization generation.
|
||||
@@ -198,15 +209,14 @@ class FlowPlot:
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
shutil.rmtree(lib_folder)
|
||||
except ValueError as e:
|
||||
print(f"Error validating lib folder path: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up lib folder: {e}")
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def plot_flow(flow, filename="flow_plot"):
|
||||
"""
|
||||
Convenience function to create and save a flow visualization.
|
||||
def plot_flow(flow, filename="flow_plot") -> None:
|
||||
"""Convenience function to create and save a flow visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -221,6 +231,7 @@ def plot_flow(flow, filename="flow_plot"):
|
||||
If flow object or filename is invalid.
|
||||
IOError
|
||||
If file operations fail.
|
||||
|
||||
"""
|
||||
visualizer = FlowPlot(flow)
|
||||
visualizer.plot(filename)
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import base64
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import validate_path_exists
|
||||
|
||||
|
||||
class HTMLTemplateHandler:
|
||||
"""Handles HTML template processing and generation for flow visualization diagrams."""
|
||||
|
||||
def __init__(self, template_path, logo_path):
|
||||
"""
|
||||
Initialize HTMLTemplateHandler with validated template and logo paths.
|
||||
def __init__(self, template_path, logo_path) -> None:
|
||||
"""Initialize HTMLTemplateHandler with validated template and logo paths.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -23,16 +21,18 @@ class HTMLTemplateHandler:
|
||||
------
|
||||
ValueError
|
||||
If template or logo paths are invalid or files don't exist.
|
||||
|
||||
"""
|
||||
try:
|
||||
self.template_path = validate_path_exists(template_path, "file")
|
||||
self.logo_path = validate_path_exists(logo_path, "file")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid template or logo path: {e}")
|
||||
msg = f"Invalid template or logo path: {e}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def read_template(self):
|
||||
"""Read and return the HTML template file contents."""
|
||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||
with open(self.template_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def encode_logo(self):
|
||||
@@ -81,13 +81,12 @@ class HTMLTemplateHandler:
|
||||
|
||||
final_html_content = html_template.replace("{{ title }}", title)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ network_content }}", network_body
|
||||
"{{ network_content }}", network_body,
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64,
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
return final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html,
|
||||
)
|
||||
|
||||
return final_html_content
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
"""
|
||||
Path utilities for secure file operations in CrewAI flow module.
|
||||
"""Path utilities for secure file operations in CrewAI flow module.
|
||||
|
||||
This module provides utilities for secure path handling to prevent directory
|
||||
traversal attacks and ensure paths remain within allowed boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
"""
|
||||
Safely join path components and ensure the result is within allowed boundaries.
|
||||
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
|
||||
"""Safely join path components and ensure the result is within allowed boundaries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -31,39 +27,43 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
ValueError
|
||||
If the resulting path would be outside the root directory
|
||||
or if any path component is invalid.
|
||||
|
||||
"""
|
||||
if not parts:
|
||||
raise ValueError("No path components provided")
|
||||
msg = "No path components provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Convert all parts to strings and clean them
|
||||
clean_parts = [str(part).strip() for part in parts if part]
|
||||
if not clean_parts:
|
||||
raise ValueError("No valid path components provided")
|
||||
msg = "No valid path components provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Establish root directory
|
||||
root_path = Path(root).resolve() if root else Path.cwd()
|
||||
|
||||
|
||||
# Join and resolve the full path
|
||||
full_path = Path(root_path, *clean_parts).resolve()
|
||||
|
||||
|
||||
# Check if the resolved path is within root
|
||||
if not str(full_path).startswith(str(root_path)):
|
||||
msg = f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
||||
raise ValueError(
|
||||
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
||||
msg,
|
||||
)
|
||||
|
||||
|
||||
return str(full_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path components: {str(e)}")
|
||||
msg = f"Invalid path components: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
|
||||
"""
|
||||
Validate that a path exists and is of the expected type.
|
||||
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
|
||||
"""Validate that a path exists and is of the expected type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -81,29 +81,33 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
|
||||
------
|
||||
ValueError
|
||||
If path doesn't exist or is not of expected type.
|
||||
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path).resolve()
|
||||
|
||||
|
||||
if not path_obj.exists():
|
||||
raise ValueError(f"Path does not exist: {path}")
|
||||
|
||||
msg = f"Path does not exist: {path}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if file_type == "file" and not path_obj.is_file():
|
||||
raise ValueError(f"Path is not a file: {path}")
|
||||
elif file_type == "directory" and not path_obj.is_dir():
|
||||
raise ValueError(f"Path is not a directory: {path}")
|
||||
|
||||
msg = f"Path is not a file: {path}"
|
||||
raise ValueError(msg)
|
||||
if file_type == "directory" and not path_obj.is_dir():
|
||||
msg = f"Path is not a directory: {path}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return str(path_obj)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path: {str(e)}")
|
||||
msg = f"Invalid path: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
"""
|
||||
Safely list files in a directory matching a pattern.
|
||||
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
|
||||
"""Safely list files in a directory matching a pattern.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -121,15 +125,18 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
------
|
||||
ValueError
|
||||
If directory is invalid or inaccessible.
|
||||
|
||||
"""
|
||||
try:
|
||||
dir_path = Path(directory).resolve()
|
||||
if not dir_path.is_dir():
|
||||
raise ValueError(f"Not a directory: {directory}")
|
||||
|
||||
msg = f"Not a directory: {directory}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error listing files: {str(e)}")
|
||||
msg = f"Error listing files: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -1,53 +1,52 @@
|
||||
"""Base class for flow state persistence."""
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FlowPersistence(abc.ABC):
|
||||
"""Abstract base class for flow state persistence.
|
||||
|
||||
|
||||
This class defines the interface that all persistence implementations must follow.
|
||||
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
|
||||
"""
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_db(self) -> None:
|
||||
"""Initialize the persistence backend.
|
||||
|
||||
|
||||
This method should handle any necessary setup, such as:
|
||||
- Creating tables
|
||||
- Establishing connections
|
||||
- Setting up indexes
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel]
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
"""Persist the flow state after method completion.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Decorators for flow state persistence.
|
||||
"""Decorators for flow state persistence.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -19,18 +18,16 @@ Example:
|
||||
# Asynchronous method implementation
|
||||
await some_async_operation()
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
|
||||
"save_state": "Saving flow state to memory for ID: {}",
|
||||
"save_error": "Failed to persist state for method {}: {}",
|
||||
"state_missing": "Flow instance has no state",
|
||||
"id_missing": "Flow state must have an 'id' field for persistence"
|
||||
"id_missing": "Flow state must have an 'id' field for persistence",
|
||||
}
|
||||
|
||||
|
||||
@@ -74,20 +71,23 @@ class PersistenceDecorator:
|
||||
ValueError: If flow has no state or state lacks an ID
|
||||
RuntimeError: If state persistence fails
|
||||
AttributeError: If flow instance lacks required state attributes
|
||||
|
||||
"""
|
||||
try:
|
||||
state = getattr(flow_instance, 'state', None)
|
||||
state = getattr(flow_instance, "state", None)
|
||||
if state is None:
|
||||
raise ValueError("Flow instance has no state")
|
||||
msg = "Flow instance has no state"
|
||||
raise ValueError(msg)
|
||||
|
||||
flow_uuid: Optional[str] = None
|
||||
flow_uuid: str | None = None
|
||||
if isinstance(state, dict):
|
||||
flow_uuid = state.get('id')
|
||||
flow_uuid = state.get("id")
|
||||
elif isinstance(state, BaseModel):
|
||||
flow_uuid = getattr(state, 'id', None)
|
||||
flow_uuid = getattr(state, "id", None)
|
||||
|
||||
if not flow_uuid:
|
||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||
msg = "Flow state must have an 'id' field for persistence"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Log state saving only if verbose is True
|
||||
if verbose:
|
||||
@@ -103,21 +103,22 @@ class PersistenceDecorator:
|
||||
except Exception as e:
|
||||
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(f"State persistence failed: {str(e)}") from e
|
||||
logger.exception(error_msg)
|
||||
msg = f"State persistence failed: {e!s}"
|
||||
raise RuntimeError(msg) from e
|
||||
except AttributeError:
|
||||
error_msg = LOG_MESSAGES["state_missing"]
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
except (TypeError, ValueError) as e:
|
||||
error_msg = LOG_MESSAGES["id_missing"]
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
|
||||
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||
"""Decorator to persist flow state.
|
||||
|
||||
This decorator can be applied at either the class level or method level.
|
||||
@@ -143,22 +144,23 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
@start()
|
||||
def begin(self):
|
||||
pass
|
||||
|
||||
"""
|
||||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
||||
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||
"""Decorator that handles both class and method decoration."""
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
|
||||
if isinstance(target, type):
|
||||
# Class decoration
|
||||
original_init = getattr(target, "__init__")
|
||||
original_init = target.__init__
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
if 'persistence' not in kwargs:
|
||||
kwargs['persistence'] = actual_persistence
|
||||
if "persistence" not in kwargs:
|
||||
kwargs["persistence"] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
setattr(target, "__init__", new_init)
|
||||
target.__init__ = new_init
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
original_methods = {}
|
||||
@@ -191,7 +193,7 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
@@ -211,44 +213,42 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
return target
|
||||
else:
|
||||
# Method decoration
|
||||
method = target
|
||||
setattr(method, "__is_flow_method__", True)
|
||||
# Method decoration
|
||||
method = target
|
||||
method.__is_flow_method__ = True
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_async_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
else:
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
method_async_wrapper.__is_flow_method__ = True
|
||||
return cast("Callable[..., T]", method_async_wrapper)
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_sync_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
method_sync_wrapper.__is_flow_method__ = True
|
||||
return cast("Callable[..., T]", method_sync_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""
|
||||
SQLite-based implementation of flow state persistence.
|
||||
"""
|
||||
"""SQLite-based implementation of flow state persistence."""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -23,7 +21,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
db_path: str
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""Initialize SQLite persistence.
|
||||
|
||||
Args:
|
||||
@@ -32,6 +30,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
Raises:
|
||||
ValueError: If db_path is invalid
|
||||
|
||||
"""
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
@@ -39,7 +38,8 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
|
||||
|
||||
if not path:
|
||||
raise ValueError("Database path must be provided")
|
||||
msg = "Database path must be provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.db_path = path # Now mypy knows this is str
|
||||
self.init_db()
|
||||
@@ -56,21 +56,21 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
timestamp DATETIME NOT NULL,
|
||||
state_json TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
# Add index for faster UUID lookups
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
|
||||
ON flow_states(flow_uuid)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel],
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
"""Save the current flow state to SQLite.
|
||||
|
||||
@@ -78,6 +78,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
|
||||
"""
|
||||
# Convert state_data to dict, handling both Pydantic and dict cases
|
||||
if isinstance(state_data, BaseModel):
|
||||
@@ -85,8 +86,9 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
elif isinstance(state_data, dict):
|
||||
state_dict = state_data
|
||||
else:
|
||||
msg = f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
msg,
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -107,7 +109,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
),
|
||||
)
|
||||
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
Args:
|
||||
@@ -115,6 +117,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.execute(
|
||||
|
||||
@@ -1,33 +1,32 @@
|
||||
"""
|
||||
Utility functions for flow visualization and dependency analysis.
|
||||
"""Utility functions for flow visualization and dependency analysis.
|
||||
|
||||
This module provides core functionality for analyzing and manipulating flow structures,
|
||||
including node level calculation, ancestor tracking, and return value analysis.
|
||||
Functions in this module are primarily used by the visualization system to create
|
||||
accurate and informative flow diagrams.
|
||||
|
||||
Example
|
||||
Example:
|
||||
-------
|
||||
>>> flow = Flow()
|
||||
>>> node_levels = calculate_node_levels(flow)
|
||||
>>> ancestors = build_ancestor_dict(flow)
|
||||
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Deque, Dict, List, Optional, Set, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
# Can't get source code
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error retrieving source code for function {function.__name__}: {e}")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -35,24 +34,18 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
source = textwrap.dedent(source)
|
||||
# Parse the source code into an AST
|
||||
code_ast = ast.parse(source)
|
||||
except IndentationError as e:
|
||||
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
except IndentationError:
|
||||
return None
|
||||
except SyntaxError as e:
|
||||
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
except SyntaxError:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return_values = set()
|
||||
dict_definitions = {}
|
||||
|
||||
class DictionaryAssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_Assign(self, node):
|
||||
def visit_Assign(self, node) -> None:
|
||||
# Check if this assignment is assigning a dictionary literal to a variable
|
||||
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
|
||||
target = node.targets[0]
|
||||
@@ -69,10 +62,10 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
self.generic_visit(node)
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
def visit_Return(self, node) -> None:
|
||||
# Direct string return
|
||||
if isinstance(node.value, ast.Constant) and isinstance(
|
||||
node.value.value, str
|
||||
node.value.value, str,
|
||||
):
|
||||
return_values.add(node.value.value)
|
||||
# Dictionary-based return, like return paths[result]
|
||||
@@ -94,9 +87,8 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
return list(return_values) if return_values else None
|
||||
|
||||
|
||||
def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
"""
|
||||
Calculate the hierarchical level of each node in the flow.
|
||||
def calculate_node_levels(flow: Any) -> dict[str, int]:
|
||||
"""Calculate the hierarchical level of each node in the flow.
|
||||
|
||||
Performs a breadth-first traversal of the flow graph to assign levels
|
||||
to nodes, starting with start methods at level 0.
|
||||
@@ -117,11 +109,12 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
- Each subsequent connected node is assigned level = parent_level + 1
|
||||
- Handles both OR and AND conditions for listeners
|
||||
- Processes router paths separately
|
||||
|
||||
"""
|
||||
levels: Dict[str, int] = {}
|
||||
queue: Deque[str] = deque()
|
||||
visited: Set[str] = set()
|
||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
levels: dict[str, int] = {}
|
||||
queue: deque[str] = deque()
|
||||
visited: set[str] = set()
|
||||
pending_and_listeners: dict[str, set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
@@ -172,9 +165,8 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
"""
|
||||
Count the number of outgoing edges for each method in the flow.
|
||||
def count_outgoing_edges(flow: Any) -> dict[str, int]:
|
||||
"""Count the number of outgoing edges for each method in the flow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -185,6 +177,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
-------
|
||||
Dict[str, int]
|
||||
Dictionary mapping method names to their outgoing edge count.
|
||||
|
||||
"""
|
||||
counts = {}
|
||||
for method_name in flow._methods:
|
||||
@@ -197,9 +190,8 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Build a dictionary mapping each node to its ancestor nodes.
|
||||
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
|
||||
"""Build a dictionary mapping each node to its ancestor nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -210,9 +202,10 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
-------
|
||||
Dict[str, Set[str]]
|
||||
Dictionary mapping each node to a set of its ancestor nodes.
|
||||
|
||||
"""
|
||||
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
|
||||
visited: Set[str] = set()
|
||||
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
|
||||
visited: set[str] = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
@@ -220,10 +213,9 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
|
||||
|
||||
def dfs_ancestors(
|
||||
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
|
||||
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Perform depth-first search to build ancestor relationships.
|
||||
"""Perform depth-first search to build ancestor relationships.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -240,6 +232,7 @@ def dfs_ancestors(
|
||||
-----
|
||||
This function modifies the ancestors dictionary in-place to build
|
||||
the complete ancestor graph.
|
||||
|
||||
"""
|
||||
if node in visited:
|
||||
return
|
||||
@@ -265,10 +258,9 @@ def dfs_ancestors(
|
||||
|
||||
|
||||
def is_ancestor(
|
||||
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
|
||||
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if one node is an ancestor of another.
|
||||
"""Check if one node is an ancestor of another.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -283,13 +275,13 @@ def is_ancestor(
|
||||
-------
|
||||
bool
|
||||
True if ancestor_candidate is an ancestor of node, False otherwise.
|
||||
|
||||
"""
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
|
||||
|
||||
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Build a dictionary mapping parent nodes to their children.
|
||||
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
|
||||
"""Build a dictionary mapping parent nodes to their children.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -306,8 +298,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
- Maps listeners to their trigger methods
|
||||
- Maps router methods to their paths and listeners
|
||||
- Children lists are sorted for consistent ordering
|
||||
|
||||
"""
|
||||
parent_children: Dict[str, List[str]] = {}
|
||||
parent_children: dict[str, list[str]] = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
@@ -332,10 +325,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
|
||||
|
||||
def get_child_index(
|
||||
parent: str, child: str, parent_children: Dict[str, List[str]]
|
||||
parent: str, child: str, parent_children: dict[str, list[str]],
|
||||
) -> int:
|
||||
"""
|
||||
Get the index of a child node in its parent's sorted children list.
|
||||
"""Get the index of a child node in its parent's sorted children list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -350,27 +342,25 @@ def get_child_index(
|
||||
-------
|
||||
int
|
||||
Zero-based index of the child in its parent's sorted children list.
|
||||
|
||||
"""
|
||||
children = parent_children.get(parent, [])
|
||||
children.sort()
|
||||
return children.index(child)
|
||||
|
||||
|
||||
def process_router_paths(flow, current, current_level, levels, queue):
|
||||
"""
|
||||
Handle the router connections for the current node.
|
||||
"""
|
||||
def process_router_paths(flow, current, current_level, levels, queue) -> None:
|
||||
"""Handle the router connections for the current node."""
|
||||
if current in flow._routers:
|
||||
paths = flow._router_paths.get(current, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
_condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
queue.append(listener_name)
|
||||
if path in trigger_methods and (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
queue.append(listener_name)
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
"""
|
||||
Utilities for creating visual representations of flow structures.
|
||||
"""Utilities for creating visual representations of flow structures.
|
||||
|
||||
This module provides functions for generating network visualizations of flows,
|
||||
including node placement, edge creation, and visual styling. It handles the
|
||||
conversion of flow structures into visual network graphs with appropriate
|
||||
styling and layout.
|
||||
|
||||
Example
|
||||
Example:
|
||||
-------
|
||||
>>> flow = Flow()
|
||||
>>> net = Network(directed=True)
|
||||
>>> node_positions = compute_positions(flow, node_levels)
|
||||
>>> add_nodes_to_network(net, flow, node_positions, node_styles)
|
||||
>>> add_edges(net, flow, node_positions, colors)
|
||||
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
build_ancestor_dict,
|
||||
@@ -28,8 +28,7 @@ from .utils import (
|
||||
|
||||
|
||||
def method_calls_crew(method: Any) -> bool:
|
||||
"""
|
||||
Check if the method contains a call to `.crew()`.
|
||||
"""Check if the method contains a call to `.crew()`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -45,21 +44,22 @@ def method_calls_crew(method: Any) -> bool:
|
||||
-----
|
||||
Uses AST analysis to detect method calls, specifically looking for
|
||||
attribute access of 'crew'.
|
||||
|
||||
"""
|
||||
try:
|
||||
source = inspect.getsource(method)
|
||||
source = inspect.cleandoc(source)
|
||||
tree = ast.parse(source)
|
||||
except Exception as e:
|
||||
print(f"Could not parse method {method.__name__}: {e}")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to detect .crew() method calls."""
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.found = False
|
||||
|
||||
def visit_Call(self, node):
|
||||
def visit_Call(self, node) -> None:
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
if node.func.attr == "crew":
|
||||
self.found = True
|
||||
@@ -73,11 +73,10 @@ def method_calls_crew(method: Any) -> bool:
|
||||
def add_nodes_to_network(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
node_styles: Dict[str, Dict[str, Any]]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
node_styles: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Add nodes to the network visualization with appropriate styling.
|
||||
"""Add nodes to the network visualization with appropriate styling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -97,6 +96,7 @@ def add_nodes_to_network(
|
||||
- Router methods
|
||||
- Crew methods
|
||||
- Regular methods
|
||||
|
||||
"""
|
||||
def human_friendly_label(method_name):
|
||||
return method_name.replace("_", " ").title()
|
||||
@@ -123,7 +123,7 @@ def add_nodes_to_network(
|
||||
"multi": "html",
|
||||
"color": node_style.get("font", {}).get("color", "#FFFFFF"),
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
net.add_node(
|
||||
@@ -138,12 +138,11 @@ def add_nodes_to_network(
|
||||
|
||||
def compute_positions(
|
||||
flow: Any,
|
||||
node_levels: Dict[str, int],
|
||||
node_levels: dict[str, int],
|
||||
y_spacing: float = 150,
|
||||
x_spacing: float = 150
|
||||
) -> Dict[str, Tuple[float, float]]:
|
||||
"""
|
||||
Compute the (x, y) positions for each node in the flow graph.
|
||||
x_spacing: float = 150,
|
||||
) -> dict[str, tuple[float, float]]:
|
||||
"""Compute the (x, y) positions for each node in the flow graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -160,9 +159,10 @@ def compute_positions(
|
||||
-------
|
||||
Dict[str, Tuple[float, float]]
|
||||
Dictionary mapping node names to their (x, y) coordinates.
|
||||
|
||||
"""
|
||||
level_nodes: Dict[int, List[str]] = {}
|
||||
node_positions: Dict[str, Tuple[float, float]] = {}
|
||||
level_nodes: dict[int, list[str]] = {}
|
||||
node_positions: dict[str, tuple[float, float]] = {}
|
||||
|
||||
for method_name, level in node_levels.items():
|
||||
level_nodes.setdefault(level, []).append(method_name)
|
||||
@@ -180,10 +180,10 @@ def compute_positions(
|
||||
def add_edges(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
colors: Dict[str, str]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
colors: dict[str, str],
|
||||
) -> None:
|
||||
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
|
||||
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
|
||||
"""
|
||||
Add edges to the network visualization with appropriate styling.
|
||||
|
||||
@@ -245,7 +245,7 @@ def add_edges(
|
||||
"color": edge_color,
|
||||
"width": 2,
|
||||
"arrows": "to",
|
||||
"dashes": True if is_router_edge or is_and_condition else False,
|
||||
"dashes": bool(is_router_edge or is_and_condition),
|
||||
"smooth": edge_smooth,
|
||||
}
|
||||
|
||||
@@ -261,9 +261,7 @@ def add_edges(
|
||||
# If it's a known router edge and the method is known, don't warn.
|
||||
# This means the path is legitimate, just not reflected as nodes here.
|
||||
if not (is_router_edge and method_known):
|
||||
print(
|
||||
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
|
||||
)
|
||||
pass
|
||||
|
||||
# Edges for router return paths
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
@@ -278,7 +276,7 @@ def add_edges(
|
||||
and listener_name in node_positions
|
||||
):
|
||||
is_cycle_edge = is_ancestor(
|
||||
router_method_name, listener_name, ancestors
|
||||
router_method_name, listener_name, ancestors,
|
||||
)
|
||||
parent_has_multiple_children = (
|
||||
len(parent_children.get(router_method_name, [])) > 1
|
||||
@@ -293,7 +291,7 @@ def add_edges(
|
||||
dx = target_pos[0] - source_pos[0]
|
||||
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
||||
index = get_child_index(
|
||||
router_method_name, listener_name, parent_children
|
||||
router_method_name, listener_name, parent_children,
|
||||
)
|
||||
edge_smooth = {
|
||||
"type": smooth_type,
|
||||
@@ -316,6 +314,4 @@ def add_edges(
|
||||
# Same check here: known router edge and known method?
|
||||
method_known = listener_name in flow._methods
|
||||
if not method_known:
|
||||
print(
|
||||
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -1,55 +1,48 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseEmbedder(ABC):
|
||||
"""
|
||||
Abstract base class for text embedding models
|
||||
"""
|
||||
"""Abstract base class for text embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_chunks(self, chunks: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of text chunks
|
||||
def embed_chunks(self, chunks: list[str]) -> np.ndarray:
|
||||
"""Generate embeddings for a list of text chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
|
||||
Returns:
|
||||
Array of embeddings
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_texts(self, texts: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of texts
|
||||
def embed_texts(self, texts: list[str]) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
Array of embeddings
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_text(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single text
|
||||
"""Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding array
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimension(self) -> int:
|
||||
"""Get the dimension of the embeddings"""
|
||||
pass
|
||||
"""Get the dimension of the embeddings."""
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -19,75 +18,74 @@ except ImportError:
|
||||
|
||||
|
||||
class FastEmbed(BaseEmbedder):
|
||||
"""
|
||||
A wrapper class for text embedding models using FastEmbed
|
||||
"""
|
||||
"""A wrapper class for text embedding models using FastEmbed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the embedding model
|
||||
cache_dir: str | Path | None = None,
|
||||
) -> None:
|
||||
"""Initialize the embedding model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to use
|
||||
cache_dir: Directory to cache the model
|
||||
gpu: Whether to use GPU acceleration
|
||||
|
||||
"""
|
||||
if not FASTEMBED_AVAILABLE:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"FastEmbed is not installed. Please install it with: "
|
||||
"uv pip install fastembed or uv pip install fastembed-gpu for GPU support"
|
||||
)
|
||||
raise ImportError(
|
||||
msg,
|
||||
)
|
||||
|
||||
self.model = TextEmbedding(
|
||||
model_name=model_name,
|
||||
cache_dir=str(cache_dir) if cache_dir else None,
|
||||
)
|
||||
|
||||
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
|
||||
"""
|
||||
Generate embeddings for a list of text chunks
|
||||
def embed_chunks(self, chunks: list[str]) -> list[np.ndarray]:
|
||||
"""Generate embeddings for a list of text chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings
|
||||
"""
|
||||
embeddings = list(self.model.embed(chunks))
|
||||
return embeddings
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
|
||||
"""
|
||||
Generate embeddings for a list of texts
|
||||
return list(self.model.embed(chunks))
|
||||
|
||||
def embed_texts(self, texts: list[str]) -> list[np.ndarray]:
|
||||
"""Generate embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings
|
||||
|
||||
"""
|
||||
embeddings = list(self.model.embed(texts))
|
||||
return embeddings
|
||||
return list(self.model.embed(texts))
|
||||
|
||||
def embed_text(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single text
|
||||
"""Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding array
|
||||
|
||||
"""
|
||||
return self.embed_texts([text])[0]
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
"""Get the dimension of the embeddings"""
|
||||
"""Get the dimension of the embeddings."""
|
||||
# Generate a test embedding to get dimensions
|
||||
test_embed = self.embed_text("test")
|
||||
return len(test_embed)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -10,68 +10,70 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
|
||||
class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
"""Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
embedder: Optional[Dict[str, Any]] = None.
|
||||
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
collection_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: dict[str, Any] | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
embedder=embedder, collection_name=collection_name,
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
def query(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
|
||||
Raises:
|
||||
ValueError: If storage is not initialized.
|
||||
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
msg = "Storage is not initialized."
|
||||
raise ValueError(msg)
|
||||
|
||||
results = self.storage.search(
|
||||
return self.storage.search(
|
||||
query,
|
||||
limit=results_limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def add_sources(self):
|
||||
def add_sources(self) -> None:
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
source.add()
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.reset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
msg = "Storage is not initialized."
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -7,6 +7,7 @@ class KnowledgeConfig(BaseModel):
|
||||
Args:
|
||||
results_limit (int): The number of relevant documents to return.
|
||||
score_threshold (float): The minimum score for a document to be considered relevant.
|
||||
|
||||
"""
|
||||
|
||||
results_limit: int = Field(default=3, description="The number of results to return")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -14,43 +13,43 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
"""Base class for knowledge sources that load content from files."""
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
)
|
||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
default_factory=list, description="The path to the file"
|
||||
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||
default_factory=list, description="The path to the file",
|
||||
)
|
||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info):
|
||||
def validate_file_path(self, v, info):
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
v is None
|
||||
and info.data.get(
|
||||
"file_path" if info.field_name == "file_paths" else "file_paths"
|
||||
"file_path" if info.field_name == "file_paths" else "file_paths",
|
||||
)
|
||||
is None
|
||||
):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
msg = "Either file_path or file_paths must be provided"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _) -> None:
|
||||
"""Post-initialization method to load content."""
|
||||
self.safe_file_paths = self._process_file_paths()
|
||||
self.validate_content()
|
||||
self.content = self.load_content()
|
||||
|
||||
@abstractmethod
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||
pass
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -59,7 +58,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
|
||||
color="red",
|
||||
)
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
msg = f"File not found: {path}"
|
||||
raise FileNotFoundError(msg)
|
||||
if not path.is_file():
|
||||
self._logger.log(
|
||||
"error",
|
||||
@@ -67,20 +67,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _save_documents(self):
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage."""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
msg = "No storage found to save documents."
|
||||
raise ValueError(msg)
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
def _process_file_paths(self) -> List[Path]:
|
||||
def _process_file_paths(self) -> list[Path]:
|
||||
"""Convert file_path to a list of Path objects."""
|
||||
|
||||
if hasattr(self, "file_path") and self.file_path is not None:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -90,10 +90,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
self.file_paths = self.file_path
|
||||
|
||||
if self.file_paths is None:
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
msg = "Your source must be provided with a file_paths: []"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Convert single path to list
|
||||
path_list: List[Union[Path, str]] = (
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
else list(self.file_paths)
|
||||
@@ -102,8 +103,9 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
)
|
||||
|
||||
if not path_list:
|
||||
msg = "file_path/file_paths must be a Path, str, or a list of these types"
|
||||
raise ValueError(
|
||||
"file_path/file_paths must be a Path, str, or a list of these types"
|
||||
msg,
|
||||
)
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -12,41 +12,39 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
|
||||
chunk_size: int = 4000
|
||||
chunk_overlap: int = 200
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
@abstractmethod
|
||||
def validate_content(self) -> Any:
|
||||
"""Load and preprocess content from the source."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them."""
|
||||
pass
|
||||
|
||||
def get_embeddings(self) -> List[np.ndarray]:
|
||||
def get_embeddings(self) -> list[np.ndarray]:
|
||||
"""Return the list of embeddings for the chunks."""
|
||||
return self.chunk_embeddings
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
||||
]
|
||||
|
||||
def _save_documents(self):
|
||||
"""
|
||||
Save the documents to the storage.
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage.
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
"""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
msg = "No storage found to save documents."
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
@@ -7,7 +8,6 @@ try:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -19,27 +19,33 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
|
||||
class CrewDoclingSource(BaseKnowledgeSource):
|
||||
"""Default Source class for converting documents to markdown or json
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
if not DOCLING_AVAILABLE:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
"Please install it using: uv add docling"
|
||||
)
|
||||
raise ImportError(
|
||||
msg,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
|
||||
file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||
content: List["DoclingDocument"] = Field(default_factory=list)
|
||||
file_path: list[Path | str] | None = Field(default=None)
|
||||
file_paths: list[Path | str] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
safe_file_paths: list[Path | str] = Field(default_factory=list)
|
||||
content: list["DoclingDocument"] = Field(default_factory=list)
|
||||
document_converter: "DocumentConverter" = Field(
|
||||
default_factory=lambda: DocumentConverter(
|
||||
allowed_formats=[
|
||||
@@ -51,8 +57,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
InputFormat.IMAGE,
|
||||
InputFormat.XLSX,
|
||||
InputFormat.PPTX,
|
||||
]
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
@@ -66,7 +72,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.safe_file_paths = self.validate_content()
|
||||
self.content = self._load_content()
|
||||
|
||||
def _load_content(self) -> List["DoclingDocument"]:
|
||||
def _load_content(self) -> list["DoclingDocument"]:
|
||||
try:
|
||||
return self._convert_source_to_docling_documents()
|
||||
except ConversionError as e:
|
||||
@@ -75,10 +81,10 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
|
||||
"red",
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error loading content: {e}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
def add(self) -> None:
|
||||
if self.content is None:
|
||||
@@ -88,7 +94,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
self._save_documents()
|
||||
|
||||
def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]:
|
||||
def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]:
|
||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
@@ -97,8 +103,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
for chunk in chunker.chunk(doc):
|
||||
yield chunk.text
|
||||
|
||||
def validate_content(self) -> List[Union[Path, str]]:
|
||||
processed_paths: List[Union[Path, str]] = []
|
||||
def validate_content(self) -> list[Path | str]:
|
||||
processed_paths: list[Path | str] = []
|
||||
for path in self.file_paths:
|
||||
if isinstance(path, str):
|
||||
if path.startswith(("http://", "https://")):
|
||||
@@ -106,15 +112,18 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
if self._validate_url(path):
|
||||
processed_paths.append(path)
|
||||
else:
|
||||
raise ValueError(f"Invalid URL format: {path}")
|
||||
msg = f"Invalid URL format: {path}"
|
||||
raise ValueError(msg)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
|
||||
msg = f"Invalid URL: {path}. Error: {e!s}"
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
|
||||
if local_path.exists():
|
||||
processed_paths.append(local_path)
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {local_path}")
|
||||
msg = f"File not found: {local_path}"
|
||||
raise FileNotFoundError(msg)
|
||||
else:
|
||||
# this is an instance of Path
|
||||
processed_paths.append(path)
|
||||
@@ -128,7 +137,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
result.scheme in ("http", "https"),
|
||||
result.netloc,
|
||||
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
|
||||
]
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,11 +7,11 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries CSV file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess CSV file content."""
|
||||
content_dict = {}
|
||||
for file_path in self.safe_file_paths:
|
||||
with open(file_path, "r", encoding="utf-8") as csvfile:
|
||||
with open(file_path, encoding="utf-8") as csvfile:
|
||||
reader = csv.reader(csvfile)
|
||||
content = ""
|
||||
for row in reader:
|
||||
@@ -21,8 +20,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
return content_dict
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add CSV file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add CSV file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
content_str = (
|
||||
@@ -32,7 +30,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -16,34 +14,34 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
)
|
||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
default_factory=list, description="The path to the file"
|
||||
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||
default_factory=list, description="The path to the file",
|
||||
)
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
content: Dict[Path, Dict[str, str]] = Field(default_factory=dict)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
content: dict[Path, dict[str, str]] = Field(default_factory=dict)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info):
|
||||
def validate_file_path(self, v, info):
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
v is None
|
||||
and info.data.get(
|
||||
"file_path" if info.field_name == "file_paths" else "file_paths"
|
||||
"file_path" if info.field_name == "file_paths" else "file_paths",
|
||||
)
|
||||
is None
|
||||
):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
msg = "Either file_path or file_paths must be provided"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
def _process_file_paths(self) -> List[Path]:
|
||||
def _process_file_paths(self) -> list[Path]:
|
||||
"""Convert file_path to a list of Path objects."""
|
||||
|
||||
if hasattr(self, "file_path") and self.file_path is not None:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -53,10 +51,11 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.file_paths = self.file_path
|
||||
|
||||
if self.file_paths is None:
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
msg = "Your source must be provided with a file_paths: []"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Convert single path to list
|
||||
path_list: List[Union[Path, str]] = (
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
else list(self.file_paths)
|
||||
@@ -65,13 +64,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
)
|
||||
|
||||
if not path_list:
|
||||
msg = "file_path/file_paths must be a Path, str, or a list of these types"
|
||||
raise ValueError(
|
||||
"file_path/file_paths must be a Path, str, or a list of these types"
|
||||
msg,
|
||||
)
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -80,7 +80,8 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
|
||||
color="red",
|
||||
)
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
msg = f"File not found: {path}"
|
||||
raise FileNotFoundError(msg)
|
||||
if not path.is_file():
|
||||
self._logger.log(
|
||||
"error",
|
||||
@@ -100,7 +101,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.validate_content()
|
||||
self.content = self._load_content()
|
||||
|
||||
def _load_content(self) -> Dict[Path, Dict[str, str]]:
|
||||
def _load_content(self) -> dict[Path, dict[str, str]]:
|
||||
"""Load and preprocess Excel file content from multiple sheets.
|
||||
|
||||
Each sheet's content is converted to CSV format and stored.
|
||||
@@ -111,6 +112,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
Raises:
|
||||
ImportError: If required dependencies are missing.
|
||||
FileNotFoundError: If the specified Excel file cannot be opened.
|
||||
|
||||
"""
|
||||
pd = self._import_dependencies()
|
||||
content_dict = {}
|
||||
@@ -119,14 +121,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
with pd.ExcelFile(file_path) as xl:
|
||||
sheet_dict = {
|
||||
str(sheet_name): str(
|
||||
pd.read_excel(xl, sheet_name).to_csv(index=False)
|
||||
pd.read_excel(xl, sheet_name).to_csv(index=False),
|
||||
)
|
||||
for sheet_name in xl.sheet_names
|
||||
}
|
||||
content_dict[file_path] = sheet_dict
|
||||
return content_dict
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
@@ -138,13 +140,13 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
return pd
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
msg = f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
raise ImportError(
|
||||
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
msg,
|
||||
)
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add Excel file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add Excel file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
# Convert dictionary values to a single string if content is a dictionary
|
||||
@@ -161,7 +163,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,12 +8,12 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries JSON file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess JSON file content."""
|
||||
content: Dict[Path, str] = {}
|
||||
content: dict[Path, str] = {}
|
||||
for path in self.safe_file_paths:
|
||||
path = self.convert_to_path(path)
|
||||
with open(path, "r", encoding="utf-8") as json_file:
|
||||
with open(path, encoding="utf-8") as json_file:
|
||||
data = json.load(json_file)
|
||||
content[path] = self._json_to_text(data)
|
||||
return content
|
||||
@@ -29,12 +29,11 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
for item in data:
|
||||
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
|
||||
else:
|
||||
text += f"{str(data)}"
|
||||
text += f"{data!s}"
|
||||
return text
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add JSON file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add JSON file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
content_str = (
|
||||
@@ -44,7 +43,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries PDF file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess PDF file content."""
|
||||
pdfplumber = self._import_pdfplumber()
|
||||
|
||||
@@ -31,21 +30,21 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
|
||||
return pdfplumber
|
||||
except ImportError:
|
||||
msg = "pdfplumber is not installed. Please install it with: pip install pdfplumber"
|
||||
raise ImportError(
|
||||
"pdfplumber is not installed. Please install it with: pip install pdfplumber"
|
||||
msg,
|
||||
)
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add PDF file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add PDF file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
for _, text in self.content.items():
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -9,16 +8,17 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
"""A knowledge source that stores and queries plain text content using embeddings."""
|
||||
|
||||
content: str = Field(...)
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _) -> None:
|
||||
"""Post-initialization method to validate content."""
|
||||
self.validate_content()
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate string content."""
|
||||
if not isinstance(self.content, str):
|
||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||
msg = "StringKnowledgeSource only accepts string content"
|
||||
raise ValueError(msg)
|
||||
|
||||
def add(self) -> None:
|
||||
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them."""
|
||||
@@ -26,7 +26,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,26 +6,25 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries text file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess text file content."""
|
||||
content = {}
|
||||
for path in self.safe_file_paths:
|
||||
path = self.convert_to_path(path)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
content[path] = f.read()
|
||||
return content
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add text file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add text file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
for _, text in self.content.items():
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseKnowledgeStorage(ABC):
|
||||
@@ -8,22 +8,19 @@ class BaseKnowledgeStorage(ABC):
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
|
||||
self, documents: list[str], metadata: dict[str, Any] | list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Save documents to the knowledge base."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the knowledge base."""
|
||||
pass
|
||||
|
||||
@@ -4,12 +4,11 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import OneOrMany
|
||||
from chromadb.config import Settings
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
@@ -19,6 +18,9 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api.types import OneOrMany
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(
|
||||
@@ -38,30 +40,29 @@ def suppress_logging(
|
||||
|
||||
|
||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
"""Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
collection: Optional[chromadb.Collection] = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
app: Optional[ClientAPI] = None
|
||||
collection: chromadb.Collection | None = None
|
||||
collection_name: str | None = "knowledge"
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
embedder: dict[str, Any] | None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._set_embedder_config(embedder)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
with suppress_logging():
|
||||
if self.collection:
|
||||
fetched = self.collection.query(
|
||||
@@ -80,10 +81,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
return results
|
||||
else:
|
||||
raise Exception("Collection not initialized")
|
||||
msg = "Collection not initialized"
|
||||
raise Exception(msg)
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
def initialize_knowledge_storage(self) -> None:
|
||||
base_path = os.path.join(db_storage_path(), "knowledge")
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=base_path,
|
||||
@@ -104,11 +105,13 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
embedding_function=self.embedder,
|
||||
)
|
||||
else:
|
||||
raise Exception("Vector Database Client not initialized")
|
||||
msg = "Vector Database Client not initialized"
|
||||
raise Exception(msg)
|
||||
except Exception:
|
||||
raise Exception("Failed to create or get collection")
|
||||
msg = "Failed to create or get collection"
|
||||
raise Exception(msg)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||
if not self.app:
|
||||
self.app = chromadb.PersistentClient(
|
||||
@@ -123,11 +126,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def save(
|
||||
self,
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
):
|
||||
documents: list[str],
|
||||
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
if not self.collection:
|
||||
raise Exception("Collection not initialized")
|
||||
msg = "Collection not initialized"
|
||||
raise Exception(msg)
|
||||
|
||||
try:
|
||||
# Create a dictionary to store unique documents
|
||||
@@ -156,7 +160,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
filtered_ids.append(doc_id)
|
||||
|
||||
# If we have no metadata at all, set it to None
|
||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||
final_metadata: OneOrMany[chromadb.Metadata] | None = (
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
||||
)
|
||||
|
||||
@@ -171,10 +175,13 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
)
|
||||
raise ValueError(
|
||||
msg,
|
||||
) from e
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
@@ -186,15 +193,16 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
|
||||
)
|
||||
|
||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||
def _set_embedder_config(self, embedder: dict[str, Any] | None = None) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
|
||||
"""
|
||||
self.embedder = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
|
||||
def extract_knowledge_context(knowledge_snippets: list[dict[str, Any]]) -> str:
|
||||
"""Extract knowledge from the task prompt."""
|
||||
valid_snippets = [
|
||||
result["context"]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -35,7 +35,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
show_agent_logs,
|
||||
)
|
||||
from crewai.utilities.converter import convert_to_model, generate_model_description
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.events.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
@@ -60,15 +60,15 @@ class LiteAgentOutput(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
raw: str = Field(description="Raw output of the agent", default="")
|
||||
pydantic: Optional[BaseModel] = Field(
|
||||
description="Pydantic output of the agent", default=None
|
||||
pydantic: BaseModel | None = Field(
|
||||
description="Pydantic output of the agent", default=None,
|
||||
)
|
||||
agent_role: str = Field(description="Role of the agent that produced this output")
|
||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
||||
description="Token usage metrics for this execution", default=None
|
||||
usage_metrics: dict[str, Any] | None = Field(
|
||||
description="Token usage metrics for this execution", default=None,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert pydantic_output to a dictionary."""
|
||||
if self.pydantic:
|
||||
return self.pydantic.model_dump()
|
||||
@@ -82,8 +82,7 @@ class LiteAgentOutput(BaseModel):
|
||||
|
||||
|
||||
class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""
|
||||
A lightweight agent that can process messages and use tools.
|
||||
"""A lightweight agent that can process messages and use tools.
|
||||
|
||||
This agent is simpler than the full Agent class, focusing on direct execution
|
||||
rather than task delegation. It's designed to be used for simple interactions
|
||||
@@ -99,6 +98,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
max_iterations: Maximum number of iterations for tool usage.
|
||||
max_execution_time: Maximum execution time in seconds.
|
||||
response_format: Optional Pydantic model for structured output.
|
||||
|
||||
"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
@@ -107,19 +107,19 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
llm: str | InstanceOf[LLM] | Any | None = Field(
|
||||
default=None, description="Language model that will run the agent",
|
||||
)
|
||||
tools: List[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agent's disposal"
|
||||
tools: list[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agent's disposal",
|
||||
)
|
||||
|
||||
# Execution Control Properties
|
||||
max_iterations: int = Field(
|
||||
default=15, description="Maximum number of iterations for tool usage"
|
||||
default=15, description="Maximum number of iterations for tool usage",
|
||||
)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
default=None, description="Maximum execution time in seconds"
|
||||
max_execution_time: int | None = Field(
|
||||
default=None, description="Maximum execution time in seconds",
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
default=True,
|
||||
@@ -129,38 +129,38 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=True,
|
||||
description="Whether to use stop words to prevent the LLM from using tools",
|
||||
)
|
||||
request_within_rpm_limit: Optional[Callable[[], bool]] = Field(
|
||||
request_within_rpm_limit: Callable[[], bool] | None = Field(
|
||||
default=None,
|
||||
description="Callback to check if the request is within the RPM limit",
|
||||
)
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
|
||||
# Output and Formatting Properties
|
||||
response_format: Optional[Type[BaseModel]] = Field(
|
||||
default=None, description="Pydantic model for structured output"
|
||||
response_format: type[BaseModel] | None = Field(
|
||||
default=None, description="Pydantic model for structured output",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Whether to print execution details"
|
||||
default=False, description="Whether to print execution details",
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
callbacks: list[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent",
|
||||
)
|
||||
|
||||
# State and Results
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent.",
|
||||
)
|
||||
|
||||
# Reference of Agent
|
||||
original_agent: Optional[BaseAgent] = Field(
|
||||
default=None, description="Reference to the agent that created this LiteAgent"
|
||||
original_agent: BaseAgent | None = Field(
|
||||
default=None, description="Reference to the agent that created this LiteAgent",
|
||||
)
|
||||
# Private Attributes
|
||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
@@ -169,7 +169,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
self.llm = create_llm(self.llm)
|
||||
if not isinstance(self.llm, LLM):
|
||||
raise ValueError("Unable to create LLM instance")
|
||||
msg = "Unable to create LLM instance"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Initialize callbacks
|
||||
token_callback = TokenCalcHandler(token_cost_process=self._token_process)
|
||||
@@ -194,9 +195,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages.
|
||||
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||
"""Execute the agent with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Either a string query or a list of message dictionaries.
|
||||
@@ -205,6 +205,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
|
||||
"""
|
||||
# Create agent info for event emission
|
||||
agent_info = {
|
||||
@@ -235,18 +236,18 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
agent_finish = self._invoke_loop()
|
||||
formatted_result: Optional[BaseModel] = None
|
||||
formatted_result: BaseModel | None = None
|
||||
if self.response_format:
|
||||
try:
|
||||
# Cast to BaseModel to ensure type safety
|
||||
result = self.response_format.model_validate_json(
|
||||
agent_finish.output
|
||||
agent_finish.output,
|
||||
)
|
||||
if isinstance(result, BaseModel):
|
||||
formatted_result = result
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
content=f"Failed to parse output into response format: {str(e)}",
|
||||
content=f"Failed to parse output into response format: {e!s}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
@@ -286,13 +287,12 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
self, messages: str | list[dict[str, str]],
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages.
|
||||
"""Execute the agent asynchronously with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Either a string query or a list of message dictionaries.
|
||||
@@ -301,6 +301,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
|
||||
"""
|
||||
return await asyncio.to_thread(self.kickoff, messages)
|
||||
|
||||
@@ -319,7 +320,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
else:
|
||||
# Use the prompt template for agents without tools
|
||||
base_prompt = self.i18n.slice(
|
||||
"lite_agent_system_prompt_without_tools"
|
||||
"lite_agent_system_prompt_without_tools",
|
||||
).format(
|
||||
role=self.role,
|
||||
backstory=self.backstory,
|
||||
@@ -330,14 +331,14 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
if self.response_format:
|
||||
schema = generate_model_description(self.response_format)
|
||||
base_prompt += self.i18n.slice("lite_agent_response_format").format(
|
||||
response_format=schema
|
||||
response_format=schema,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: str | list[dict[str, str]],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages for the LLM."""
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -353,11 +354,11 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return formatted_messages
|
||||
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""
|
||||
Run the agent's thought process until it reaches a conclusion or max iterations.
|
||||
"""Run the agent's thought process until it reaches a conclusion or max iterations.
|
||||
|
||||
Returns:
|
||||
AgentFinish: The final result of the agent execution.
|
||||
|
||||
"""
|
||||
# Execute the agent loop
|
||||
formatted_answer = None
|
||||
@@ -369,7 +370,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
printer=self._printer,
|
||||
i18n=self.i18n,
|
||||
messages=self._messages,
|
||||
llm=cast(LLM, self.llm),
|
||||
llm=cast("LLM", self.llm),
|
||||
callbacks=self._callbacks,
|
||||
)
|
||||
|
||||
@@ -387,7 +388,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
try:
|
||||
answer = get_llm_response(
|
||||
llm=cast(LLM, self.llm),
|
||||
llm=cast("LLM", self.llm),
|
||||
messages=self._messages,
|
||||
callbacks=self._callbacks,
|
||||
printer=self._printer,
|
||||
@@ -407,7 +408,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
|
||||
@@ -421,8 +422,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
agent_role=self.role,
|
||||
agent=self.original_agent,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
formatted_answer = handle_agent_action_core(
|
||||
formatted_answer=formatted_answer,
|
||||
@@ -443,20 +444,19 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
raise
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self._messages,
|
||||
llm=cast(LLM, self.llm),
|
||||
llm=cast("LLM", self.llm),
|
||||
callbacks=self._callbacks,
|
||||
i18n=self.i18n,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
self._iterations += 1
|
||||
@@ -465,7 +465,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
|
||||
"""Show logs for the agent's execution."""
|
||||
show_agent_logs(
|
||||
printer=self._printer,
|
||||
|
||||
@@ -6,17 +6,10 @@ import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -31,7 +24,6 @@ from crewai.utilities.events.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
@@ -55,7 +47,7 @@ load_dotenv()
|
||||
|
||||
|
||||
class FilteredStream:
|
||||
def __init__(self, original_stream):
|
||||
def __init__(self, original_stream) -> None:
|
||||
self._original_stream = original_stream
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@@ -210,7 +202,7 @@ def suppress_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="open_text is deprecated*", category=DeprecationWarning
|
||||
"ignore", message="open_text is deprecated*", category=DeprecationWarning,
|
||||
)
|
||||
|
||||
# Redirect stdout and stderr
|
||||
@@ -226,14 +218,14 @@ def suppress_warnings():
|
||||
|
||||
|
||||
class Delta(TypedDict):
|
||||
content: Optional[str]
|
||||
role: Optional[str]
|
||||
content: str | None
|
||||
role: str | None
|
||||
|
||||
|
||||
class StreamingChoices(TypedDict):
|
||||
delta: Delta
|
||||
index: int
|
||||
finish_reason: Optional[str]
|
||||
finish_reason: str | None
|
||||
|
||||
|
||||
class FunctionArgs(BaseModel):
|
||||
@@ -249,29 +241,31 @@ class LLM(BaseLLM):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
timeout: float | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
if callbacks is None:
|
||||
callbacks = []
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -301,7 +295,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -318,15 +312,16 @@ class LLM(BaseLLM):
|
||||
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
|
||||
"""
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the completion call.
|
||||
|
||||
Args:
|
||||
@@ -337,6 +332,7 @@ class LLM(BaseLLM):
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Parameters for the completion call
|
||||
|
||||
"""
|
||||
# --- 1) Format messages according to provider requirements
|
||||
if isinstance(messages, str):
|
||||
@@ -375,9 +371,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
@@ -391,6 +387,7 @@ class LLM(BaseLLM):
|
||||
|
||||
Raises:
|
||||
Exception: If no content is received from the streaming response
|
||||
|
||||
"""
|
||||
# --- 1) Initialize response tracking
|
||||
full_response = ""
|
||||
@@ -399,8 +396,8 @@ class LLM(BaseLLM):
|
||||
usage_info = None
|
||||
tool_calls = None
|
||||
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs,
|
||||
)
|
||||
|
||||
# --- 2) Make sure stream is set to True and include usage metrics
|
||||
@@ -424,16 +421,16 @@ class LLM(BaseLLM):
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "choices"), type):
|
||||
choices = getattr(chunk, "choices")
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
# Try to extract usage information if available
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "usage"), type):
|
||||
usage_info = getattr(chunk, "usage")
|
||||
if not isinstance(chunk.usage, type):
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -443,7 +440,7 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
delta = getattr(choice, "delta")
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content from delta
|
||||
if delta:
|
||||
@@ -453,7 +450,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = getattr(delta, "content")
|
||||
chunk_content = delta.content
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
@@ -491,21 +488,21 @@ class LLM(BaseLLM):
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
if not full_response.strip() and chunk_count == 0:
|
||||
logging.warning(
|
||||
"No chunks received in streaming response, falling back to non-streaming"
|
||||
"No chunks received in streaming response, falling back to non-streaming",
|
||||
)
|
||||
non_streaming_params = params.copy()
|
||||
non_streaming_params["stream"] = False
|
||||
non_streaming_params.pop(
|
||||
"stream_options", None
|
||||
"stream_options", None,
|
||||
) # Remove stream_options for non-streaming call
|
||||
return self._handle_non_streaming_response(
|
||||
non_streaming_params, callbacks, available_functions
|
||||
non_streaming_params, callbacks, available_functions,
|
||||
)
|
||||
|
||||
# --- 5) Handle empty response with chunks
|
||||
if not full_response.strip() and chunk_count > 0:
|
||||
logging.warning(
|
||||
f"Received {chunk_count} chunks but no content was extracted"
|
||||
f"Received {chunk_count} chunks but no content was extracted",
|
||||
)
|
||||
if last_chunk is not None:
|
||||
try:
|
||||
@@ -514,8 +511,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -525,30 +522,31 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
content = getattr(message, "content")
|
||||
content = message.content
|
||||
|
||||
if content:
|
||||
full_response = content
|
||||
logging.info(
|
||||
f"Extracted content from last chunk message: {full_response}"
|
||||
f"Extracted content from last chunk message: {full_response}",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting content from last chunk: {e}")
|
||||
logging.debug(
|
||||
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}"
|
||||
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}",
|
||||
)
|
||||
|
||||
# --- 6) If still empty, raise an error instead of using a default response
|
||||
if not full_response.strip() and len(accumulated_tool_args) == 0:
|
||||
msg = "No content received from streaming response. Received empty chunks or failed to extract content."
|
||||
raise Exception(
|
||||
"No content received from streaming response. Received empty chunks or failed to extract content."
|
||||
msg,
|
||||
)
|
||||
|
||||
# --- 7) Check for tool calls in the final response
|
||||
@@ -559,8 +557,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -569,13 +567,13 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
tool_calls = message.tool_calls
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
@@ -605,9 +603,9 @@ class LLM(BaseLLM):
|
||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
logging.exception(f"Error in streaming response: {e!s}")
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
logging.warning(f"Returning partial response despite error: {e!s}")
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
return full_response
|
||||
|
||||
@@ -617,13 +615,14 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
msg = f"Failed to get streaming response: {e!s}"
|
||||
raise Exception(msg)
|
||||
|
||||
def _handle_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
tool_calls: list[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> None | str:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -662,9 +661,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_callbacks(
|
||||
self,
|
||||
callbacks: Optional[List[Any]],
|
||||
usage_info: Optional[Dict[str, Any]],
|
||||
last_chunk: Optional[Any],
|
||||
callbacks: list[Any] | None,
|
||||
usage_info: dict[str, Any] | None,
|
||||
last_chunk: Any | None,
|
||||
) -> None:
|
||||
"""Handle callbacks with usage info for streaming responses.
|
||||
|
||||
@@ -672,6 +671,7 @@ class LLM(BaseLLM):
|
||||
callbacks: Optional list of callback functions
|
||||
usage_info: Usage information collected during streaming
|
||||
last_chunk: The last chunk received from the streaming response
|
||||
|
||||
"""
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
@@ -688,9 +688,9 @@ class LLM(BaseLLM):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(
|
||||
getattr(last_chunk, "usage"), type
|
||||
last_chunk.usage, type,
|
||||
):
|
||||
usage_info = getattr(last_chunk, "usage")
|
||||
usage_info = last_chunk.usage
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
@@ -704,9 +704,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_non_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
@@ -717,6 +717,7 @@ class LLM(BaseLLM):
|
||||
|
||||
Returns:
|
||||
str: The response text
|
||||
|
||||
"""
|
||||
# --- 1) Make the completion call
|
||||
try:
|
||||
@@ -731,7 +732,7 @@ class LLM(BaseLLM):
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
|
||||
# --- 2) Extract response message and content
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
response_message = cast("Choices", cast("ModelResponse", response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
@@ -768,9 +769,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
tool_calls: list[Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
Args:
|
||||
@@ -779,6 +780,7 @@ class LLM(BaseLLM):
|
||||
|
||||
Returns:
|
||||
Optional[str]: The result of the tool call, or None if no tool call was made
|
||||
|
||||
"""
|
||||
# --- 1) Validate tool calls and available functions
|
||||
if not tool_calls or not available_functions:
|
||||
@@ -805,23 +807,23 @@ class LLM(BaseLLM):
|
||||
except Exception as e:
|
||||
# --- 3.4) Handle execution errors
|
||||
fn = available_functions.get(
|
||||
function_name, lambda: None
|
||||
function_name, lambda: None,
|
||||
) # Ensure fn is always a callable
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
logging.exception(f"Error executing function '{function_name}': {e}")
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
)
|
||||
return None
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
"""High-level LLM call method.
|
||||
|
||||
Args:
|
||||
@@ -844,6 +846,7 @@ class LLM(BaseLLM):
|
||||
TypeError: If messages format is invalid
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededException: If input exceeds model's context limit
|
||||
|
||||
"""
|
||||
# --- 1) Emit call started event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
@@ -882,12 +885,11 @@ class LLM(BaseLLM):
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
params, callbacks, available_functions,
|
||||
)
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions,
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
@@ -900,15 +902,16 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
logging.exception(f"LiteLLM call failed: {e!s}")
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType):
|
||||
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType) -> None:
|
||||
"""Handle the events for the LLM call.
|
||||
|
||||
Args:
|
||||
response (str): The response from the LLM call.
|
||||
call_type (str): The type of call, either "tool_call" or "llm_call".
|
||||
|
||||
"""
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
@@ -917,8 +920,8 @@ class LLM(BaseLLM):
|
||||
)
|
||||
|
||||
def _format_messages_for_provider(
|
||||
self, messages: List[Dict[str, str]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: list[dict[str, str]],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages according to provider requirements.
|
||||
|
||||
Args:
|
||||
@@ -931,15 +934,18 @@ class LLM(BaseLLM):
|
||||
|
||||
Raises:
|
||||
TypeError: If messages is None or contains invalid message format.
|
||||
|
||||
"""
|
||||
if messages is None:
|
||||
raise TypeError("Messages cannot be None")
|
||||
msg = "Messages cannot be None"
|
||||
raise TypeError(msg)
|
||||
|
||||
# Validate message format first
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||
msg = "Invalid message format. Each message must be a dict with 'role' and 'content' keys"
|
||||
raise TypeError(
|
||||
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
|
||||
msg,
|
||||
)
|
||||
|
||||
# Handle O1 models specially
|
||||
@@ -949,7 +955,7 @@ class LLM(BaseLLM):
|
||||
# Convert system messages to assistant messages
|
||||
if msg["role"] == "system":
|
||||
formatted_messages.append(
|
||||
{"role": "assistant", "content": msg["content"]}
|
||||
{"role": "assistant", "content": msg["content"]},
|
||||
)
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
@@ -977,9 +983,8 @@ class LLM(BaseLLM):
|
||||
|
||||
return messages
|
||||
|
||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
||||
"""
|
||||
Derives the custom_llm_provider from the model string.
|
||||
def _get_custom_llm_provider(self) -> str | None:
|
||||
"""Derives the custom_llm_provider from the model string.
|
||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||
- If the model is "gemini/gemini-1.5-pro", returns "gemini".
|
||||
- If there is no '/', defaults to "openai".
|
||||
@@ -989,8 +994,7 @@ class LLM(BaseLLM):
|
||||
return None
|
||||
|
||||
def _validate_call_params(self) -> None:
|
||||
"""
|
||||
Validate parameters before making a call. Currently this only checks if
|
||||
"""Validate parameters before making a call. Currently this only checks if
|
||||
a response_format is provided and whether the model supports it.
|
||||
The custom_llm_provider is dynamically determined from the model:
|
||||
- E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter"
|
||||
@@ -1002,19 +1006,22 @@ class LLM(BaseLLM):
|
||||
model=self.model,
|
||||
custom_llm_provider=provider,
|
||||
):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"The model {self.model} does not support response_format for provider '{provider}'. "
|
||||
"Please remove response_format or use a supported model."
|
||||
)
|
||||
raise ValueError(
|
||||
msg,
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
provider = self._get_custom_llm_provider()
|
||||
return litellm.utils.supports_function_calling(
|
||||
self.model, custom_llm_provider=provider
|
||||
self.model, custom_llm_provider=provider,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
||||
logging.exception(f"Failed to check function calling support: {e!s}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -1022,16 +1029,16 @@ class LLM(BaseLLM):
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
logging.exception(f"Failed to get supported params: {e!s}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""
|
||||
Returns the context window size, using 75% of the maximum to avoid
|
||||
"""Returns the context window size, using 75% of the maximum to avoid
|
||||
cutting off messages mid-thread.
|
||||
|
||||
Raises:
|
||||
ValueError: If a model's context window size is outside valid bounds (1024-2097152)
|
||||
|
||||
"""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
@@ -1042,21 +1049,21 @@ class LLM(BaseLLM):
|
||||
# Validate all context window sizes
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < MIN_CONTEXT or value > MAX_CONTEXT:
|
||||
msg = f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
|
||||
msg,
|
||||
)
|
||||
|
||||
self.context_window_size = int(
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO,
|
||||
)
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if self.model.startswith(key):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
def set_callbacks(self, callbacks: list[Any]) -> None:
|
||||
"""Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
"""
|
||||
with suppress_warnings():
|
||||
@@ -1071,9 +1078,8 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def set_env_callbacks(self):
|
||||
"""
|
||||
Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
def set_env_callbacks(self) -> None:
|
||||
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
|
||||
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
||||
environment variables, which should contain comma-separated lists of callback names.
|
||||
@@ -1089,6 +1095,7 @@ class LLM(BaseLLM):
|
||||
|
||||
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
|
||||
`litellm.failure_callback` to ["langfuse"].
|
||||
|
||||
"""
|
||||
with suppress_warnings():
|
||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
@@ -17,17 +17,18 @@ class BaseLLM(ABC):
|
||||
Attributes:
|
||||
stop (list): A list of stop sequences that the LLM should use to stop generation.
|
||||
This is used by the CrewAgentExecutor and other components.
|
||||
|
||||
"""
|
||||
|
||||
model: str
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
temperature: float | None = None
|
||||
stop: list[str] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: Optional[float] = None,
|
||||
):
|
||||
temperature: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
|
||||
This constructor sets default values for attributes that are expected
|
||||
@@ -43,11 +44,11 @@ class BaseLLM(ABC):
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
@@ -70,14 +71,15 @@ class BaseLLM(ABC):
|
||||
ValueError: If the messages format is invalid.
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
bool: True if the LLM supports stop words, False otherwise.
|
||||
|
||||
"""
|
||||
return True # Default implementation assumes support for stop words
|
||||
|
||||
@@ -86,6 +88,7 @@ class BaseLLM(ABC):
|
||||
|
||||
Returns:
|
||||
int: The number of tokens/characters the model can handle.
|
||||
|
||||
"""
|
||||
# Default implementation - subclasses should override with model-specific values
|
||||
return 4096
|
||||
|
||||
20
src/crewai/llms/third_party/ai_suite.py
vendored
20
src/crewai/llms/third_party/ai_suite.py
vendored
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import aisuite as ai
|
||||
|
||||
@@ -6,17 +6,17 @@ from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
class AISuiteLLM(BaseLLM):
|
||||
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
|
||||
def __init__(self, model: str, temperature: float | None = None, **kwargs) -> None:
|
||||
super().__init__(model, temperature, **kwargs)
|
||||
self.client = ai.Client()
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
completion_params = self._prepare_completion_params(messages, tools)
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
|
||||
@@ -24,9 +24,9 @@ class AISuiteLLM(BaseLLM):
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory import (
|
||||
EntityMemory,
|
||||
@@ -12,13 +12,13 @@ from crewai.memory import (
|
||||
class ContextualMemory:
|
||||
def __init__(
|
||||
self,
|
||||
memory_config: Optional[Dict[str, Any]],
|
||||
memory_config: dict[str, Any] | None,
|
||||
stm: ShortTermMemory,
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
um: UserMemory,
|
||||
exm: ExternalMemory,
|
||||
):
|
||||
) -> None:
|
||||
if memory_config is not None:
|
||||
self.memory_provider = memory_config.get("provider")
|
||||
else:
|
||||
@@ -30,8 +30,7 @@ class ContextualMemory:
|
||||
self.exm = exm
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
"""Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
@@ -49,11 +48,9 @@ class ContextualMemory:
|
||||
return "\n".join(filter(None, context))
|
||||
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
"""Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
|
||||
if self.stm is None:
|
||||
return ""
|
||||
|
||||
@@ -62,16 +59,14 @@ class ContextualMemory:
|
||||
[
|
||||
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
|
||||
for result in stm_results
|
||||
]
|
||||
],
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
def _fetch_ltm_context(self, task) -> str | None:
|
||||
"""Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
@@ -90,8 +85,7 @@ class ContextualMemory:
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
"""Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
if self.em is None:
|
||||
@@ -102,19 +96,20 @@ class ContextualMemory:
|
||||
[
|
||||
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
|
||||
for result in em_results
|
||||
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
], # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
def _fetch_user_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches and formats relevant user information from User Memory.
|
||||
"""Fetches and formats relevant user information from User Memory.
|
||||
|
||||
Args:
|
||||
query (str): The search query to find relevant user memories.
|
||||
|
||||
Returns:
|
||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||
"""
|
||||
|
||||
"""
|
||||
if self.um is None:
|
||||
return ""
|
||||
|
||||
@@ -128,12 +123,14 @@ class ContextualMemory:
|
||||
return f"User memories/preferences:\n{formatted_memories}"
|
||||
|
||||
def _fetch_external_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches and formats relevant information from External Memory.
|
||||
"""Fetches and formats relevant information from External Memory.
|
||||
|
||||
Args:
|
||||
query (str): The search query to find relevant information.
|
||||
|
||||
Returns:
|
||||
str: Formatted information as bullet points, or an empty string if none found.
|
||||
|
||||
"""
|
||||
if self.exm is None:
|
||||
return ""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
@@ -8,15 +7,14 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
"""
|
||||
EntityMemory class for managing structured information about entities
|
||||
"""EntityMemory class for managing structured information about entities
|
||||
and their relationships using SQLite storage.
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
@@ -26,8 +24,9 @@ class EntityMemory(Memory):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
msg,
|
||||
)
|
||||
storage = Mem0Storage(type="entities", crew=crew)
|
||||
else:
|
||||
@@ -63,4 +62,5 @@ class EntityMemory(Memory):
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the entity memory: {e}")
|
||||
msg = f"An error occurred while resetting the entity memory: {e}"
|
||||
raise Exception(msg)
|
||||
|
||||
@@ -5,7 +5,7 @@ class EntityMemoryItem:
|
||||
type: str,
|
||||
description: str,
|
||||
relationships: str,
|
||||
):
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.description = description
|
||||
|
||||
23
src/crewai/memory/external/external_memory.py
vendored
23
src/crewai/memory/external/external_memory.py
vendored
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -9,41 +9,44 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ExternalMemory(Memory):
|
||||
def __init__(self, storage: Optional[Storage] = None, **data: Any):
|
||||
def __init__(self, storage: Storage | None = None, **data: Any) -> None:
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@staticmethod
|
||||
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> Dict[str, Any]:
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
return {
|
||||
"mem0": ExternalMemory._configure_mem0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
|
||||
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
msg = "embedder_config is required"
|
||||
raise ValueError(msg)
|
||||
|
||||
if "provider" not in embedder_config:
|
||||
raise ValueError("embedder_config must include a 'provider' key")
|
||||
msg = "embedder_config must include a 'provider' key"
|
||||
raise ValueError(msg)
|
||||
|
||||
provider = embedder_config["provider"]
|
||||
supported_storages = ExternalMemory.external_supported_storages()
|
||||
if provider not in supported_storages:
|
||||
raise ValueError(f"Provider {provider} not supported")
|
||||
msg = f"Provider {provider} not supported"
|
||||
raise ValueError(msg)
|
||||
|
||||
return supported_storages[provider](crew, embedder_config.get("config", {}))
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
"""Saves a value into the external storage."""
|
||||
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ExternalMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
):
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
self.value = value
|
||||
self.metadata = metadata
|
||||
self.agent = agent
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -6,15 +6,14 @@ from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
class LongTermMemory(Memory):
|
||||
"""
|
||||
LongTermMemory class for managing cross runs data related to overall crew's
|
||||
"""LongTermMemory class for managing cross runs data related to overall crew's
|
||||
execution and performance.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
def __init__(self, storage=None, path=None) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
@@ -29,7 +28,7 @@ class LongTermMemory(Memory):
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
def search(self, task: str, latest_n: int = 3) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LongTermMemoryItem:
|
||||
@@ -8,9 +8,9 @@ class LongTermMemoryItem:
|
||||
task: str,
|
||||
expected_output: str,
|
||||
datetime: str,
|
||||
quality: Optional[Union[int, float]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
quality: float | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.quality = quality
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
"""Base class for memory, now supporting agent tags and generic metadata."""
|
||||
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
crew: Optional[Any] = None
|
||||
embedder_config: dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
|
||||
def __init__(self, storage: Any, **data: Any):
|
||||
def __init__(self, storage: Any, **data: Any) -> None:
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
if agent:
|
||||
@@ -33,9 +31,9 @@ class Memory(BaseModel):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
query=query, limit=limit, score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
def set_crew(self, crew: Any) -> "Memory":
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
@@ -8,17 +8,16 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
"""
|
||||
ShortTermMemory class for managing transient data related to immediate tasks
|
||||
"""ShortTermMemory class for managing transient data related to immediate tasks
|
||||
and interactions.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
@@ -28,8 +27,9 @@ class ShortTermMemory(Memory):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
msg,
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew)
|
||||
else:
|
||||
@@ -49,8 +49,8 @@ class ShortTermMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self._memory_provider == "mem0":
|
||||
@@ -65,13 +65,14 @@ class ShortTermMemory(Memory):
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
query=query, limit=limit, score_threshold=score_threshold,
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
msg = f"An error occurred while resetting the short-term memory: {e}"
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the short-term memory: {e}"
|
||||
msg,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ShortTermMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
data: Any,
|
||||
agent: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
agent: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
Base class for RAG-based Storage implementations.
|
||||
"""
|
||||
"""Base class for RAG-based Storage implementations."""
|
||||
|
||||
app: Any | None = None
|
||||
|
||||
@@ -13,9 +11,9 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
) -> None:
|
||||
self.type = type
|
||||
self.allow_reset = allow_reset
|
||||
self.embedder_config = embedder_config
|
||||
@@ -25,52 +23,44 @@ class BaseRAGStorage(ABC):
|
||||
def _initialize_agents(self) -> str:
|
||||
if self.crew:
|
||||
return "_".join(
|
||||
[self._sanitize_role(agent.role) for agent in self.crew.agents]
|
||||
[self._sanitize_role(agent.role) for agent in self.crew.agents],
|
||||
)
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value with metadata to the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
||||
self, text: str, metadata: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Generate an embedding for the given text and metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_app(self):
|
||||
"""Initialize the vector db."""
|
||||
pass
|
||||
|
||||
def setup_config(self, config: Dict[str, Any]):
|
||||
def setup_config(self, config: dict[str, Any]) -> None:
|
||||
"""Setup the config of the storage."""
|
||||
pass
|
||||
|
||||
def initialize_client(self):
|
||||
"""Initialize the client of the storage. This should setup the app and the db collection"""
|
||||
pass
|
||||
def initialize_client(self) -> None:
|
||||
"""Initialize the client of the storage. This should setup the app and the db collection."""
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Abstract base class defining the storage interface"""
|
||||
"""Abstract base class defining the storage interface."""
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int, score_threshold: float
|
||||
) -> Dict[str, Any] | List[Any]:
|
||||
self, query: str, limit: int, score_threshold: float,
|
||||
) -> dict[str, Any] | list[Any]:
|
||||
return {}
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.task import Task
|
||||
from crewai.utilities import Printer
|
||||
@@ -14,12 +14,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for kickoff task outputs storage.
|
||||
"""
|
||||
"""An updated SQLite storage class for kickoff task outputs storage."""
|
||||
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None
|
||||
self, db_path: str | None = None,
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
@@ -37,6 +35,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -52,22 +51,22 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
was_replayed BOOLEAN,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def add(
|
||||
self,
|
||||
task: Task,
|
||||
output: Dict[str, Any],
|
||||
output: dict[str, Any],
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
inputs: Dict[str, Any] = {},
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Add a new task output record to the database.
|
||||
|
||||
@@ -80,7 +79,10 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If saving the task output fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
@@ -103,7 +105,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def update(
|
||||
@@ -123,6 +125,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -136,7 +139,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
values.append(
|
||||
json.dumps(value, cls=CrewJSONEncoder)
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
else value,
|
||||
)
|
||||
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
|
||||
@@ -149,10 +152,10 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
logger.warning(f"No row found with task_index {task_index}. No update performed.")
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def load(self) -> List[Dict[str, Any]]:
|
||||
def load(self) -> list[dict[str, Any]]:
|
||||
"""Load all task output records from the database.
|
||||
|
||||
Returns:
|
||||
@@ -162,6 +165,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -190,7 +194,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def delete_all(self) -> None:
|
||||
@@ -201,6 +205,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -210,5 +215,5 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
"""An updated SQLite storage class for LTM data storage."""
|
||||
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None
|
||||
self, db_path: str | None = None,
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
@@ -24,10 +22,8 @@ class LTMSQLiteStorage:
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
def _initialize_db(self) -> None:
|
||||
"""Initializes the SQLite database and creates LTM table."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
@@ -40,7 +36,7 @@ class LTMSQLiteStorage:
|
||||
datetime TEXT,
|
||||
score REAL
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -53,9 +49,9 @@ class LTMSQLiteStorage:
|
||||
def save(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: Dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: Union[int, float],
|
||||
score: float,
|
||||
) -> None:
|
||||
"""Saves data to the LTM table with error handling."""
|
||||
try:
|
||||
@@ -76,8 +72,8 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
|
||||
def load(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
self, task_description: str, latest_n: int,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -125,4 +121,3 @@ class LTMSQLiteStorage:
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient
|
||||
|
||||
@@ -7,17 +7,15 @@ from crewai.memory.storage.interface import Storage
|
||||
|
||||
|
||||
class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
"""Extends Storage to handle embedding and searching across entities using Mem0."""
|
||||
|
||||
def __init__(self, type, crew=None, config=None):
|
||||
def __init__(self, type, crew=None, config=None) -> None:
|
||||
super().__init__()
|
||||
supported_types = ["user", "short_term", "long_term", "entities", "external"]
|
||||
if type not in supported_types:
|
||||
raise ValueError(
|
||||
f"Invalid type '{type}' for Mem0Storage. Must be one of: "
|
||||
+ ", ".join(supported_types)
|
||||
+ ", ".join(supported_types),
|
||||
)
|
||||
|
||||
self.memory_type = type
|
||||
@@ -29,7 +27,8 @@ class Mem0Storage(Storage):
|
||||
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
||||
user_id = self._get_user_id()
|
||||
if type == "user" and not user_id:
|
||||
raise ValueError("User ID is required for user memory type")
|
||||
msg = "User ID is required for user memory type"
|
||||
raise ValueError(msg)
|
||||
|
||||
# API key in memory config overrides the environment variable
|
||||
config = self._get_config()
|
||||
@@ -42,23 +41,20 @@ class Mem0Storage(Storage):
|
||||
if mem0_api_key:
|
||||
if mem0_org_id and mem0_project_id:
|
||||
self.memory = MemoryClient(
|
||||
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
|
||||
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id,
|
||||
)
|
||||
else:
|
||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||
elif mem0_local_config and len(mem0_local_config):
|
||||
self.memory = Memory.from_config(mem0_local_config)
|
||||
else:
|
||||
if mem0_local_config and len(mem0_local_config):
|
||||
self.memory = Memory.from_config(mem0_local_config)
|
||||
else:
|
||||
self.memory = Memory()
|
||||
self.memory = Memory()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
user_id = self._get_user_id()
|
||||
agent_name = self._get_agent_name()
|
||||
params = None
|
||||
@@ -97,7 +93,7 @@ class Mem0Storage(Storage):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
params = {"query": query, "limit": limit, "output_format": "v1.1"}
|
||||
if user_id := self._get_user_id():
|
||||
params["user_id"] = user_id
|
||||
@@ -120,7 +116,7 @@ class Mem0Storage(Storage):
|
||||
# automatically when the crew is created.
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["output_format"]
|
||||
|
||||
|
||||
results = self.memory.search(**params)
|
||||
return [r for r in results["results"] if r["score"] >= score_threshold]
|
||||
|
||||
@@ -133,12 +129,11 @@ class Mem0Storage(Storage):
|
||||
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return agents
|
||||
return "_".join(agents)
|
||||
|
||||
def _get_config(self) -> Dict[str, Any]:
|
||||
def _get_config(self) -> dict[str, Any]:
|
||||
return self.config or getattr(self, "memory_config", {}).get("config", {}) or {}
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
if self.memory:
|
||||
self.memory.reset()
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api import ClientAPI
|
||||
|
||||
@@ -32,16 +32,15 @@ def suppress_logging(
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
"""Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
@@ -55,11 +54,11 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self):
|
||||
def _set_embedder_config(self) -> None:
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
def _initialize_app(self) -> None:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
@@ -73,48 +72,44 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
try:
|
||||
self.collection = self.app.get_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
name=self.type, embedding_function=self.embedder_config,
|
||||
)
|
||||
except Exception:
|
||||
self.collection = self.app.create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
name=self.type, embedding_function=self.embedder_config,
|
||||
)
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
"""Ensures file name does not exceed max allowed by OS."""
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters.",
|
||||
)
|
||||
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
try:
|
||||
self._generate_embedding(value, metadata)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
logging.exception(f"Error during {self.type} save: {e!s}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
@@ -135,10 +130,10 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
logging.exception(f"Error during {self.type} search: {e!s}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
|
||||
@@ -160,8 +155,9 @@ class RAGStorage(BaseRAGStorage):
|
||||
# Ignore this specific error
|
||||
pass
|
||||
else:
|
||||
msg = f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
msg,
|
||||
)
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
@@ -170,5 +166,5 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
|
||||
)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.memory import Memory
|
||||
|
||||
|
||||
class UserMemory(Memory):
|
||||
"""
|
||||
UserMemory class for handling user memory storage and retrieval.
|
||||
"""UserMemory class for handling user memory storage and retrieval.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None):
|
||||
def __init__(self, crew=None) -> None:
|
||||
warnings.warn(
|
||||
"UserMemory is deprecated and will be removed in a future version. "
|
||||
"Please use ExternalMemory instead.",
|
||||
@@ -22,8 +21,9 @@ class UserMemory(Memory):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
msg,
|
||||
)
|
||||
storage = Mem0Storage(type="user", crew=crew)
|
||||
super().__init__(storage)
|
||||
@@ -31,8 +31,8 @@ class UserMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
||||
data = f"Remember the details about the user: {value}"
|
||||
@@ -44,15 +44,15 @@ class UserMemory(Memory):
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
results = self.storage.search(
|
||||
return self.storage.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the user memory: {e}")
|
||||
msg = f"An error occurred while resetting the user memory: {e}"
|
||||
raise Exception(msg)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class UserMemoryItem:
|
||||
def __init__(self, data: Any, user: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, data: Any, user: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
self.data = data
|
||||
self.user = user
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
|
||||
@@ -2,9 +2,7 @@ from enum import Enum
|
||||
|
||||
|
||||
class Process(str, Enum):
|
||||
"""
|
||||
Class representing the different processes that can be used to tackle tasks
|
||||
"""
|
||||
"""Class representing the different processes that can be used to tackle tasks."""
|
||||
|
||||
sequential = "sequential"
|
||||
hierarchical = "hierarchical"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.project.utils import memoize
|
||||
@@ -36,15 +36,13 @@ def task(func):
|
||||
def agent(func):
|
||||
"""Marks a method as a crew agent."""
|
||||
func.is_agent = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def llm(func):
|
||||
"""Marks a method as an LLM provider."""
|
||||
func.is_llm = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def output_json(cls):
|
||||
@@ -91,7 +89,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agents = self._original_agents.items()
|
||||
|
||||
# Instantiate tasks in order
|
||||
for task_name, task_method in tasks:
|
||||
for _task_name, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
@@ -100,7 +98,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agent_roles.add(agent_instance.role)
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for agent_name, agent_method in agents:
|
||||
for _agent_name, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
@@ -117,9 +115,9 @@ def crew(func) -> Callable[..., Crew]:
|
||||
|
||||
return wrapper
|
||||
|
||||
for _, callback in self._before_kickoff.items():
|
||||
for callback in self._before_kickoff.values():
|
||||
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
for _, callback in self._after_kickoff.items():
|
||||
for callback in self._after_kickoff.values():
|
||||
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
|
||||
return crew
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, TypeVar, cast
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
@@ -23,11 +24,11 @@ def CrewBase(cls: T) -> T:
|
||||
base_directory = Path(inspect.getfile(cls)).parent
|
||||
|
||||
original_agents_config_path = getattr(
|
||||
cls, "agents_config", "config/agents.yaml"
|
||||
cls, "agents_config", "config/agents.yaml",
|
||||
)
|
||||
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.load_configurations()
|
||||
self.map_all_agent_variables()
|
||||
@@ -49,22 +50,22 @@ def CrewBase(cls: T) -> T:
|
||||
}
|
||||
# Store specific function types
|
||||
self._original_tasks = self._filter_functions(
|
||||
self._original_functions, "is_task"
|
||||
self._original_functions, "is_task",
|
||||
)
|
||||
self._original_agents = self._filter_functions(
|
||||
self._original_functions, "is_agent"
|
||||
self._original_functions, "is_agent",
|
||||
)
|
||||
self._before_kickoff = self._filter_functions(
|
||||
self._original_functions, "is_before_kickoff"
|
||||
self._original_functions, "is_before_kickoff",
|
||||
)
|
||||
self._after_kickoff = self._filter_functions(
|
||||
self._original_functions, "is_after_kickoff"
|
||||
self._original_functions, "is_after_kickoff",
|
||||
)
|
||||
self._kickoff = self._filter_functions(
|
||||
self._original_functions, "is_kickoff"
|
||||
self._original_functions, "is_kickoff",
|
||||
)
|
||||
|
||||
def load_configurations(self):
|
||||
def load_configurations(self) -> None:
|
||||
"""Load agent and task configurations from YAML files."""
|
||||
if isinstance(self.original_agents_config_path, str):
|
||||
agents_config_path = (
|
||||
@@ -75,12 +76,12 @@ def CrewBase(cls: T) -> T:
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Agent config file not found at {agents_config_path}. "
|
||||
"Proceeding with empty agent configurations."
|
||||
"Proceeding with empty agent configurations.",
|
||||
)
|
||||
self.agents_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No agent configuration path provided. Proceeding with empty agent configurations."
|
||||
"No agent configuration path provided. Proceeding with empty agent configurations.",
|
||||
)
|
||||
self.agents_config = {}
|
||||
|
||||
@@ -93,22 +94,21 @@ def CrewBase(cls: T) -> T:
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Task config file not found at {tasks_config_path}. "
|
||||
"Proceeding with empty task configurations."
|
||||
"Proceeding with empty task configurations.",
|
||||
)
|
||||
self.tasks_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No task configuration path provided. Proceeding with empty task configurations."
|
||||
"No task configuration path provided. Proceeding with empty task configurations.",
|
||||
)
|
||||
self.tasks_config = {}
|
||||
|
||||
@staticmethod
|
||||
def load_yaml(config_path: Path):
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
with open(config_path, encoding="utf-8") as file:
|
||||
return yaml.safe_load(file)
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {config_path}")
|
||||
raise
|
||||
|
||||
def _get_all_functions(self):
|
||||
@@ -119,8 +119,8 @@ def CrewBase(cls: T) -> T:
|
||||
}
|
||||
|
||||
def _filter_functions(
|
||||
self, functions: Dict[str, Callable], attribute: str
|
||||
) -> Dict[str, Callable]:
|
||||
self, functions: dict[str, Callable], attribute: str,
|
||||
) -> dict[str, Callable]:
|
||||
return {
|
||||
name: func
|
||||
for name, func in functions.items()
|
||||
@@ -132,7 +132,7 @@ def CrewBase(cls: T) -> T:
|
||||
llms = self._filter_functions(all_functions, "is_llm")
|
||||
tool_functions = self._filter_functions(all_functions, "is_tool")
|
||||
cache_handler_functions = self._filter_functions(
|
||||
all_functions, "is_cache_handler"
|
||||
all_functions, "is_cache_handler",
|
||||
)
|
||||
callbacks = self._filter_functions(all_functions, "is_callback")
|
||||
|
||||
@@ -149,11 +149,11 @@ def CrewBase(cls: T) -> T:
|
||||
def _map_agent_variables(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_info: Dict[str, Any],
|
||||
llms: Dict[str, Callable],
|
||||
tool_functions: Dict[str, Callable],
|
||||
cache_handler_functions: Dict[str, Callable],
|
||||
callbacks: Dict[str, Callable],
|
||||
agent_info: dict[str, Any],
|
||||
llms: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
cache_handler_functions: dict[str, Callable],
|
||||
callbacks: dict[str, Callable],
|
||||
) -> None:
|
||||
if llm := agent_info.get("llm"):
|
||||
try:
|
||||
@@ -187,12 +187,12 @@ def CrewBase(cls: T) -> T:
|
||||
agents = self._filter_functions(all_functions, "is_agent")
|
||||
tasks = self._filter_functions(all_functions, "is_task")
|
||||
output_json_functions = self._filter_functions(
|
||||
all_functions, "is_output_json"
|
||||
all_functions, "is_output_json",
|
||||
)
|
||||
tool_functions = self._filter_functions(all_functions, "is_tool")
|
||||
callback_functions = self._filter_functions(all_functions, "is_callback")
|
||||
output_pydantic_functions = self._filter_functions(
|
||||
all_functions, "is_output_pydantic"
|
||||
all_functions, "is_output_pydantic",
|
||||
)
|
||||
|
||||
for task_name, task_info in self.tasks_config.items():
|
||||
@@ -210,13 +210,13 @@ def CrewBase(cls: T) -> T:
|
||||
def _map_task_variables(
|
||||
self,
|
||||
task_name: str,
|
||||
task_info: Dict[str, Any],
|
||||
agents: Dict[str, Callable],
|
||||
tasks: Dict[str, Callable],
|
||||
output_json_functions: Dict[str, Callable],
|
||||
tool_functions: Dict[str, Callable],
|
||||
callback_functions: Dict[str, Callable],
|
||||
output_pydantic_functions: Dict[str, Callable],
|
||||
task_info: dict[str, Any],
|
||||
agents: dict[str, Callable],
|
||||
tasks: dict[str, Callable],
|
||||
output_json_functions: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
callback_functions: dict[str, Callable],
|
||||
output_pydantic_functions: dict[str, Callable],
|
||||
) -> None:
|
||||
if context_list := task_info.get("context"):
|
||||
self.tasks_config[task_name]["context"] = [
|
||||
@@ -253,4 +253,4 @@ def CrewBase(cls: T) -> T:
|
||||
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
|
||||
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
|
||||
|
||||
return cast(T, WrappedClass)
|
||||
return cast("T", WrappedClass)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Fingerprint Module
|
||||
"""Fingerprint Module.
|
||||
|
||||
This module provides functionality for generating and validating unique identifiers
|
||||
for CrewAI agents. These identifiers are used for tracking, auditing, and security.
|
||||
@@ -7,14 +6,13 @@ for CrewAI agents. These identifiers are used for tracking, auditing, and securi
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class Fingerprint(BaseModel):
|
||||
"""
|
||||
A class for generating and managing unique identifiers for agents.
|
||||
"""A class for generating and managing unique identifiers for agents.
|
||||
|
||||
Each agent has dual identifiers:
|
||||
- Human-readable ID: For debugging and reference (derived from role if not specified)
|
||||
@@ -24,48 +22,54 @@ class Fingerprint(BaseModel):
|
||||
uuid_str (str): String representation of the UUID for this fingerprint, auto-generated
|
||||
created_at (datetime): When this fingerprint was created, auto-generated
|
||||
metadata (Dict[str, Any]): Additional metadata associated with this fingerprint
|
||||
|
||||
"""
|
||||
|
||||
uuid_str: str = Field(default_factory=lambda: str(uuid.uuid4()), description="String representation of the UUID")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="When this fingerprint was created")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this fingerprint")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this fingerprint")
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator('metadata')
|
||||
|
||||
@field_validator("metadata")
|
||||
@classmethod
|
||||
def validate_metadata(cls, v):
|
||||
"""Validate that metadata is a dictionary with string keys and valid values."""
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("Metadata must be a dictionary")
|
||||
|
||||
msg = "Metadata must be a dictionary"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Validate that all keys are strings
|
||||
for key, value in v.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Metadata keys must be strings, got {type(key)}")
|
||||
|
||||
msg = f"Metadata keys must be strings, got {type(key)}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Validate nested dictionaries (prevent deeply nested structures)
|
||||
if isinstance(value, dict):
|
||||
# Check for nested dictionaries (limit depth to 1)
|
||||
for nested_key, nested_value in value.items():
|
||||
if not isinstance(nested_key, str):
|
||||
raise ValueError(f"Nested metadata keys must be strings, got {type(nested_key)}")
|
||||
msg = f"Nested metadata keys must be strings, got {type(nested_key)}"
|
||||
raise ValueError(msg)
|
||||
if isinstance(nested_value, dict):
|
||||
raise ValueError("Metadata can only be nested one level deep")
|
||||
|
||||
msg = "Metadata can only be nested one level deep"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check for maximum metadata size (prevent DoS)
|
||||
if len(str(v)) > 10000: # Limit metadata size to 10KB
|
||||
raise ValueError("Metadata size exceeds maximum allowed (10KB)")
|
||||
|
||||
msg = "Metadata size exceeds maximum allowed (10KB)"
|
||||
raise ValueError(msg)
|
||||
|
||||
return v
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data) -> None:
|
||||
"""Initialize a Fingerprint with auto-generated uuid_str and created_at."""
|
||||
# Remove uuid_str and created_at from data to ensure they're auto-generated
|
||||
if 'uuid_str' in data:
|
||||
data.pop('uuid_str')
|
||||
if 'created_at' in data:
|
||||
data.pop('created_at')
|
||||
if "uuid_str" in data:
|
||||
data.pop("uuid_str")
|
||||
if "created_at" in data:
|
||||
data.pop("created_at")
|
||||
|
||||
# Call the parent constructor with the modified data
|
||||
super().__init__(**data)
|
||||
@@ -77,32 +81,33 @@ class Fingerprint(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def _generate_uuid(cls, seed: str) -> str:
|
||||
"""
|
||||
Generate a deterministic UUID based on a seed string.
|
||||
"""Generate a deterministic UUID based on a seed string.
|
||||
|
||||
Args:
|
||||
seed (str): The seed string to use for UUID generation
|
||||
|
||||
Returns:
|
||||
str: A string representation of the UUID consistently generated from the seed
|
||||
|
||||
"""
|
||||
if not isinstance(seed, str):
|
||||
raise ValueError("Seed must be a string")
|
||||
|
||||
msg = "Seed must be a string"
|
||||
raise ValueError(msg)
|
||||
|
||||
if not seed.strip():
|
||||
raise ValueError("Seed cannot be empty or whitespace")
|
||||
|
||||
msg = "Seed cannot be empty or whitespace"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create a deterministic UUID using v5 (SHA-1)
|
||||
# Custom namespace for CrewAI to enhance security
|
||||
|
||||
# Using a unique namespace specific to CrewAI to reduce collision risks
|
||||
CREW_AI_NAMESPACE = uuid.UUID('f47ac10b-58cc-4372-a567-0e02b2c3d479')
|
||||
CREW_AI_NAMESPACE = uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")
|
||||
return str(uuid.uuid5(CREW_AI_NAMESPACE, seed))
|
||||
|
||||
@classmethod
|
||||
def generate(cls, seed: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> 'Fingerprint':
|
||||
"""
|
||||
Static factory method to create a new Fingerprint.
|
||||
def generate(cls, seed: str | None = None, metadata: dict[str, Any] | None = None) -> "Fingerprint":
|
||||
"""Static factory method to create a new Fingerprint.
|
||||
|
||||
Args:
|
||||
seed (Optional[str]): A string to use as seed for the UUID generation.
|
||||
@@ -111,11 +116,12 @@ class Fingerprint(BaseModel):
|
||||
|
||||
Returns:
|
||||
Fingerprint: A new Fingerprint instance
|
||||
|
||||
"""
|
||||
fingerprint = cls(metadata=metadata or {})
|
||||
if seed:
|
||||
# For seed-based generation, we need to manually set the uuid_str after creation
|
||||
object.__setattr__(fingerprint, 'uuid_str', cls._generate_uuid(seed))
|
||||
object.__setattr__(fingerprint, "uuid_str", cls._generate_uuid(seed))
|
||||
return fingerprint
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -132,29 +138,29 @@ class Fingerprint(BaseModel):
|
||||
"""Hash of the fingerprint (based on UUID)."""
|
||||
return hash(self.uuid_str)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert the fingerprint to a dictionary representation.
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert the fingerprint to a dictionary representation.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary representation of the fingerprint
|
||||
|
||||
"""
|
||||
return {
|
||||
"uuid_str": self.uuid_str,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"metadata": self.metadata
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Fingerprint':
|
||||
"""
|
||||
Create a Fingerprint from a dictionary representation.
|
||||
def from_dict(cls, data: dict[str, Any]) -> "Fingerprint":
|
||||
"""Create a Fingerprint from a dictionary representation.
|
||||
|
||||
Args:
|
||||
data (Dict[str, Any]): Dictionary representation of a fingerprint
|
||||
|
||||
Returns:
|
||||
Fingerprint: A new Fingerprint instance
|
||||
|
||||
"""
|
||||
if not data:
|
||||
return cls()
|
||||
@@ -163,8 +169,8 @@ class Fingerprint(BaseModel):
|
||||
|
||||
# For consistency with existing stored fingerprints, we need to manually set these
|
||||
if "uuid_str" in data:
|
||||
object.__setattr__(fingerprint, 'uuid_str', data["uuid_str"])
|
||||
object.__setattr__(fingerprint, "uuid_str", data["uuid_str"])
|
||||
if "created_at" in data and isinstance(data["created_at"], str):
|
||||
object.__setattr__(fingerprint, 'created_at', datetime.fromisoformat(data["created_at"]))
|
||||
object.__setattr__(fingerprint, "created_at", datetime.fromisoformat(data["created_at"]))
|
||||
|
||||
return fingerprint
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user