mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-04 13:48:31 +00:00
Compare commits
7 Commits
lg-agent-r
...
devin/1747
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90980e8190 | ||
|
|
f6cdfc1099 | ||
|
|
39c4ed33bb | ||
|
|
1860026d61 | ||
|
|
46621113af | ||
|
|
ad1ea46bbb | ||
|
|
807dfe0558 |
47
.ruff.toml
47
.ruff.toml
@@ -2,3 +2,50 @@ exclude = [
|
||||
"templates",
|
||||
"__init__.py",
|
||||
]
|
||||
|
||||
[lint]
|
||||
select = ["ALL"]
|
||||
ignore = [
|
||||
"D100", # Missing docstring in public module
|
||||
"D101", # Missing docstring in public class
|
||||
"D102", # Missing docstring in public method
|
||||
"D103", # Missing docstring in public function
|
||||
"D104", # Missing docstring in public package
|
||||
"D105", # Missing docstring in magic method
|
||||
"D106", # Missing docstring in public nested class
|
||||
"D107", # Missing docstring in __init__
|
||||
"D205", # 1 blank line required between summary line and description
|
||||
"ANN001", # Missing type annotation for function argument
|
||||
"ANN002", # Missing type annotation for *args
|
||||
"ANN003", # Missing type annotation for **kwargs
|
||||
"ANN201", # Missing return type annotation for public function
|
||||
"ANN202", # Missing return type annotation for private function
|
||||
"ANN204", # Missing return type annotation for special method
|
||||
"ANN205", # Missing return type annotation for staticmethod
|
||||
"ANN206", # Missing return type annotation for classmethod
|
||||
"E501", # Line too long
|
||||
"PT011", # pytest.raises() without match parameter
|
||||
"PT012", # pytest.raises() block should contain a single simple statement
|
||||
"SIM117", # Use a single `with` statement with multiple contexts
|
||||
"PLR2004", # Magic value used in comparison
|
||||
"B017", # Do not assert blind exception
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"tests/*" = [
|
||||
"S101", # Allow assert in tests
|
||||
"SLF001", # Allow private member access in tests
|
||||
"DTZ001", # Allow datetime without tzinfo in tests
|
||||
"PTH107", # Allow os.remove instead of Path.unlink in tests
|
||||
"PTH118", # Allow os.path.join() in tests
|
||||
"PTH120", # Allow os.path.dirname() in tests
|
||||
"PTH123", # Allow open() instead of Path.open() in tests
|
||||
"PTH202", # Allow os.path.getsize in tests
|
||||
"PT012", # Allow multiple statements in pytest.raises() block in tests
|
||||
"SIM117", # Allow nested with statements in tests
|
||||
"PLR2004", # Allow magic values in tests
|
||||
"B017", # Allow asserting blind exceptions in tests
|
||||
]
|
||||
|
||||
[lint.isort]
|
||||
known-first-party = ["crewai"]
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
---
|
||||
title: "Agent Repository"
|
||||
description: "Store and retrieve agents for your CrewAI projects"
|
||||
---
|
||||
|
||||
# Agent Repository
|
||||
|
||||
The Agent Repository allows you to store, manage, and reuse agents across your CrewAI projects. This feature streamlines the development process by enabling you to configure agents once and use them in multiple projects.
|
||||
|
||||
## How It Works
|
||||
|
||||
When you create an agent in the CrewAI interface, it's stored in the Agent Repository. You can then initialize these agents in your code using the `from_repository` parameter.
|
||||
|
||||
## Usage
|
||||
|
||||
To use an agent from the repository in your CrewAI project, initialize it with the following code:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
# Initialize the agent with its role
|
||||
agent = Agent(from_repository="python-job-researcher")
|
||||
```
|
||||
|
||||
### Creating a Crew with Repository Agents
|
||||
|
||||
```python
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
agent = Agent(from_repository="python-job-researcher")
|
||||
|
||||
job_search_task = Task(
|
||||
description="Search for recent Python developer job listings online",
|
||||
expected_output="Markdown list of 5 recent Python developer jobs with details.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[job_search_task], verbose=True)
|
||||
|
||||
result = crew.kickoff()
|
||||
print(result)
|
||||
```
|
||||
|
||||
## Important Notes
|
||||
|
||||
- The `from_repository` value must match the agent's role in a URL-safe format.
|
||||
- If you change an agent's role after creation, you must update the `from_repository` value in your code accordingly, or you won't be able to find the agent anymore.
|
||||
- Make sure you have permission to use the agent as mentioned in the key points.
|
||||
|
||||
## Agent Configuration
|
||||
|
||||
When configuring an agent in the repository, you can specify:
|
||||
|
||||
1. **Role** - The agent's primary function (e.g., "Python Job Researcher")
|
||||
2. **Goal** - What the agent aims to achieve (e.g., "Find Python developer job opportunities")
|
||||
3. **Backstory** - Context for the agent's behavior
|
||||
4. **Tools** - Available capabilities for the agent to use when performing tasks
|
||||
5. **Visibility Controls** - Who can access and use the agent
|
||||
|
||||
## Managing Agents
|
||||
|
||||
The Agent Repository interface provides functionality to:
|
||||
|
||||
- View all available agents
|
||||
- Add new agents
|
||||
- Edit existing agents
|
||||
- Delete agents
|
||||
- View agent details including usage examples
|
||||
|
||||
By leveraging the Agent Repository, you can build more modular and reusable AI workflows while maintaining a central location for managing your agents.
|
||||
@@ -94,8 +94,10 @@ crewai = "crewai.cli.cli:crewai"
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = true
|
||||
disable_error_code = 'import-untyped'
|
||||
disable_error_code = 'import-untyped,union-attr'
|
||||
exclude = ["cli/templates"]
|
||||
implicit_optional = true
|
||||
strict_optional = false
|
||||
|
||||
[tool.bandit]
|
||||
exclude_dirs = ["src/crewai/cli/templates"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Type, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -20,7 +21,6 @@ from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Converter, Prompts
|
||||
from crewai.utilities.agent_utils import (
|
||||
get_tool_names,
|
||||
load_agent_from_repository,
|
||||
parse_tools,
|
||||
render_text_description_and_args,
|
||||
)
|
||||
@@ -68,40 +68,41 @@ class Agent(BaseAgent):
|
||||
step_callback: Callback to be executed after each step of the agent execution.
|
||||
knowledge_sources: Knowledge sources for the agent.
|
||||
embedder: Embedder configuration for the agent.
|
||||
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
max_execution_time: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
)
|
||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
step_callback: Optional[Any] = Field(
|
||||
step_callback: Any | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
use_system_prompt: Optional[bool] = Field(
|
||||
use_system_prompt: bool | None = Field(
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||
description="Language model that will run the agent.", default=None,
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
description="Language model that will run the agent.", default=None,
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
default=None, description="System format for the agent."
|
||||
system_template: str | None = Field(
|
||||
default=None, description="System format for the agent.",
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
default=None, description="Prompt format for the agent."
|
||||
prompt_template: str | None = Field(
|
||||
default=None, description="Prompt format for the agent.",
|
||||
)
|
||||
response_template: Optional[str] = Field(
|
||||
default=None, description="Response format for the agent."
|
||||
response_template: str | None = Field(
|
||||
default=None, description="Response format for the agent.",
|
||||
)
|
||||
allow_code_execution: Optional[bool] = Field(
|
||||
default=False, description="Enable code execution for the agent."
|
||||
allow_code_execution: bool | None = Field(
|
||||
default=False, description="Enable code execution for the agent.",
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
default=True,
|
||||
@@ -119,32 +120,22 @@ class Agent(BaseAgent):
|
||||
default="safe",
|
||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||
)
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
embedder: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
agent_knowledge_context: Optional[str] = Field(
|
||||
agent_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the agent.",
|
||||
)
|
||||
crew_knowledge_context: Optional[str] = Field(
|
||||
crew_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the crew.",
|
||||
)
|
||||
knowledge_search_query: Optional[str] = Field(
|
||||
knowledge_search_query: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||
)
|
||||
from_repository: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v):
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
@@ -152,7 +143,7 @@ class Agent(BaseAgent):
|
||||
|
||||
self.llm = create_llm(self.llm)
|
||||
if self.function_calling_llm and not isinstance(
|
||||
self.function_calling_llm, BaseLLM
|
||||
self.function_calling_llm, BaseLLM,
|
||||
):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
|
||||
@@ -164,12 +155,12 @@ class Agent(BaseAgent):
|
||||
|
||||
return self
|
||||
|
||||
def _setup_agent_executor(self):
|
||||
def _setup_agent_executor(self) -> None:
|
||||
if not self.cache_handler:
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
@@ -185,7 +176,8 @@ class Agent(BaseAgent):
|
||||
storage=self.knowledge_storage or None,
|
||||
)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
||||
msg = f"Invalid Knowledge Configuration: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def _is_any_available_memory(self) -> bool:
|
||||
"""Check if any memory is available."""
|
||||
@@ -207,8 +199,8 @@ class Agent(BaseAgent):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
@@ -224,6 +216,7 @@ class Agent(BaseAgent):
|
||||
TimeoutError: If execution exceeds the maximum execution time.
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
|
||||
"""
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
|
||||
@@ -239,18 +232,18 @@ class Agent(BaseAgent):
|
||||
# schema = json.dumps(task.output_json, indent=2)
|
||||
schema = generate_model_description(task.output_json)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
"formatted_task_instructions",
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema = generate_model_description(task.output_pydantic)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
"formatted_task_instructions",
|
||||
).format(output_format=schema)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
task=task_prompt, context=context,
|
||||
)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
@@ -278,25 +271,25 @@ class Agent(BaseAgent):
|
||||
)
|
||||
try:
|
||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
||||
task_prompt
|
||||
task_prompt,
|
||||
)
|
||||
if self.knowledge_search_query:
|
||||
agent_knowledge_snippets = self.knowledge.query(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
[self.knowledge_search_query], **knowledge_config,
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
self.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
agent_knowledge_snippets,
|
||||
)
|
||||
if self.agent_knowledge_context:
|
||||
task_prompt += self.agent_knowledge_context
|
||||
if self.crew:
|
||||
knowledge_snippets = self.crew.query_knowledge(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
[self.knowledge_search_query], **knowledge_config,
|
||||
)
|
||||
if knowledge_snippets:
|
||||
self.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
knowledge_snippets,
|
||||
)
|
||||
if self.crew_knowledge_context:
|
||||
task_prompt += self.crew_knowledge_context
|
||||
@@ -353,11 +346,12 @@ class Agent(BaseAgent):
|
||||
not isinstance(self.max_execution_time, int)
|
||||
or self.max_execution_time <= 0
|
||||
):
|
||||
msg = "Max Execution time must be a positive integer greater than zero"
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
msg,
|
||||
)
|
||||
result = self._execute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
task_prompt, task, self.max_execution_time,
|
||||
)
|
||||
else:
|
||||
result = self._execute_without_timeout(task_prompt, task)
|
||||
@@ -372,7 +366,7 @@ class Agent(BaseAgent):
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
@@ -384,7 +378,7 @@ class Agent(BaseAgent):
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
@@ -395,7 +389,7 @@ class Agent(BaseAgent):
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
result = self.execute_task(task, context, tools)
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
@@ -427,24 +421,27 @@ class Agent(BaseAgent):
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the timeout.
|
||||
RuntimeError: If execution fails for other reasons.
|
||||
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
self._execute_without_timeout, task_prompt=task_prompt, task=task
|
||||
self._execute_without_timeout, task_prompt=task_prompt, task=task,
|
||||
)
|
||||
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
future.cancel()
|
||||
msg = f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
||||
msg,
|
||||
)
|
||||
except Exception as e:
|
||||
future.cancel()
|
||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
||||
msg = f"Task execution failed: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||
"""Execute a task without a timeout.
|
||||
@@ -455,6 +452,7 @@ class Agent(BaseAgent):
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
|
||||
"""
|
||||
return self.agent_executor.invoke(
|
||||
{
|
||||
@@ -462,18 +460,19 @@ class Agent(BaseAgent):
|
||||
"tool_names": self.agent_executor.tools_names,
|
||||
"tools": self.agent_executor.tools_description,
|
||||
"ask_for_human_input": task.human_input,
|
||||
}
|
||||
},
|
||||
)["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
self, tools: list[BaseTool] | None = None, task=None,
|
||||
) -> None:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the CrewAgentExecutor class.
|
||||
|
||||
"""
|
||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
prompt = Prompts(
|
||||
@@ -490,7 +489,7 @@ class Agent(BaseAgent):
|
||||
|
||||
if self.response_template:
|
||||
stop_words.append(
|
||||
self.response_template.split("{{ .Response }}")[1].strip()
|
||||
self.response_template.split("{{ .Response }}")[1].strip(),
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
@@ -515,10 +514,9 @@ class Agent(BaseAgent):
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]):
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
return agent_tools.tools()
|
||||
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
@@ -534,7 +532,7 @@ class Agent(BaseAgent):
|
||||
return [CodeInterpreterTool(unsafe_mode=unsafe_mode)]
|
||||
except ModuleNotFoundError:
|
||||
self._logger.log(
|
||||
"info", "Coding tools not available. Install crewai_tools. "
|
||||
"info", "Coding tools not available. Install crewai_tools. ",
|
||||
)
|
||||
|
||||
def get_output_converter(self, llm, text, model, instructions):
|
||||
@@ -566,7 +564,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
@@ -576,48 +574,48 @@ class Agent(BaseAgent):
|
||||
search: This tool is used for search
|
||||
calculator: This tool is used for math
|
||||
"""
|
||||
description = "\n".join(
|
||||
return "\n".join(
|
||||
[
|
||||
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
|
||||
for tool in tools
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
return description
|
||||
|
||||
def _validate_docker_installation(self) -> None:
|
||||
"""Check if Docker is installed and running."""
|
||||
if not shutil.which("docker"):
|
||||
msg = f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}"
|
||||
raise RuntimeError(
|
||||
f"Docker is not installed. Please install Docker to use code execution with agent: {self.role}"
|
||||
msg,
|
||||
)
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "info"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
capture_output=True,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
msg = f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
raise RuntimeError(
|
||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
msg,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
|
||||
@property
|
||||
def fingerprint(self) -> Fingerprint:
|
||||
"""
|
||||
Get the agent's fingerprint.
|
||||
"""Get the agent's fingerprint.
|
||||
|
||||
Returns:
|
||||
Fingerprint: The agent's fingerprint
|
||||
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def set_fingerprint(self, fingerprint: Fingerprint):
|
||||
def set_fingerprint(self, fingerprint: Fingerprint) -> None:
|
||||
self.security_config.fingerprint = fingerprint
|
||||
|
||||
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
|
||||
@@ -630,7 +628,7 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
query = self.i18n.slice("knowledge_search_query").format(
|
||||
task_prompt=task_prompt
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
|
||||
if not isinstance(self.llm, BaseLLM):
|
||||
@@ -655,7 +653,7 @@ class Agent(BaseAgent):
|
||||
"content": rewriter_prompt,
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
],
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -677,11 +675,10 @@ class Agent(BaseAgent):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
"""Execute the agent with the given messages using a LiteAgent instance.
|
||||
|
||||
This method is useful when you want to use the Agent configuration but
|
||||
with the simpler and more direct execution flow of LiteAgent.
|
||||
@@ -694,6 +691,7 @@ class Agent(BaseAgent):
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
|
||||
"""
|
||||
lite_agent = LiteAgent(
|
||||
role=self.role,
|
||||
@@ -714,11 +712,10 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
"""Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
|
||||
This is the async version of the kickoff method.
|
||||
|
||||
@@ -730,6 +727,7 @@ class Agent(BaseAgent):
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
|
||||
"""
|
||||
lite_agent = LiteAgent(
|
||||
role=self.role,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
@@ -16,27 +16,27 @@ class BaseAgentAdapter(BaseAgent, ABC):
|
||||
"""
|
||||
|
||||
adapted_structured_output: bool = False
|
||||
_agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
_agent_config: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any):
|
||||
def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
super().__init__(adapted_agent=True, **kwargs)
|
||||
self._agent_config = agent_config
|
||||
|
||||
@abstractmethod
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure and adapt tools for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
tools: Optional list of BaseTool instances to be configured
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def configure_structured_output(self, structured_output: Any) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
structured_output: The structured output to be configured
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -8,7 +8,7 @@ class BaseConverterAdapter(ABC):
|
||||
converter adapters must implement for converting structured output.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
def __init__(self, agent_adapter) -> None:
|
||||
self.agent_adapter = agent_adapter
|
||||
|
||||
@abstractmethod
|
||||
@@ -16,14 +16,11 @@ class BaseConverterAdapter(ABC):
|
||||
"""Configure agents to return structured output.
|
||||
Must support json and pydantic output.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enhance_system_prompt(self, base_prompt: str) -> str:
|
||||
"""Enhance the system prompt with structured output instructions."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""Post-process the result to ensure it matches the expected format: string."""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
@@ -12,23 +12,23 @@ class BaseToolAdapter(ABC):
|
||||
different frameworks and platforms.
|
||||
"""
|
||||
|
||||
original_tools: List[BaseTool]
|
||||
converted_tools: List[Any]
|
||||
original_tools: list[BaseTool]
|
||||
converted_tools: list[Any]
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
|
||||
@abstractmethod
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure and convert tools for the specific implementation.
|
||||
|
||||
Args:
|
||||
tools: List of BaseTool instances to be configured and converted
|
||||
"""
|
||||
pass
|
||||
|
||||
def tools(self) -> List[Any]:
|
||||
"""
|
||||
|
||||
def tools(self) -> list[Any]:
|
||||
"""Return all converted tools."""
|
||||
return self.converted_tools
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, AsyncIterable, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -52,16 +52,17 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
role: str,
|
||||
goal: str,
|
||||
backstory: str,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
llm: Any = None,
|
||||
max_iterations: int = 10,
|
||||
agent_config: Optional[Dict[str, Any]] = None,
|
||||
agent_config: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the LangGraph agent adapter."""
|
||||
if not LANGGRAPH_AVAILABLE:
|
||||
msg = "LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
|
||||
raise ImportError(
|
||||
"LangGraph Agent Dependencies are not installed. Please install it using `uv add langchain-core langgraph`"
|
||||
msg,
|
||||
)
|
||||
super().__init__(
|
||||
role=role,
|
||||
@@ -82,7 +83,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
try:
|
||||
self._memory = MemorySaver()
|
||||
|
||||
converted_tools: List[Any] = self._tool_adapter.tools()
|
||||
converted_tools: list[Any] = self._tool_adapter.tools()
|
||||
if self._agent_config:
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
@@ -101,18 +102,18 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
except ImportError as e:
|
||||
self._logger.log(
|
||||
"error", f"Failed to import LangGraph dependencies: {str(e)}"
|
||||
"error", f"Failed to import LangGraph dependencies: {e!s}",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error setting up LangGraph agent: {str(e)}")
|
||||
self._logger.log("error", f"Error setting up LangGraph agent: {e!s}")
|
||||
raise
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build a system prompt for the LangGraph agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -124,8 +125,8 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task using the LangGraph workflow."""
|
||||
self.create_agent_executor(tools)
|
||||
@@ -137,7 +138,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
task=task_prompt, context=context,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -159,7 +160,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
"messages": [
|
||||
("system", self._build_system_prompt()),
|
||||
("user", task_prompt),
|
||||
]
|
||||
],
|
||||
},
|
||||
config,
|
||||
)
|
||||
@@ -180,14 +181,14 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(
|
||||
agent=self, task=task, output=final_answer
|
||||
agent=self, task=task, output=final_answer,
|
||||
),
|
||||
)
|
||||
|
||||
return final_answer
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error executing LangGraph task: {str(e)}")
|
||||
self._logger.log("error", f"Error executing LangGraph task: {e!s}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -198,11 +199,11 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure the LangGraph agent for execution."""
|
||||
self.configure_tools(tools)
|
||||
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure tools for the LangGraph agent."""
|
||||
if tools:
|
||||
all_tools = list(self.tools or []) + list(tools or [])
|
||||
@@ -210,13 +211,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
available_tools = self._tool_adapter.tools()
|
||||
self._graph.tools = available_tools
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph."""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
|
||||
def get_output_converter(
|
||||
self, llm: Any, text: str, model: Any, instructions: str
|
||||
self, llm: Any, text: str, model: Any, instructions: str,
|
||||
) -> Any:
|
||||
"""Convert output format if needed."""
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
@@ -1,29 +1,25 @@
|
||||
import inspect
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_adapters.base_tool_adapter import BaseToolAdapter
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class LangGraphToolAdapter(BaseToolAdapter):
|
||||
"""Adapts CrewAI tools to LangGraph agent tool compatible format"""
|
||||
"""Adapts CrewAI tools to LangGraph agent tool compatible format."""
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
"""
|
||||
Configure and convert CrewAI tools to LangGraph-compatible format.
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure and convert CrewAI tools to LangGraph-compatible format.
|
||||
LangGraph expects tools in langchain_core.tools format.
|
||||
"""
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
|
||||
converted_tools = []
|
||||
if self.original_tools:
|
||||
all_tools = tools + self.original_tools
|
||||
else:
|
||||
all_tools = tools
|
||||
all_tools = tools + self.original_tools if self.original_tools else tools
|
||||
for tool in all_tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
converted_tools.append(tool)
|
||||
@@ -57,5 +53,5 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
|
||||
self.converted_tools = converted_tools
|
||||
|
||||
def tools(self) -> List[Any]:
|
||||
def tools(self) -> list[Any]:
|
||||
return self.converted_tools or []
|
||||
|
||||
@@ -5,10 +5,10 @@ from crewai.utilities.converter import generate_model_description
|
||||
|
||||
|
||||
class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
"""Adapter for handling structured output conversion in LangGraph agents"""
|
||||
"""Adapter for handling structured output conversion in LangGraph agents."""
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
def __init__(self, agent_adapter) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter."""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
@@ -32,7 +32,7 @@ class LangGraphConverterAdapter(BaseConverterAdapter):
|
||||
self._system_prompt_appendix = self._generate_system_prompt_appendix()
|
||||
|
||||
def _generate_system_prompt_appendix(self) -> str:
|
||||
"""Generate an appendix for the system prompt to enforce structured output"""
|
||||
"""Generate an appendix for the system prompt to enforce structured output."""
|
||||
if not self._output_format or not self._schema:
|
||||
return ""
|
||||
|
||||
@@ -41,19 +41,19 @@ Important: Your final answer MUST be provided in the following structured format
|
||||
|
||||
{self._schema}
|
||||
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
DO NOT include any markdown code blocks, backticks, or other formatting around your response.
|
||||
The output should be raw JSON that exactly matches the specified schema.
|
||||
"""
|
||||
|
||||
def enhance_system_prompt(self, original_prompt: str) -> str:
|
||||
"""Add structured output instructions to the system prompt if needed"""
|
||||
"""Add structured output instructions to the system prompt if needed."""
|
||||
if not self._system_prompt_appendix:
|
||||
return original_prompt
|
||||
|
||||
return f"{original_prompt}\n{self._system_prompt_appendix}"
|
||||
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""Post-process the result to ensure it matches the expected format"""
|
||||
"""Post-process the result to ensure it matches the expected format."""
|
||||
if not self._output_format:
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
@@ -29,13 +29,13 @@ except ImportError:
|
||||
|
||||
|
||||
class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
"""Adapter for OpenAI Assistants"""
|
||||
"""Adapter for OpenAI Assistants."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
_openai_agent: "OpenAIAgent" = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger())
|
||||
_active_thread: Optional[str] = PrivateAttr(default=None)
|
||||
_active_thread: str | None = PrivateAttr(default=None)
|
||||
function_calling_llm: Any = Field(default=None)
|
||||
step_callback: Any = Field(default=None)
|
||||
_tool_adapter: "OpenAIAgentToolAdapter" = PrivateAttr()
|
||||
@@ -44,35 +44,35 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o-mini",
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
agent_config: Optional[dict] = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
agent_config: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
if not OPENAI_AVAILABLE:
|
||||
msg = "OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
|
||||
raise ImportError(
|
||||
"OpenAI Agent Dependencies are not installed. Please install it using `uv add openai-agents`"
|
||||
msg,
|
||||
)
|
||||
else:
|
||||
role = kwargs.pop("role", None)
|
||||
goal = kwargs.pop("goal", None)
|
||||
backstory = kwargs.pop("backstory", None)
|
||||
super().__init__(
|
||||
role=role,
|
||||
goal=goal,
|
||||
backstory=backstory,
|
||||
tools=tools,
|
||||
agent_config=agent_config,
|
||||
**kwargs,
|
||||
)
|
||||
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
|
||||
self.llm = model
|
||||
self._converter_adapter = OpenAIConverterAdapter(self)
|
||||
role = kwargs.pop("role", None)
|
||||
goal = kwargs.pop("goal", None)
|
||||
backstory = kwargs.pop("backstory", None)
|
||||
super().__init__(
|
||||
role=role,
|
||||
goal=goal,
|
||||
backstory=backstory,
|
||||
tools=tools,
|
||||
agent_config=agent_config,
|
||||
**kwargs,
|
||||
)
|
||||
self._tool_adapter = OpenAIAgentToolAdapter(tools=tools)
|
||||
self.llm = model
|
||||
self._converter_adapter = OpenAIConverterAdapter(self)
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build a system prompt for the OpenAI agent."""
|
||||
base_prompt = f"""
|
||||
You are {self.role}.
|
||||
|
||||
|
||||
Your goal is: {self.goal}
|
||||
|
||||
Your backstory: {self.backstory}
|
||||
@@ -84,10 +84,10 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task using the OpenAI Assistant"""
|
||||
"""Execute a task using the OpenAI Assistant."""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
self.create_agent_executor(tools)
|
||||
|
||||
@@ -98,7 +98,7 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
task_prompt = task.prompt()
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
task=task_prompt, context=context,
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -114,13 +114,13 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(
|
||||
agent=self, task=task, output=final_answer
|
||||
agent=self, task=task, output=final_answer,
|
||||
),
|
||||
)
|
||||
return final_answer
|
||||
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error executing OpenAI task: {str(e)}")
|
||||
self._logger.log("error", f"Error executing OpenAI task: {e!s}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -131,9 +131,8 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
)
|
||||
raise
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""
|
||||
Configure the OpenAI agent for execution.
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure the OpenAI agent for execution.
|
||||
While OpenAI handles execution differently through Runner,
|
||||
we can use this method to set up tools and configurations.
|
||||
"""
|
||||
@@ -152,27 +151,27 @@ class OpenAIAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
self.agent_executor = Runner
|
||||
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure tools for the OpenAI Assistant."""
|
||||
if tools:
|
||||
self._tool_adapter.configure_tools(tools)
|
||||
if self._tool_adapter.converted_tools:
|
||||
self._openai_agent.tools = self._tool_adapter.converted_tools
|
||||
|
||||
def handle_execution_result(self, result: Any) -> str:
|
||||
"""Process OpenAI Assistant execution result converting any structured output to a string"""
|
||||
"""Process OpenAI Assistant execution result converting any structured output to a string."""
|
||||
return self._converter_adapter.post_process_result(result.final_output)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
"""Implement delegation tools support"""
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]) -> list[BaseTool]:
|
||||
"""Implement delegation tools support."""
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
return agent_tools.tools()
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
structured_output: The structured output to be configured
|
||||
|
||||
"""
|
||||
self._converter_adapter.configure_structured_output(task)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from agents import FunctionTool, Tool
|
||||
|
||||
@@ -8,42 +8,36 @@ from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
"""Adapter for OpenAI Assistant tools"""
|
||||
"""Adapter for OpenAI Assistant tools."""
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
def __init__(self, tools: list[BaseTool] | None = None) -> None:
|
||||
self.original_tools = tools or []
|
||||
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
"""Configure tools for the OpenAI Assistant"""
|
||||
if self.original_tools:
|
||||
all_tools = tools + self.original_tools
|
||||
else:
|
||||
all_tools = tools
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure tools for the OpenAI Assistant."""
|
||||
all_tools = tools + self.original_tools if self.original_tools else tools
|
||||
if all_tools:
|
||||
self.converted_tools = self._convert_tools_to_openai_format(all_tools)
|
||||
|
||||
def _convert_tools_to_openai_format(
|
||||
self, tools: Optional[List[BaseTool]]
|
||||
) -> List[Tool]:
|
||||
"""Convert CrewAI tools to OpenAI Assistant tool format"""
|
||||
self, tools: list[BaseTool] | None,
|
||||
) -> list[Tool]:
|
||||
"""Convert CrewAI tools to OpenAI Assistant tool format."""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
def sanitize_tool_name(name: str) -> str:
|
||||
"""Convert tool name to match OpenAI's required pattern"""
|
||||
"""Convert tool name to match OpenAI's required pattern."""
|
||||
import re
|
||||
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
|
||||
return sanitized
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name).lower()
|
||||
|
||||
def create_tool_wrapper(tool: BaseTool):
|
||||
"""Create a wrapper function that handles the OpenAI function tool interface"""
|
||||
"""Create a wrapper function that handles the OpenAI function tool interface."""
|
||||
|
||||
async def wrapper(context_wrapper: Any, arguments: Any) -> Any:
|
||||
# Get the parameter name from the schema
|
||||
param_name = list(
|
||||
tool.args_schema.model_json_schema()["properties"].keys()
|
||||
)[0]
|
||||
param_name = next(iter(tool.args_schema.model_json_schema()["properties"].keys()))
|
||||
|
||||
# Handle different argument types
|
||||
if isinstance(arguments, dict):
|
||||
|
||||
@@ -7,8 +7,7 @@ from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
"""
|
||||
Adapter for handling structured output conversion in OpenAI agents.
|
||||
"""Adapter for handling structured output conversion in OpenAI agents.
|
||||
|
||||
This adapter enhances the OpenAI agent to handle structured output formats
|
||||
and post-processes the results when needed.
|
||||
@@ -17,21 +16,22 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
_output_format: The expected output format (json, pydantic, or None)
|
||||
_schema: The schema description for the expected output
|
||||
_output_model: The Pydantic model for the output
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, agent_adapter):
|
||||
"""Initialize the converter adapter with a reference to the agent adapter"""
|
||||
def __init__(self, agent_adapter) -> None:
|
||||
"""Initialize the converter adapter with a reference to the agent adapter."""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format = None
|
||||
self._schema = None
|
||||
self._output_model = None
|
||||
|
||||
def configure_structured_output(self, task) -> None:
|
||||
"""
|
||||
Configure the structured output for OpenAI agent based on task requirements.
|
||||
"""Configure the structured output for OpenAI agent based on task requirements.
|
||||
|
||||
Args:
|
||||
task: The task containing output format requirements
|
||||
|
||||
"""
|
||||
# Reset configuration
|
||||
self._output_format = None
|
||||
@@ -55,14 +55,14 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
self._output_model = task.output_pydantic
|
||||
|
||||
def enhance_system_prompt(self, base_prompt: str) -> str:
|
||||
"""
|
||||
Enhance the base system prompt with structured output requirements if needed.
|
||||
"""Enhance the base system prompt with structured output requirements if needed.
|
||||
|
||||
Args:
|
||||
base_prompt: The original system prompt
|
||||
|
||||
Returns:
|
||||
Enhanced system prompt with output format instructions if needed
|
||||
|
||||
"""
|
||||
if not self._output_format:
|
||||
return base_prompt
|
||||
@@ -76,8 +76,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
return f"{base_prompt}\n\n{output_schema}"
|
||||
|
||||
def post_process_result(self, result: str) -> str:
|
||||
"""
|
||||
Post-process the result to ensure it matches the expected format.
|
||||
"""Post-process the result to ensure it matches the expected format.
|
||||
|
||||
This method attempts to extract valid JSON from the result if necessary.
|
||||
|
||||
@@ -86,6 +85,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
||||
|
||||
Returns:
|
||||
Processed result conforming to the expected output format
|
||||
|
||||
"""
|
||||
if not self._output_format:
|
||||
return result
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -14,6 +15,7 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
@@ -25,7 +27,6 @@ from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
T = TypeVar("T", bound="BaseAgent")
|
||||
@@ -77,30 +78,31 @@ class BaseAgent(ABC, BaseModel):
|
||||
Set the rpm controller for the agent.
|
||||
set_private_attrs() -> "BaseAgent":
|
||||
Set private attributes.
|
||||
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||
_rpm_controller: RPMController | None = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||
_original_role: str | None = PrivateAttr(default=None)
|
||||
_original_goal: str | None = PrivateAttr(default=None)
|
||||
_original_backstory: str | None = PrivateAttr(default=None)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Objective of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True
|
||||
config: dict[str, Any] | None = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True,
|
||||
)
|
||||
cache: bool = Field(
|
||||
default=True, description="Whether the agent should use a cache for tool usage."
|
||||
default=True, description="Whether the agent should use a cache for tool usage.",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Verbose mode for the Agent Execution"
|
||||
default=False, description="Verbose mode for the Agent Execution",
|
||||
)
|
||||
max_rpm: Optional[int] = Field(
|
||||
max_rpm: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of requests per minute for the agent execution to be respected.",
|
||||
)
|
||||
@@ -108,41 +110,41 @@ class BaseAgent(ABC, BaseModel):
|
||||
default=False,
|
||||
description="Enable agent to delegate and ask questions among each other.",
|
||||
)
|
||||
tools: Optional[List[BaseTool]] = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
tools: list[BaseTool] | None = Field(
|
||||
default_factory=list, description="Tools at agents' disposal",
|
||||
)
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
default=25, description="Maximum iterations for an agent to execute a task",
|
||||
)
|
||||
agent_executor: InstanceOf = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
default=None, description="An instance of the CrewAgentExecutor class.",
|
||||
)
|
||||
llm: Any = Field(
|
||||
default=None, description="Language model that will run the agent."
|
||||
default=None, description="Language model that will run the agent.",
|
||||
)
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
cache_handler: InstanceOf[CacheHandler] | None = Field(
|
||||
default=None, description="An instance of the CacheHandler class.",
|
||||
)
|
||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||
default_factory=ToolsHandler,
|
||||
description="An instance of the ToolsHandler class.",
|
||||
)
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent.",
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution."
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution.",
|
||||
)
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
default=None, description="Knowledge for the agent."
|
||||
knowledge: Knowledge | None = Field(
|
||||
default=None, description="Knowledge for the agent.",
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: Optional[Any] = Field(
|
||||
knowledge_storage: Any | None = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
@@ -150,13 +152,13 @@ class BaseAgent(ABC, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the agent, including fingerprinting.",
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
callbacks: list[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent",
|
||||
)
|
||||
adapted_agent: bool = Field(
|
||||
default=False, description="Whether the agent is adapted"
|
||||
default=False, description="Whether the agent is adapted",
|
||||
)
|
||||
knowledge_config: Optional[KnowledgeConfig] = Field(
|
||||
knowledge_config: KnowledgeConfig | None = Field(
|
||||
default=None,
|
||||
description="Knowledge configuration for the agent such as limits and threshold",
|
||||
)
|
||||
@@ -168,7 +170,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@field_validator("tools")
|
||||
@classmethod
|
||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
||||
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
|
||||
"""Validate and process the tools provided to the agent.
|
||||
|
||||
This method ensures that each tool is either an instance of BaseTool
|
||||
@@ -188,11 +190,14 @@ class BaseAgent(ABC, BaseModel):
|
||||
# Tool has the required attributes, create a Tool instance
|
||||
processed_tools.append(Tool.from_langchain(tool))
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Invalid tool type: {type(tool)}. "
|
||||
"Tool must be an instance of BaseTool or "
|
||||
"an object with 'name', 'func', and 'description' attributes."
|
||||
)
|
||||
raise ValueError(
|
||||
msg,
|
||||
)
|
||||
return processed_tools
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -200,15 +205,16 @@ class BaseAgent(ABC, BaseModel):
|
||||
# Validate required fields
|
||||
for field in ["role", "goal", "backstory"]:
|
||||
if getattr(self, field) is None:
|
||||
msg = f"{field} must be provided either directly or through config"
|
||||
raise ValueError(
|
||||
f"{field} must be provided either directly or through config"
|
||||
msg,
|
||||
)
|
||||
|
||||
# Set private attributes
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
max_rpm=self.max_rpm, logger=self._logger,
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
@@ -221,10 +227,11 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
if v:
|
||||
msg = "may_not_set_field"
|
||||
raise PydanticCustomError(
|
||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||
msg, "This field is not to be set by the user.", {},
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -233,7 +240,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
max_rpm=self.max_rpm, logger=self._logger,
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
@@ -252,8 +259,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@@ -262,11 +269,10 @@ class BaseAgent(ABC, BaseModel):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
||||
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
|
||||
"""Set the task tools that init BaseAgenTools class."""
|
||||
pass
|
||||
|
||||
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||
def copy(self) -> Self: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||
"""Create a deep copy of the Agent."""
|
||||
exclude = {
|
||||
"id",
|
||||
@@ -309,7 +315,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
copied_agent = type(self)(
|
||||
return type(self)(
|
||||
**copied_data,
|
||||
llm=existing_llm,
|
||||
tools=self.tools,
|
||||
@@ -318,9 +324,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
knowledge_storage=copied_knowledge_storage,
|
||||
)
|
||||
|
||||
return copied_agent
|
||||
|
||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Interpolate inputs into the agent description and backstory."""
|
||||
if self._original_role is None:
|
||||
self._original_role = self.role
|
||||
@@ -331,13 +336,13 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
if inputs:
|
||||
self.role = interpolate_only(
|
||||
input_string=self._original_role, inputs=inputs
|
||||
input_string=self._original_role, inputs=inputs,
|
||||
)
|
||||
self.goal = interpolate_only(
|
||||
input_string=self._original_goal, inputs=inputs
|
||||
input_string=self._original_goal, inputs=inputs,
|
||||
)
|
||||
self.backstory = interpolate_only(
|
||||
input_string=self._original_backstory, inputs=inputs
|
||||
input_string=self._original_backstory, inputs=inputs,
|
||||
)
|
||||
|
||||
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
|
||||
@@ -345,6 +350,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
Args:
|
||||
cache_handler: An instance of the CacheHandler class.
|
||||
|
||||
"""
|
||||
self.tools_handler = ToolsHandler()
|
||||
if self.cache:
|
||||
@@ -357,10 +363,11 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
Args:
|
||||
rpm_controller: An instance of the RPMController class.
|
||||
|
||||
"""
|
||||
if not self._rpm_controller:
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -43,8 +44,7 @@ class CrewAgentExecutorMixin:
|
||||
},
|
||||
agent=self.agent.role,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to short term memory: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _create_external_memory(self, output) -> None:
|
||||
@@ -56,7 +56,7 @@ class CrewAgentExecutorMixin:
|
||||
and hasattr(self.crew, "_external_memory")
|
||||
and self.crew._external_memory
|
||||
):
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
self.crew._external_memory.save(
|
||||
value=output.text,
|
||||
metadata={
|
||||
@@ -64,9 +64,6 @@ class CrewAgentExecutorMixin:
|
||||
},
|
||||
agent=self.agent.role,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to external memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_long_term_memory(self, output) -> None:
|
||||
"""Create and save long-term and entity memory items based on evaluation."""
|
||||
@@ -103,15 +100,13 @@ class CrewAgentExecutorMixin:
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
relationships="\n".join(
|
||||
[f"- {r}" for r in entity.relationships]
|
||||
[f"- {r}" for r in entity.relationships],
|
||||
),
|
||||
)
|
||||
self.crew._entity_memory.save(entity_memory)
|
||||
except AttributeError as e:
|
||||
print(f"Missing attributes for long term memory: {e}")
|
||||
except AttributeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Failed to add to long term memory: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
elif (
|
||||
self.crew
|
||||
@@ -126,7 +121,7 @@ class CrewAgentExecutorMixin:
|
||||
def _ask_human_input(self, final_answer: str) -> str:
|
||||
"""Prompt human input with mode-appropriate messaging."""
|
||||
self._printer.print(
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m",
|
||||
)
|
||||
|
||||
# Training mode prompt (single iteration)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OutputConverter(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for converting task results into structured formats.
|
||||
"""Abstract base class for converting task results into structured formats.
|
||||
|
||||
This class provides a framework for converting unstructured text into
|
||||
either Pydantic models or JSON, tailored for specific agent requirements.
|
||||
@@ -19,6 +18,7 @@ class OutputConverter(BaseModel, ABC):
|
||||
model (Any): The target model for structuring the output.
|
||||
instructions (str): Specific instructions for the conversion process.
|
||||
max_attempts (int): Maximum number of conversion attempts (default: 3).
|
||||
|
||||
"""
|
||||
|
||||
text: str = Field(description="Text to be converted.")
|
||||
@@ -33,9 +33,7 @@ class OutputConverter(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def to_pydantic(self, current_attempt=1) -> BaseModel:
|
||||
"""Convert text to pydantic."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_json(self, current_attempt=1) -> dict:
|
||||
"""Convert text to json."""
|
||||
pass
|
||||
|
||||
8
src/crewai/agents/cache/cache_handler.py
vendored
8
src/crewai/agents/cache/cache_handler.py
vendored
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
@@ -6,10 +6,10 @@ from pydantic import BaseModel, PrivateAttr
|
||||
class CacheHandler(BaseModel):
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
_cache: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def add(self, tool, input, output):
|
||||
def add(self, tool, input, output) -> None:
|
||||
self._cache[f"{tool}-{input}"] = output
|
||||
|
||||
def read(self, tool, input) -> Optional[str]:
|
||||
def read(self, tool, input) -> str | None:
|
||||
return self._cache.get(f"{tool}-{input}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
@@ -10,8 +9,6 @@ from crewai.agents.parser import (
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities import I18N, Printer
|
||||
@@ -34,6 +31,10 @@ from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
_logger: Logger = Logger()
|
||||
@@ -46,18 +47,22 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
agent: BaseAgent,
|
||||
prompt: dict[str, str],
|
||||
max_iter: int,
|
||||
tools: List[CrewStructuredTool],
|
||||
tools: list[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: List[str],
|
||||
stop_words: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Any = None,
|
||||
original_tools: List[Any] = [],
|
||||
original_tools: list[Any] | None = None,
|
||||
function_calling_llm: Any = None,
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Optional[Callable[[], bool]] = None,
|
||||
callbacks: List[Any] = [],
|
||||
):
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
) -> None:
|
||||
if callbacks is None:
|
||||
callbacks = []
|
||||
if original_tools is None:
|
||||
original_tools = []
|
||||
self._i18n: I18N = I18N()
|
||||
self.llm: BaseLLM = llm
|
||||
self.task = task
|
||||
@@ -79,10 +84,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.respect_context_window = respect_context_window
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.ask_for_human_input = False
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
self.messages: list[dict[str, str]] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.tool_name_to_tool_map: Dict[str, Union[CrewStructuredTool, BaseTool]] = {
|
||||
self.tool_name_to_tool_map: dict[str, CrewStructuredTool | BaseTool] = {
|
||||
tool.name: tool for tool in self.tools
|
||||
}
|
||||
existing_stop = self.llm.stop or []
|
||||
@@ -90,11 +95,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
set(
|
||||
existing_stop + self.stop
|
||||
if isinstance(existing_stop, list)
|
||||
else self.stop
|
||||
)
|
||||
else self.stop,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
def invoke(self, inputs: dict[str, str]) -> dict[str, Any]:
|
||||
if "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
|
||||
user_prompt = self._format_prompt(self.prompt.get("user", ""), inputs)
|
||||
@@ -120,9 +125,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
handle_unknown_error(self._printer, e)
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
raise
|
||||
raise
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
@@ -133,8 +137,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""
|
||||
Main loop to invoke the agent's thought process until it reaches a conclusion
|
||||
"""Main loop to invoke the agent's thought process until it reaches a conclusion
|
||||
or the maximum number of iterations is reached.
|
||||
"""
|
||||
formatted_answer = None
|
||||
@@ -170,8 +173,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
):
|
||||
fingerprint_context = {
|
||||
"agent_fingerprint": str(
|
||||
self.agent.security_config.fingerprint
|
||||
)
|
||||
self.agent.security_config.fingerprint,
|
||||
),
|
||||
}
|
||||
|
||||
tool_result = execute_tool_and_check_finality(
|
||||
@@ -187,7 +190,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
)
|
||||
formatted_answer = self._handle_agent_action(
|
||||
formatted_answer, tool_result
|
||||
formatted_answer, tool_result,
|
||||
)
|
||||
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
@@ -205,7 +208,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
raise
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
respect_context_window=self.respect_context_window,
|
||||
@@ -216,9 +219,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
@@ -231,8 +233,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return formatted_answer
|
||||
|
||||
def _handle_agent_action(
|
||||
self, formatted_answer: AgentAction, tool_result: ToolResult
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
self, formatted_answer: AgentAction, tool_result: ToolResult,
|
||||
) -> AgentAction | AgentFinish:
|
||||
"""Handle the AgentAction, execute tools, and process the results."""
|
||||
# Special case for add_image_tool
|
||||
add_image_tool = self._i18n.tools("add_image")
|
||||
@@ -261,24 +263,26 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""Append a message to the message list with the given role."""
|
||||
self.messages.append(format_message_for_llm(text, role=role))
|
||||
|
||||
def _show_start_logs(self):
|
||||
def _show_start_logs(self) -> None:
|
||||
"""Show logs for the start of agent execution."""
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
msg = "Agent cannot be None"
|
||||
raise ValueError(msg)
|
||||
show_agent_logs(
|
||||
printer=self._printer,
|
||||
agent_role=self.agent.role,
|
||||
task_description=(
|
||||
getattr(self.task, "description") if self.task else "Not Found"
|
||||
self.task.description if self.task else "Not Found"
|
||||
),
|
||||
verbose=self.agent.verbose
|
||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
||||
)
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
|
||||
"""Show logs for the agent's execution."""
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
msg = "Agent cannot be None"
|
||||
raise ValueError(msg)
|
||||
show_agent_logs(
|
||||
printer=self._printer,
|
||||
agent_role=self.agent.role,
|
||||
@@ -300,11 +304,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
summary = self.llm.call(
|
||||
[
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("summarizer_system_message"), role="system"
|
||||
self._i18n.slice("summarizer_system_message"), role="system",
|
||||
),
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("summarize_instruction").format(
|
||||
group=group["content"]
|
||||
group=group["content"],
|
||||
),
|
||||
),
|
||||
],
|
||||
@@ -316,12 +320,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
self.messages = [
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("summary").format(merged_summary=merged_summary)
|
||||
)
|
||||
self._i18n.slice("summary").format(merged_summary=merged_summary),
|
||||
),
|
||||
]
|
||||
|
||||
def _handle_crew_training_output(
|
||||
self, result: AgentFinish, human_feedback: Optional[str] = None
|
||||
self, result: AgentFinish, human_feedback: str | None = None,
|
||||
) -> None:
|
||||
"""Handle the process of saving training data."""
|
||||
agent_id = str(self.agent.id) # type: ignore
|
||||
@@ -348,29 +352,27 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"initial_output": result.output,
|
||||
"human_feedback": human_feedback,
|
||||
}
|
||||
# Save improved output
|
||||
elif train_iteration in agent_training_data:
|
||||
agent_training_data[train_iteration]["improved_output"] = result.output
|
||||
else:
|
||||
# Save improved output
|
||||
if train_iteration in agent_training_data:
|
||||
agent_training_data[train_iteration]["improved_output"] = result.output
|
||||
else:
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"No existing training data for agent {agent_id} and iteration "
|
||||
f"{train_iteration}. Cannot save improved output."
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"No existing training data for agent {agent_id} and iteration "
|
||||
f"{train_iteration}. Cannot save improved output."
|
||||
),
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
# Update the training data and save
|
||||
training_data[agent_id] = agent_training_data
|
||||
training_handler.save(training_data)
|
||||
|
||||
def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str:
|
||||
def _format_prompt(self, prompt: str, inputs: dict[str, str]) -> str:
|
||||
prompt = prompt.replace("{input}", inputs["input"])
|
||||
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
|
||||
prompt = prompt.replace("{tools}", inputs["tools"])
|
||||
return prompt
|
||||
return prompt.replace("{tools}", inputs["tools"])
|
||||
|
||||
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
||||
"""Handle human feedback with different flows for training vs regular use.
|
||||
@@ -380,6 +382,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
Returns:
|
||||
AgentFinish: The final answer after processing feedback
|
||||
|
||||
"""
|
||||
human_feedback = self._ask_human_input(formatted_answer.output)
|
||||
|
||||
@@ -393,14 +396,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return bool(self.crew and self.crew._train)
|
||||
|
||||
def _handle_training_feedback(
|
||||
self, initial_answer: AgentFinish, feedback: str
|
||||
self, initial_answer: AgentFinish, feedback: str,
|
||||
) -> AgentFinish:
|
||||
"""Process feedback for training scenarios with single iteration."""
|
||||
self._handle_crew_training_output(initial_answer, feedback)
|
||||
self.messages.append(
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback),
|
||||
),
|
||||
)
|
||||
improved_answer = self._invoke_loop()
|
||||
self._handle_crew_training_output(improved_answer)
|
||||
@@ -408,7 +411,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return improved_answer
|
||||
|
||||
def _handle_regular_feedback(
|
||||
self, current_answer: AgentFinish, initial_feedback: str
|
||||
self, current_answer: AgentFinish, initial_feedback: str,
|
||||
) -> AgentFinish:
|
||||
"""Process feedback for regular use with potential multiple iterations."""
|
||||
feedback = initial_feedback
|
||||
@@ -428,8 +431,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""Process a single feedback iteration."""
|
||||
self.messages.append(
|
||||
format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback),
|
||||
),
|
||||
)
|
||||
return self._invoke_loop()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
@@ -18,7 +18,7 @@ class AgentAction:
|
||||
text: str
|
||||
result: str
|
||||
|
||||
def __init__(self, thought: str, tool: str, tool_input: str, text: str):
|
||||
def __init__(self, thought: str, tool: str, tool_input: str, text: str) -> None:
|
||||
self.thought = thought
|
||||
self.tool = tool
|
||||
self.tool_input = tool_input
|
||||
@@ -30,7 +30,7 @@ class AgentFinish:
|
||||
output: str
|
||||
text: str
|
||||
|
||||
def __init__(self, thought: str, output: str, text: str):
|
||||
def __init__(self, thought: str, output: str, text: str) -> None:
|
||||
self.thought = thought
|
||||
self.output = output
|
||||
self.text = text
|
||||
@@ -39,7 +39,7 @@ class AgentFinish:
|
||||
class OutputParserException(Exception):
|
||||
error: str
|
||||
|
||||
def __init__(self, error: str):
|
||||
def __init__(self, error: str) -> None:
|
||||
self.error = error
|
||||
|
||||
|
||||
@@ -67,24 +67,24 @@ class CrewAgentParser:
|
||||
_i18n: I18N = I18N()
|
||||
agent: Any = None
|
||||
|
||||
def __init__(self, agent: Optional[Any] = None):
|
||||
def __init__(self, agent: Any | None = None) -> None:
|
||||
self.agent = agent
|
||||
|
||||
@staticmethod
|
||||
def parse_text(text: str) -> Union[AgentAction, AgentFinish]:
|
||||
"""
|
||||
Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
|
||||
def parse_text(text: str) -> AgentAction | AgentFinish:
|
||||
"""Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
|
||||
|
||||
Args:
|
||||
text: The text to parse.
|
||||
|
||||
Returns:
|
||||
Either an AgentAction or AgentFinish based on the parsed content.
|
||||
|
||||
"""
|
||||
parser = CrewAgentParser()
|
||||
return parser.parse(text)
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
def parse(self, text: str) -> AgentAction | AgentFinish:
|
||||
thought = self._extract_thought(text)
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
regex = (
|
||||
@@ -102,7 +102,7 @@ class CrewAgentParser:
|
||||
final_answer = final_answer[:-3].rstrip()
|
||||
return AgentFinish(thought, final_answer, text)
|
||||
|
||||
elif action_match:
|
||||
if action_match:
|
||||
action = action_match.group(1)
|
||||
clean_action = self._clean_action(action)
|
||||
|
||||
@@ -114,21 +114,21 @@ class CrewAgentParser:
|
||||
return AgentAction(thought, clean_action, safe_tool_input, text)
|
||||
|
||||
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", text, re.DOTALL):
|
||||
msg = f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}"
|
||||
raise OutputParserException(
|
||||
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{self._i18n.slice('final_answer_format')}",
|
||||
msg,
|
||||
)
|
||||
elif not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL
|
||||
if not re.search(
|
||||
r"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", text, re.DOTALL,
|
||||
):
|
||||
raise OutputParserException(
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
)
|
||||
else:
|
||||
format = self._i18n.slice("format_without_tools")
|
||||
error = f"{format}"
|
||||
raise OutputParserException(
|
||||
error,
|
||||
)
|
||||
format = self._i18n.slice("format_without_tools")
|
||||
error = f"{format}"
|
||||
raise OutputParserException(
|
||||
error,
|
||||
)
|
||||
|
||||
def _extract_thought(self, text: str) -> str:
|
||||
thought_index = text.find("\nAction")
|
||||
@@ -138,8 +138,7 @@ class CrewAgentParser:
|
||||
return ""
|
||||
thought = text[:thought_index].strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
thought = thought.replace("```", "").strip()
|
||||
return thought
|
||||
return thought.replace("```", "").strip()
|
||||
|
||||
def _clean_action(self, text: str) -> str:
|
||||
"""Clean action string by removing non-essential formatting characters."""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools.cache_tools.cache_tools import CacheTools
|
||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
|
||||
from ..tools.cache_tools.cache_tools import CacheTools
|
||||
from ..tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
from .cache.cache_handler import CacheHandler
|
||||
|
||||
|
||||
@@ -9,16 +10,16 @@ class ToolsHandler:
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
last_used_tool: ToolCalling = {} # type: ignore # BUG?: Incompatible types in assignment (expression has type "Dict[...]", variable has type "ToolCalling")
|
||||
cache: Optional[CacheHandler]
|
||||
cache: CacheHandler | None
|
||||
|
||||
def __init__(self, cache: Optional[CacheHandler] = None):
|
||||
def __init__(self, cache: CacheHandler | None = None) -> None:
|
||||
"""Initialize the callback handler."""
|
||||
self.cache = cache
|
||||
self.last_used_tool = {} # type: ignore # BUG?: same as above
|
||||
|
||||
def on_tool_use(
|
||||
self,
|
||||
calling: Union[ToolCalling, InstructorToolCalling],
|
||||
calling: ToolCalling | InstructorToolCalling,
|
||||
output: str,
|
||||
should_cache: bool = True,
|
||||
) -> Any:
|
||||
|
||||
@@ -9,9 +9,9 @@ def add_crew_to_flow(crew_name: str) -> None:
|
||||
"""Add a new crew to the current flow."""
|
||||
# Check if pyproject.toml exists in the current directory
|
||||
if not Path("pyproject.toml").exists():
|
||||
print("This command must be run from the root of a flow project.")
|
||||
msg = "This command must be run from the root of a flow project."
|
||||
raise click.ClickException(
|
||||
"This command must be run from the root of a flow project."
|
||||
msg,
|
||||
)
|
||||
|
||||
# Determine the flow folder based on the current directory
|
||||
@@ -19,8 +19,8 @@ def add_crew_to_flow(crew_name: str) -> None:
|
||||
crews_folder = flow_folder / "src" / flow_folder.name / "crews"
|
||||
|
||||
if not crews_folder.exists():
|
||||
print("Crews folder does not exist in the current flow.")
|
||||
raise click.ClickException("Crews folder does not exist in the current flow.")
|
||||
msg = "Crews folder does not exist in the current flow."
|
||||
raise click.ClickException(msg)
|
||||
|
||||
# Create the crew within the flow's crews directory
|
||||
create_embedded_crew(crew_name, parent_folder=crews_folder)
|
||||
@@ -39,7 +39,7 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
|
||||
|
||||
if crew_folder.exists():
|
||||
if not click.confirm(
|
||||
f"Crew {folder_name} already exists. Do you want to override it?"
|
||||
f"Crew {folder_name} already exists. Do you want to override it?",
|
||||
):
|
||||
click.secho("Operation cancelled.", fg="yellow")
|
||||
return
|
||||
@@ -66,5 +66,5 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
|
||||
copy_template(src_file, dst_file, crew_name, class_name, folder_name)
|
||||
|
||||
click.secho(
|
||||
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True
|
||||
f"Crew {crew_name} added to the flow successfully!", fg="green", bold=True,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
import webbrowser
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from rich.console import Console
|
||||
@@ -17,38 +17,37 @@ class AuthenticationCommand:
|
||||
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
|
||||
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
def signup(self) -> None:
|
||||
"""Sign up to CrewAI+"""
|
||||
"""Sign up to CrewAI+."""
|
||||
console.print("Signing Up to CrewAI+ \n", style="bold blue")
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
|
||||
return self._poll_for_token(device_code_data)
|
||||
|
||||
def _get_device_code(self) -> Dict[str, Any]:
|
||||
def _get_device_code(self) -> dict[str, Any]:
|
||||
"""Get the device code to authenticate the user."""
|
||||
|
||||
device_code_payload = {
|
||||
"client_id": AUTH0_CLIENT_ID,
|
||||
"scope": "openid",
|
||||
"audience": AUTH0_AUDIENCE,
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20
|
||||
url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _display_auth_instructions(self, device_code_data: Dict[str, str]) -> None:
|
||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||
"""Display the authentication instructions to the user."""
|
||||
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
|
||||
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
|
||||
"""Poll the server for the token."""
|
||||
token_payload = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
@@ -81,7 +80,7 @@ class AuthenticationCommand:
|
||||
)
|
||||
|
||||
console.print(
|
||||
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n"
|
||||
"\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n",
|
||||
)
|
||||
return
|
||||
|
||||
@@ -92,5 +91,5 @@ class AuthenticationCommand:
|
||||
attempts += 1
|
||||
|
||||
console.print(
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red",
|
||||
)
|
||||
|
||||
@@ -5,5 +5,5 @@ def get_auth_token() -> str:
|
||||
"""Get the authentication token."""
|
||||
access_token = TokenManager().get_token()
|
||||
if not access_token:
|
||||
raise Exception("No token found, make sure you are logged in")
|
||||
raise Exception
|
||||
return access_token
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from auth0.authentication.token_verifier import (
|
||||
AsymmetricSignatureVerifier,
|
||||
@@ -15,8 +14,7 @@ from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
|
||||
|
||||
|
||||
def validate_token(id_token: str) -> None:
|
||||
"""
|
||||
Verify the token and its precedence
|
||||
"""Verify the token and its precedence.
|
||||
|
||||
:param id_token:
|
||||
"""
|
||||
@@ -24,15 +22,14 @@ def validate_token(id_token: str) -> None:
|
||||
issuer = f"https://{AUTH0_DOMAIN}/"
|
||||
signature_verifier = AsymmetricSignatureVerifier(jwks_url)
|
||||
token_verifier = TokenVerifier(
|
||||
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
|
||||
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID,
|
||||
)
|
||||
token_verifier.verify(id_token)
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
"""Initialize the TokenManager class.
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
"""
|
||||
@@ -41,8 +38,7 @@ class TokenManager:
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key.
|
||||
"""Get or create the encryption key.
|
||||
|
||||
:return: The encryption key.
|
||||
"""
|
||||
@@ -57,8 +53,7 @@ class TokenManager:
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_in: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
"""Save the access token and its expiration time.
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_in: The expiration time of the access token in seconds.
|
||||
@@ -71,9 +66,8 @@ class TokenManager:
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
def get_token(self) -> str | None:
|
||||
"""Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
@@ -89,8 +83,7 @@ class TokenManager:
|
||||
return data["access_token"]
|
||||
|
||||
def get_secure_storage_path(self) -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
"""Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
"""
|
||||
@@ -112,8 +105,7 @@ class TokenManager:
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
"""Save the content to a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
@@ -127,9 +119,8 @@ class TokenManager:
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
def read_secure_file(self, filename: str) -> bytes | None:
|
||||
"""Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import os
|
||||
from importlib.metadata import version as get_version
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import click
|
||||
|
||||
@@ -28,7 +26,7 @@ from .update_crew import update_crew
|
||||
|
||||
@click.group()
|
||||
@click.version_option(get_version("crewai"))
|
||||
def crewai():
|
||||
def crewai() -> None:
|
||||
"""Top-level command group for crewai."""
|
||||
|
||||
|
||||
@@ -37,7 +35,7 @@ def crewai():
|
||||
@click.argument("name")
|
||||
@click.option("--provider", type=str, help="The provider to use for the crew")
|
||||
@click.option("--skip_provider", is_flag=True, help="Skip provider validation")
|
||||
def create(type, name, provider, skip_provider=False):
|
||||
def create(type, name, provider, skip_provider=False) -> None:
|
||||
"""Create a new crew, or flow."""
|
||||
if type == "crew":
|
||||
create_crew(name, provider, skip_provider)
|
||||
@@ -49,9 +47,9 @@ def create(type, name, provider, skip_provider=False):
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--tools", is_flag=True, help="Show the installed version of crewai tools"
|
||||
"--tools", is_flag=True, help="Show the installed version of crewai tools",
|
||||
)
|
||||
def version(tools):
|
||||
def version(tools) -> None:
|
||||
"""Show the installed version of crewai."""
|
||||
try:
|
||||
crewai_version = get_version("crewai")
|
||||
@@ -82,7 +80,7 @@ def version(tools):
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a custom file for training",
|
||||
)
|
||||
def train(n_iterations: int, filename: str):
|
||||
def train(n_iterations: int, filename: str) -> None:
|
||||
"""Train the crew."""
|
||||
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations, filename)
|
||||
@@ -96,11 +94,11 @@ def train(n_iterations: int, filename: str):
|
||||
help="Replay the crew from this task ID, including all subsequent tasks.",
|
||||
)
|
||||
def replay(task_id: str) -> None:
|
||||
"""
|
||||
Replay the crew execution from a specific task.
|
||||
"""Replay the crew execution from a specific task.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to replay from.
|
||||
|
||||
"""
|
||||
try:
|
||||
click.echo(f"Replaying the crew from task {task_id}")
|
||||
@@ -111,16 +109,14 @@ def replay(task_id: str) -> None:
|
||||
|
||||
@crewai.command()
|
||||
def log_tasks_outputs() -> None:
|
||||
"""
|
||||
Retrieve your latest crew.kickoff() task outputs.
|
||||
"""
|
||||
"""Retrieve your latest crew.kickoff() task outputs."""
|
||||
try:
|
||||
storage = KickoffTaskOutputsSQLiteStorage()
|
||||
tasks = storage.load()
|
||||
|
||||
if not tasks:
|
||||
click.echo(
|
||||
"No task outputs found. Only crew kickoff task outputs are logged."
|
||||
"No task outputs found. Only crew kickoff task outputs are logged.",
|
||||
)
|
||||
return
|
||||
|
||||
@@ -153,13 +149,11 @@ def reset_memories(
|
||||
kickoff_outputs: bool,
|
||||
all: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.
|
||||
"""
|
||||
"""Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved."""
|
||||
try:
|
||||
if not all and not (long or short or entities or knowledge or kickoff_outputs):
|
||||
click.echo(
|
||||
"Please specify at least one memory type to reset using the appropriate flags."
|
||||
"Please specify at least one memory type to reset using the appropriate flags.",
|
||||
)
|
||||
return
|
||||
reset_memories_command(long, short, entities, knowledge, kickoff_outputs, all)
|
||||
@@ -182,71 +176,69 @@ def reset_memories(
|
||||
default="gpt-4o-mini",
|
||||
help="LLM Model to run the tests on the Crew. For now only accepting only OpenAI models.",
|
||||
)
|
||||
def test(n_iterations: int, model: str):
|
||||
def test(n_iterations: int, model: str) -> None:
|
||||
"""Test the crew and evaluate the results."""
|
||||
click.echo(f"Testing the crew for {n_iterations} iterations with model {model}")
|
||||
evaluate_crew(n_iterations, model)
|
||||
|
||||
|
||||
@crewai.command(
|
||||
context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
allow_extra_args=True,
|
||||
)
|
||||
context_settings={
|
||||
"ignore_unknown_options": True,
|
||||
"allow_extra_args": True,
|
||||
},
|
||||
)
|
||||
@click.pass_context
|
||||
def install(context):
|
||||
def install(context) -> None:
|
||||
"""Install the Crew."""
|
||||
install_crew(context.args)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def run():
|
||||
def run() -> None:
|
||||
"""Run the Crew."""
|
||||
run_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def update():
|
||||
def update() -> None:
|
||||
"""Update the pyproject.toml of the Crew project to use uv."""
|
||||
update_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def signup():
|
||||
def signup() -> None:
|
||||
"""Sign Up/Login to CrewAI+."""
|
||||
AuthenticationCommand().signup()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def login():
|
||||
def login() -> None:
|
||||
"""Sign Up/Login to CrewAI+."""
|
||||
AuthenticationCommand().signup()
|
||||
|
||||
|
||||
# DEPLOY CREWAI+ COMMANDS
|
||||
@crewai.group()
|
||||
def deploy():
|
||||
def deploy() -> None:
|
||||
"""Deploy the Crew CLI group."""
|
||||
pass
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
def tool() -> None:
|
||||
"""Tool Repository related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool):
|
||||
def deploy_create(yes: bool) -> None:
|
||||
"""Create a Crew deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.create_crew(yes)
|
||||
|
||||
|
||||
@deploy.command(name="list")
|
||||
def deploy_list():
|
||||
def deploy_list() -> None:
|
||||
"""List all deployments."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.list_crews()
|
||||
@@ -254,7 +246,7 @@ def deploy_list():
|
||||
|
||||
@deploy.command(name="push")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_push(uuid: Optional[str]):
|
||||
def deploy_push(uuid: str | None) -> None:
|
||||
"""Deploy the Crew."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid)
|
||||
@@ -262,7 +254,7 @@ def deploy_push(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="status")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deply_status(uuid: Optional[str]):
|
||||
def deply_status(uuid: str | None) -> None:
|
||||
"""Get the status of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_status(uuid=uuid)
|
||||
@@ -270,7 +262,7 @@ def deply_status(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="logs")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_logs(uuid: Optional[str]):
|
||||
def deploy_logs(uuid: str | None) -> None:
|
||||
"""Get the logs of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||
@@ -278,7 +270,7 @@ def deploy_logs(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="remove")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_remove(uuid: Optional[str]):
|
||||
def deploy_remove(uuid: str | None) -> None:
|
||||
"""Remove a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
@@ -286,14 +278,14 @@ def deploy_remove(uuid: Optional[str]):
|
||||
|
||||
@tool.command(name="create")
|
||||
@click.argument("handle")
|
||||
def tool_create(handle: str):
|
||||
def tool_create(handle: str) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.create(handle)
|
||||
|
||||
|
||||
@tool.command(name="install")
|
||||
@click.argument("handle")
|
||||
def tool_install(handle: str):
|
||||
def tool_install(handle: str) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.install(handle)
|
||||
@@ -309,27 +301,26 @@ def tool_install(handle: str):
|
||||
)
|
||||
@click.option("--public", "is_public", flag_value=True, default=False)
|
||||
@click.option("--private", "is_public", flag_value=False)
|
||||
def tool_publish(is_public: bool, force: bool):
|
||||
def tool_publish(is_public: bool, force: bool) -> None:
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.login()
|
||||
tool_cmd.publish(is_public, force)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def flow():
|
||||
def flow() -> None:
|
||||
"""Flow related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@flow.command(name="kickoff")
|
||||
def flow_run():
|
||||
def flow_run() -> None:
|
||||
"""Kickoff the Flow."""
|
||||
click.echo("Running the Flow")
|
||||
kickoff_flow()
|
||||
|
||||
|
||||
@flow.command(name="plot")
|
||||
def flow_plot():
|
||||
def flow_plot() -> None:
|
||||
"""Plot the Flow."""
|
||||
click.echo("Plotting the Flow")
|
||||
plot_flow()
|
||||
@@ -337,20 +328,19 @@ def flow_plot():
|
||||
|
||||
@flow.command(name="add-crew")
|
||||
@click.argument("crew_name")
|
||||
def flow_add_crew(crew_name):
|
||||
def flow_add_crew(crew_name) -> None:
|
||||
"""Add a crew to an existing flow."""
|
||||
click.echo(f"Adding crew {crew_name} to the flow")
|
||||
add_crew_to_flow(crew_name)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def chat():
|
||||
"""
|
||||
Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
def chat() -> None:
|
||||
"""Start a conversation with the Crew, collecting user-supplied inputs,
|
||||
and using the Chat LLM to generate responses.
|
||||
"""
|
||||
click.secho(
|
||||
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
|
||||
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
|
||||
)
|
||||
|
||||
run_chat()
|
||||
|
||||
@@ -10,13 +10,13 @@ console = Console()
|
||||
|
||||
|
||||
class BaseCommand:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
|
||||
|
||||
class PlusAPIMixin:
|
||||
def __init__(self, telemetry):
|
||||
def __init__(self, telemetry) -> None:
|
||||
try:
|
||||
telemetry.set_tracer()
|
||||
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
||||
@@ -30,11 +30,11 @@ class PlusAPIMixin:
|
||||
raise SystemExit
|
||||
|
||||
def _validate_response(self, response: requests.Response) -> None:
|
||||
"""
|
||||
Handle and display error messages from API responses.
|
||||
"""Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
response (requests.Response): The response from the Plus API
|
||||
|
||||
"""
|
||||
try:
|
||||
json_response = response.json()
|
||||
@@ -55,13 +55,13 @@ class PlusAPIMixin:
|
||||
for field, messages in json_response.items():
|
||||
for message in messages:
|
||||
console.print(
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}"
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
if not response.ok:
|
||||
console.print(
|
||||
"Request to Enterprise API failed. Details:", style="bold red"
|
||||
"Request to Enterprise API failed. Details:", style="bold red",
|
||||
)
|
||||
details = (
|
||||
json_response.get("error")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -8,16 +7,16 @@ DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
tool_repository_username: Optional[str] = Field(
|
||||
None, description="Username for interacting with the Tool Repository"
|
||||
tool_repository_username: str | None = Field(
|
||||
None, description="Username for interacting with the Tool Repository",
|
||||
)
|
||||
tool_repository_password: Optional[str] = Field(
|
||||
None, description="Password for interacting with the Tool Repository"
|
||||
tool_repository_password: str | None = Field(
|
||||
None, description="Password for interacting with the Tool Repository",
|
||||
)
|
||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
"""Load Settings from config path"""
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data) -> None:
|
||||
"""Load Settings from config path."""
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_data = {}
|
||||
@@ -32,7 +31,7 @@ class Settings(BaseModel):
|
||||
super().__init__(config_path=config_path, **merged_data)
|
||||
|
||||
def dump(self) -> None:
|
||||
"""Save current settings to settings.json"""
|
||||
"""Save current settings to settings.json."""
|
||||
if self.config_path.is_file():
|
||||
with self.config_path.open("r") as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
@@ -3,31 +3,31 @@ ENV_VARS = {
|
||||
{
|
||||
"prompt": "Enter your OPENAI API key (press Enter to skip)",
|
||||
"key_name": "OPENAI_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"anthropic": [
|
||||
{
|
||||
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
|
||||
"key_name": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"gemini": [
|
||||
{
|
||||
"prompt": "Enter your GEMINI API key from https://ai.dev/apikey (press Enter to skip)",
|
||||
"key_name": "GEMINI_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"nvidia_nim": [
|
||||
{
|
||||
"prompt": "Enter your NVIDIA API key (press Enter to skip)",
|
||||
"key_name": "NVIDIA_NIM_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"groq": [
|
||||
{
|
||||
"prompt": "Enter your GROQ API key (press Enter to skip)",
|
||||
"key_name": "GROQ_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
"watson": [
|
||||
{
|
||||
@@ -47,7 +47,7 @@ ENV_VARS = {
|
||||
{
|
||||
"default": True,
|
||||
"API_BASE": "http://localhost:11434",
|
||||
}
|
||||
},
|
||||
],
|
||||
"bedrock": [
|
||||
{
|
||||
@@ -101,7 +101,7 @@ ENV_VARS = {
|
||||
{
|
||||
"prompt": "Enter your SambaNovaCloud API key (press Enter to skip)",
|
||||
"key_name": "SAMBANOVA_API_KEY",
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ def create_folder_structure(name, parent_folder=None):
|
||||
|
||||
if folder_path.exists():
|
||||
if not click.confirm(
|
||||
f"Folder {folder_name} already exists. Do you want to override it?"
|
||||
f"Folder {folder_name} already exists. Do you want to override it?",
|
||||
):
|
||||
click.secho("Operation cancelled.", fg="yellow")
|
||||
sys.exit(0)
|
||||
@@ -48,7 +48,7 @@ def create_folder_structure(name, parent_folder=None):
|
||||
return folder_path, folder_name, class_name
|
||||
|
||||
|
||||
def copy_template_files(folder_path, name, class_name, parent_folder):
|
||||
def copy_template_files(folder_path, name, class_name, parent_folder) -> None:
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
|
||||
@@ -89,7 +89,7 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
|
||||
copy_template(src_file, dst_file, name, class_name, folder_path.name)
|
||||
|
||||
|
||||
def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
def create_crew(name, provider=None, skip_provider=False, parent_folder=None) -> None:
|
||||
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
||||
env_vars = load_env_vars(folder_path)
|
||||
if not skip_provider:
|
||||
@@ -109,7 +109,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
|
||||
if existing_provider:
|
||||
if not click.confirm(
|
||||
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
|
||||
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?",
|
||||
):
|
||||
click.secho("Keeping existing provider configuration.", fg="yellow")
|
||||
return
|
||||
@@ -126,11 +126,11 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
if selected_provider: # Valid selection
|
||||
break
|
||||
click.secho(
|
||||
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
||||
"No provider selected. Please try again or press 'q' to exit.", fg="red",
|
||||
)
|
||||
|
||||
# Check if the selected provider has predefined models
|
||||
if selected_provider in MODELS and MODELS[selected_provider]:
|
||||
if MODELS.get(selected_provider):
|
||||
while True:
|
||||
selected_model = select_model(selected_provider, provider_models)
|
||||
if selected_model is None: # User typed 'q'
|
||||
@@ -167,7 +167,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
click.secho("API keys and model saved to .env file", fg="green")
|
||||
else:
|
||||
click.secho(
|
||||
"No API keys provided. Skipping .env file creation.", fg="yellow"
|
||||
"No API keys provided. Skipping .env file creation.", fg="yellow",
|
||||
)
|
||||
|
||||
click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")
|
||||
|
||||
@@ -5,7 +5,7 @@ import click
|
||||
from crewai.telemetry import Telemetry
|
||||
|
||||
|
||||
def create_flow(name):
|
||||
def create_flow(name) -> None:
|
||||
"""Create a new flow."""
|
||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
@@ -43,12 +43,12 @@ def create_flow(name):
|
||||
"poem_crew",
|
||||
]
|
||||
|
||||
def process_file(src_file, dst_file):
|
||||
def process_file(src_file, dst_file) -> None:
|
||||
if src_file.suffix in [".pyc", ".pyo", ".pyd"]:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(src_file, "r", encoding="utf-8") as file:
|
||||
with open(src_file, encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
except Exception as e:
|
||||
click.secho(f"Error processing file {src_file}: {e}", fg="red")
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -22,10 +22,9 @@ MIN_REQUIRED_VERSION = "0.98.0"
|
||||
|
||||
|
||||
def check_conversational_crews_version(
|
||||
crewai_version: str, pyproject_data: dict
|
||||
crewai_version: str, pyproject_data: dict,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the installed crewAI version supports conversational crews.
|
||||
"""Check if the installed crewAI version supports conversational crews.
|
||||
|
||||
Args:
|
||||
crewai_version: The current version of crewAI.
|
||||
@@ -33,6 +32,7 @@ def check_conversational_crews_version(
|
||||
|
||||
Returns:
|
||||
bool: True if version check passes, False otherwise.
|
||||
|
||||
"""
|
||||
try:
|
||||
if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION):
|
||||
@@ -48,9 +48,8 @@ def check_conversational_crews_version(
|
||||
return True
|
||||
|
||||
|
||||
def run_chat():
|
||||
"""
|
||||
Runs an interactive chat loop using the Crew's chat LLM with function calling.
|
||||
def run_chat() -> None:
|
||||
"""Runs an interactive chat loop using the Crew's chat LLM with function calling.
|
||||
Incorporates crew_name, crew_description, and input fields to build a tool schema.
|
||||
Exits if crew_name or crew_description are missing.
|
||||
"""
|
||||
@@ -84,7 +83,7 @@ def run_chat():
|
||||
|
||||
# Call the LLM to generate the introductory message
|
||||
introductory_message = chat_llm.call(
|
||||
messages=[{"role": "system", "content": system_message}]
|
||||
messages=[{"role": "system", "content": system_message}],
|
||||
)
|
||||
finally:
|
||||
# Stop loading indicator
|
||||
@@ -108,15 +107,13 @@ def run_chat():
|
||||
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
|
||||
|
||||
|
||||
def show_loading(event: threading.Event):
|
||||
def show_loading(event: threading.Event) -> None:
|
||||
"""Display animated loading dots while processing."""
|
||||
while not event.is_set():
|
||||
print(".", end="", flush=True)
|
||||
time.sleep(1)
|
||||
print()
|
||||
|
||||
|
||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
|
||||
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
|
||||
"""Initializes the chat LLM and handles exceptions."""
|
||||
try:
|
||||
return create_llm(crew.chat_llm)
|
||||
@@ -157,7 +154,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
)
|
||||
|
||||
|
||||
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
||||
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
|
||||
"""Creates a wrapper function for running the crew tool with messages."""
|
||||
|
||||
def run_crew_tool_with_messages(**kwargs):
|
||||
@@ -166,7 +163,7 @@ def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
||||
return run_crew_tool_with_messages
|
||||
|
||||
|
||||
def flush_input():
|
||||
def flush_input() -> None:
|
||||
"""Flush any pending input from the user."""
|
||||
if platform.system() == "Windows":
|
||||
# Windows platform
|
||||
@@ -181,7 +178,7 @@ def flush_input():
|
||||
termios.tcflush(sys.stdin, termios.TCIFLUSH)
|
||||
|
||||
|
||||
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions) -> None:
|
||||
"""Main chat loop for interacting with the user."""
|
||||
while True:
|
||||
try:
|
||||
@@ -190,7 +187,7 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
|
||||
user_input = get_user_input()
|
||||
handle_user_input(
|
||||
user_input, chat_llm, messages, crew_tool_schema, available_functions
|
||||
user_input, chat_llm, messages, crew_tool_schema, available_functions,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@@ -221,9 +218,9 @@ def get_user_input() -> str:
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: LLM,
|
||||
messages: List[Dict[str, str]],
|
||||
crew_tool_schema: Dict[str, Any],
|
||||
available_functions: Dict[str, Any],
|
||||
messages: list[dict[str, str]],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
) -> None:
|
||||
if user_input.strip().lower() == "exit":
|
||||
click.echo("Exiting chat. Goodbye!")
|
||||
@@ -251,8 +248,7 @@ def handle_user_input(
|
||||
|
||||
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
"""
|
||||
Dynamically build a Littellm 'function' schema for the given crew.
|
||||
"""Dynamically build a Littellm 'function' schema for the given crew.
|
||||
|
||||
crew_name: The name of the crew (used for the function 'name').
|
||||
crew_inputs: A ChatInputs object containing crew_description
|
||||
@@ -281,9 +277,8 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
"""
|
||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
|
||||
"""Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew instance to run.
|
||||
@@ -295,6 +290,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
|
||||
Raises:
|
||||
SystemExit: Exits the chat if an error occurs during crew execution.
|
||||
|
||||
"""
|
||||
try:
|
||||
# Serialize 'messages' to JSON string before adding to kwargs
|
||||
@@ -304,9 +300,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
crew_output = crew.kickoff(inputs=kwargs)
|
||||
|
||||
# Convert CrewOutput to a string to send back to the user
|
||||
result = str(crew_output)
|
||||
return str(crew_output)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# Exit the chat and show the error message
|
||||
click.secho("An error occurred while running the crew:", fg="red")
|
||||
@@ -314,12 +309,12 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
"""
|
||||
Loads the crew by importing the crew class from the user's project.
|
||||
def load_crew_and_name() -> tuple[Crew, str]:
|
||||
"""Loads the crew by importing the crew class from the user's project.
|
||||
|
||||
Returns:
|
||||
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
|
||||
|
||||
"""
|
||||
# Get the current working directory
|
||||
cwd = Path.cwd()
|
||||
@@ -327,7 +322,8 @@ def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
# Path to the pyproject.toml file
|
||||
pyproject_path = cwd / "pyproject.toml"
|
||||
if not pyproject_path.exists():
|
||||
raise FileNotFoundError("pyproject.toml not found in the current directory.")
|
||||
msg = "pyproject.toml not found in the current directory."
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Load the pyproject.toml file using 'tomli'
|
||||
with pyproject_path.open("rb") as f:
|
||||
@@ -351,14 +347,16 @@ def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
try:
|
||||
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import crew module {crew_module_name}: {e}")
|
||||
msg = f"Failed to import crew module {crew_module_name}: {e}"
|
||||
raise ImportError(msg)
|
||||
|
||||
# Get the crew class from the module
|
||||
try:
|
||||
crew_class = getattr(crew_module, crew_class_name)
|
||||
except AttributeError:
|
||||
msg = f"Crew class {crew_class_name} not found in module {crew_module_name}"
|
||||
raise AttributeError(
|
||||
f"Crew class {crew_class_name} not found in module {crew_module_name}"
|
||||
msg,
|
||||
)
|
||||
|
||||
# Instantiate the crew
|
||||
@@ -367,8 +365,7 @@ def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
|
||||
|
||||
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
|
||||
"""
|
||||
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||
"""Generates the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew object containing tasks and agents.
|
||||
@@ -377,6 +374,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
||||
|
||||
Returns:
|
||||
ChatInputs: An object containing the crew's name, description, and input fields.
|
||||
|
||||
"""
|
||||
# Extract placeholders from tasks and agents
|
||||
required_inputs = fetch_required_inputs(crew)
|
||||
@@ -391,22 +389,22 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
||||
crew_description = generate_crew_description_with_ai(crew, chat_llm)
|
||||
|
||||
return ChatInputs(
|
||||
crew_name=crew_name, crew_description=crew_description, inputs=input_fields
|
||||
crew_name=crew_name, crew_description=crew_description, inputs=input_fields,
|
||||
)
|
||||
|
||||
|
||||
def fetch_required_inputs(crew: Crew) -> Set[str]:
|
||||
"""
|
||||
Extracts placeholders from the crew's tasks and agents.
|
||||
def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||
"""Extracts placeholders from the crew's tasks and agents.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew object.
|
||||
|
||||
Returns:
|
||||
Set[str]: A set of placeholder names.
|
||||
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: Set[str] = set()
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks
|
||||
for task in crew.tasks:
|
||||
@@ -422,8 +420,7 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
|
||||
|
||||
|
||||
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
|
||||
"""
|
||||
Generates an input description using AI based on the context of the crew.
|
||||
"""Generates an input description using AI based on the context of the crew.
|
||||
|
||||
Args:
|
||||
input_name (str): The name of the input placeholder.
|
||||
@@ -432,6 +429,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
|
||||
Returns:
|
||||
str: A concise description of the input.
|
||||
|
||||
"""
|
||||
# Gather context from tasks and agents where the input is used
|
||||
context_texts = []
|
||||
@@ -444,10 +442,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
):
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description or ""
|
||||
lambda m: m.group(1), task.description or "",
|
||||
)
|
||||
expected_output = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.expected_output or ""
|
||||
lambda m: m.group(1), task.expected_output or "",
|
||||
)
|
||||
context_texts.append(f"Task Description: {task_description}")
|
||||
context_texts.append(f"Expected Output: {expected_output}")
|
||||
@@ -461,7 +459,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
|
||||
agent_backstory = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), agent.backstory or ""
|
||||
lambda m: m.group(1), agent.backstory or "",
|
||||
)
|
||||
context_texts.append(f"Agent Role: {agent_role}")
|
||||
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||
@@ -470,7 +468,8 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
context = "\n".join(context_texts)
|
||||
if not context:
|
||||
# If no context is found for the input, raise an exception as per instruction
|
||||
raise ValueError(f"No context found for input '{input_name}'.")
|
||||
msg = f"No context found for input '{input_name}'."
|
||||
raise ValueError(msg)
|
||||
|
||||
prompt = (
|
||||
f"Based on the following context, write a concise description (15 words or less) of the input '{input_name}'.\n"
|
||||
@@ -479,14 +478,12 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
description = response.strip()
|
||||
return response.strip()
|
||||
|
||||
return description
|
||||
|
||||
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
"""
|
||||
Generates a brief description of the crew using AI.
|
||||
"""Generates a brief description of the crew using AI.
|
||||
|
||||
Args:
|
||||
crew (Crew): The crew object.
|
||||
@@ -494,6 +491,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
|
||||
Returns:
|
||||
str: A concise description of the crew's purpose (15 words or less).
|
||||
|
||||
"""
|
||||
# Gather context from tasks and agents
|
||||
context_texts = []
|
||||
@@ -502,10 +500,10 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
for task in crew.tasks:
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description or ""
|
||||
lambda m: m.group(1), task.description or "",
|
||||
)
|
||||
expected_output = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.expected_output or ""
|
||||
lambda m: m.group(1), task.expected_output or "",
|
||||
)
|
||||
context_texts.append(f"Task Description: {task_description}")
|
||||
context_texts.append(f"Expected Output: {expected_output}")
|
||||
@@ -514,7 +512,7 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "")
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "")
|
||||
agent_backstory = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), agent.backstory or ""
|
||||
lambda m: m.group(1), agent.backstory or "",
|
||||
)
|
||||
context_texts.append(f"Agent Role: {agent_role}")
|
||||
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||
@@ -522,7 +520,8 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
|
||||
context = "\n".join(context_texts)
|
||||
if not context:
|
||||
raise ValueError("No context found for generating crew description.")
|
||||
msg = "No context found for generating crew description."
|
||||
raise ValueError(msg)
|
||||
|
||||
prompt = (
|
||||
"Based on the following context, write a concise, action-oriented description (15 words or less) of the crew's purpose.\n"
|
||||
@@ -531,6 +530,5 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
crew_description = response.strip()
|
||||
return response.strip()
|
||||
|
||||
return crew_description
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
@@ -10,34 +10,27 @@ console = Console()
|
||||
|
||||
|
||||
class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle deployment-related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the DeployCommand with project name and API client.
|
||||
"""
|
||||
"""A class to handle deployment-related operations for CrewAI projects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the DeployCommand with project name and API client."""
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
self.project_name = get_project_name(require=True)
|
||||
|
||||
def _standard_no_param_error_message(self) -> None:
|
||||
"""
|
||||
Display a standard error message when no UUID or project name is available.
|
||||
"""
|
||||
"""Display a standard error message when no UUID or project name is available."""
|
||||
console.print(
|
||||
"No UUID provided, project pyproject.toml not found or with error.",
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
def _display_deployment_info(self, json_response: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Display deployment information.
|
||||
def _display_deployment_info(self, json_response: dict[str, Any]) -> None:
|
||||
"""Display deployment information.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The deployment information to display.
|
||||
|
||||
"""
|
||||
console.print("Deploying the crew...\n", style="bold blue")
|
||||
for key, value in json_response.items():
|
||||
@@ -47,24 +40,24 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
console.print(" or")
|
||||
console.print(f"crewai deploy status --uuid \"{json_response['uuid']}\"")
|
||||
|
||||
def _display_logs(self, log_messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Display log messages.
|
||||
def _display_logs(self, log_messages: list[dict[str, Any]]) -> None:
|
||||
"""Display log messages.
|
||||
|
||||
Args:
|
||||
log_messages (List[Dict[str, Any]]): The log messages to display.
|
||||
|
||||
"""
|
||||
for log_message in log_messages:
|
||||
console.print(
|
||||
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}"
|
||||
f"{log_message['timestamp']} - {log_message['level']}: {log_message['message']}",
|
||||
)
|
||||
|
||||
def deploy(self, uuid: Optional[str] = None) -> None:
|
||||
"""
|
||||
Deploy a crew using either UUID or project name.
|
||||
def deploy(self, uuid: str | None = None) -> None:
|
||||
"""Deploy a crew using either UUID or project name.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to deploy.
|
||||
|
||||
"""
|
||||
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
|
||||
console.print("Starting deployment...", style="bold blue")
|
||||
@@ -80,9 +73,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._display_deployment_info(response.json())
|
||||
|
||||
def create_crew(self, confirm: bool = False) -> None:
|
||||
"""
|
||||
Create a new crew deployment.
|
||||
"""
|
||||
"""Create a new crew deployment."""
|
||||
self._create_crew_deployment_span = (
|
||||
self._telemetry.create_crew_deployment_span()
|
||||
)
|
||||
@@ -110,29 +101,28 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._display_creation_success(response.json())
|
||||
|
||||
def _confirm_input(
|
||||
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
|
||||
self, env_vars: dict[str, str], remote_repo_url: str, confirm: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Confirm input parameters with the user.
|
||||
"""Confirm input parameters with the user.
|
||||
|
||||
Args:
|
||||
env_vars (Dict[str, str]): Environment variables.
|
||||
remote_repo_url (str): Remote repository URL.
|
||||
confirm (bool): Whether to confirm input.
|
||||
|
||||
"""
|
||||
if not confirm:
|
||||
input(f"Press Enter to continue with the following Env vars: {env_vars}")
|
||||
input(
|
||||
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n"
|
||||
f"Press Enter to continue with the following remote repository: {remote_repo_url}\n",
|
||||
)
|
||||
|
||||
def _create_payload(
|
||||
self,
|
||||
env_vars: Dict[str, str],
|
||||
env_vars: dict[str, str],
|
||||
remote_repo_url: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create the payload for crew creation.
|
||||
) -> dict[str, Any]:
|
||||
"""Create the payload for crew creation.
|
||||
|
||||
Args:
|
||||
remote_repo_url (str): Remote repository URL.
|
||||
@@ -140,25 +130,26 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The payload for crew creation.
|
||||
|
||||
"""
|
||||
return {
|
||||
"deploy": {
|
||||
"name": self.project_name,
|
||||
"repo_clone_url": remote_repo_url,
|
||||
"env": env_vars,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _display_creation_success(self, json_response: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Display success message after crew creation.
|
||||
def _display_creation_success(self, json_response: dict[str, Any]) -> None:
|
||||
"""Display success message after crew creation.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The response containing crew information.
|
||||
|
||||
"""
|
||||
console.print("Deployment created successfully!\n", style="bold green")
|
||||
console.print(
|
||||
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green"
|
||||
f"Name: {self.project_name} ({json_response['uuid']})", style="bold green",
|
||||
)
|
||||
console.print(f"Status: {json_response['status']}", style="bold green")
|
||||
console.print("\nTo (re)deploy the crew, run:")
|
||||
@@ -167,9 +158,7 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
console.print(f"crewai deploy push --uuid {json_response['uuid']}")
|
||||
|
||||
def list_crews(self) -> None:
|
||||
"""
|
||||
List all available crews.
|
||||
"""
|
||||
"""List all available crews."""
|
||||
console.print("Listing all Crews\n", style="bold blue")
|
||||
|
||||
response = self.plus_api_client.list_crews()
|
||||
@@ -179,31 +168,29 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
else:
|
||||
self._display_no_crews_message()
|
||||
|
||||
def _display_crews(self, crews_data: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Display the list of crews.
|
||||
def _display_crews(self, crews_data: list[dict[str, Any]]) -> None:
|
||||
"""Display the list of crews.
|
||||
|
||||
Args:
|
||||
crews_data (List[Dict[str, Any]]): List of crew data to display.
|
||||
|
||||
"""
|
||||
for crew_data in crews_data:
|
||||
console.print(
|
||||
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]"
|
||||
f"- {crew_data['name']} ({crew_data['uuid']}) [blue]{crew_data['status']}[/blue]",
|
||||
)
|
||||
|
||||
def _display_no_crews_message(self) -> None:
|
||||
"""
|
||||
Display a message when no crews are available.
|
||||
"""
|
||||
"""Display a message when no crews are available."""
|
||||
console.print("You don't have any Crews yet. Let's create one!", style="yellow")
|
||||
console.print(" crewai create crew <crew_name>", style="green")
|
||||
|
||||
def get_crew_status(self, uuid: Optional[str] = None) -> None:
|
||||
"""
|
||||
Get the status of a crew.
|
||||
def get_crew_status(self, uuid: str | None = None) -> None:
|
||||
"""Get the status of a crew.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to check.
|
||||
|
||||
"""
|
||||
console.print("Fetching deployment status...", style="bold blue")
|
||||
if uuid:
|
||||
@@ -217,23 +204,23 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._validate_response(response)
|
||||
self._display_crew_status(response.json())
|
||||
|
||||
def _display_crew_status(self, status_data: Dict[str, str]) -> None:
|
||||
"""
|
||||
Display the status of a crew.
|
||||
def _display_crew_status(self, status_data: dict[str, str]) -> None:
|
||||
"""Display the status of a crew.
|
||||
|
||||
Args:
|
||||
status_data (Dict[str, str]): The status data to display.
|
||||
|
||||
"""
|
||||
console.print(f"Name:\t {status_data['name']}")
|
||||
console.print(f"Status:\t {status_data['status']}")
|
||||
|
||||
def get_crew_logs(self, uuid: Optional[str], log_type: str = "deployment") -> None:
|
||||
"""
|
||||
Get logs for a crew.
|
||||
def get_crew_logs(self, uuid: str | None, log_type: str = "deployment") -> None:
|
||||
"""Get logs for a crew.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to get logs for.
|
||||
log_type (str): The type of logs to retrieve (default: "deployment").
|
||||
|
||||
"""
|
||||
self._get_crew_logs_span = self._telemetry.get_crew_logs_span(uuid, log_type)
|
||||
console.print(f"Fetching {log_type} logs...", style="bold blue")
|
||||
@@ -249,12 +236,12 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
self._validate_response(response)
|
||||
self._display_logs(response.json())
|
||||
|
||||
def remove_crew(self, uuid: Optional[str]) -> None:
|
||||
"""
|
||||
Remove a crew deployment.
|
||||
def remove_crew(self, uuid: str | None) -> None:
|
||||
"""Remove a crew deployment.
|
||||
|
||||
Args:
|
||||
uuid (Optional[str]): The UUID of the crew to remove.
|
||||
|
||||
"""
|
||||
self._remove_crew_span = self._telemetry.remove_crew_span(uuid)
|
||||
console.print("Removing deployment...", style="bold blue")
|
||||
@@ -269,9 +256,9 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
if response.status_code == 204:
|
||||
console.print(
|
||||
f"Crew '{self.project_name}' removed successfully.", style="green"
|
||||
f"Crew '{self.project_name}' removed successfully.", style="green",
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"Failed to remove crew '{self.project_name}'", style="bold red"
|
||||
f"Failed to remove crew '{self.project_name}'", style="bold red",
|
||||
)
|
||||
|
||||
@@ -4,18 +4,19 @@ import click
|
||||
|
||||
|
||||
def evaluate_crew(n_iterations: int, model: str) -> None:
|
||||
"""
|
||||
Test and Evaluate the crew by running a command in the UV environment.
|
||||
"""Test and Evaluate the crew by running a command in the UV environment.
|
||||
|
||||
Args:
|
||||
n_iterations (int): The number of iterations to test the crew.
|
||||
model (str): The model to test the crew with.
|
||||
|
||||
"""
|
||||
command = ["uv", "run", "test", str(n_iterations), model]
|
||||
|
||||
try:
|
||||
if n_iterations <= 0:
|
||||
raise ValueError("The number of iterations must be a positive integer.")
|
||||
msg = "The number of iterations must be a positive integer."
|
||||
raise ValueError(msg)
|
||||
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import subprocess
|
||||
from functools import lru_cache
|
||||
from functools import cache
|
||||
|
||||
|
||||
class Repository:
|
||||
def __init__(self, path="."):
|
||||
def __init__(self, path=".") -> None:
|
||||
self.path = path
|
||||
|
||||
if not self.is_git_installed():
|
||||
raise ValueError("Git is not installed or not found in your PATH.")
|
||||
msg = "Git is not installed or not found in your PATH."
|
||||
raise ValueError(msg)
|
||||
|
||||
if not self.is_git_repo():
|
||||
raise ValueError(f"{self.path} is not a Git repository.")
|
||||
msg = f"{self.path} is not a Git repository."
|
||||
raise ValueError(msg)
|
||||
|
||||
self.fetch()
|
||||
|
||||
@@ -18,7 +20,7 @@ class Repository:
|
||||
"""Check if Git is installed and available in the system."""
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "--version"], capture_output=True, check=True, text=True
|
||||
["git", "--version"], capture_output=True, check=True, text=True,
|
||||
)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
@@ -36,7 +38,7 @@ class Repository:
|
||||
encoding="utf-8",
|
||||
).strip()
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@cache
|
||||
def is_git_repo(self) -> bool:
|
||||
"""Check if the current directory is a git repository."""
|
||||
try:
|
||||
@@ -62,10 +64,7 @@ class Repository:
|
||||
|
||||
def is_synced(self) -> bool:
|
||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return not (self.has_uncommitted_changes() or self.is_ahead_or_behind())
|
||||
|
||||
def origin_url(self) -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
|
||||
@@ -8,11 +8,9 @@ import click
|
||||
# so if you expect this to support more things you will need to replicate it there
|
||||
# ask @joaomdmoura if you are unsure
|
||||
def install_crew(proxy_options: list[str]) -> None:
|
||||
"""
|
||||
Install the crew by running the UV command to lock and install.
|
||||
"""
|
||||
"""Install the crew by running the UV command to lock and install."""
|
||||
try:
|
||||
command = ["uv", "sync"] + proxy_options
|
||||
command = ["uv", "sync", *proxy_options]
|
||||
subprocess.run(command, check=True, capture_output=False, text=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
|
||||
@@ -4,9 +4,7 @@ import click
|
||||
|
||||
|
||||
def kickoff_flow() -> None:
|
||||
"""
|
||||
Kickoff the flow by running a command in the UV environment.
|
||||
"""
|
||||
"""Kickoff the flow by running a command in the UV environment."""
|
||||
command = ["uv", "run", "kickoff"]
|
||||
|
||||
try:
|
||||
|
||||
@@ -4,9 +4,7 @@ import click
|
||||
|
||||
|
||||
def plot_flow() -> None:
|
||||
"""
|
||||
Plot the flow by running a command in the UV environment.
|
||||
"""
|
||||
"""Plot the flow by running a command in the UV environment."""
|
||||
command = ["uv", "run", "plot"]
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from os import getenv
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
@@ -8,13 +7,10 @@ from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class PlusAPI:
|
||||
"""
|
||||
This class exposes methods for working with the CrewAI+ API.
|
||||
"""
|
||||
"""This class exposes methods for working with the CrewAI+ API."""
|
||||
|
||||
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
||||
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
||||
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
@@ -38,15 +34,12 @@ class PlusAPI:
|
||||
def get_tool(self, handle: str):
|
||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||
|
||||
def get_agent(self, handle: str):
|
||||
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
|
||||
|
||||
def publish_tool(
|
||||
self,
|
||||
handle: str,
|
||||
is_public: bool,
|
||||
version: str,
|
||||
description: Optional[str],
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
):
|
||||
params = {
|
||||
@@ -60,7 +53,7 @@ class PlusAPI:
|
||||
|
||||
def deploy_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
|
||||
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy",
|
||||
)
|
||||
|
||||
def deploy_by_uuid(self, uuid: str) -> requests.Response:
|
||||
@@ -68,29 +61,29 @@ class PlusAPI:
|
||||
|
||||
def crew_status_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status",
|
||||
)
|
||||
|
||||
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
|
||||
|
||||
def crew_by_name(
|
||||
self, project_name: str, log_type: str = "deployment"
|
||||
self, project_name: str, log_type: str = "deployment",
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}",
|
||||
)
|
||||
|
||||
def crew_by_uuid(
|
||||
self, uuid: str, log_type: str = "deployment"
|
||||
self, uuid: str, log_type: str = "deployment",
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
|
||||
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}",
|
||||
)
|
||||
|
||||
def delete_crew_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
|
||||
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}",
|
||||
)
|
||||
|
||||
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
|
||||
|
||||
@@ -10,8 +10,7 @@ from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||
|
||||
|
||||
def select_choice(prompt_message, choices):
|
||||
"""
|
||||
Presents a list of choices to the user and prompts them to select one.
|
||||
"""Presents a list of choices to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
- prompt_message (str): The message to display to the user before presenting the choices.
|
||||
@@ -19,11 +18,11 @@ def select_choice(prompt_message, choices):
|
||||
|
||||
Returns:
|
||||
- str: The selected choice from the list, or None if the user chooses to quit.
|
||||
"""
|
||||
|
||||
"""
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
return None
|
||||
click.secho(prompt_message, fg="cyan")
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||
@@ -31,7 +30,7 @@ def select_choice(prompt_message, choices):
|
||||
|
||||
while True:
|
||||
choice = click.prompt(
|
||||
"Enter the number of your choice or 'q' to quit", type=str
|
||||
"Enter the number of your choice or 'q' to quit", type=str,
|
||||
)
|
||||
|
||||
if choice.lower() == "q":
|
||||
@@ -51,8 +50,7 @@ def select_choice(prompt_message, choices):
|
||||
|
||||
|
||||
def select_provider(provider_models):
|
||||
"""
|
||||
Presents a list of providers to the user and prompts them to select one.
|
||||
"""Presents a list of providers to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
- provider_models (dict): A dictionary of provider models.
|
||||
@@ -60,12 +58,13 @@ def select_provider(provider_models):
|
||||
Returns:
|
||||
- str: The selected provider
|
||||
- None: If user explicitly quits
|
||||
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||
|
||||
provider = select_choice(
|
||||
"Select a provider to set up:", predefined_providers + ["other"]
|
||||
"Select a provider to set up:", [*predefined_providers, "other"],
|
||||
)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
@@ -79,8 +78,7 @@ def select_provider(provider_models):
|
||||
|
||||
|
||||
def select_model(provider, provider_models):
|
||||
"""
|
||||
Presents a list of models for a given provider to the user and prompts them to select one.
|
||||
"""Presents a list of models for a given provider to the user and prompts them to select one.
|
||||
|
||||
Args:
|
||||
- provider (str): The provider for which to select a model.
|
||||
@@ -88,6 +86,7 @@ def select_model(provider, provider_models):
|
||||
|
||||
Returns:
|
||||
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
|
||||
|
||||
"""
|
||||
predefined_providers = [p.lower() for p in PROVIDERS]
|
||||
|
||||
@@ -100,15 +99,13 @@ def select_model(provider, provider_models):
|
||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||
return None
|
||||
|
||||
selected_model = select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||
return select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models,
|
||||
)
|
||||
return selected_model
|
||||
|
||||
|
||||
def load_provider_data(cache_file, cache_expiry):
|
||||
"""
|
||||
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
|
||||
"""Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
@@ -116,6 +113,7 @@ def load_provider_data(cache_file, cache_expiry):
|
||||
|
||||
Returns:
|
||||
- dict or None: The loaded provider data or None if the operation fails.
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
if (
|
||||
@@ -126,7 +124,7 @@ def load_provider_data(cache_file, cache_expiry):
|
||||
if data:
|
||||
return data
|
||||
click.secho(
|
||||
"Cache is corrupted. Fetching provider data from the web...", fg="yellow"
|
||||
"Cache is corrupted. Fetching provider data from the web...", fg="yellow",
|
||||
)
|
||||
else:
|
||||
click.secho(
|
||||
@@ -137,31 +135,31 @@ def load_provider_data(cache_file, cache_expiry):
|
||||
|
||||
|
||||
def read_cache_file(cache_file):
|
||||
"""
|
||||
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
||||
"""Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
Returns:
|
||||
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
|
||||
|
||||
"""
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
with open(cache_file) as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_provider_data(cache_file):
|
||||
"""
|
||||
Fetches provider data from a specified URL and caches it to a file.
|
||||
"""Fetches provider data from a specified URL and caches it to a file.
|
||||
|
||||
Args:
|
||||
- cache_file (Path): The path to the cache file.
|
||||
|
||||
Returns:
|
||||
- dict or None: The fetched provider data or None if the operation fails.
|
||||
|
||||
"""
|
||||
try:
|
||||
response = requests.get(JSON_URL, stream=True, timeout=60)
|
||||
@@ -178,20 +176,20 @@ def fetch_provider_data(cache_file):
|
||||
|
||||
|
||||
def download_data(response):
|
||||
"""
|
||||
Downloads data from a given HTTP response and returns the JSON content.
|
||||
"""Downloads data from a given HTTP response and returns the JSON content.
|
||||
|
||||
Args:
|
||||
- response (requests.Response): The HTTP response object.
|
||||
|
||||
Returns:
|
||||
- dict: The JSON content of the response.
|
||||
|
||||
"""
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
data_chunks = []
|
||||
with click.progressbar(
|
||||
length=total_size, label="Downloading", show_pos=True
|
||||
length=total_size, label="Downloading", show_pos=True,
|
||||
) as progress_bar:
|
||||
for chunk in response.iter_content(block_size):
|
||||
if chunk:
|
||||
@@ -202,11 +200,11 @@ def download_data(response):
|
||||
|
||||
|
||||
def get_provider_data():
|
||||
"""
|
||||
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
||||
"""Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
|
||||
|
||||
Returns:
|
||||
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
|
||||
|
||||
"""
|
||||
cache_dir = Path.home() / ".crewai"
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
|
||||
@@ -4,11 +4,11 @@ import click
|
||||
|
||||
|
||||
def replay_task_command(task_id: str) -> None:
|
||||
"""
|
||||
Replay the crew execution from a specific task.
|
||||
"""Replay the crew execution from a specific task.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to replay from.
|
||||
|
||||
"""
|
||||
command = ["uv", "run", "replay", task_id]
|
||||
|
||||
|
||||
@@ -13,8 +13,7 @@ def reset_memories_command(
|
||||
kickoff_outputs,
|
||||
all,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the crew memories.
|
||||
"""Reset the crew memories.
|
||||
|
||||
Args:
|
||||
long (bool): Whether to reset the long-term memory.
|
||||
@@ -23,49 +22,50 @@ def reset_memories_command(
|
||||
kickoff_outputs (bool): Whether to reset the latest kickoff task outputs.
|
||||
all (bool): Whether to reset all memories.
|
||||
knowledge (bool): Whether to reset the knowledge.
|
||||
"""
|
||||
|
||||
"""
|
||||
try:
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge, all]):
|
||||
click.echo(
|
||||
"No memory type specified. Please specify at least one type to reset."
|
||||
"No memory type specified. Please specify at least one type to reset.",
|
||||
)
|
||||
return
|
||||
|
||||
crews = get_crews()
|
||||
if not crews:
|
||||
raise ValueError("No crew found.")
|
||||
msg = "No crew found."
|
||||
raise ValueError(msg)
|
||||
for crew in crews:
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Reset memories command has been completed.",
|
||||
)
|
||||
continue
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Long term memory has been reset.",
|
||||
)
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Short term memory has been reset.",
|
||||
)
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Entity memory has been reset.",
|
||||
)
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Latest Kickoff outputs stored has been reset.",
|
||||
)
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset."
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Knowledge has been reset.",
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
@@ -15,8 +14,7 @@ class CrewType(Enum):
|
||||
|
||||
|
||||
def run_crew() -> None:
|
||||
"""
|
||||
Run the crew or flow by running a command in the UV environment.
|
||||
"""Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
@@ -48,11 +46,11 @@ def run_crew() -> None:
|
||||
|
||||
|
||||
def execute_command(crew_type: CrewType) -> None:
|
||||
"""
|
||||
Execute the appropriate command based on crew type.
|
||||
"""Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run
|
||||
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
@@ -67,12 +65,12 @@ def execute_command(crew_type: CrewType) -> None:
|
||||
|
||||
|
||||
def handle_error(error: subprocess.CalledProcessError, crew_type: CrewType) -> None:
|
||||
"""
|
||||
Handle subprocess errors with appropriate messaging.
|
||||
"""Handle subprocess errors with appropriate messaging.
|
||||
|
||||
Args:
|
||||
error: The subprocess error that occurred
|
||||
crew_type: The type of crew that was being run
|
||||
|
||||
"""
|
||||
entity_type = "flow" if crew_type == CrewType.FLOW else "crew"
|
||||
click.echo(f"An error occurred while running the {entity_type}: {error}", err=True)
|
||||
|
||||
@@ -22,15 +22,13 @@ console = Console()
|
||||
|
||||
|
||||
class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle tool repository related operations for CrewAI projects.
|
||||
"""
|
||||
"""A class to handle tool repository related operations for CrewAI projects."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def create(self, handle: str):
|
||||
def create(self, handle: str) -> None:
|
||||
self._ensure_not_in_project()
|
||||
|
||||
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
|
||||
@@ -40,8 +38,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
if project_root.exists():
|
||||
click.secho(f"Folder {folder_name} already exists.", fg="red")
|
||||
raise SystemExit
|
||||
else:
|
||||
os.makedirs(project_root)
|
||||
os.makedirs(project_root)
|
||||
|
||||
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
|
||||
|
||||
@@ -56,12 +53,12 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
self.login()
|
||||
subprocess.run(["git", "init"], check=True)
|
||||
console.print(
|
||||
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]"
|
||||
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]",
|
||||
)
|
||||
finally:
|
||||
os.chdir(old_directory)
|
||||
|
||||
def publish(self, is_public: bool, force: bool = False):
|
||||
def publish(self, is_public: bool, force: bool = False) -> None:
|
||||
if not git.Repository().is_synced() and not force:
|
||||
console.print(
|
||||
"[bold red]Failed to publish tool.[/bold red]\n"
|
||||
@@ -69,9 +66,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"* [bold]Commit[/bold] your changes.\n"
|
||||
"* [bold]Push[/bold] to sync with the remote.\n"
|
||||
"* [bold]Pull[/bold] the latest changes from the remote.\n"
|
||||
"\nOnce your repository is up-to-date, retry publishing the tool."
|
||||
"\nOnce your repository is up-to-date, retry publishing the tool.",
|
||||
)
|
||||
raise SystemExit()
|
||||
raise SystemExit
|
||||
|
||||
project_name = get_project_name(require=True)
|
||||
assert isinstance(project_name, str)
|
||||
@@ -90,7 +87,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
|
||||
tarball_filename = next(
|
||||
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None
|
||||
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None,
|
||||
)
|
||||
if not tarball_filename:
|
||||
console.print(
|
||||
@@ -123,7 +120,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
def install(self, handle: str):
|
||||
def install(self, handle: str) -> None:
|
||||
get_response = self.plus_api_client.get_tool(handle)
|
||||
|
||||
if get_response.status_code == 404:
|
||||
@@ -132,9 +129,9 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
elif get_response.status_code != 200:
|
||||
if get_response.status_code != 200:
|
||||
console.print(
|
||||
"Failed to get tool details. Please try again later.", style="bold red"
|
||||
"Failed to get tool details. Please try again later.", style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -142,7 +139,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
console.print(f"Successfully installed {handle}", style="bold green")
|
||||
|
||||
def login(self):
|
||||
def login(self) -> None:
|
||||
login_response = self.plus_api_client.login_to_tool_repository()
|
||||
|
||||
if login_response.status_code != 200:
|
||||
@@ -164,10 +161,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
settings.dump()
|
||||
|
||||
console.print(
|
||||
"Successfully authenticated to the tool repository.", style="bold green"
|
||||
"Successfully authenticated to the tool repository.", style="bold green",
|
||||
)
|
||||
|
||||
def _add_package(self, tool_details):
|
||||
def _add_package(self, tool_details) -> None:
|
||||
tool_handle = tool_details["handle"]
|
||||
repository_handle = tool_details["repository"]["handle"]
|
||||
repository_url = tool_details["repository"]["url"]
|
||||
@@ -192,16 +189,16 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
click.echo(add_package_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
|
||||
def _ensure_not_in_project(self):
|
||||
def _ensure_not_in_project(self) -> None:
|
||||
if os.path.isfile("./pyproject.toml"):
|
||||
console.print(
|
||||
"[bold red]Oops! It looks like you're inside a project.[/bold red]"
|
||||
"[bold red]Oops! It looks like you're inside a project.[/bold red]",
|
||||
)
|
||||
console.print(
|
||||
"You can't create a new tool while inside an existing project."
|
||||
"You can't create a new tool while inside an existing project.",
|
||||
)
|
||||
console.print(
|
||||
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again."
|
||||
"[bold yellow]Tip:[/bold yellow] Navigate to a different directory and try again.",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
@@ -211,10 +208,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
env = os.environ.copy()
|
||||
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
|
||||
settings.tool_repository_username or ""
|
||||
settings.tool_repository_username or "",
|
||||
)
|
||||
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
|
||||
settings.tool_repository_password or ""
|
||||
settings.tool_repository_password or "",
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
@@ -4,20 +4,22 @@ import click
|
||||
|
||||
|
||||
def train_crew(n_iterations: int, filename: str) -> None:
|
||||
"""
|
||||
Train the crew by running a command in the UV environment.
|
||||
"""Train the crew by running a command in the UV environment.
|
||||
|
||||
Args:
|
||||
n_iterations (int): The number of iterations to train the crew.
|
||||
|
||||
"""
|
||||
command = ["uv", "run", "train", str(n_iterations), filename]
|
||||
|
||||
try:
|
||||
if n_iterations <= 0:
|
||||
raise ValueError("The number of iterations must be a positive integer.")
|
||||
msg = "The number of iterations must be a positive integer."
|
||||
raise ValueError(msg)
|
||||
|
||||
if not filename.endswith(".pkl"):
|
||||
raise ValueError("The filename must not end with .pkl")
|
||||
msg = "The filename must not end with .pkl"
|
||||
raise ValueError(msg)
|
||||
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
|
||||
@@ -11,9 +11,8 @@ def update_crew() -> None:
|
||||
migrate_pyproject("pyproject.toml", "pyproject.toml")
|
||||
|
||||
|
||||
def migrate_pyproject(input_file, output_file):
|
||||
"""
|
||||
Migrate the pyproject.toml to the new format.
|
||||
def migrate_pyproject(input_file, output_file) -> None:
|
||||
"""Migrate the pyproject.toml to the new format.
|
||||
|
||||
This function is used to migrate the pyproject.toml to the new format.
|
||||
And it will be used to migrate the pyproject.toml to the new format when uv is used.
|
||||
@@ -81,7 +80,7 @@ def migrate_pyproject(input_file, output_file):
|
||||
# Extract the module name from any existing script
|
||||
existing_scripts = new_pyproject["project"]["scripts"]
|
||||
module_name = next(
|
||||
(value.split(".")[0] for value in existing_scripts.values() if "." in value)
|
||||
value.split(".")[0] for value in existing_scripts.values() if "." in value
|
||||
)
|
||||
|
||||
new_pyproject["project"]["scripts"]["run_crew"] = f"{module_name}.main:run"
|
||||
@@ -93,22 +92,19 @@ def migrate_pyproject(input_file, output_file):
|
||||
# Backup the old pyproject.toml
|
||||
backup_file = "pyproject-old.toml"
|
||||
shutil.copy2(input_file, backup_file)
|
||||
print(f"Original pyproject.toml backed up as {backup_file}")
|
||||
|
||||
# Rename the poetry.lock file
|
||||
lock_file = "poetry.lock"
|
||||
lock_backup = "poetry-old.lock"
|
||||
if os.path.exists(lock_file):
|
||||
os.rename(lock_file, lock_backup)
|
||||
print(f"Original poetry.lock renamed to {lock_backup}")
|
||||
else:
|
||||
print("No poetry.lock file found to rename.")
|
||||
pass
|
||||
|
||||
# Write the new pyproject.toml
|
||||
with open(output_file, "wb") as f:
|
||||
tomli_w.dump(new_pyproject, f)
|
||||
|
||||
print(f"Migration complete. New pyproject.toml written to {output_file}")
|
||||
|
||||
|
||||
def parse_version(version: str) -> str:
|
||||
|
||||
@@ -3,7 +3,7 @@ import shutil
|
||||
import sys
|
||||
from functools import reduce
|
||||
from inspect import isfunction, ismethod
|
||||
from typing import Any, Dict, List, get_type_hints
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -19,9 +19,9 @@ if sys.version_info >= (3, 11):
|
||||
console = Console()
|
||||
|
||||
|
||||
def copy_template(src, dst, name, class_name, folder_name):
|
||||
def copy_template(src, dst, name, class_name, folder_name) -> None:
|
||||
"""Copy a file from src to dst."""
|
||||
with open(src, "r") as file:
|
||||
with open(src) as file:
|
||||
content = file.read()
|
||||
|
||||
# Interpolate the content
|
||||
@@ -39,8 +39,7 @@ def copy_template(src, dst, name, class_name, folder_name):
|
||||
def read_toml(file_path: str = "pyproject.toml"):
|
||||
"""Read the content of a TOML file and return it as a dictionary."""
|
||||
with open(file_path, "rb") as f:
|
||||
toml_dict = tomli.load(f)
|
||||
return toml_dict
|
||||
return tomli.load(f)
|
||||
|
||||
|
||||
def parse_toml(content):
|
||||
@@ -50,59 +49,56 @@ def parse_toml(content):
|
||||
|
||||
|
||||
def get_project_name(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False,
|
||||
) -> str | None:
|
||||
"""Get the project name from the pyproject.toml file."""
|
||||
return _get_project_attribute(pyproject_path, ["project", "name"], require=require)
|
||||
|
||||
|
||||
def get_project_version(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False,
|
||||
) -> str | None:
|
||||
"""Get the project version from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["project", "version"], require=require
|
||||
pyproject_path, ["project", "version"], require=require,
|
||||
)
|
||||
|
||||
|
||||
def get_project_description(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False,
|
||||
) -> str | None:
|
||||
"""Get the project description from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["project", "description"], require=require
|
||||
pyproject_path, ["project", "description"], require=require,
|
||||
)
|
||||
|
||||
|
||||
def _get_project_attribute(
|
||||
pyproject_path: str, keys: List[str], require: bool
|
||||
pyproject_path: str, keys: list[str], require: bool,
|
||||
) -> Any | None:
|
||||
"""Get an attribute from the pyproject.toml file."""
|
||||
attribute = None
|
||||
|
||||
try:
|
||||
with open(pyproject_path, "r") as f:
|
||||
with open(pyproject_path) as f:
|
||||
pyproject_content = parse_toml(f.read())
|
||||
|
||||
dependencies = (
|
||||
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
|
||||
)
|
||||
if not any(True for dep in dependencies if "crewai" in dep):
|
||||
raise Exception("crewai is not in the dependencies.")
|
||||
msg = "crewai is not in the dependencies."
|
||||
raise Exception(msg)
|
||||
|
||||
attribute = _get_nested_value(pyproject_content, keys)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {pyproject_path} not found.")
|
||||
pass
|
||||
except KeyError:
|
||||
print(f"Error: {pyproject_path} is not a valid pyproject.toml file.")
|
||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
||||
print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file."
|
||||
if sys.version_info >= (3, 11)
|
||||
else f"Error reading the pyproject.toml file: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error reading the pyproject.toml file: {e}")
|
||||
pass
|
||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception: # type: ignore
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if require and not attribute:
|
||||
console.print(
|
||||
@@ -114,7 +110,7 @@ def _get_project_attribute(
|
||||
return attribute
|
||||
|
||||
|
||||
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
||||
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
@@ -122,7 +118,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
# Read the .env file
|
||||
with open(env_file_path, "r") as f:
|
||||
with open(env_file_path) as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Parse the .env file content to a dictionary
|
||||
@@ -135,14 +131,14 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
return env_dict
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {env_file_path} not found.")
|
||||
except Exception as e:
|
||||
print(f"Error reading the .env file: {e}")
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def tree_copy(source, destination):
|
||||
def tree_copy(source, destination) -> None:
|
||||
"""Copies the entire directory structure from the source to the destination."""
|
||||
for item in os.listdir(source):
|
||||
source_item = os.path.join(source, item)
|
||||
@@ -153,7 +149,7 @@ def tree_copy(source, destination):
|
||||
shutil.copy2(source_item, destination_item)
|
||||
|
||||
|
||||
def tree_find_and_replace(directory, find, replace):
|
||||
def tree_find_and_replace(directory, find, replace) -> None:
|
||||
"""Recursively searches through a directory, replacing a target string in
|
||||
both file contents and filenames with a specified replacement string.
|
||||
"""
|
||||
@@ -161,7 +157,7 @@ def tree_find_and_replace(directory, find, replace):
|
||||
for filename in files:
|
||||
filepath = os.path.join(path, filename)
|
||||
|
||||
with open(filepath, "r") as file:
|
||||
with open(filepath) as file:
|
||||
contents = file.read()
|
||||
with open(filepath, "w") as file:
|
||||
file.write(contents.replace(find, replace))
|
||||
@@ -180,19 +176,19 @@ def tree_find_and_replace(directory, find, replace):
|
||||
|
||||
|
||||
def load_env_vars(folder_path):
|
||||
"""
|
||||
Loads environment variables from a .env file in the specified folder path.
|
||||
"""Loads environment variables from a .env file in the specified folder path.
|
||||
|
||||
Args:
|
||||
- folder_path (Path): The path to the folder containing the .env file.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary of environment variables.
|
||||
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
env_vars = {}
|
||||
if env_file_path.exists():
|
||||
with open(env_file_path, "r") as file:
|
||||
with open(env_file_path) as file:
|
||||
for line in file:
|
||||
key, _, value = line.strip().partition("=")
|
||||
if key and value:
|
||||
@@ -201,8 +197,7 @@ def load_env_vars(folder_path):
|
||||
|
||||
|
||||
def update_env_vars(env_vars, provider, model):
|
||||
"""
|
||||
Updates environment variables with the API key for the selected provider and model.
|
||||
"""Updates environment variables with the API key for the selected provider and model.
|
||||
|
||||
Args:
|
||||
- env_vars (dict): Environment variables dictionary.
|
||||
@@ -211,6 +206,7 @@ def update_env_vars(env_vars, provider, model):
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
"""
|
||||
api_key_var = ENV_VARS.get(
|
||||
provider,
|
||||
@@ -218,14 +214,14 @@ def update_env_vars(env_vars, provider, model):
|
||||
click.prompt(
|
||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||
type=str,
|
||||
)
|
||||
),
|
||||
],
|
||||
)[0]
|
||||
|
||||
if api_key_var not in env_vars:
|
||||
try:
|
||||
env_vars[api_key_var] = click.prompt(
|
||||
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True
|
||||
f"Enter your {provider.capitalize()} API key", type=str, hide_input=True,
|
||||
)
|
||||
except click.exceptions.Abort:
|
||||
click.secho("Operation aborted by the user.", fg="red")
|
||||
@@ -238,13 +234,13 @@ def update_env_vars(env_vars, provider, model):
|
||||
return env_vars
|
||||
|
||||
|
||||
def write_env_file(folder_path, env_vars):
|
||||
"""
|
||||
Writes environment variables to a .env file in the specified folder.
|
||||
def write_env_file(folder_path, env_vars) -> None:
|
||||
"""Writes environment variables to a .env file in the specified folder.
|
||||
|
||||
Args:
|
||||
- folder_path (Path): The path to the folder where the .env file will be written.
|
||||
- env_vars (dict): A dictionary of environment variables to write.
|
||||
|
||||
"""
|
||||
env_file_path = folder_path / ".env"
|
||||
with open(env_file_path, "w") as file:
|
||||
@@ -263,7 +259,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
crew_os_path = os.path.join(root, crew_path)
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"crew_module", crew_os_path
|
||||
"crew_module", crew_os_path,
|
||||
)
|
||||
if not spec or not spec.loader:
|
||||
continue
|
||||
@@ -277,19 +273,16 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
|
||||
try:
|
||||
crew_instances.extend(fetch_crews(module_attr))
|
||||
except Exception as e:
|
||||
print(f"Error processing attribute {attr_name}: {e}")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
except Exception as exec_error:
|
||||
print(f"Error executing module: {exec_error}")
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
print(f"Traceback: {traceback.format_exc()}")
|
||||
except (ImportError, AttributeError) as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Error importing crew from {crew_path}: {str(e)}",
|
||||
f"Error importing crew from {crew_path}: {e!s}",
|
||||
style="bold red",
|
||||
)
|
||||
continue
|
||||
@@ -303,7 +296,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
except Exception as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||
f"Unexpected error while loading crew: {e!s}", style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
return crew_instances
|
||||
@@ -317,13 +310,12 @@ def get_crew_instance(module_attr) -> Crew | None:
|
||||
):
|
||||
return module_attr().crew()
|
||||
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
|
||||
module_attr
|
||||
module_attr,
|
||||
).get("return") is Crew:
|
||||
return module_attr()
|
||||
elif isinstance(module_attr, Crew):
|
||||
if isinstance(module_attr, Crew):
|
||||
return module_attr
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def fetch_crews(module_attr) -> list[Crew]:
|
||||
|
||||
@@ -2,5 +2,5 @@ import importlib.metadata
|
||||
|
||||
|
||||
def get_crewai_version() -> str:
|
||||
"""Get the version number of CrewAI running the CLI"""
|
||||
"""Get the version number of CrewAI running the CLI."""
|
||||
return importlib.metadata.version("crewai")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -12,27 +12,28 @@ class CrewOutput(BaseModel):
|
||||
"""Class that represents the result of a crew."""
|
||||
|
||||
raw: str = Field(description="Raw output of crew", default="")
|
||||
pydantic: Optional[BaseModel] = Field(
|
||||
description="Pydantic output of Crew", default=None
|
||||
pydantic: BaseModel | None = Field(
|
||||
description="Pydantic output of Crew", default=None,
|
||||
)
|
||||
json_dict: Optional[Dict[str, Any]] = Field(
|
||||
description="JSON dict output of Crew", default=None
|
||||
json_dict: dict[str, Any] | None = Field(
|
||||
description="JSON dict output of Crew", default=None,
|
||||
)
|
||||
tasks_output: list[TaskOutput] = Field(
|
||||
description="Output of each task", default=[]
|
||||
description="Output of each task", default=[],
|
||||
)
|
||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def json(self) -> str | None:
|
||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
msg = "No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
raise ValueError(
|
||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
msg,
|
||||
)
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
output_dict = {}
|
||||
if self.json_dict:
|
||||
@@ -44,12 +45,12 @@ class CrewOutput(BaseModel):
|
||||
def __getitem__(self, key):
|
||||
if self.pydantic and hasattr(self.pydantic, key):
|
||||
return getattr(self.pydantic, key)
|
||||
elif self.json_dict and key in self.json_dict:
|
||||
if self.json_dict and key in self.json_dict:
|
||||
return self.json_dict[key]
|
||||
else:
|
||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||
msg = f"Key '{key}' not found in CrewOutput."
|
||||
raise KeyError(msg)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.pydantic:
|
||||
return str(self.pydantic)
|
||||
if self.json_dict:
|
||||
|
||||
@@ -2,17 +2,11 @@ import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import uuid4
|
||||
@@ -48,14 +42,14 @@ class FlowState(BaseModel):
|
||||
|
||||
# Type variables with explicit bounds
|
||||
T = TypeVar(
|
||||
"T", bound=Union[Dict[str, Any], BaseModel]
|
||||
"T", bound=dict[str, Any] | BaseModel,
|
||||
) # Generic flow state type parameter
|
||||
StateT = TypeVar(
|
||||
"StateT", bound=Union[Dict[str, Any], BaseModel]
|
||||
"StateT", bound=dict[str, Any] | BaseModel,
|
||||
) # State validation type parameter
|
||||
|
||||
|
||||
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||
def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT:
|
||||
"""Ensure state matches expected type with proper validation.
|
||||
|
||||
Args:
|
||||
@@ -68,6 +62,7 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||
Raises:
|
||||
TypeError: If state doesn't match expected type
|
||||
ValueError: If state validation fails
|
||||
|
||||
"""
|
||||
"""Ensure state matches expected type with proper validation.
|
||||
|
||||
@@ -84,20 +79,22 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
|
||||
"""
|
||||
if expected_type is dict:
|
||||
if not isinstance(state, dict):
|
||||
raise TypeError(f"Expected dict, got {type(state).__name__}")
|
||||
return cast(StateT, state)
|
||||
msg = f"Expected dict, got {type(state).__name__}"
|
||||
raise TypeError(msg)
|
||||
return cast("StateT", state)
|
||||
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
|
||||
if not isinstance(state, expected_type):
|
||||
msg = f"Expected {expected_type.__name__}, got {type(state).__name__}"
|
||||
raise TypeError(
|
||||
f"Expected {expected_type.__name__}, got {type(state).__name__}"
|
||||
msg,
|
||||
)
|
||||
return cast(StateT, state)
|
||||
raise TypeError(f"Invalid expected_type: {expected_type}")
|
||||
return cast("StateT", state)
|
||||
msg = f"Invalid expected_type: {expected_type}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
"""
|
||||
Marks a method as a flow's starting point.
|
||||
def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
"""Marks a method as a flow's starting point.
|
||||
|
||||
This decorator designates a method as an entry point for the flow execution.
|
||||
It can optionally specify conditions that trigger the start based on other
|
||||
@@ -135,6 +132,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
>>> @start(and_("method1", "method2")) # Start after multiple methods
|
||||
>>> def complex_start(self):
|
||||
... pass
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@@ -154,17 +152,17 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
else:
|
||||
msg = "Condition must be a method, string, or a result of or_() or and_()"
|
||||
raise ValueError(
|
||||
"Condition must be a method, string, or a result of or_() or and_()"
|
||||
msg,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""
|
||||
Creates a listener that executes when specified conditions are met.
|
||||
def listen(condition: str | dict | Callable) -> Callable:
|
||||
"""Creates a listener that executes when specified conditions are met.
|
||||
|
||||
This decorator sets up a method to execute in response to other method
|
||||
executions in the flow. It supports both simple and complex triggering
|
||||
@@ -197,6 +195,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
>>> @listen(or_("success", "failure")) # Listen to multiple methods
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@@ -214,17 +213,17 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
else:
|
||||
msg = "Condition must be a method, string, or a result of or_() or and_()"
|
||||
raise ValueError(
|
||||
"Condition must be a method, string, or a result of or_() or and_()"
|
||||
msg,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
"""
|
||||
Creates a routing method that directs flow execution based on conditions.
|
||||
def router(condition: str | dict | Callable) -> Callable:
|
||||
"""Creates a routing method that directs flow execution based on conditions.
|
||||
|
||||
This decorator marks a method as a router, which can dynamically determine
|
||||
the next steps in the flow based on its return value. Routers are triggered
|
||||
@@ -262,6 +261,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
... if all([self.state.valid, self.state.processed]):
|
||||
... return CONTINUE
|
||||
... return STOP
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@@ -280,17 +280,17 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
else:
|
||||
msg = "Condition must be a method, string, or a result of or_() or and_()"
|
||||
raise ValueError(
|
||||
"Condition must be a method, string, or a result of or_() or and_()"
|
||||
msg,
|
||||
)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with OR logic for flow control.
|
||||
def or_(*conditions: str | dict | Callable) -> dict:
|
||||
"""Combines multiple conditions with OR logic for flow control.
|
||||
|
||||
Creates a condition that is satisfied when any of the specified conditions
|
||||
are met. This is used with @start, @listen, or @router decorators to create
|
||||
@@ -320,6 +320,7 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
>>> @listen(or_("success", "timeout"))
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
|
||||
"""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
@@ -330,13 +331,13 @@ def or_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
elif callable(condition):
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in or_()")
|
||||
msg = "Invalid condition in or_()"
|
||||
raise ValueError(msg)
|
||||
return {"type": "OR", "methods": methods}
|
||||
|
||||
|
||||
def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
"""
|
||||
Combines multiple conditions with AND logic for flow control.
|
||||
def and_(*conditions: str | dict | Callable) -> dict:
|
||||
"""Combines multiple conditions with AND logic for flow control.
|
||||
|
||||
Creates a condition that is satisfied only when all specified conditions
|
||||
are met. This is used with @start, @listen, or @router decorators to create
|
||||
@@ -366,6 +367,7 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
>>> @listen(and_("validated", "processed"))
|
||||
>>> def handle_complete_data(self):
|
||||
... pass
|
||||
|
||||
"""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
@@ -376,7 +378,8 @@ def and_(*conditions: Union[str, dict, Callable]) -> dict:
|
||||
elif callable(condition):
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in and_()")
|
||||
msg = "Invalid condition in and_()"
|
||||
raise ValueError(msg)
|
||||
return {"type": "AND", "methods": methods}
|
||||
|
||||
|
||||
@@ -416,10 +419,10 @@ class FlowMeta(type):
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
|
||||
setattr(cls, "_start_methods", start_methods)
|
||||
setattr(cls, "_listeners", listeners)
|
||||
setattr(cls, "_routers", routers)
|
||||
setattr(cls, "_router_paths", router_paths)
|
||||
cls._start_methods = start_methods
|
||||
cls._listeners = listeners
|
||||
cls._routers = routers
|
||||
cls._router_paths = router_paths
|
||||
|
||||
return cls
|
||||
|
||||
@@ -427,17 +430,18 @@ class FlowMeta(type):
|
||||
class Flow(Generic[T], metaclass=FlowMeta):
|
||||
"""Base class for all flows.
|
||||
|
||||
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
|
||||
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel.
|
||||
"""
|
||||
|
||||
_printer = Printer()
|
||||
|
||||
_start_methods: List[str] = []
|
||||
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
||||
_routers: Set[str] = set()
|
||||
_router_paths: Dict[str, List[str]] = {}
|
||||
initial_state: Union[Type[T], T, None] = None
|
||||
_start_methods: list[str] = []
|
||||
_listeners: dict[str, tuple[str, list[str]]] = {}
|
||||
_routers: set[str] = set()
|
||||
_router_paths: dict[str, list[str]] = {}
|
||||
initial_state: type[T] | T | None = None
|
||||
|
||||
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
|
||||
def __class_getitem__(cls: type["Flow"], item: type[T]) -> type["Flow"]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
_initial_state_T = item # type: ignore
|
||||
|
||||
@@ -446,7 +450,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persistence: Optional[FlowPersistence] = None,
|
||||
persistence: FlowPersistence | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new Flow instance.
|
||||
@@ -454,13 +458,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Args:
|
||||
persistence: Optional persistence backend for storing flow states
|
||||
**kwargs: Additional state values to initialize or override
|
||||
|
||||
"""
|
||||
# Initialize basic instance attributes
|
||||
self._methods: Dict[str, Callable] = {}
|
||||
self._method_execution_counts: Dict[str, int] = {}
|
||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
self._methods: dict[str, Callable] = {}
|
||||
self._method_execution_counts: dict[str, int] = {}
|
||||
self._pending_and_listeners: dict[str, set[str]] = {}
|
||||
self._method_outputs: list[Any] = [] # List to store all method outputs
|
||||
self._persistence: FlowPersistence | None = persistence
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
@@ -502,58 +507,61 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Raises:
|
||||
ValueError: If structured state model lacks 'id' field
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
|
||||
"""
|
||||
# Handle case where initial_state is None but we have a type parameter
|
||||
if self.initial_state is None and hasattr(self, "_initial_state_T"):
|
||||
state_type = getattr(self, "_initial_state_T")
|
||||
state_type = self._initial_state_T
|
||||
if isinstance(state_type, type):
|
||||
if issubclass(state_type, FlowState):
|
||||
# Create instance without id, then set it
|
||||
instance = state_type()
|
||||
if not hasattr(instance, "id"):
|
||||
setattr(instance, "id", str(uuid4()))
|
||||
return cast(T, instance)
|
||||
elif issubclass(state_type, BaseModel):
|
||||
instance.id = str(uuid4())
|
||||
return cast("T", instance)
|
||||
if issubclass(state_type, BaseModel):
|
||||
# Create a new type that includes the ID field
|
||||
class StateWithId(state_type, FlowState): # type: ignore
|
||||
pass
|
||||
|
||||
instance = StateWithId()
|
||||
if not hasattr(instance, "id"):
|
||||
setattr(instance, "id", str(uuid4()))
|
||||
return cast(T, instance)
|
||||
elif state_type is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
instance.id = str(uuid4())
|
||||
return cast("T", instance)
|
||||
if state_type is dict:
|
||||
return cast("T", {"id": str(uuid4())})
|
||||
|
||||
# Handle case where no initial state is provided
|
||||
if self.initial_state is None:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
return cast("T", {"id": str(uuid4())})
|
||||
|
||||
# Handle case where initial_state is a type (class)
|
||||
if isinstance(self.initial_state, type):
|
||||
if issubclass(self.initial_state, FlowState):
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
elif issubclass(self.initial_state, BaseModel):
|
||||
return cast("T", self.initial_state()) # Uses model defaults
|
||||
if issubclass(self.initial_state, BaseModel):
|
||||
# Validate that the model has an id field
|
||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||
if not model_fields or "id" not in model_fields:
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
elif self.initial_state is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
msg = "Flow state model must have an 'id' field"
|
||||
raise ValueError(msg)
|
||||
return cast("T", self.initial_state()) # Uses model defaults
|
||||
if self.initial_state is dict:
|
||||
return cast("T", {"id": str(uuid4())})
|
||||
|
||||
# Handle dictionary instance case
|
||||
if isinstance(self.initial_state, dict):
|
||||
new_state = dict(self.initial_state) # Copy to avoid mutations
|
||||
if "id" not in new_state:
|
||||
new_state["id"] = str(uuid4())
|
||||
return cast(T, new_state)
|
||||
return cast("T", new_state)
|
||||
|
||||
# Handle BaseModel instance case
|
||||
if isinstance(self.initial_state, BaseModel):
|
||||
model = cast(BaseModel, self.initial_state)
|
||||
model = cast("BaseModel", self.initial_state)
|
||||
if not hasattr(model, "id"):
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
msg = "Flow state model must have an 'id' field"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create new instance with same values to avoid mutations
|
||||
if hasattr(model, "model_dump"):
|
||||
@@ -570,9 +578,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
# Create new instance of the same class
|
||||
model_class = type(model)
|
||||
return cast(T, model_class(**state_dict))
|
||||
return cast("T", model_class(**state_dict))
|
||||
msg = f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
|
||||
raise TypeError(
|
||||
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
|
||||
msg,
|
||||
)
|
||||
|
||||
def _copy_state(self) -> T:
|
||||
@@ -583,7 +592,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def method_outputs(self) -> List[Any]:
|
||||
def method_outputs(self) -> list[Any]:
|
||||
"""Returns the list of all outputs from executed methods."""
|
||||
return self._method_outputs
|
||||
|
||||
@@ -607,6 +616,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
flow = MyFlow()
|
||||
print(f"Current flow ID: {flow.flow_id}") # Safely get flow ID
|
||||
```
|
||||
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "_state"):
|
||||
@@ -614,13 +624,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
return str(self._state.get("id", ""))
|
||||
elif isinstance(self._state, BaseModel):
|
||||
if isinstance(self._state, BaseModel):
|
||||
return str(getattr(self._state, "id", ""))
|
||||
return ""
|
||||
except (AttributeError, TypeError):
|
||||
return "" # Safely handle any unexpected attribute access issues
|
||||
|
||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||
def _initialize_state(self, inputs: dict[str, Any]) -> None:
|
||||
"""Initialize or update flow state with new inputs.
|
||||
|
||||
Args:
|
||||
@@ -629,6 +639,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Raises:
|
||||
ValueError: If validation fails for structured state
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
|
||||
"""
|
||||
if isinstance(self._state, dict):
|
||||
# For dict states, preserve existing fields unless overridden
|
||||
@@ -644,7 +655,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
elif isinstance(self._state, BaseModel):
|
||||
# For BaseModel states, preserve existing fields unless overridden
|
||||
try:
|
||||
model = cast(BaseModel, self._state)
|
||||
model = cast("BaseModel", self._state)
|
||||
# Get current state as dict
|
||||
if hasattr(model, "model_dump"):
|
||||
current_state = model.model_dump()
|
||||
@@ -662,19 +673,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
model_class = type(model)
|
||||
if hasattr(model_class, "model_validate"):
|
||||
# Pydantic v2
|
||||
self._state = cast(T, model_class.model_validate(new_state))
|
||||
self._state = cast("T", model_class.model_validate(new_state))
|
||||
elif hasattr(model_class, "parse_obj"):
|
||||
# Pydantic v1
|
||||
self._state = cast(T, model_class.parse_obj(new_state))
|
||||
self._state = cast("T", model_class.parse_obj(new_state))
|
||||
else:
|
||||
# Fallback for other BaseModel implementations
|
||||
self._state = cast(T, model_class(**new_state))
|
||||
self._state = cast("T", model_class(**new_state))
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
||||
msg = f"Invalid inputs for structured state: {e}"
|
||||
raise ValueError(msg) from e
|
||||
else:
|
||||
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
||||
msg = "State must be a BaseModel instance or a dictionary."
|
||||
raise TypeError(msg)
|
||||
|
||||
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
|
||||
def _restore_state(self, stored_state: dict[str, Any]) -> None:
|
||||
"""Restore flow state from persistence.
|
||||
|
||||
Args:
|
||||
@@ -683,11 +696,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Raises:
|
||||
ValueError: If validation fails for structured state
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
|
||||
"""
|
||||
# When restoring from persistence, use the stored ID
|
||||
stored_id = stored_state.get("id")
|
||||
if not stored_id:
|
||||
raise ValueError("Stored state must have an 'id' field")
|
||||
msg = "Stored state must have an 'id' field"
|
||||
raise ValueError(msg)
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
# For dict states, update all fields from stored state
|
||||
@@ -695,22 +710,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._state.update(stored_state)
|
||||
elif isinstance(self._state, BaseModel):
|
||||
# For BaseModel states, create new instance with stored values
|
||||
model = cast(BaseModel, self._state)
|
||||
model = cast("BaseModel", self._state)
|
||||
if hasattr(model, "model_validate"):
|
||||
# Pydantic v2
|
||||
self._state = cast(T, type(model).model_validate(stored_state))
|
||||
self._state = cast("T", type(model).model_validate(stored_state))
|
||||
elif hasattr(model, "parse_obj"):
|
||||
# Pydantic v1
|
||||
self._state = cast(T, type(model).parse_obj(stored_state))
|
||||
self._state = cast("T", type(model).parse_obj(stored_state))
|
||||
else:
|
||||
# Fallback for other BaseModel implementations
|
||||
self._state = cast(T, type(model)(**stored_state))
|
||||
self._state = cast("T", type(model)(**stored_state))
|
||||
else:
|
||||
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
|
||||
msg = f"State must be dict or BaseModel, got {type(self._state)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Start the flow execution in a synchronous context.
|
||||
def kickoff(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
"""Start the flow execution in a synchronous context.
|
||||
|
||||
This method wraps kickoff_async so that all state initialization and event
|
||||
emission is handled in the asynchronous method.
|
||||
@@ -721,9 +736,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
return asyncio.run(run_flow())
|
||||
|
||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Start the flow execution asynchronously.
|
||||
async def kickoff_async(self, inputs: dict[str, Any] | None = None) -> Any:
|
||||
"""Start the flow execution asynchronously.
|
||||
|
||||
This method performs state restoration (if an 'id' is provided and persistence is available)
|
||||
and updates the flow state with any additional inputs. It then emits the FlowStartedEvent,
|
||||
@@ -735,6 +749,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
|
||||
"""
|
||||
if inputs:
|
||||
# Override the id in the state if it exists in inputs
|
||||
@@ -742,7 +757,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(self._state, dict):
|
||||
self._state["id"] = inputs["id"]
|
||||
elif isinstance(self._state, BaseModel):
|
||||
setattr(self._state, "id", inputs["id"])
|
||||
self._state.id = inputs["id"]
|
||||
|
||||
# If persistence is enabled, attempt to restore the stored state using the provided id.
|
||||
if "id" in inputs and self._persistence is not None:
|
||||
@@ -756,7 +771,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._restore_state(stored_state)
|
||||
else:
|
||||
self._log_flow_event(
|
||||
f"No flow state found for UUID: {restore_uuid}", color="red"
|
||||
f"No flow state found for UUID: {restore_uuid}", color="red",
|
||||
)
|
||||
|
||||
# Update state with any additional inputs (ignoring the 'id' key)
|
||||
@@ -774,7 +789,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
),
|
||||
)
|
||||
self._log_flow_event(
|
||||
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
|
||||
f"Flow started with ID: {self.flow_id}", color="bold_magenta",
|
||||
)
|
||||
|
||||
if inputs is not None and "id" not in inputs:
|
||||
@@ -800,8 +815,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return final_output
|
||||
|
||||
async def _execute_start_method(self, start_method_name: str) -> None:
|
||||
"""
|
||||
Executes a flow's start method and its triggered listeners.
|
||||
"""Executes a flow's start method and its triggered listeners.
|
||||
|
||||
This internal method handles the execution of methods marked with @start
|
||||
decorator and manages the subsequent chain of listener executions.
|
||||
@@ -816,14 +830,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
- Executes the start method and captures its result
|
||||
- Triggers execution of any listeners waiting on this start method
|
||||
- Part of the flow's initialization sequence
|
||||
|
||||
"""
|
||||
result = await self._execute_method(
|
||||
start_method_name, self._methods[start_method_name]
|
||||
start_method_name, self._methods[start_method_name],
|
||||
)
|
||||
await self._execute_listeners(start_method_name, result)
|
||||
|
||||
async def _execute_method(
|
||||
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
|
||||
self, method_name: str, method: Callable, *args: Any, **kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
|
||||
@@ -873,11 +888,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
error=e,
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
|
||||
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
||||
"""
|
||||
Executes all listeners and routers triggered by a method completion.
|
||||
"""Executes all listeners and routers triggered by a method completion.
|
||||
|
||||
This internal method manages the execution flow by:
|
||||
1. First executing all triggered routers sequentially
|
||||
@@ -897,6 +911,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
- Each router's result becomes a new trigger_method
|
||||
- Normal listeners are executed in parallel for efficiency
|
||||
- Listeners can receive the trigger method's result as a parameter
|
||||
|
||||
"""
|
||||
# First, handle routers repeatedly until no router triggers anymore
|
||||
router_results = []
|
||||
@@ -904,7 +919,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
while True:
|
||||
routers_triggered = self._find_triggered_methods(
|
||||
current_trigger, router_only=True
|
||||
current_trigger, router_only=True,
|
||||
)
|
||||
if not routers_triggered:
|
||||
break
|
||||
@@ -920,12 +935,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
# Now execute normal listeners for all router results and the original trigger
|
||||
all_triggers = [trigger_method] + router_results
|
||||
all_triggers = [trigger_method, *router_results]
|
||||
|
||||
for current_trigger in all_triggers:
|
||||
if current_trigger: # Skip None results
|
||||
listeners_triggered = self._find_triggered_methods(
|
||||
current_trigger, router_only=False
|
||||
current_trigger, router_only=False,
|
||||
)
|
||||
if listeners_triggered:
|
||||
tasks = [
|
||||
@@ -935,10 +950,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> List[str]:
|
||||
"""
|
||||
Finds all methods that should be triggered based on conditions.
|
||||
self, trigger_method: str, router_only: bool,
|
||||
) -> list[str]:
|
||||
"""Finds all methods that should be triggered based on conditions.
|
||||
|
||||
This internal method evaluates both OR and AND conditions to determine
|
||||
which methods should be executed next in the flow.
|
||||
@@ -963,6 +977,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
* AND: Triggers only when all conditions are met
|
||||
- Maintains state for AND conditions using _pending_and_listeners
|
||||
- Separates router and normal listener evaluation
|
||||
|
||||
"""
|
||||
triggered = []
|
||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||
@@ -992,8 +1007,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
return triggered
|
||||
|
||||
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
||||
"""
|
||||
Executes a single listener method with proper event handling.
|
||||
"""Executes a single listener method with proper event handling.
|
||||
|
||||
This internal method manages the execution of an individual listener,
|
||||
including parameter inspection, event emission, and error handling.
|
||||
@@ -1018,6 +1032,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
-------------
|
||||
Catches and logs any exceptions during execution, preventing
|
||||
individual listener failures from breaking the entire flow.
|
||||
|
||||
"""
|
||||
try:
|
||||
method = self._methods[listener_name]
|
||||
@@ -1028,7 +1043,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
if method_params:
|
||||
listener_result = await self._execute_method(
|
||||
listener_name, method, result
|
||||
listener_name, method, result,
|
||||
)
|
||||
else:
|
||||
listener_result = await self._execute_method(listener_name, method)
|
||||
@@ -1036,17 +1051,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Execute listeners (and possibly routers) of this listener
|
||||
await self._execute_listeners(listener_name, listener_result)
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
|
||||
)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _log_flow_event(
|
||||
self, message: str, color: str = "yellow", level: str = "info"
|
||||
self, message: str, color: str = "yellow", level: str = "info",
|
||||
) -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
@@ -1064,6 +1076,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Note:
|
||||
This method uses the Printer utility for colored console output
|
||||
and the standard logging module for log level support.
|
||||
|
||||
"""
|
||||
self._printer.print(message, color=color)
|
||||
if level == "info":
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||
|
||||
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
|
||||
inspecting the call stack.
|
||||
"""
|
||||
|
||||
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
||||
parent_flow: InstanceOf[Flow] | None = Field(
|
||||
default=None,
|
||||
description="The parent flow of the instance, if it was created inside a flow.",
|
||||
)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
# flow_visualizer.py
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pyvis.network import Network
|
||||
|
||||
from crewai.flow.config import COLORS, NODE_STYLES
|
||||
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
||||
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import safe_path_join
|
||||
from crewai.flow.utils import calculate_node_levels
|
||||
from crewai.flow.visualization_utils import (
|
||||
add_edges,
|
||||
@@ -20,9 +19,8 @@ from crewai.flow.visualization_utils import (
|
||||
class FlowPlot:
|
||||
"""Handles the creation and rendering of flow visualization diagrams."""
|
||||
|
||||
def __init__(self, flow):
|
||||
"""
|
||||
Initialize FlowPlot with a flow object.
|
||||
def __init__(self, flow) -> None:
|
||||
"""Initialize FlowPlot with a flow object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -33,21 +31,24 @@ class FlowPlot:
|
||||
------
|
||||
ValueError
|
||||
If flow object is invalid or missing required attributes.
|
||||
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
||||
if not hasattr(flow, '_start_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
|
||||
|
||||
if not hasattr(flow, "_methods"):
|
||||
msg = "Invalid flow object: missing '_methods' attribute"
|
||||
raise ValueError(msg)
|
||||
if not hasattr(flow, "_listeners"):
|
||||
msg = "Invalid flow object: missing '_listeners' attribute"
|
||||
raise ValueError(msg)
|
||||
if not hasattr(flow, "_start_methods"):
|
||||
msg = "Invalid flow object: missing '_start_methods' attribute"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.flow = flow
|
||||
self.colors = COLORS
|
||||
self.node_styles = NODE_STYLES
|
||||
|
||||
def plot(self, filename):
|
||||
"""
|
||||
Generate and save an HTML visualization of the flow.
|
||||
def plot(self, filename) -> None:
|
||||
"""Generate and save an HTML visualization of the flow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -62,10 +63,12 @@ class FlowPlot:
|
||||
If file operations fail or visualization cannot be generated.
|
||||
RuntimeError
|
||||
If network visualization generation fails.
|
||||
|
||||
"""
|
||||
if not filename or not isinstance(filename, str):
|
||||
raise ValueError("Filename must be a non-empty string")
|
||||
|
||||
msg = "Filename must be a non-empty string"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Initialize network
|
||||
net = Network(
|
||||
@@ -89,58 +92,63 @@ class FlowPlot:
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
# Calculate levels for nodes
|
||||
try:
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to calculate node levels: {str(e)}")
|
||||
msg = f"Failed to calculate node levels: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Compute positions
|
||||
try:
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to compute node positions: {str(e)}")
|
||||
msg = f"Failed to compute node positions: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Add nodes to the network
|
||||
try:
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
||||
msg = f"Failed to add nodes to network: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Add edges to the network
|
||||
try:
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
|
||||
msg = f"Failed to add edges to network: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Generate HTML
|
||||
try:
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
|
||||
msg = f"Failed to generate network visualization: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Save the final HTML content to the file
|
||||
try:
|
||||
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {filename}.html")
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
|
||||
except OSError as e:
|
||||
msg = f"Failed to save flow visualization to {filename}.html: {e!s}"
|
||||
raise OSError(msg)
|
||||
|
||||
except (ValueError, RuntimeError, IOError) as e:
|
||||
raise e
|
||||
except (OSError, ValueError, RuntimeError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
|
||||
msg = f"Unexpected error during flow visualization: {e!s}"
|
||||
raise RuntimeError(msg)
|
||||
finally:
|
||||
self._cleanup_pyvis_lib()
|
||||
|
||||
def _generate_final_html(self, network_html):
|
||||
"""
|
||||
Generate the final HTML content with network visualization and legend.
|
||||
"""Generate the final HTML content with network visualization and legend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -158,9 +166,11 @@ class FlowPlot:
|
||||
If template or logo files cannot be accessed.
|
||||
ValueError
|
||||
If network_html is invalid.
|
||||
|
||||
"""
|
||||
if not network_html:
|
||||
raise ValueError("Invalid network HTML content")
|
||||
msg = "Invalid network HTML content"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Extract just the body content from the generated HTML
|
||||
@@ -169,9 +179,11 @@ class FlowPlot:
|
||||
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
|
||||
|
||||
if not os.path.exists(template_path):
|
||||
raise IOError(f"Template file not found: {template_path}")
|
||||
msg = f"Template file not found: {template_path}"
|
||||
raise OSError(msg)
|
||||
if not os.path.exists(logo_path):
|
||||
raise IOError(f"Logo file not found: {logo_path}")
|
||||
msg = f"Logo file not found: {logo_path}"
|
||||
raise OSError(msg)
|
||||
|
||||
html_handler = HTMLTemplateHandler(template_path, logo_path)
|
||||
network_body = html_handler.extract_body_content(network_html)
|
||||
@@ -179,16 +191,15 @@ class FlowPlot:
|
||||
# Generate the legend items HTML
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
network_body, legend_items_html
|
||||
return html_handler.generate_final_html(
|
||||
network_body, legend_items_html,
|
||||
)
|
||||
return final_html_content
|
||||
except Exception as e:
|
||||
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
|
||||
msg = f"Failed to generate visualization HTML: {e!s}"
|
||||
raise OSError(msg)
|
||||
|
||||
def _cleanup_pyvis_lib(self):
|
||||
"""
|
||||
Clean up the generated lib folder from pyvis.
|
||||
def _cleanup_pyvis_lib(self) -> None:
|
||||
"""Clean up the generated lib folder from pyvis.
|
||||
|
||||
This method safely removes the temporary lib directory created by pyvis
|
||||
during network visualization generation.
|
||||
@@ -198,15 +209,14 @@ class FlowPlot:
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
shutil.rmtree(lib_folder)
|
||||
except ValueError as e:
|
||||
print(f"Error validating lib folder path: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up lib folder: {e}")
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def plot_flow(flow, filename="flow_plot"):
|
||||
"""
|
||||
Convenience function to create and save a flow visualization.
|
||||
def plot_flow(flow, filename="flow_plot") -> None:
|
||||
"""Convenience function to create and save a flow visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -221,6 +231,7 @@ def plot_flow(flow, filename="flow_plot"):
|
||||
If flow object or filename is invalid.
|
||||
IOError
|
||||
If file operations fail.
|
||||
|
||||
"""
|
||||
visualizer = FlowPlot(flow)
|
||||
visualizer.plot(filename)
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import base64
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import validate_path_exists
|
||||
|
||||
|
||||
class HTMLTemplateHandler:
|
||||
"""Handles HTML template processing and generation for flow visualization diagrams."""
|
||||
|
||||
def __init__(self, template_path, logo_path):
|
||||
"""
|
||||
Initialize HTMLTemplateHandler with validated template and logo paths.
|
||||
def __init__(self, template_path, logo_path) -> None:
|
||||
"""Initialize HTMLTemplateHandler with validated template and logo paths.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -23,16 +21,18 @@ class HTMLTemplateHandler:
|
||||
------
|
||||
ValueError
|
||||
If template or logo paths are invalid or files don't exist.
|
||||
|
||||
"""
|
||||
try:
|
||||
self.template_path = validate_path_exists(template_path, "file")
|
||||
self.logo_path = validate_path_exists(logo_path, "file")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid template or logo path: {e}")
|
||||
msg = f"Invalid template or logo path: {e}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def read_template(self):
|
||||
"""Read and return the HTML template file contents."""
|
||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||
with open(self.template_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def encode_logo(self):
|
||||
@@ -81,13 +81,12 @@ class HTMLTemplateHandler:
|
||||
|
||||
final_html_content = html_template.replace("{{ title }}", title)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ network_content }}", network_body
|
||||
"{{ network_content }}", network_body,
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64,
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
return final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html,
|
||||
)
|
||||
|
||||
return final_html_content
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
"""
|
||||
Path utilities for secure file operations in CrewAI flow module.
|
||||
"""Path utilities for secure file operations in CrewAI flow module.
|
||||
|
||||
This module provides utilities for secure path handling to prevent directory
|
||||
traversal attacks and ensure paths remain within allowed boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
"""
|
||||
Safely join path components and ensure the result is within allowed boundaries.
|
||||
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
|
||||
"""Safely join path components and ensure the result is within allowed boundaries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -31,39 +27,43 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
ValueError
|
||||
If the resulting path would be outside the root directory
|
||||
or if any path component is invalid.
|
||||
|
||||
"""
|
||||
if not parts:
|
||||
raise ValueError("No path components provided")
|
||||
msg = "No path components provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Convert all parts to strings and clean them
|
||||
clean_parts = [str(part).strip() for part in parts if part]
|
||||
if not clean_parts:
|
||||
raise ValueError("No valid path components provided")
|
||||
msg = "No valid path components provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Establish root directory
|
||||
root_path = Path(root).resolve() if root else Path.cwd()
|
||||
|
||||
|
||||
# Join and resolve the full path
|
||||
full_path = Path(root_path, *clean_parts).resolve()
|
||||
|
||||
|
||||
# Check if the resolved path is within root
|
||||
if not str(full_path).startswith(str(root_path)):
|
||||
msg = f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
||||
raise ValueError(
|
||||
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
||||
msg,
|
||||
)
|
||||
|
||||
|
||||
return str(full_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path components: {str(e)}")
|
||||
msg = f"Invalid path components: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
|
||||
"""
|
||||
Validate that a path exists and is of the expected type.
|
||||
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
|
||||
"""Validate that a path exists and is of the expected type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -81,29 +81,33 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
|
||||
------
|
||||
ValueError
|
||||
If path doesn't exist or is not of expected type.
|
||||
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path).resolve()
|
||||
|
||||
|
||||
if not path_obj.exists():
|
||||
raise ValueError(f"Path does not exist: {path}")
|
||||
|
||||
msg = f"Path does not exist: {path}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if file_type == "file" and not path_obj.is_file():
|
||||
raise ValueError(f"Path is not a file: {path}")
|
||||
elif file_type == "directory" and not path_obj.is_dir():
|
||||
raise ValueError(f"Path is not a directory: {path}")
|
||||
|
||||
msg = f"Path is not a file: {path}"
|
||||
raise ValueError(msg)
|
||||
if file_type == "directory" and not path_obj.is_dir():
|
||||
msg = f"Path is not a directory: {path}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return str(path_obj)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path: {str(e)}")
|
||||
msg = f"Invalid path: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
"""
|
||||
Safely list files in a directory matching a pattern.
|
||||
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
|
||||
"""Safely list files in a directory matching a pattern.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -121,15 +125,18 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
------
|
||||
ValueError
|
||||
If directory is invalid or inaccessible.
|
||||
|
||||
"""
|
||||
try:
|
||||
dir_path = Path(directory).resolve()
|
||||
if not dir_path.is_dir():
|
||||
raise ValueError(f"Not a directory: {directory}")
|
||||
|
||||
msg = f"Not a directory: {directory}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error listing files: {str(e)}")
|
||||
msg = f"Error listing files: {e!s}"
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -1,53 +1,52 @@
|
||||
"""Base class for flow state persistence."""
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FlowPersistence(abc.ABC):
|
||||
"""Abstract base class for flow state persistence.
|
||||
|
||||
|
||||
This class defines the interface that all persistence implementations must follow.
|
||||
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
|
||||
"""
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_db(self) -> None:
|
||||
"""Initialize the persistence backend.
|
||||
|
||||
|
||||
This method should handle any necessary setup, such as:
|
||||
- Creating tables
|
||||
- Establishing connections
|
||||
- Setting up indexes
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel]
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
"""Persist the flow state after method completion.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Decorators for flow state persistence.
|
||||
"""Decorators for flow state persistence.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -19,18 +18,16 @@ Example:
|
||||
# Asynchronous method implementation
|
||||
await some_async_operation()
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
|
||||
"save_state": "Saving flow state to memory for ID: {}",
|
||||
"save_error": "Failed to persist state for method {}: {}",
|
||||
"state_missing": "Flow instance has no state",
|
||||
"id_missing": "Flow state must have an 'id' field for persistence"
|
||||
"id_missing": "Flow state must have an 'id' field for persistence",
|
||||
}
|
||||
|
||||
|
||||
@@ -74,20 +71,23 @@ class PersistenceDecorator:
|
||||
ValueError: If flow has no state or state lacks an ID
|
||||
RuntimeError: If state persistence fails
|
||||
AttributeError: If flow instance lacks required state attributes
|
||||
|
||||
"""
|
||||
try:
|
||||
state = getattr(flow_instance, 'state', None)
|
||||
state = getattr(flow_instance, "state", None)
|
||||
if state is None:
|
||||
raise ValueError("Flow instance has no state")
|
||||
msg = "Flow instance has no state"
|
||||
raise ValueError(msg)
|
||||
|
||||
flow_uuid: Optional[str] = None
|
||||
flow_uuid: str | None = None
|
||||
if isinstance(state, dict):
|
||||
flow_uuid = state.get('id')
|
||||
flow_uuid = state.get("id")
|
||||
elif isinstance(state, BaseModel):
|
||||
flow_uuid = getattr(state, 'id', None)
|
||||
flow_uuid = getattr(state, "id", None)
|
||||
|
||||
if not flow_uuid:
|
||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||
msg = "Flow state must have an 'id' field for persistence"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Log state saving only if verbose is True
|
||||
if verbose:
|
||||
@@ -103,21 +103,22 @@ class PersistenceDecorator:
|
||||
except Exception as e:
|
||||
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(f"State persistence failed: {str(e)}") from e
|
||||
logger.exception(error_msg)
|
||||
msg = f"State persistence failed: {e!s}"
|
||||
raise RuntimeError(msg) from e
|
||||
except AttributeError:
|
||||
error_msg = LOG_MESSAGES["state_missing"]
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
except (TypeError, ValueError) as e:
|
||||
error_msg = LOG_MESSAGES["id_missing"]
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
|
||||
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||
"""Decorator to persist flow state.
|
||||
|
||||
This decorator can be applied at either the class level or method level.
|
||||
@@ -143,22 +144,23 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
@start()
|
||||
def begin(self):
|
||||
pass
|
||||
|
||||
"""
|
||||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
||||
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||
"""Decorator that handles both class and method decoration."""
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
|
||||
if isinstance(target, type):
|
||||
# Class decoration
|
||||
original_init = getattr(target, "__init__")
|
||||
original_init = target.__init__
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
if 'persistence' not in kwargs:
|
||||
kwargs['persistence'] = actual_persistence
|
||||
if "persistence" not in kwargs:
|
||||
kwargs["persistence"] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
setattr(target, "__init__", new_init)
|
||||
target.__init__ = new_init
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
original_methods = {}
|
||||
@@ -191,7 +193,7 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
@@ -211,44 +213,42 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
return target
|
||||
else:
|
||||
# Method decoration
|
||||
method = target
|
||||
setattr(method, "__is_flow_method__", True)
|
||||
# Method decoration
|
||||
method = target
|
||||
method.__is_flow_method__ = True
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_async_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
else:
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
method_async_wrapper.__is_flow_method__ = True
|
||||
return cast("Callable[..., T]", method_async_wrapper)
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_sync_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
method_sync_wrapper.__is_flow_method__ = True
|
||||
return cast("Callable[..., T]", method_sync_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""
|
||||
SQLite-based implementation of flow state persistence.
|
||||
"""
|
||||
"""SQLite-based implementation of flow state persistence."""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -23,7 +21,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
db_path: str
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""Initialize SQLite persistence.
|
||||
|
||||
Args:
|
||||
@@ -32,6 +30,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
Raises:
|
||||
ValueError: If db_path is invalid
|
||||
|
||||
"""
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
@@ -39,7 +38,8 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
|
||||
|
||||
if not path:
|
||||
raise ValueError("Database path must be provided")
|
||||
msg = "Database path must be provided"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.db_path = path # Now mypy knows this is str
|
||||
self.init_db()
|
||||
@@ -56,21 +56,21 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
timestamp DATETIME NOT NULL,
|
||||
state_json TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
# Add index for faster UUID lookups
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
|
||||
ON flow_states(flow_uuid)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel],
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
"""Save the current flow state to SQLite.
|
||||
|
||||
@@ -78,6 +78,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
|
||||
"""
|
||||
# Convert state_data to dict, handling both Pydantic and dict cases
|
||||
if isinstance(state_data, BaseModel):
|
||||
@@ -85,8 +86,9 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
elif isinstance(state_data, dict):
|
||||
state_dict = state_data
|
||||
else:
|
||||
msg = f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
msg,
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -107,7 +109,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
),
|
||||
)
|
||||
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
Args:
|
||||
@@ -115,6 +117,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.execute(
|
||||
|
||||
@@ -1,33 +1,32 @@
|
||||
"""
|
||||
Utility functions for flow visualization and dependency analysis.
|
||||
"""Utility functions for flow visualization and dependency analysis.
|
||||
|
||||
This module provides core functionality for analyzing and manipulating flow structures,
|
||||
including node level calculation, ancestor tracking, and return value analysis.
|
||||
Functions in this module are primarily used by the visualization system to create
|
||||
accurate and informative flow diagrams.
|
||||
|
||||
Example
|
||||
Example:
|
||||
-------
|
||||
>>> flow = Flow()
|
||||
>>> node_levels = calculate_node_levels(flow)
|
||||
>>> ancestors = build_ancestor_dict(flow)
|
||||
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Deque, Dict, List, Optional, Set, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
# Can't get source code
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error retrieving source code for function {function.__name__}: {e}")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -35,24 +34,18 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
source = textwrap.dedent(source)
|
||||
# Parse the source code into an AST
|
||||
code_ast = ast.parse(source)
|
||||
except IndentationError as e:
|
||||
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
except IndentationError:
|
||||
return None
|
||||
except SyntaxError as e:
|
||||
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
except SyntaxError:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
|
||||
print(f"Source code:\n{source}")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return_values = set()
|
||||
dict_definitions = {}
|
||||
|
||||
class DictionaryAssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_Assign(self, node):
|
||||
def visit_Assign(self, node) -> None:
|
||||
# Check if this assignment is assigning a dictionary literal to a variable
|
||||
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
|
||||
target = node.targets[0]
|
||||
@@ -69,10 +62,10 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
self.generic_visit(node)
|
||||
|
||||
class ReturnVisitor(ast.NodeVisitor):
|
||||
def visit_Return(self, node):
|
||||
def visit_Return(self, node) -> None:
|
||||
# Direct string return
|
||||
if isinstance(node.value, ast.Constant) and isinstance(
|
||||
node.value.value, str
|
||||
node.value.value, str,
|
||||
):
|
||||
return_values.add(node.value.value)
|
||||
# Dictionary-based return, like return paths[result]
|
||||
@@ -94,9 +87,8 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
return list(return_values) if return_values else None
|
||||
|
||||
|
||||
def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
"""
|
||||
Calculate the hierarchical level of each node in the flow.
|
||||
def calculate_node_levels(flow: Any) -> dict[str, int]:
|
||||
"""Calculate the hierarchical level of each node in the flow.
|
||||
|
||||
Performs a breadth-first traversal of the flow graph to assign levels
|
||||
to nodes, starting with start methods at level 0.
|
||||
@@ -117,11 +109,12 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
- Each subsequent connected node is assigned level = parent_level + 1
|
||||
- Handles both OR and AND conditions for listeners
|
||||
- Processes router paths separately
|
||||
|
||||
"""
|
||||
levels: Dict[str, int] = {}
|
||||
queue: Deque[str] = deque()
|
||||
visited: Set[str] = set()
|
||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
levels: dict[str, int] = {}
|
||||
queue: deque[str] = deque()
|
||||
visited: set[str] = set()
|
||||
pending_and_listeners: dict[str, set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
@@ -172,9 +165,8 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
"""
|
||||
Count the number of outgoing edges for each method in the flow.
|
||||
def count_outgoing_edges(flow: Any) -> dict[str, int]:
|
||||
"""Count the number of outgoing edges for each method in the flow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -185,6 +177,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
-------
|
||||
Dict[str, int]
|
||||
Dictionary mapping method names to their outgoing edge count.
|
||||
|
||||
"""
|
||||
counts = {}
|
||||
for method_name in flow._methods:
|
||||
@@ -197,9 +190,8 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Build a dictionary mapping each node to its ancestor nodes.
|
||||
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
|
||||
"""Build a dictionary mapping each node to its ancestor nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -210,9 +202,10 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
-------
|
||||
Dict[str, Set[str]]
|
||||
Dictionary mapping each node to a set of its ancestor nodes.
|
||||
|
||||
"""
|
||||
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
|
||||
visited: Set[str] = set()
|
||||
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
|
||||
visited: set[str] = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
@@ -220,10 +213,9 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
|
||||
|
||||
def dfs_ancestors(
|
||||
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
|
||||
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Perform depth-first search to build ancestor relationships.
|
||||
"""Perform depth-first search to build ancestor relationships.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -240,6 +232,7 @@ def dfs_ancestors(
|
||||
-----
|
||||
This function modifies the ancestors dictionary in-place to build
|
||||
the complete ancestor graph.
|
||||
|
||||
"""
|
||||
if node in visited:
|
||||
return
|
||||
@@ -265,10 +258,9 @@ def dfs_ancestors(
|
||||
|
||||
|
||||
def is_ancestor(
|
||||
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
|
||||
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if one node is an ancestor of another.
|
||||
"""Check if one node is an ancestor of another.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -283,13 +275,13 @@ def is_ancestor(
|
||||
-------
|
||||
bool
|
||||
True if ancestor_candidate is an ancestor of node, False otherwise.
|
||||
|
||||
"""
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
|
||||
|
||||
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Build a dictionary mapping parent nodes to their children.
|
||||
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
|
||||
"""Build a dictionary mapping parent nodes to their children.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -306,8 +298,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
- Maps listeners to their trigger methods
|
||||
- Maps router methods to their paths and listeners
|
||||
- Children lists are sorted for consistent ordering
|
||||
|
||||
"""
|
||||
parent_children: Dict[str, List[str]] = {}
|
||||
parent_children: dict[str, list[str]] = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
@@ -332,10 +325,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
|
||||
|
||||
def get_child_index(
|
||||
parent: str, child: str, parent_children: Dict[str, List[str]]
|
||||
parent: str, child: str, parent_children: dict[str, list[str]],
|
||||
) -> int:
|
||||
"""
|
||||
Get the index of a child node in its parent's sorted children list.
|
||||
"""Get the index of a child node in its parent's sorted children list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -350,27 +342,25 @@ def get_child_index(
|
||||
-------
|
||||
int
|
||||
Zero-based index of the child in its parent's sorted children list.
|
||||
|
||||
"""
|
||||
children = parent_children.get(parent, [])
|
||||
children.sort()
|
||||
return children.index(child)
|
||||
|
||||
|
||||
def process_router_paths(flow, current, current_level, levels, queue):
|
||||
"""
|
||||
Handle the router connections for the current node.
|
||||
"""
|
||||
def process_router_paths(flow, current, current_level, levels, queue) -> None:
|
||||
"""Handle the router connections for the current node."""
|
||||
if current in flow._routers:
|
||||
paths = flow._router_paths.get(current, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
_condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
if (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
queue.append(listener_name)
|
||||
if path in trigger_methods and (
|
||||
listener_name not in levels
|
||||
or levels[listener_name] > current_level + 1
|
||||
):
|
||||
levels[listener_name] = current_level + 1
|
||||
queue.append(listener_name)
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
"""
|
||||
Utilities for creating visual representations of flow structures.
|
||||
"""Utilities for creating visual representations of flow structures.
|
||||
|
||||
This module provides functions for generating network visualizations of flows,
|
||||
including node placement, edge creation, and visual styling. It handles the
|
||||
conversion of flow structures into visual network graphs with appropriate
|
||||
styling and layout.
|
||||
|
||||
Example
|
||||
Example:
|
||||
-------
|
||||
>>> flow = Flow()
|
||||
>>> net = Network(directed=True)
|
||||
>>> node_positions = compute_positions(flow, node_levels)
|
||||
>>> add_nodes_to_network(net, flow, node_positions, node_styles)
|
||||
>>> add_edges(net, flow, node_positions, colors)
|
||||
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
build_ancestor_dict,
|
||||
@@ -28,8 +28,7 @@ from .utils import (
|
||||
|
||||
|
||||
def method_calls_crew(method: Any) -> bool:
|
||||
"""
|
||||
Check if the method contains a call to `.crew()`.
|
||||
"""Check if the method contains a call to `.crew()`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -45,21 +44,22 @@ def method_calls_crew(method: Any) -> bool:
|
||||
-----
|
||||
Uses AST analysis to detect method calls, specifically looking for
|
||||
attribute access of 'crew'.
|
||||
|
||||
"""
|
||||
try:
|
||||
source = inspect.getsource(method)
|
||||
source = inspect.cleandoc(source)
|
||||
tree = ast.parse(source)
|
||||
except Exception as e:
|
||||
print(f"Could not parse method {method.__name__}: {e}")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to detect .crew() method calls."""
|
||||
def __init__(self):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.found = False
|
||||
|
||||
def visit_Call(self, node):
|
||||
def visit_Call(self, node) -> None:
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
if node.func.attr == "crew":
|
||||
self.found = True
|
||||
@@ -73,11 +73,10 @@ def method_calls_crew(method: Any) -> bool:
|
||||
def add_nodes_to_network(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
node_styles: Dict[str, Dict[str, Any]]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
node_styles: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Add nodes to the network visualization with appropriate styling.
|
||||
"""Add nodes to the network visualization with appropriate styling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -97,6 +96,7 @@ def add_nodes_to_network(
|
||||
- Router methods
|
||||
- Crew methods
|
||||
- Regular methods
|
||||
|
||||
"""
|
||||
def human_friendly_label(method_name):
|
||||
return method_name.replace("_", " ").title()
|
||||
@@ -123,7 +123,7 @@ def add_nodes_to_network(
|
||||
"multi": "html",
|
||||
"color": node_style.get("font", {}).get("color", "#FFFFFF"),
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
net.add_node(
|
||||
@@ -138,12 +138,11 @@ def add_nodes_to_network(
|
||||
|
||||
def compute_positions(
|
||||
flow: Any,
|
||||
node_levels: Dict[str, int],
|
||||
node_levels: dict[str, int],
|
||||
y_spacing: float = 150,
|
||||
x_spacing: float = 150
|
||||
) -> Dict[str, Tuple[float, float]]:
|
||||
"""
|
||||
Compute the (x, y) positions for each node in the flow graph.
|
||||
x_spacing: float = 150,
|
||||
) -> dict[str, tuple[float, float]]:
|
||||
"""Compute the (x, y) positions for each node in the flow graph.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -160,9 +159,10 @@ def compute_positions(
|
||||
-------
|
||||
Dict[str, Tuple[float, float]]
|
||||
Dictionary mapping node names to their (x, y) coordinates.
|
||||
|
||||
"""
|
||||
level_nodes: Dict[int, List[str]] = {}
|
||||
node_positions: Dict[str, Tuple[float, float]] = {}
|
||||
level_nodes: dict[int, list[str]] = {}
|
||||
node_positions: dict[str, tuple[float, float]] = {}
|
||||
|
||||
for method_name, level in node_levels.items():
|
||||
level_nodes.setdefault(level, []).append(method_name)
|
||||
@@ -180,10 +180,10 @@ def compute_positions(
|
||||
def add_edges(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
colors: Dict[str, str]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
colors: dict[str, str],
|
||||
) -> None:
|
||||
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
|
||||
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
|
||||
"""
|
||||
Add edges to the network visualization with appropriate styling.
|
||||
|
||||
@@ -245,7 +245,7 @@ def add_edges(
|
||||
"color": edge_color,
|
||||
"width": 2,
|
||||
"arrows": "to",
|
||||
"dashes": True if is_router_edge or is_and_condition else False,
|
||||
"dashes": bool(is_router_edge or is_and_condition),
|
||||
"smooth": edge_smooth,
|
||||
}
|
||||
|
||||
@@ -261,9 +261,7 @@ def add_edges(
|
||||
# If it's a known router edge and the method is known, don't warn.
|
||||
# This means the path is legitimate, just not reflected as nodes here.
|
||||
if not (is_router_edge and method_known):
|
||||
print(
|
||||
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
|
||||
)
|
||||
pass
|
||||
|
||||
# Edges for router return paths
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
@@ -278,7 +276,7 @@ def add_edges(
|
||||
and listener_name in node_positions
|
||||
):
|
||||
is_cycle_edge = is_ancestor(
|
||||
router_method_name, listener_name, ancestors
|
||||
router_method_name, listener_name, ancestors,
|
||||
)
|
||||
parent_has_multiple_children = (
|
||||
len(parent_children.get(router_method_name, [])) > 1
|
||||
@@ -293,7 +291,7 @@ def add_edges(
|
||||
dx = target_pos[0] - source_pos[0]
|
||||
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
||||
index = get_child_index(
|
||||
router_method_name, listener_name, parent_children
|
||||
router_method_name, listener_name, parent_children,
|
||||
)
|
||||
edge_smooth = {
|
||||
"type": smooth_type,
|
||||
@@ -316,6 +314,4 @@ def add_edges(
|
||||
# Same check here: known router edge and known method?
|
||||
method_known = listener_name in flow._methods
|
||||
if not method_known:
|
||||
print(
|
||||
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -1,55 +1,48 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseEmbedder(ABC):
|
||||
"""
|
||||
Abstract base class for text embedding models
|
||||
"""
|
||||
"""Abstract base class for text embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_chunks(self, chunks: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of text chunks
|
||||
def embed_chunks(self, chunks: list[str]) -> np.ndarray:
|
||||
"""Generate embeddings for a list of text chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
|
||||
Returns:
|
||||
Array of embeddings
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_texts(self, texts: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of texts
|
||||
def embed_texts(self, texts: list[str]) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
Array of embeddings
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_text(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single text
|
||||
"""Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding array
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimension(self) -> int:
|
||||
"""Get the dimension of the embeddings"""
|
||||
pass
|
||||
"""Get the dimension of the embeddings."""
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -19,75 +18,74 @@ except ImportError:
|
||||
|
||||
|
||||
class FastEmbed(BaseEmbedder):
|
||||
"""
|
||||
A wrapper class for text embedding models using FastEmbed
|
||||
"""
|
||||
"""A wrapper class for text embedding models using FastEmbed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the embedding model
|
||||
cache_dir: str | Path | None = None,
|
||||
) -> None:
|
||||
"""Initialize the embedding model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to use
|
||||
cache_dir: Directory to cache the model
|
||||
gpu: Whether to use GPU acceleration
|
||||
|
||||
"""
|
||||
if not FASTEMBED_AVAILABLE:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"FastEmbed is not installed. Please install it with: "
|
||||
"uv pip install fastembed or uv pip install fastembed-gpu for GPU support"
|
||||
)
|
||||
raise ImportError(
|
||||
msg,
|
||||
)
|
||||
|
||||
self.model = TextEmbedding(
|
||||
model_name=model_name,
|
||||
cache_dir=str(cache_dir) if cache_dir else None,
|
||||
)
|
||||
|
||||
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
|
||||
"""
|
||||
Generate embeddings for a list of text chunks
|
||||
def embed_chunks(self, chunks: list[str]) -> list[np.ndarray]:
|
||||
"""Generate embeddings for a list of text chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings
|
||||
"""
|
||||
embeddings = list(self.model.embed(chunks))
|
||||
return embeddings
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
|
||||
"""
|
||||
Generate embeddings for a list of texts
|
||||
return list(self.model.embed(chunks))
|
||||
|
||||
def embed_texts(self, texts: list[str]) -> list[np.ndarray]:
|
||||
"""Generate embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings
|
||||
|
||||
"""
|
||||
embeddings = list(self.model.embed(texts))
|
||||
return embeddings
|
||||
return list(self.model.embed(texts))
|
||||
|
||||
def embed_text(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single text
|
||||
"""Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding array
|
||||
|
||||
"""
|
||||
return self.embed_texts([text])[0]
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
"""Get the dimension of the embeddings"""
|
||||
"""Get the dimension of the embeddings."""
|
||||
# Generate a test embedding to get dimensions
|
||||
test_embed = self.embed_text("test")
|
||||
return len(test_embed)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -10,68 +10,70 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
|
||||
class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
"""Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
embedder: Optional[Dict[str, Any]] = None.
|
||||
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
collection_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: List[BaseKnowledgeSource],
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[KnowledgeStorage] = None,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: dict[str, Any] | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
embedder=embedder, collection_name=collection_name,
|
||||
)
|
||||
self.sources = sources
|
||||
self.storage.initialize_knowledge_storage()
|
||||
|
||||
def query(
|
||||
self, query: List[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query across all knowledge sources to find the most relevant information.
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query across all knowledge sources to find the most relevant information.
|
||||
Returns the top_k most relevant chunks.
|
||||
|
||||
Raises:
|
||||
ValueError: If storage is not initialized.
|
||||
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
msg = "Storage is not initialized."
|
||||
raise ValueError(msg)
|
||||
|
||||
results = self.storage.search(
|
||||
return self.storage.search(
|
||||
query,
|
||||
limit=results_limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def add_sources(self):
|
||||
def add_sources(self) -> None:
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
source.add()
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.reset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
msg = "Storage is not initialized."
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -7,6 +7,7 @@ class KnowledgeConfig(BaseModel):
|
||||
Args:
|
||||
results_limit (int): The number of relevant documents to return.
|
||||
score_threshold (float): The minimum score for a document to be considered relevant.
|
||||
|
||||
"""
|
||||
|
||||
results_limit: int = Field(default=3, description="The number of results to return")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -14,43 +13,43 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
"""Base class for knowledge sources that load content from files."""
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
)
|
||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
default_factory=list, description="The path to the file"
|
||||
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||
default_factory=list, description="The path to the file",
|
||||
)
|
||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info):
|
||||
def validate_file_path(self, v, info):
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
v is None
|
||||
and info.data.get(
|
||||
"file_path" if info.field_name == "file_paths" else "file_paths"
|
||||
"file_path" if info.field_name == "file_paths" else "file_paths",
|
||||
)
|
||||
is None
|
||||
):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
msg = "Either file_path or file_paths must be provided"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _) -> None:
|
||||
"""Post-initialization method to load content."""
|
||||
self.safe_file_paths = self._process_file_paths()
|
||||
self.validate_content()
|
||||
self.content = self.load_content()
|
||||
|
||||
@abstractmethod
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||
pass
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -59,7 +58,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
|
||||
color="red",
|
||||
)
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
msg = f"File not found: {path}"
|
||||
raise FileNotFoundError(msg)
|
||||
if not path.is_file():
|
||||
self._logger.log(
|
||||
"error",
|
||||
@@ -67,20 +67,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _save_documents(self):
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage."""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
msg = "No storage found to save documents."
|
||||
raise ValueError(msg)
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
def _process_file_paths(self) -> List[Path]:
|
||||
def _process_file_paths(self) -> list[Path]:
|
||||
"""Convert file_path to a list of Path objects."""
|
||||
|
||||
if hasattr(self, "file_path") and self.file_path is not None:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -90,10 +90,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
self.file_paths = self.file_path
|
||||
|
||||
if self.file_paths is None:
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
msg = "Your source must be provided with a file_paths: []"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Convert single path to list
|
||||
path_list: List[Union[Path, str]] = (
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
else list(self.file_paths)
|
||||
@@ -102,8 +103,9 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
)
|
||||
|
||||
if not path_list:
|
||||
msg = "file_path/file_paths must be a Path, str, or a list of these types"
|
||||
raise ValueError(
|
||||
"file_path/file_paths must be a Path, str, or a list of these types"
|
||||
msg,
|
||||
)
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -12,41 +12,39 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
|
||||
chunk_size: int = 4000
|
||||
chunk_overlap: int = 200
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
@abstractmethod
|
||||
def validate_content(self) -> Any:
|
||||
"""Load and preprocess content from the source."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them."""
|
||||
pass
|
||||
|
||||
def get_embeddings(self) -> List[np.ndarray]:
|
||||
def get_embeddings(self) -> list[np.ndarray]:
|
||||
"""Return the list of embeddings for the chunks."""
|
||||
return self.chunk_embeddings
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
||||
]
|
||||
|
||||
def _save_documents(self):
|
||||
"""
|
||||
Save the documents to the storage.
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage.
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
"""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
msg = "No storage found to save documents."
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
@@ -7,7 +8,6 @@ try:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -19,27 +19,33 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
|
||||
class CrewDoclingSource(BaseKnowledgeSource):
|
||||
"""Default Source class for converting documents to markdown or json
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
if not DOCLING_AVAILABLE:
|
||||
raise ImportError(
|
||||
msg = (
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
"Please install it using: uv add docling"
|
||||
)
|
||||
raise ImportError(
|
||||
msg,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
|
||||
file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||
content: List["DoclingDocument"] = Field(default_factory=list)
|
||||
file_path: list[Path | str] | None = Field(default=None)
|
||||
file_paths: list[Path | str] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
safe_file_paths: list[Path | str] = Field(default_factory=list)
|
||||
content: list["DoclingDocument"] = Field(default_factory=list)
|
||||
document_converter: "DocumentConverter" = Field(
|
||||
default_factory=lambda: DocumentConverter(
|
||||
allowed_formats=[
|
||||
@@ -51,8 +57,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
InputFormat.IMAGE,
|
||||
InputFormat.XLSX,
|
||||
InputFormat.PPTX,
|
||||
]
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
@@ -66,7 +72,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.safe_file_paths = self.validate_content()
|
||||
self.content = self._load_content()
|
||||
|
||||
def _load_content(self) -> List["DoclingDocument"]:
|
||||
def _load_content(self) -> list["DoclingDocument"]:
|
||||
try:
|
||||
return self._convert_source_to_docling_documents()
|
||||
except ConversionError as e:
|
||||
@@ -75,10 +81,10 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
|
||||
"red",
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error loading content: {e}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
def add(self) -> None:
|
||||
if self.content is None:
|
||||
@@ -88,7 +94,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
self._save_documents()
|
||||
|
||||
def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]:
|
||||
def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]:
|
||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
@@ -97,8 +103,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
for chunk in chunker.chunk(doc):
|
||||
yield chunk.text
|
||||
|
||||
def validate_content(self) -> List[Union[Path, str]]:
|
||||
processed_paths: List[Union[Path, str]] = []
|
||||
def validate_content(self) -> list[Path | str]:
|
||||
processed_paths: list[Path | str] = []
|
||||
for path in self.file_paths:
|
||||
if isinstance(path, str):
|
||||
if path.startswith(("http://", "https://")):
|
||||
@@ -106,15 +112,18 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
if self._validate_url(path):
|
||||
processed_paths.append(path)
|
||||
else:
|
||||
raise ValueError(f"Invalid URL format: {path}")
|
||||
msg = f"Invalid URL format: {path}"
|
||||
raise ValueError(msg)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
|
||||
msg = f"Invalid URL: {path}. Error: {e!s}"
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
|
||||
if local_path.exists():
|
||||
processed_paths.append(local_path)
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {local_path}")
|
||||
msg = f"File not found: {local_path}"
|
||||
raise FileNotFoundError(msg)
|
||||
else:
|
||||
# this is an instance of Path
|
||||
processed_paths.append(path)
|
||||
@@ -128,7 +137,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
result.scheme in ("http", "https"),
|
||||
result.netloc,
|
||||
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
|
||||
]
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,11 +7,11 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries CSV file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess CSV file content."""
|
||||
content_dict = {}
|
||||
for file_path in self.safe_file_paths:
|
||||
with open(file_path, "r", encoding="utf-8") as csvfile:
|
||||
with open(file_path, encoding="utf-8") as csvfile:
|
||||
reader = csv.reader(csvfile)
|
||||
content = ""
|
||||
for row in reader:
|
||||
@@ -21,8 +20,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
return content_dict
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add CSV file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add CSV file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
content_str = (
|
||||
@@ -32,7 +30,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -16,34 +14,34 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
)
|
||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
default_factory=list, description="The path to the file"
|
||||
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||
default_factory=list, description="The path to the file",
|
||||
)
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
content: Dict[Path, Dict[str, str]] = Field(default_factory=dict)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
content: dict[Path, dict[str, str]] = Field(default_factory=dict)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info):
|
||||
def validate_file_path(self, v, info):
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
v is None
|
||||
and info.data.get(
|
||||
"file_path" if info.field_name == "file_paths" else "file_paths"
|
||||
"file_path" if info.field_name == "file_paths" else "file_paths",
|
||||
)
|
||||
is None
|
||||
):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
msg = "Either file_path or file_paths must be provided"
|
||||
raise ValueError(msg)
|
||||
return v
|
||||
|
||||
def _process_file_paths(self) -> List[Path]:
|
||||
def _process_file_paths(self) -> list[Path]:
|
||||
"""Convert file_path to a list of Path objects."""
|
||||
|
||||
if hasattr(self, "file_path") and self.file_path is not None:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -53,10 +51,11 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.file_paths = self.file_path
|
||||
|
||||
if self.file_paths is None:
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
msg = "Your source must be provided with a file_paths: []"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Convert single path to list
|
||||
path_list: List[Union[Path, str]] = (
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
else list(self.file_paths)
|
||||
@@ -65,13 +64,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
)
|
||||
|
||||
if not path_list:
|
||||
msg = "file_path/file_paths must be a Path, str, or a list of these types"
|
||||
raise ValueError(
|
||||
"file_path/file_paths must be a Path, str, or a list of these types"
|
||||
msg,
|
||||
)
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -80,7 +80,8 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
f"File not found: {path}. Try adding sources to the knowledge directory. If it's inside the knowledge directory, use the relative path.",
|
||||
color="red",
|
||||
)
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
msg = f"File not found: {path}"
|
||||
raise FileNotFoundError(msg)
|
||||
if not path.is_file():
|
||||
self._logger.log(
|
||||
"error",
|
||||
@@ -100,7 +101,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.validate_content()
|
||||
self.content = self._load_content()
|
||||
|
||||
def _load_content(self) -> Dict[Path, Dict[str, str]]:
|
||||
def _load_content(self) -> dict[Path, dict[str, str]]:
|
||||
"""Load and preprocess Excel file content from multiple sheets.
|
||||
|
||||
Each sheet's content is converted to CSV format and stored.
|
||||
@@ -111,6 +112,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
Raises:
|
||||
ImportError: If required dependencies are missing.
|
||||
FileNotFoundError: If the specified Excel file cannot be opened.
|
||||
|
||||
"""
|
||||
pd = self._import_dependencies()
|
||||
content_dict = {}
|
||||
@@ -119,14 +121,14 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
with pd.ExcelFile(file_path) as xl:
|
||||
sheet_dict = {
|
||||
str(sheet_name): str(
|
||||
pd.read_excel(xl, sheet_name).to_csv(index=False)
|
||||
pd.read_excel(xl, sheet_name).to_csv(index=False),
|
||||
)
|
||||
for sheet_name in xl.sheet_names
|
||||
}
|
||||
content_dict[file_path] = sheet_dict
|
||||
return content_dict
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
@@ -138,13 +140,13 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
return pd
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
msg = f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
raise ImportError(
|
||||
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
msg,
|
||||
)
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add Excel file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add Excel file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
# Convert dictionary values to a single string if content is a dictionary
|
||||
@@ -161,7 +163,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,12 +8,12 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries JSON file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess JSON file content."""
|
||||
content: Dict[Path, str] = {}
|
||||
content: dict[Path, str] = {}
|
||||
for path in self.safe_file_paths:
|
||||
path = self.convert_to_path(path)
|
||||
with open(path, "r", encoding="utf-8") as json_file:
|
||||
with open(path, encoding="utf-8") as json_file:
|
||||
data = json.load(json_file)
|
||||
content[path] = self._json_to_text(data)
|
||||
return content
|
||||
@@ -29,12 +29,11 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
for item in data:
|
||||
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
|
||||
else:
|
||||
text += f"{str(data)}"
|
||||
text += f"{data!s}"
|
||||
return text
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add JSON file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add JSON file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
content_str = (
|
||||
@@ -44,7 +43,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries PDF file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess PDF file content."""
|
||||
pdfplumber = self._import_pdfplumber()
|
||||
|
||||
@@ -31,21 +30,21 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
|
||||
return pdfplumber
|
||||
except ImportError:
|
||||
msg = "pdfplumber is not installed. Please install it with: pip install pdfplumber"
|
||||
raise ImportError(
|
||||
"pdfplumber is not installed. Please install it with: pip install pdfplumber"
|
||||
msg,
|
||||
)
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add PDF file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add PDF file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
for _, text in self.content.items():
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -9,16 +8,17 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
"""A knowledge source that stores and queries plain text content using embeddings."""
|
||||
|
||||
content: str = Field(...)
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _) -> None:
|
||||
"""Post-initialization method to validate content."""
|
||||
self.validate_content()
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate string content."""
|
||||
if not isinstance(self.content, str):
|
||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||
msg = "StringKnowledgeSource only accepts string content"
|
||||
raise ValueError(msg)
|
||||
|
||||
def add(self) -> None:
|
||||
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them."""
|
||||
@@ -26,7 +26,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,26 +6,25 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries text file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess text file content."""
|
||||
content = {}
|
||||
for path in self.safe_file_paths:
|
||||
path = self.convert_to_path(path)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
content[path] = f.read()
|
||||
return content
|
||||
|
||||
def add(self) -> None:
|
||||
"""
|
||||
Add text file content to the knowledge source, chunk it, compute embeddings,
|
||||
"""Add text file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
for _, text in self.content.items():
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseKnowledgeStorage(ABC):
|
||||
@@ -8,22 +8,19 @@ class BaseKnowledgeStorage(ABC):
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
|
||||
self, documents: list[str], metadata: dict[str, Any] | list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Save documents to the knowledge base."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the knowledge base."""
|
||||
pass
|
||||
|
||||
@@ -4,12 +4,11 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import chromadb
|
||||
import chromadb.errors
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import OneOrMany
|
||||
from chromadb.config import Settings
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
@@ -19,6 +18,9 @@ from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api.types import OneOrMany
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(
|
||||
@@ -38,30 +40,29 @@ def suppress_logging(
|
||||
|
||||
|
||||
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
"""Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
collection: Optional[chromadb.Collection] = None
|
||||
collection_name: Optional[str] = "knowledge"
|
||||
app: Optional[ClientAPI] = None
|
||||
collection: chromadb.Collection | None = None
|
||||
collection_name: str | None = "knowledge"
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Optional[Dict[str, Any]] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
):
|
||||
embedder: dict[str, Any] | None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._set_embedder_config(embedder)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: List[str],
|
||||
query: list[str],
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
with suppress_logging():
|
||||
if self.collection:
|
||||
fetched = self.collection.query(
|
||||
@@ -80,10 +81,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
if result["score"] >= score_threshold:
|
||||
results.append(result)
|
||||
return results
|
||||
else:
|
||||
raise Exception("Collection not initialized")
|
||||
msg = "Collection not initialized"
|
||||
raise Exception(msg)
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
def initialize_knowledge_storage(self) -> None:
|
||||
base_path = os.path.join(db_storage_path(), "knowledge")
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=base_path,
|
||||
@@ -104,11 +105,13 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
embedding_function=self.embedder,
|
||||
)
|
||||
else:
|
||||
raise Exception("Vector Database Client not initialized")
|
||||
msg = "Vector Database Client not initialized"
|
||||
raise Exception(msg)
|
||||
except Exception:
|
||||
raise Exception("Failed to create or get collection")
|
||||
msg = "Failed to create or get collection"
|
||||
raise Exception(msg)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||
if not self.app:
|
||||
self.app = chromadb.PersistentClient(
|
||||
@@ -123,11 +126,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def save(
|
||||
self,
|
||||
documents: List[str],
|
||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||
):
|
||||
documents: list[str],
|
||||
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
if not self.collection:
|
||||
raise Exception("Collection not initialized")
|
||||
msg = "Collection not initialized"
|
||||
raise Exception(msg)
|
||||
|
||||
try:
|
||||
# Create a dictionary to store unique documents
|
||||
@@ -156,7 +160,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
filtered_ids.append(doc_id)
|
||||
|
||||
# If we have no metadata at all, set it to None
|
||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||
final_metadata: OneOrMany[chromadb.Metadata] | None = (
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
||||
)
|
||||
|
||||
@@ -171,10 +175,13 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
msg = (
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
)
|
||||
raise ValueError(
|
||||
msg,
|
||||
) from e
|
||||
except Exception as e:
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
@@ -186,15 +193,16 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
|
||||
)
|
||||
|
||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||
def _set_embedder_config(self, embedder: dict[str, Any] | None = None) -> None:
|
||||
"""Set the embedding configuration for the knowledge storage.
|
||||
|
||||
Args:
|
||||
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||
If None or empty, defaults to the default embedding function.
|
||||
|
||||
"""
|
||||
self.embedder = (
|
||||
EmbeddingConfigurator().configure_embedder(embedder)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_knowledge_context(knowledge_snippets: List[Dict[str, Any]]) -> str:
|
||||
def extract_knowledge_context(knowledge_snippets: list[dict[str, Any]]) -> str:
|
||||
"""Extract knowledge from the task prompt."""
|
||||
valid_snippets = [
|
||||
result["context"]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -35,7 +35,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
show_agent_logs,
|
||||
)
|
||||
from crewai.utilities.converter import convert_to_model, generate_model_description
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.events.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
@@ -60,15 +60,15 @@ class LiteAgentOutput(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
raw: str = Field(description="Raw output of the agent", default="")
|
||||
pydantic: Optional[BaseModel] = Field(
|
||||
description="Pydantic output of the agent", default=None
|
||||
pydantic: BaseModel | None = Field(
|
||||
description="Pydantic output of the agent", default=None,
|
||||
)
|
||||
agent_role: str = Field(description="Role of the agent that produced this output")
|
||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
||||
description="Token usage metrics for this execution", default=None
|
||||
usage_metrics: dict[str, Any] | None = Field(
|
||||
description="Token usage metrics for this execution", default=None,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert pydantic_output to a dictionary."""
|
||||
if self.pydantic:
|
||||
return self.pydantic.model_dump()
|
||||
@@ -82,8 +82,7 @@ class LiteAgentOutput(BaseModel):
|
||||
|
||||
|
||||
class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""
|
||||
A lightweight agent that can process messages and use tools.
|
||||
"""A lightweight agent that can process messages and use tools.
|
||||
|
||||
This agent is simpler than the full Agent class, focusing on direct execution
|
||||
rather than task delegation. It's designed to be used for simple interactions
|
||||
@@ -99,6 +98,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
max_iterations: Maximum number of iterations for tool usage.
|
||||
max_execution_time: Maximum execution time in seconds.
|
||||
response_format: Optional Pydantic model for structured output.
|
||||
|
||||
"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
@@ -107,19 +107,19 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
llm: str | InstanceOf[LLM] | Any | None = Field(
|
||||
default=None, description="Language model that will run the agent",
|
||||
)
|
||||
tools: List[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agent's disposal"
|
||||
tools: list[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agent's disposal",
|
||||
)
|
||||
|
||||
# Execution Control Properties
|
||||
max_iterations: int = Field(
|
||||
default=15, description="Maximum number of iterations for tool usage"
|
||||
default=15, description="Maximum number of iterations for tool usage",
|
||||
)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
default=None, description="Maximum execution time in seconds"
|
||||
max_execution_time: int | None = Field(
|
||||
default=None, description="Maximum execution time in seconds",
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
default=True,
|
||||
@@ -129,38 +129,38 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=True,
|
||||
description="Whether to use stop words to prevent the LLM from using tools",
|
||||
)
|
||||
request_within_rpm_limit: Optional[Callable[[], bool]] = Field(
|
||||
request_within_rpm_limit: Callable[[], bool] | None = Field(
|
||||
default=None,
|
||||
description="Callback to check if the request is within the RPM limit",
|
||||
)
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
|
||||
# Output and Formatting Properties
|
||||
response_format: Optional[Type[BaseModel]] = Field(
|
||||
default=None, description="Pydantic model for structured output"
|
||||
response_format: type[BaseModel] | None = Field(
|
||||
default=None, description="Pydantic model for structured output",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Whether to print execution details"
|
||||
default=False, description="Whether to print execution details",
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
callbacks: list[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent",
|
||||
)
|
||||
|
||||
# State and Results
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent.",
|
||||
)
|
||||
|
||||
# Reference of Agent
|
||||
original_agent: Optional[BaseAgent] = Field(
|
||||
default=None, description="Reference to the agent that created this LiteAgent"
|
||||
original_agent: BaseAgent | None = Field(
|
||||
default=None, description="Reference to the agent that created this LiteAgent",
|
||||
)
|
||||
# Private Attributes
|
||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
@@ -169,7 +169,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
self.llm = create_llm(self.llm)
|
||||
if not isinstance(self.llm, LLM):
|
||||
raise ValueError("Unable to create LLM instance")
|
||||
msg = "Unable to create LLM instance"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Initialize callbacks
|
||||
token_callback = TokenCalcHandler(token_cost_process=self._token_process)
|
||||
@@ -194,9 +195,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages.
|
||||
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||
"""Execute the agent with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Either a string query or a list of message dictionaries.
|
||||
@@ -205,6 +205,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
|
||||
"""
|
||||
# Create agent info for event emission
|
||||
agent_info = {
|
||||
@@ -235,18 +236,18 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
agent_finish = self._invoke_loop()
|
||||
formatted_result: Optional[BaseModel] = None
|
||||
formatted_result: BaseModel | None = None
|
||||
if self.response_format:
|
||||
try:
|
||||
# Cast to BaseModel to ensure type safety
|
||||
result = self.response_format.model_validate_json(
|
||||
agent_finish.output
|
||||
agent_finish.output,
|
||||
)
|
||||
if isinstance(result, BaseModel):
|
||||
formatted_result = result
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
content=f"Failed to parse output into response format: {str(e)}",
|
||||
content=f"Failed to parse output into response format: {e!s}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
@@ -286,13 +287,12 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
self, messages: str | list[dict[str, str]],
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages.
|
||||
"""Execute the agent asynchronously with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Either a string query or a list of message dictionaries.
|
||||
@@ -301,6 +301,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
|
||||
"""
|
||||
return await asyncio.to_thread(self.kickoff, messages)
|
||||
|
||||
@@ -319,7 +320,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
else:
|
||||
# Use the prompt template for agents without tools
|
||||
base_prompt = self.i18n.slice(
|
||||
"lite_agent_system_prompt_without_tools"
|
||||
"lite_agent_system_prompt_without_tools",
|
||||
).format(
|
||||
role=self.role,
|
||||
backstory=self.backstory,
|
||||
@@ -330,14 +331,14 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
if self.response_format:
|
||||
schema = generate_model_description(self.response_format)
|
||||
base_prompt += self.i18n.slice("lite_agent_response_format").format(
|
||||
response_format=schema
|
||||
response_format=schema,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: str | list[dict[str, str]],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages for the LLM."""
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -353,11 +354,11 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return formatted_messages
|
||||
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""
|
||||
Run the agent's thought process until it reaches a conclusion or max iterations.
|
||||
"""Run the agent's thought process until it reaches a conclusion or max iterations.
|
||||
|
||||
Returns:
|
||||
AgentFinish: The final result of the agent execution.
|
||||
|
||||
"""
|
||||
# Execute the agent loop
|
||||
formatted_answer = None
|
||||
@@ -369,7 +370,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
printer=self._printer,
|
||||
i18n=self.i18n,
|
||||
messages=self._messages,
|
||||
llm=cast(LLM, self.llm),
|
||||
llm=cast("LLM", self.llm),
|
||||
callbacks=self._callbacks,
|
||||
)
|
||||
|
||||
@@ -387,7 +388,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
try:
|
||||
answer = get_llm_response(
|
||||
llm=cast(LLM, self.llm),
|
||||
llm=cast("LLM", self.llm),
|
||||
messages=self._messages,
|
||||
callbacks=self._callbacks,
|
||||
printer=self._printer,
|
||||
@@ -407,7 +408,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
raise e
|
||||
raise
|
||||
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
|
||||
@@ -421,8 +422,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
agent_role=self.role,
|
||||
agent=self.original_agent,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
formatted_answer = handle_agent_action_core(
|
||||
formatted_answer=formatted_answer,
|
||||
@@ -443,20 +444,19 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
raise
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self._messages,
|
||||
llm=cast(LLM, self.llm),
|
||||
llm=cast("LLM", self.llm),
|
||||
callbacks=self._callbacks,
|
||||
i18n=self.i18n,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
self._iterations += 1
|
||||
@@ -465,7 +465,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish) -> None:
|
||||
"""Show logs for the agent's execution."""
|
||||
show_agent_logs(
|
||||
printer=self._printer,
|
||||
|
||||
@@ -5,17 +5,11 @@ import sys
|
||||
import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, redirect_stderr, redirect_stdout
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -43,9 +37,6 @@ with warnings.catch_warnings():
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
|
||||
import io
|
||||
from typing import TextIO
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -55,17 +46,12 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class FilteredStream(io.TextIOBase):
|
||||
_lock = None
|
||||
|
||||
def __init__(self, original_stream: TextIO):
|
||||
class FilteredStream:
|
||||
def __init__(self, original_stream) -> None:
|
||||
self._original_stream = original_stream
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def write(self, s: str) -> int:
|
||||
if not self._lock:
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def write(self, s) -> int:
|
||||
with self._lock:
|
||||
# Filter out extraneous messages from LiteLLM
|
||||
if (
|
||||
@@ -216,26 +202,30 @@ def suppress_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="open_text is deprecated*", category=DeprecationWarning
|
||||
"ignore", message="open_text is deprecated*", category=DeprecationWarning,
|
||||
)
|
||||
|
||||
# Redirect stdout and stderr
|
||||
with (
|
||||
redirect_stdout(FilteredStream(sys.stdout)),
|
||||
redirect_stderr(FilteredStream(sys.stderr)),
|
||||
):
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
sys.stdout = FilteredStream(old_stdout)
|
||||
sys.stderr = FilteredStream(old_stderr)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class Delta(TypedDict):
|
||||
content: Optional[str]
|
||||
role: Optional[str]
|
||||
content: str | None
|
||||
role: str | None
|
||||
|
||||
|
||||
class StreamingChoices(TypedDict):
|
||||
delta: Delta
|
||||
index: int
|
||||
finish_reason: Optional[str]
|
||||
finish_reason: str | None
|
||||
|
||||
|
||||
class FunctionArgs(BaseModel):
|
||||
@@ -251,29 +241,31 @@ class LLM(BaseLLM):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
timeout: float | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
if callbacks is None:
|
||||
callbacks = []
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -303,7 +295,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -320,15 +312,16 @@ class LLM(BaseLLM):
|
||||
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
|
||||
"""
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the completion call.
|
||||
|
||||
Args:
|
||||
@@ -339,6 +332,7 @@ class LLM(BaseLLM):
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Parameters for the completion call
|
||||
|
||||
"""
|
||||
# --- 1) Format messages according to provider requirements
|
||||
if isinstance(messages, str):
|
||||
@@ -377,9 +371,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
@@ -393,6 +387,7 @@ class LLM(BaseLLM):
|
||||
|
||||
Raises:
|
||||
Exception: If no content is received from the streaming response
|
||||
|
||||
"""
|
||||
# --- 1) Initialize response tracking
|
||||
full_response = ""
|
||||
@@ -401,8 +396,8 @@ class LLM(BaseLLM):
|
||||
usage_info = None
|
||||
tool_calls = None
|
||||
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs,
|
||||
)
|
||||
|
||||
# --- 2) Make sure stream is set to True and include usage metrics
|
||||
@@ -426,16 +421,16 @@ class LLM(BaseLLM):
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "choices"), type):
|
||||
choices = getattr(chunk, "choices")
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
# Try to extract usage information if available
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "usage"), type):
|
||||
usage_info = getattr(chunk, "usage")
|
||||
if not isinstance(chunk.usage, type):
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -445,7 +440,7 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
delta = getattr(choice, "delta")
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content from delta
|
||||
if delta:
|
||||
@@ -455,7 +450,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = getattr(delta, "content")
|
||||
chunk_content = delta.content
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
@@ -493,21 +488,21 @@ class LLM(BaseLLM):
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
if not full_response.strip() and chunk_count == 0:
|
||||
logging.warning(
|
||||
"No chunks received in streaming response, falling back to non-streaming"
|
||||
"No chunks received in streaming response, falling back to non-streaming",
|
||||
)
|
||||
non_streaming_params = params.copy()
|
||||
non_streaming_params["stream"] = False
|
||||
non_streaming_params.pop(
|
||||
"stream_options", None
|
||||
"stream_options", None,
|
||||
) # Remove stream_options for non-streaming call
|
||||
return self._handle_non_streaming_response(
|
||||
non_streaming_params, callbacks, available_functions
|
||||
non_streaming_params, callbacks, available_functions,
|
||||
)
|
||||
|
||||
# --- 5) Handle empty response with chunks
|
||||
if not full_response.strip() and chunk_count > 0:
|
||||
logging.warning(
|
||||
f"Received {chunk_count} chunks but no content was extracted"
|
||||
f"Received {chunk_count} chunks but no content was extracted",
|
||||
)
|
||||
if last_chunk is not None:
|
||||
try:
|
||||
@@ -516,8 +511,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -527,30 +522,31 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
content = getattr(message, "content")
|
||||
content = message.content
|
||||
|
||||
if content:
|
||||
full_response = content
|
||||
logging.info(
|
||||
f"Extracted content from last chunk message: {full_response}"
|
||||
f"Extracted content from last chunk message: {full_response}",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting content from last chunk: {e}")
|
||||
logging.debug(
|
||||
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}"
|
||||
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}",
|
||||
)
|
||||
|
||||
# --- 6) If still empty, raise an error instead of using a default response
|
||||
if not full_response.strip() and len(accumulated_tool_args) == 0:
|
||||
msg = "No content received from streaming response. Received empty chunks or failed to extract content."
|
||||
raise Exception(
|
||||
"No content received from streaming response. Received empty chunks or failed to extract content."
|
||||
msg,
|
||||
)
|
||||
|
||||
# --- 7) Check for tool calls in the final response
|
||||
@@ -561,8 +557,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -571,13 +567,13 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
tool_calls = message.tool_calls
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
@@ -607,9 +603,9 @@ class LLM(BaseLLM):
|
||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
logging.exception(f"Error in streaming response: {e!s}")
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
logging.warning(f"Returning partial response despite error: {e!s}")
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
return full_response
|
||||
|
||||
@@ -619,13 +615,14 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
msg = f"Failed to get streaming response: {e!s}"
|
||||
raise Exception(msg)
|
||||
|
||||
def _handle_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
tool_calls: list[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> None | str:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -664,9 +661,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_callbacks(
|
||||
self,
|
||||
callbacks: Optional[List[Any]],
|
||||
usage_info: Optional[Dict[str, Any]],
|
||||
last_chunk: Optional[Any],
|
||||
callbacks: list[Any] | None,
|
||||
usage_info: dict[str, Any] | None,
|
||||
last_chunk: Any | None,
|
||||
) -> None:
|
||||
"""Handle callbacks with usage info for streaming responses.
|
||||
|
||||
@@ -674,6 +671,7 @@ class LLM(BaseLLM):
|
||||
callbacks: Optional list of callback functions
|
||||
usage_info: Usage information collected during streaming
|
||||
last_chunk: The last chunk received from the streaming response
|
||||
|
||||
"""
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
@@ -690,9 +688,9 @@ class LLM(BaseLLM):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(
|
||||
getattr(last_chunk, "usage"), type
|
||||
last_chunk.usage, type,
|
||||
):
|
||||
usage_info = getattr(last_chunk, "usage")
|
||||
usage_info = last_chunk.usage
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
@@ -706,9 +704,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_non_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
@@ -719,6 +717,7 @@ class LLM(BaseLLM):
|
||||
|
||||
Returns:
|
||||
str: The response text
|
||||
|
||||
"""
|
||||
# --- 1) Make the completion call
|
||||
try:
|
||||
@@ -733,7 +732,7 @@ class LLM(BaseLLM):
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
|
||||
# --- 2) Extract response message and content
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
response_message = cast("Choices", cast("ModelResponse", response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
@@ -770,9 +769,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
tool_calls: list[Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
Args:
|
||||
@@ -781,6 +780,7 @@ class LLM(BaseLLM):
|
||||
|
||||
Returns:
|
||||
Optional[str]: The result of the tool call, or None if no tool call was made
|
||||
|
||||
"""
|
||||
# --- 1) Validate tool calls and available functions
|
||||
if not tool_calls or not available_functions:
|
||||
@@ -807,23 +807,23 @@ class LLM(BaseLLM):
|
||||
except Exception as e:
|
||||
# --- 3.4) Handle execution errors
|
||||
fn = available_functions.get(
|
||||
function_name, lambda: None
|
||||
function_name, lambda: None,
|
||||
) # Ensure fn is always a callable
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
logging.exception(f"Error executing function '{function_name}': {e}")
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
)
|
||||
return None
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
"""High-level LLM call method.
|
||||
|
||||
Args:
|
||||
@@ -846,6 +846,7 @@ class LLM(BaseLLM):
|
||||
TypeError: If messages format is invalid
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededException: If input exceeds model's context limit
|
||||
|
||||
"""
|
||||
# --- 1) Emit call started event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
@@ -884,12 +885,11 @@ class LLM(BaseLLM):
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
params, callbacks, available_functions,
|
||||
)
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions,
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
@@ -902,15 +902,16 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
logging.exception(f"LiteLLM call failed: {e!s}")
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType):
|
||||
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType) -> None:
|
||||
"""Handle the events for the LLM call.
|
||||
|
||||
Args:
|
||||
response (str): The response from the LLM call.
|
||||
call_type (str): The type of call, either "tool_call" or "llm_call".
|
||||
|
||||
"""
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
@@ -919,8 +920,8 @@ class LLM(BaseLLM):
|
||||
)
|
||||
|
||||
def _format_messages_for_provider(
|
||||
self, messages: List[Dict[str, str]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: list[dict[str, str]],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages according to provider requirements.
|
||||
|
||||
Args:
|
||||
@@ -933,15 +934,18 @@ class LLM(BaseLLM):
|
||||
|
||||
Raises:
|
||||
TypeError: If messages is None or contains invalid message format.
|
||||
|
||||
"""
|
||||
if messages is None:
|
||||
raise TypeError("Messages cannot be None")
|
||||
msg = "Messages cannot be None"
|
||||
raise TypeError(msg)
|
||||
|
||||
# Validate message format first
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||
msg = "Invalid message format. Each message must be a dict with 'role' and 'content' keys"
|
||||
raise TypeError(
|
||||
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
|
||||
msg,
|
||||
)
|
||||
|
||||
# Handle O1 models specially
|
||||
@@ -951,7 +955,7 @@ class LLM(BaseLLM):
|
||||
# Convert system messages to assistant messages
|
||||
if msg["role"] == "system":
|
||||
formatted_messages.append(
|
||||
{"role": "assistant", "content": msg["content"]}
|
||||
{"role": "assistant", "content": msg["content"]},
|
||||
)
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
@@ -979,9 +983,8 @@ class LLM(BaseLLM):
|
||||
|
||||
return messages
|
||||
|
||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
||||
"""
|
||||
Derives the custom_llm_provider from the model string.
|
||||
def _get_custom_llm_provider(self) -> str | None:
|
||||
"""Derives the custom_llm_provider from the model string.
|
||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||
- If the model is "gemini/gemini-1.5-pro", returns "gemini".
|
||||
- If there is no '/', defaults to "openai".
|
||||
@@ -991,8 +994,7 @@ class LLM(BaseLLM):
|
||||
return None
|
||||
|
||||
def _validate_call_params(self) -> None:
|
||||
"""
|
||||
Validate parameters before making a call. Currently this only checks if
|
||||
"""Validate parameters before making a call. Currently this only checks if
|
||||
a response_format is provided and whether the model supports it.
|
||||
The custom_llm_provider is dynamically determined from the model:
|
||||
- E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter"
|
||||
@@ -1004,19 +1006,22 @@ class LLM(BaseLLM):
|
||||
model=self.model,
|
||||
custom_llm_provider=provider,
|
||||
):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"The model {self.model} does not support response_format for provider '{provider}'. "
|
||||
"Please remove response_format or use a supported model."
|
||||
)
|
||||
raise ValueError(
|
||||
msg,
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
provider = self._get_custom_llm_provider()
|
||||
return litellm.utils.supports_function_calling(
|
||||
self.model, custom_llm_provider=provider
|
||||
self.model, custom_llm_provider=provider,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
||||
logging.exception(f"Failed to check function calling support: {e!s}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -1024,16 +1029,16 @@ class LLM(BaseLLM):
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
logging.exception(f"Failed to get supported params: {e!s}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""
|
||||
Returns the context window size, using 75% of the maximum to avoid
|
||||
"""Returns the context window size, using 75% of the maximum to avoid
|
||||
cutting off messages mid-thread.
|
||||
|
||||
Raises:
|
||||
ValueError: If a model's context window size is outside valid bounds (1024-2097152)
|
||||
|
||||
"""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
@@ -1044,21 +1049,21 @@ class LLM(BaseLLM):
|
||||
# Validate all context window sizes
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < MIN_CONTEXT or value > MAX_CONTEXT:
|
||||
msg = f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
|
||||
msg,
|
||||
)
|
||||
|
||||
self.context_window_size = int(
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO,
|
||||
)
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if self.model.startswith(key):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
def set_callbacks(self, callbacks: list[Any]) -> None:
|
||||
"""Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
"""
|
||||
with suppress_warnings():
|
||||
@@ -1073,9 +1078,8 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def set_env_callbacks(self):
|
||||
"""
|
||||
Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
def set_env_callbacks(self) -> None:
|
||||
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
|
||||
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
||||
environment variables, which should contain comma-separated lists of callback names.
|
||||
@@ -1091,6 +1095,7 @@ class LLM(BaseLLM):
|
||||
|
||||
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
|
||||
`litellm.failure_callback` to ["langfuse"].
|
||||
|
||||
"""
|
||||
with suppress_warnings():
|
||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
@@ -17,17 +17,18 @@ class BaseLLM(ABC):
|
||||
Attributes:
|
||||
stop (list): A list of stop sequences that the LLM should use to stop generation.
|
||||
This is used by the CrewAgentExecutor and other components.
|
||||
|
||||
"""
|
||||
|
||||
model: str
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
temperature: float | None = None
|
||||
stop: list[str] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: Optional[float] = None,
|
||||
):
|
||||
temperature: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
|
||||
This constructor sets default values for attributes that are expected
|
||||
@@ -43,11 +44,11 @@ class BaseLLM(ABC):
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
@@ -70,14 +71,15 @@ class BaseLLM(ABC):
|
||||
ValueError: If the messages format is invalid.
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
bool: True if the LLM supports stop words, False otherwise.
|
||||
|
||||
"""
|
||||
return True # Default implementation assumes support for stop words
|
||||
|
||||
@@ -86,6 +88,7 @@ class BaseLLM(ABC):
|
||||
|
||||
Returns:
|
||||
int: The number of tokens/characters the model can handle.
|
||||
|
||||
"""
|
||||
# Default implementation - subclasses should override with model-specific values
|
||||
return 4096
|
||||
|
||||
20
src/crewai/llms/third_party/ai_suite.py
vendored
20
src/crewai/llms/third_party/ai_suite.py
vendored
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import aisuite as ai
|
||||
|
||||
@@ -6,17 +6,17 @@ from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
class AISuiteLLM(BaseLLM):
|
||||
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
|
||||
def __init__(self, model: str, temperature: float | None = None, **kwargs) -> None:
|
||||
super().__init__(model, temperature, **kwargs)
|
||||
self.client = ai.Client()
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
completion_params = self._prepare_completion_params(messages, tools)
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
|
||||
@@ -24,9 +24,9 @@ class AISuiteLLM(BaseLLM):
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory import (
|
||||
EntityMemory,
|
||||
@@ -12,13 +12,13 @@ from crewai.memory import (
|
||||
class ContextualMemory:
|
||||
def __init__(
|
||||
self,
|
||||
memory_config: Optional[Dict[str, Any]],
|
||||
memory_config: dict[str, Any] | None,
|
||||
stm: ShortTermMemory,
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
um: UserMemory,
|
||||
exm: ExternalMemory,
|
||||
):
|
||||
) -> None:
|
||||
if memory_config is not None:
|
||||
self.memory_provider = memory_config.get("provider")
|
||||
else:
|
||||
@@ -30,8 +30,7 @@ class ContextualMemory:
|
||||
self.exm = exm
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
"""Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
@@ -49,11 +48,9 @@ class ContextualMemory:
|
||||
return "\n".join(filter(None, context))
|
||||
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
"""Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
|
||||
if self.stm is None:
|
||||
return ""
|
||||
|
||||
@@ -62,16 +59,14 @@ class ContextualMemory:
|
||||
[
|
||||
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
|
||||
for result in stm_results
|
||||
]
|
||||
],
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
def _fetch_ltm_context(self, task) -> str | None:
|
||||
"""Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
@@ -90,8 +85,7 @@ class ContextualMemory:
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
"""Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
if self.em is None:
|
||||
@@ -102,19 +96,20 @@ class ContextualMemory:
|
||||
[
|
||||
f"- {result['memory'] if self.memory_provider == 'mem0' else result['context']}"
|
||||
for result in em_results
|
||||
] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
], # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
def _fetch_user_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches and formats relevant user information from User Memory.
|
||||
"""Fetches and formats relevant user information from User Memory.
|
||||
|
||||
Args:
|
||||
query (str): The search query to find relevant user memories.
|
||||
|
||||
Returns:
|
||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||
"""
|
||||
|
||||
"""
|
||||
if self.um is None:
|
||||
return ""
|
||||
|
||||
@@ -128,12 +123,14 @@ class ContextualMemory:
|
||||
return f"User memories/preferences:\n{formatted_memories}"
|
||||
|
||||
def _fetch_external_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches and formats relevant information from External Memory.
|
||||
"""Fetches and formats relevant information from External Memory.
|
||||
|
||||
Args:
|
||||
query (str): The search query to find relevant information.
|
||||
|
||||
Returns:
|
||||
str: Formatted information as bullet points, or an empty string if none found.
|
||||
|
||||
"""
|
||||
if self.exm is None:
|
||||
return ""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
@@ -8,15 +7,14 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
"""
|
||||
EntityMemory class for managing structured information about entities
|
||||
"""EntityMemory class for managing structured information about entities
|
||||
and their relationships using SQLite storage.
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
@@ -26,8 +24,9 @@ class EntityMemory(Memory):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
msg,
|
||||
)
|
||||
storage = Mem0Storage(type="entities", crew=crew)
|
||||
else:
|
||||
@@ -63,4 +62,5 @@ class EntityMemory(Memory):
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the entity memory: {e}")
|
||||
msg = f"An error occurred while resetting the entity memory: {e}"
|
||||
raise Exception(msg)
|
||||
|
||||
@@ -5,7 +5,7 @@ class EntityMemoryItem:
|
||||
type: str,
|
||||
description: str,
|
||||
relationships: str,
|
||||
):
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.description = description
|
||||
|
||||
23
src/crewai/memory/external/external_memory.py
vendored
23
src/crewai/memory/external/external_memory.py
vendored
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -9,41 +9,44 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ExternalMemory(Memory):
|
||||
def __init__(self, storage: Optional[Storage] = None, **data: Any):
|
||||
def __init__(self, storage: Storage | None = None, **data: Any) -> None:
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
@staticmethod
|
||||
def _configure_mem0(crew: Any, config: Dict[str, Any]) -> "Mem0Storage":
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> "Mem0Storage":
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> Dict[str, Any]:
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
return {
|
||||
"mem0": ExternalMemory._configure_mem0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(crew: Any, embedder_config: Optional[Dict[str, Any]]) -> Storage:
|
||||
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
msg = "embedder_config is required"
|
||||
raise ValueError(msg)
|
||||
|
||||
if "provider" not in embedder_config:
|
||||
raise ValueError("embedder_config must include a 'provider' key")
|
||||
msg = "embedder_config must include a 'provider' key"
|
||||
raise ValueError(msg)
|
||||
|
||||
provider = embedder_config["provider"]
|
||||
supported_storages = ExternalMemory.external_supported_storages()
|
||||
if provider not in supported_storages:
|
||||
raise ValueError(f"Provider {provider} not supported")
|
||||
msg = f"Provider {provider} not supported"
|
||||
raise ValueError(msg)
|
||||
|
||||
return supported_storages[provider](crew, embedder_config.get("config", {}))
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
"""Saves a value into the external storage."""
|
||||
item = ExternalMemoryItem(value=value, metadata=metadata, agent=agent)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ExternalMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
):
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
self.value = value
|
||||
self.metadata = metadata
|
||||
self.agent = agent
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -6,15 +6,14 @@ from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
class LongTermMemory(Memory):
|
||||
"""
|
||||
LongTermMemory class for managing cross runs data related to overall crew's
|
||||
"""LongTermMemory class for managing cross runs data related to overall crew's
|
||||
execution and performance.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
def __init__(self, storage=None, path=None) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
@@ -29,7 +28,7 @@ class LongTermMemory(Memory):
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
def search(self, task: str, latest_n: int = 3) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LongTermMemoryItem:
|
||||
@@ -8,9 +8,9 @@ class LongTermMemoryItem:
|
||||
task: str,
|
||||
expected_output: str,
|
||||
datetime: str,
|
||||
quality: Optional[Union[int, float]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
quality: float | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.quality = quality
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
"""Base class for memory, now supporting agent tags and generic metadata."""
|
||||
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
crew: Optional[Any] = None
|
||||
embedder_config: dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
|
||||
def __init__(self, storage: Any, **data: Any):
|
||||
def __init__(self, storage: Any, **data: Any) -> None:
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
if agent:
|
||||
@@ -33,9 +31,9 @@ class Memory(BaseModel):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
query=query, limit=limit, score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
def set_crew(self, crew: Any) -> "Memory":
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
@@ -8,17 +8,16 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
"""
|
||||
ShortTermMemory class for managing transient data related to immediate tasks
|
||||
"""ShortTermMemory class for managing transient data related to immediate tasks
|
||||
and interactions.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None) -> None:
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
@@ -28,8 +27,9 @@ class ShortTermMemory(Memory):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
msg,
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew)
|
||||
else:
|
||||
@@ -49,8 +49,8 @@ class ShortTermMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self._memory_provider == "mem0":
|
||||
@@ -65,13 +65,14 @@ class ShortTermMemory(Memory):
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
query=query, limit=limit, score_threshold=score_threshold,
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
msg = f"An error occurred while resetting the short-term memory: {e}"
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the short-term memory: {e}"
|
||||
msg,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ShortTermMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
data: Any,
|
||||
agent: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
agent: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
Base class for RAG-based Storage implementations.
|
||||
"""
|
||||
"""Base class for RAG-based Storage implementations."""
|
||||
|
||||
app: Any | None = None
|
||||
|
||||
@@ -13,9 +11,9 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
) -> None:
|
||||
self.type = type
|
||||
self.allow_reset = allow_reset
|
||||
self.embedder_config = embedder_config
|
||||
@@ -25,52 +23,44 @@ class BaseRAGStorage(ABC):
|
||||
def _initialize_agents(self) -> str:
|
||||
if self.crew:
|
||||
return "_".join(
|
||||
[self._sanitize_role(agent.role) for agent in self.crew.agents]
|
||||
[self._sanitize_role(agent.role) for agent in self.crew.agents],
|
||||
)
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value with metadata to the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the storage."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _generate_embedding(
|
||||
self, text: str, metadata: Optional[Dict[str, Any]] = None
|
||||
self, text: str, metadata: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Generate an embedding for the given text and metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_app(self):
|
||||
"""Initialize the vector db."""
|
||||
pass
|
||||
|
||||
def setup_config(self, config: Dict[str, Any]):
|
||||
def setup_config(self, config: dict[str, Any]) -> None:
|
||||
"""Setup the config of the storage."""
|
||||
pass
|
||||
|
||||
def initialize_client(self):
|
||||
"""Initialize the client of the storage. This should setup the app and the db collection"""
|
||||
pass
|
||||
def initialize_client(self) -> None:
|
||||
"""Initialize the client of the storage. This should setup the app and the db collection."""
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Abstract base class defining the storage interface"""
|
||||
"""Abstract base class defining the storage interface."""
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int, score_threshold: float
|
||||
) -> Dict[str, Any] | List[Any]:
|
||||
self, query: str, limit: int, score_threshold: float,
|
||||
) -> dict[str, Any] | list[Any]:
|
||||
return {}
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.task import Task
|
||||
from crewai.utilities import Printer
|
||||
@@ -14,12 +14,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for kickoff task outputs storage.
|
||||
"""
|
||||
"""An updated SQLite storage class for kickoff task outputs storage."""
|
||||
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None
|
||||
self, db_path: str | None = None,
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
@@ -37,6 +35,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -52,22 +51,22 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
was_replayed BOOLEAN,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.INIT_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def add(
|
||||
self,
|
||||
task: Task,
|
||||
output: Dict[str, Any],
|
||||
output: dict[str, Any],
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
inputs: Dict[str, Any] = {},
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Add a new task output record to the database.
|
||||
|
||||
@@ -80,7 +79,10 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If saving the task output fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
@@ -103,7 +105,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.SAVE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def update(
|
||||
@@ -123,6 +125,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -136,7 +139,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
values.append(
|
||||
json.dumps(value, cls=CrewJSONEncoder)
|
||||
if isinstance(value, dict)
|
||||
else value
|
||||
else value,
|
||||
)
|
||||
|
||||
query = f"UPDATE latest_kickoff_task_outputs SET {', '.join(fields)} WHERE task_index = ?" # nosec
|
||||
@@ -149,10 +152,10 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
logger.warning(f"No row found with task_index {task_index}. No update performed.")
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.UPDATE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def load(self) -> List[Dict[str, Any]]:
|
||||
def load(self) -> list[dict[str, Any]]:
|
||||
"""Load all task output records from the database.
|
||||
|
||||
Returns:
|
||||
@@ -162,6 +165,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -190,7 +194,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.LOAD_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def delete_all(self) -> None:
|
||||
@@ -201,6 +205,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
Raises:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -210,5 +215,5 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = DatabaseError.format_error(DatabaseError.DELETE_ERROR, e)
|
||||
logger.error(error_msg)
|
||||
logger.exception(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
"""An updated SQLite storage class for LTM data storage."""
|
||||
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None
|
||||
self, db_path: str | None = None,
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
@@ -24,10 +22,8 @@ class LTMSQLiteStorage:
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
def _initialize_db(self) -> None:
|
||||
"""Initializes the SQLite database and creates LTM table."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
@@ -40,7 +36,7 @@ class LTMSQLiteStorage:
|
||||
datetime TEXT,
|
||||
score REAL
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -53,9 +49,9 @@ class LTMSQLiteStorage:
|
||||
def save(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: Dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: Union[int, float],
|
||||
score: float,
|
||||
) -> None:
|
||||
"""Saves data to the LTM table with error handling."""
|
||||
try:
|
||||
@@ -76,8 +72,8 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
|
||||
def load(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
self, task_description: str, latest_n: int,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -125,4 +121,3 @@ class LTMSQLiteStorage:
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from mem0 import Memory, MemoryClient
|
||||
|
||||
@@ -7,17 +7,15 @@ from crewai.memory.storage.interface import Storage
|
||||
|
||||
|
||||
class Mem0Storage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embedding and searching across entities using Mem0.
|
||||
"""
|
||||
"""Extends Storage to handle embedding and searching across entities using Mem0."""
|
||||
|
||||
def __init__(self, type, crew=None, config=None):
|
||||
def __init__(self, type, crew=None, config=None) -> None:
|
||||
super().__init__()
|
||||
supported_types = ["user", "short_term", "long_term", "entities", "external"]
|
||||
if type not in supported_types:
|
||||
raise ValueError(
|
||||
f"Invalid type '{type}' for Mem0Storage. Must be one of: "
|
||||
+ ", ".join(supported_types)
|
||||
+ ", ".join(supported_types),
|
||||
)
|
||||
|
||||
self.memory_type = type
|
||||
@@ -29,7 +27,8 @@ class Mem0Storage(Storage):
|
||||
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
||||
user_id = self._get_user_id()
|
||||
if type == "user" and not user_id:
|
||||
raise ValueError("User ID is required for user memory type")
|
||||
msg = "User ID is required for user memory type"
|
||||
raise ValueError(msg)
|
||||
|
||||
# API key in memory config overrides the environment variable
|
||||
config = self._get_config()
|
||||
@@ -42,23 +41,20 @@ class Mem0Storage(Storage):
|
||||
if mem0_api_key:
|
||||
if mem0_org_id and mem0_project_id:
|
||||
self.memory = MemoryClient(
|
||||
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
|
||||
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id,
|
||||
)
|
||||
else:
|
||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||
elif mem0_local_config and len(mem0_local_config):
|
||||
self.memory = Memory.from_config(mem0_local_config)
|
||||
else:
|
||||
if mem0_local_config and len(mem0_local_config):
|
||||
self.memory = Memory.from_config(mem0_local_config)
|
||||
else:
|
||||
self.memory = Memory()
|
||||
self.memory = Memory()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
user_id = self._get_user_id()
|
||||
agent_name = self._get_agent_name()
|
||||
params = None
|
||||
@@ -97,7 +93,7 @@ class Mem0Storage(Storage):
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
params = {"query": query, "limit": limit, "output_format": "v1.1"}
|
||||
if user_id := self._get_user_id():
|
||||
params["user_id"] = user_id
|
||||
@@ -120,7 +116,7 @@ class Mem0Storage(Storage):
|
||||
# automatically when the crew is created.
|
||||
if isinstance(self.memory, Memory):
|
||||
del params["metadata"], params["output_format"]
|
||||
|
||||
|
||||
results = self.memory.search(**params)
|
||||
return [r for r in results["results"] if r["score"] >= score_threshold]
|
||||
|
||||
@@ -133,12 +129,11 @@ class Mem0Storage(Storage):
|
||||
|
||||
agents = self.crew.agents
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return agents
|
||||
return "_".join(agents)
|
||||
|
||||
def _get_config(self) -> Dict[str, Any]:
|
||||
def _get_config(self) -> dict[str, Any]:
|
||||
return self.config or getattr(self, "memory_config", {}).get("config", {}) or {}
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
if self.memory:
|
||||
self.memory.reset()
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api import ClientAPI
|
||||
|
||||
@@ -32,16 +32,15 @@ def suppress_logging(
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
"""Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
):
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
@@ -55,11 +54,11 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self):
|
||||
def _set_embedder_config(self) -> None:
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
def _initialize_app(self) -> None:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
@@ -73,48 +72,44 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
try:
|
||||
self.collection = self.app.get_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
name=self.type, embedding_function=self.embedder_config,
|
||||
)
|
||||
except Exception:
|
||||
self.collection = self.app.create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
name=self.type, embedding_function=self.embedder_config,
|
||||
)
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
Sanitizes agent roles to ensure valid directory names.
|
||||
"""
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||
|
||||
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
"""Ensures file name does not exceed max allowed by OS."""
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
||||
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters.",
|
||||
)
|
||||
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
||||
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
try:
|
||||
self._generate_embedding(value, metadata)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
logging.exception(f"Error during {self.type} save: {e!s}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
@@ -135,10 +130,10 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
logging.exception(f"Error during {self.type} search: {e!s}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
def _generate_embedding(self, text: str, metadata: dict[str, Any]) -> None: # type: ignore
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
|
||||
@@ -160,8 +155,9 @@ class RAGStorage(BaseRAGStorage):
|
||||
# Ignore this specific error
|
||||
pass
|
||||
else:
|
||||
msg = f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
msg,
|
||||
)
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
@@ -170,5 +166,5 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small",
|
||||
)
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.memory import Memory
|
||||
|
||||
|
||||
class UserMemory(Memory):
|
||||
"""
|
||||
UserMemory class for handling user memory storage and retrieval.
|
||||
"""UserMemory class for handling user memory storage and retrieval.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None):
|
||||
def __init__(self, crew=None) -> None:
|
||||
warnings.warn(
|
||||
"UserMemory is deprecated and will be removed in a future version. "
|
||||
"Please use ExternalMemory instead.",
|
||||
@@ -22,8 +21,9 @@ class UserMemory(Memory):
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
msg = "Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
msg,
|
||||
)
|
||||
storage = Mem0Storage(type="user", crew=crew)
|
||||
super().__init__(storage)
|
||||
@@ -31,8 +31,8 @@ class UserMemory(Memory):
|
||||
def save(
|
||||
self,
|
||||
value,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
) -> None:
|
||||
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
||||
data = f"Remember the details about the user: {value}"
|
||||
@@ -44,15 +44,15 @@ class UserMemory(Memory):
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
results = self.storage.search(
|
||||
return self.storage.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the user memory: {e}")
|
||||
msg = f"An error occurred while resetting the user memory: {e}"
|
||||
raise Exception(msg)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class UserMemoryItem:
|
||||
def __init__(self, data: Any, user: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
def __init__(self, data: Any, user: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
self.data = data
|
||||
self.user = user
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
|
||||
@@ -2,9 +2,7 @@ from enum import Enum
|
||||
|
||||
|
||||
class Process(str, Enum):
|
||||
"""
|
||||
Class representing the different processes that can be used to tackle tasks
|
||||
"""
|
||||
"""Class representing the different processes that can be used to tackle tasks."""
|
||||
|
||||
sequential = "sequential"
|
||||
hierarchical = "hierarchical"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.project.utils import memoize
|
||||
@@ -36,15 +36,13 @@ def task(func):
|
||||
def agent(func):
|
||||
"""Marks a method as a crew agent."""
|
||||
func.is_agent = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def llm(func):
|
||||
"""Marks a method as an LLM provider."""
|
||||
func.is_llm = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def output_json(cls):
|
||||
@@ -91,7 +89,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agents = self._original_agents.items()
|
||||
|
||||
# Instantiate tasks in order
|
||||
for task_name, task_method in tasks:
|
||||
for _task_name, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
@@ -100,7 +98,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agent_roles.add(agent_instance.role)
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for agent_name, agent_method in agents:
|
||||
for _agent_name, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
@@ -117,9 +115,9 @@ def crew(func) -> Callable[..., Crew]:
|
||||
|
||||
return wrapper
|
||||
|
||||
for _, callback in self._before_kickoff.items():
|
||||
for callback in self._before_kickoff.values():
|
||||
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
for _, callback in self._after_kickoff.items():
|
||||
for callback in self._after_kickoff.values():
|
||||
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
|
||||
return crew
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, TypeVar, cast
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
@@ -23,11 +24,11 @@ def CrewBase(cls: T) -> T:
|
||||
base_directory = Path(inspect.getfile(cls)).parent
|
||||
|
||||
original_agents_config_path = getattr(
|
||||
cls, "agents_config", "config/agents.yaml"
|
||||
cls, "agents_config", "config/agents.yaml",
|
||||
)
|
||||
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.load_configurations()
|
||||
self.map_all_agent_variables()
|
||||
@@ -49,22 +50,22 @@ def CrewBase(cls: T) -> T:
|
||||
}
|
||||
# Store specific function types
|
||||
self._original_tasks = self._filter_functions(
|
||||
self._original_functions, "is_task"
|
||||
self._original_functions, "is_task",
|
||||
)
|
||||
self._original_agents = self._filter_functions(
|
||||
self._original_functions, "is_agent"
|
||||
self._original_functions, "is_agent",
|
||||
)
|
||||
self._before_kickoff = self._filter_functions(
|
||||
self._original_functions, "is_before_kickoff"
|
||||
self._original_functions, "is_before_kickoff",
|
||||
)
|
||||
self._after_kickoff = self._filter_functions(
|
||||
self._original_functions, "is_after_kickoff"
|
||||
self._original_functions, "is_after_kickoff",
|
||||
)
|
||||
self._kickoff = self._filter_functions(
|
||||
self._original_functions, "is_kickoff"
|
||||
self._original_functions, "is_kickoff",
|
||||
)
|
||||
|
||||
def load_configurations(self):
|
||||
def load_configurations(self) -> None:
|
||||
"""Load agent and task configurations from YAML files."""
|
||||
if isinstance(self.original_agents_config_path, str):
|
||||
agents_config_path = (
|
||||
@@ -75,12 +76,12 @@ def CrewBase(cls: T) -> T:
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Agent config file not found at {agents_config_path}. "
|
||||
"Proceeding with empty agent configurations."
|
||||
"Proceeding with empty agent configurations.",
|
||||
)
|
||||
self.agents_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No agent configuration path provided. Proceeding with empty agent configurations."
|
||||
"No agent configuration path provided. Proceeding with empty agent configurations.",
|
||||
)
|
||||
self.agents_config = {}
|
||||
|
||||
@@ -93,22 +94,21 @@ def CrewBase(cls: T) -> T:
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Task config file not found at {tasks_config_path}. "
|
||||
"Proceeding with empty task configurations."
|
||||
"Proceeding with empty task configurations.",
|
||||
)
|
||||
self.tasks_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No task configuration path provided. Proceeding with empty task configurations."
|
||||
"No task configuration path provided. Proceeding with empty task configurations.",
|
||||
)
|
||||
self.tasks_config = {}
|
||||
|
||||
@staticmethod
|
||||
def load_yaml(config_path: Path):
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
with open(config_path, encoding="utf-8") as file:
|
||||
return yaml.safe_load(file)
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {config_path}")
|
||||
raise
|
||||
|
||||
def _get_all_functions(self):
|
||||
@@ -119,8 +119,8 @@ def CrewBase(cls: T) -> T:
|
||||
}
|
||||
|
||||
def _filter_functions(
|
||||
self, functions: Dict[str, Callable], attribute: str
|
||||
) -> Dict[str, Callable]:
|
||||
self, functions: dict[str, Callable], attribute: str,
|
||||
) -> dict[str, Callable]:
|
||||
return {
|
||||
name: func
|
||||
for name, func in functions.items()
|
||||
@@ -132,7 +132,7 @@ def CrewBase(cls: T) -> T:
|
||||
llms = self._filter_functions(all_functions, "is_llm")
|
||||
tool_functions = self._filter_functions(all_functions, "is_tool")
|
||||
cache_handler_functions = self._filter_functions(
|
||||
all_functions, "is_cache_handler"
|
||||
all_functions, "is_cache_handler",
|
||||
)
|
||||
callbacks = self._filter_functions(all_functions, "is_callback")
|
||||
|
||||
@@ -149,11 +149,11 @@ def CrewBase(cls: T) -> T:
|
||||
def _map_agent_variables(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_info: Dict[str, Any],
|
||||
llms: Dict[str, Callable],
|
||||
tool_functions: Dict[str, Callable],
|
||||
cache_handler_functions: Dict[str, Callable],
|
||||
callbacks: Dict[str, Callable],
|
||||
agent_info: dict[str, Any],
|
||||
llms: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
cache_handler_functions: dict[str, Callable],
|
||||
callbacks: dict[str, Callable],
|
||||
) -> None:
|
||||
if llm := agent_info.get("llm"):
|
||||
try:
|
||||
@@ -187,12 +187,12 @@ def CrewBase(cls: T) -> T:
|
||||
agents = self._filter_functions(all_functions, "is_agent")
|
||||
tasks = self._filter_functions(all_functions, "is_task")
|
||||
output_json_functions = self._filter_functions(
|
||||
all_functions, "is_output_json"
|
||||
all_functions, "is_output_json",
|
||||
)
|
||||
tool_functions = self._filter_functions(all_functions, "is_tool")
|
||||
callback_functions = self._filter_functions(all_functions, "is_callback")
|
||||
output_pydantic_functions = self._filter_functions(
|
||||
all_functions, "is_output_pydantic"
|
||||
all_functions, "is_output_pydantic",
|
||||
)
|
||||
|
||||
for task_name, task_info in self.tasks_config.items():
|
||||
@@ -210,13 +210,13 @@ def CrewBase(cls: T) -> T:
|
||||
def _map_task_variables(
|
||||
self,
|
||||
task_name: str,
|
||||
task_info: Dict[str, Any],
|
||||
agents: Dict[str, Callable],
|
||||
tasks: Dict[str, Callable],
|
||||
output_json_functions: Dict[str, Callable],
|
||||
tool_functions: Dict[str, Callable],
|
||||
callback_functions: Dict[str, Callable],
|
||||
output_pydantic_functions: Dict[str, Callable],
|
||||
task_info: dict[str, Any],
|
||||
agents: dict[str, Callable],
|
||||
tasks: dict[str, Callable],
|
||||
output_json_functions: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
callback_functions: dict[str, Callable],
|
||||
output_pydantic_functions: dict[str, Callable],
|
||||
) -> None:
|
||||
if context_list := task_info.get("context"):
|
||||
self.tasks_config[task_name]["context"] = [
|
||||
@@ -253,4 +253,4 @@ def CrewBase(cls: T) -> T:
|
||||
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
|
||||
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
|
||||
|
||||
return cast(T, WrappedClass)
|
||||
return cast("T", WrappedClass)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user