mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-02 04:38:29 +00:00
Compare commits
7 Commits
brandon/fi
...
brandon/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ae073c681 | ||
|
|
e529766391 | ||
|
|
9a34d9c01e | ||
|
|
a7f5d574dc | ||
|
|
0cc02d9492 | ||
|
|
fa26f6ebae | ||
|
|
f6c2982619 |
@@ -23,14 +23,14 @@ A crew in crewAI represents a collaborative group of agents working together to
|
||||
| **Language** _(optional)_ | `language` | Language used for the crew, defaults to English. |
|
||||
| **Language File** _(optional)_ | `language_file` | Path to the language file to be used for the crew. |
|
||||
| **Memory** _(optional)_ | `memory` | Utilized for storing execution memories (short-term, long-term, entity memory). |
|
||||
| **Memory Config** _(optional)_ | `memory_config` | Configuration for the memory provider to be used by the crew. |
|
||||
| **Cache** _(optional)_ | `cache` | Specifies whether to use a cache for storing the results of tools' execution. Defaults to `True`. |
|
||||
| **Embedder** _(optional)_ | `embedder` | Configuration for the embedder to be used by the crew. Mostly used by memory for now. Default is `{"provider": "openai"}`. |
|
||||
| **Full Output** _(optional)_ | `full_output` | Whether the crew should return the full output with all tasks outputs or just the final output. Defaults to `False`. |
|
||||
| **Memory Config** _(optional)_ | `memory_config` | Configuration for the memory provider to be used by the crew. |
|
||||
| **Cache** _(optional)_ | `cache` | Specifies whether to use a cache for storing the results of tools' execution. Defaults to `True`. |
|
||||
| **Embedder** _(optional)_ | `embedder` | Configuration for the embedder to be used by the crew. Mostly used by memory for now. Default is `{"provider": "openai"}`. |
|
||||
| **Full Output** _(optional)_ | `full_output` | Whether the crew should return the full output with all tasks outputs or just the final output. Defaults to `False`. |
|
||||
| **Step Callback** _(optional)_ | `step_callback` | A function that is called after each step of every agent. This can be used to log the agent's actions or to perform other operations; it won't override the agent-specific `step_callback`. |
|
||||
| **Task Callback** _(optional)_ | `task_callback` | A function that is called after the completion of each task. Useful for monitoring or additional operations post-task execution. |
|
||||
| **Share Crew** _(optional)_ | `share_crew` | Whether you want to share the complete crew information and execution with the crewAI team to make the library better, and allow us to train models. |
|
||||
| **Output Log File** _(optional)_ | `output_log_file` | Whether you want to have a file with the complete crew output and execution. You can set it using True and it will default to the folder you are currently in and it will be called logs.txt or passing a string with the full path and name of the file. |
|
||||
| **Output Log File** _(optional)_ | `output_log_file` | Set to True to save logs as logs.txt in the current directory or provide a file path. Logs will be in JSON format if the filename ends in .json, otherwise .txt. Defautls to `None`. |
|
||||
| **Manager Agent** _(optional)_ | `manager_agent` | `manager` sets a custom agent that will be used as a manager. |
|
||||
| **Prompt File** _(optional)_ | `prompt_file` | Path to the prompt JSON file to be used for the crew. |
|
||||
| **Planning** *(optional)* | `planning` | Adds planning ability to the Crew. When activated before each Crew iteration, all Crew data is sent to an AgentPlanner that will plan the tasks and this plan will be added to each task description. |
|
||||
@@ -240,6 +240,23 @@ print(f"Tasks Output: {crew_output.tasks_output}")
|
||||
print(f"Token Usage: {crew_output.token_usage}")
|
||||
```
|
||||
|
||||
## Accessing Crew Logs
|
||||
|
||||
You can see real time log of the crew execution, by setting `output_log_file` as a `True(Boolean)` or a `file_name(str)`. Supports logging of events as both `file_name.txt` and `file_name.json`.
|
||||
In case of `True(Boolean)` will save as `logs.txt`.
|
||||
|
||||
In case of `output_log_file` is set as `False(Booelan)` or `None`, the logs will not be populated.
|
||||
|
||||
```python Code
|
||||
# Save crew logs
|
||||
crew = Crew(output_log_file = True) # Logs will be saved as logs.txt
|
||||
crew = Crew(output_log_file = file_name) # Logs will be saved as file_name.txt
|
||||
crew = Crew(output_log_file = file_name.txt) # Logs will be saved as file_name.txt
|
||||
crew = Crew(output_log_file = file_name.json) # Logs will be saved as file_name.json
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Memory Utilization
|
||||
|
||||
Crews can utilize memory (short-term, long-term, and entity memory) to enhance their execution and learning over time. This feature allows crews to store and recall execution memories, aiding in decision-making and task execution strategies.
|
||||
|
||||
@@ -463,26 +463,32 @@ Learn how to get the most out of your LLM configuration:
|
||||
|
||||
<Accordion title="Google">
|
||||
```python Code
|
||||
# Option 1. Gemini accessed with an API key.
|
||||
# Option 1: Gemini accessed with an API key.
|
||||
# https://ai.google.dev/gemini-api/docs/api-key
|
||||
GEMINI_API_KEY=<your-api-key>
|
||||
|
||||
# Option 2. Vertex AI IAM credentials for Gemini, Anthropic, and anything in the Model Garden.
|
||||
# Option 2: Vertex AI IAM credentials for Gemini, Anthropic, and Model Garden.
|
||||
# https://cloud.google.com/vertex-ai/generative-ai/docs/overview
|
||||
```
|
||||
|
||||
## GET CREDENTIALS
|
||||
Get credentials:
|
||||
```python Code
|
||||
import json
|
||||
|
||||
file_path = 'path/to/vertex_ai_service_account.json'
|
||||
|
||||
# Load the JSON file
|
||||
with open(file_path, 'r') as file:
|
||||
vertex_credentials = json.load(file)
|
||||
|
||||
# Convert to JSON string
|
||||
# Convert the credentials to a JSON string
|
||||
vertex_credentials_json = json.dumps(vertex_credentials)
|
||||
```
|
||||
|
||||
Example usage:
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro-latest",
|
||||
temperature=0.7,
|
||||
|
||||
@@ -185,7 +185,12 @@ my_crew = Crew(
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model="text-embedding-3-small"),
|
||||
embedder={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": 'text-embedding-3-small'
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -242,13 +247,15 @@ my_crew = Crew(
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=OpenAIEmbeddingFunction(
|
||||
api_key="YOUR_API_KEY",
|
||||
api_base="YOUR_API_BASE_PATH",
|
||||
api_type="azure",
|
||||
api_version="YOUR_API_VERSION",
|
||||
model="text-embedding-3-small"
|
||||
)
|
||||
embedder={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "YOUR_API_KEY",
|
||||
"api_base": "YOUR_API_BASE_PATH",
|
||||
"api_version": "YOUR_API_VERSION",
|
||||
"model_name": 'text-embedding-3-small'
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -264,12 +271,15 @@ my_crew = Crew(
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=GoogleVertexEmbeddingFunction(
|
||||
project_id="YOUR_PROJECT_ID",
|
||||
region="YOUR_REGION",
|
||||
api_key="YOUR_API_KEY",
|
||||
model="textembedding-gecko"
|
||||
)
|
||||
embedder={
|
||||
"provider": "vertexai",
|
||||
"config": {
|
||||
"project_id"="YOUR_PROJECT_ID",
|
||||
"region"="YOUR_REGION",
|
||||
"api_key"="YOUR_API_KEY",
|
||||
"model_name"="textembedding-gecko"
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -358,6 +368,33 @@ my_crew = Crew(
|
||||
)
|
||||
```
|
||||
|
||||
### Adding Custom Embedding Function
|
||||
|
||||
```python Code
|
||||
from crewai import Crew, Agent, Task, Process
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
|
||||
# Create a custom embedding function
|
||||
class CustomEmbedder(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
# generate embeddings
|
||||
return [1, 2, 3] # this is a dummy embedding
|
||||
|
||||
my_crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder={
|
||||
"provider": "custom",
|
||||
"config": {
|
||||
"embedder": CustomEmbedder()
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Resetting Memory
|
||||
|
||||
```shell
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -55,7 +55,6 @@ class Agent(BaseAgent):
|
||||
llm: The language model that will run the agent.
|
||||
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
|
||||
max_iter: Maximum number of iterations for an agent to execute a task.
|
||||
memory: Whether the agent should have memory or not.
|
||||
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
|
||||
verbose: Whether the agent execution should be in verbose mode.
|
||||
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
|
||||
@@ -72,9 +71,6 @@ class Agent(BaseAgent):
|
||||
)
|
||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
cache_handler: InstanceOf[CacheHandler] = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
)
|
||||
step_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
@@ -108,10 +104,6 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Keep messages under the context window size by summarizing content.",
|
||||
)
|
||||
max_iter: int = Field(
|
||||
default=20,
|
||||
description="Maximum number of iterations for an agent to execute a task before giving it's best answer",
|
||||
)
|
||||
max_retry_limit: int = Field(
|
||||
default=2,
|
||||
description="Maximum number of retries for an agent to execute a task when an error occurs.",
|
||||
@@ -197,13 +189,15 @@ class Agent(BaseAgent):
|
||||
if task.output_json:
|
||||
# schema = json.dumps(task.output_json, indent=2)
|
||||
schema = generate_model_description(task.output_json)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema = generate_model_description(task.output_pydantic)
|
||||
|
||||
task_prompt += "\n" + self.i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
@@ -331,14 +325,14 @@ class Agent(BaseAgent):
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
|
||||
def get_multimodal_tools(self) -> List[Tool]:
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
|
||||
return [AddImageTool()]
|
||||
|
||||
def get_code_execution_tools(self):
|
||||
try:
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
from crewai_tools import CodeInterpreterTool # type: ignore
|
||||
|
||||
# Set the unsafe_mode based on the code_execution_mode attribute
|
||||
unsafe_mode = self.code_execution_mode == "unsafe"
|
||||
|
||||
@@ -24,6 +24,7 @@ from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.converter import Converter
|
||||
|
||||
T = TypeVar("T", bound="BaseAgent")
|
||||
|
||||
@@ -42,7 +43,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
max_rpm (Optional[int]): Maximum number of requests per minute for the agent execution.
|
||||
allow_delegation (bool): Allow delegation of tasks to agents.
|
||||
tools (Optional[List[Any]]): Tools at the agent's disposal.
|
||||
max_iter (Optional[int]): Maximum iterations for an agent to execute a task.
|
||||
max_iter (int): Maximum iterations for an agent to execute a task.
|
||||
agent_executor (InstanceOf): An instance of the CrewAgentExecutor class.
|
||||
llm (Any): Language model that will run the agent.
|
||||
crew (Any): Crew to which the agent belongs.
|
||||
@@ -114,7 +115,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
tools: Optional[List[Any]] = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
)
|
||||
max_iter: Optional[int] = Field(
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
)
|
||||
agent_executor: InstanceOf = Field(
|
||||
@@ -125,11 +126,12 @@ class BaseAgent(ABC, BaseModel):
|
||||
)
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
cache_handler: InstanceOf[CacheHandler] = Field(
|
||||
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
)
|
||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||
default=None, description="An instance of the ToolsHandler class."
|
||||
default_factory=ToolsHandler,
|
||||
description="An instance of the ToolsHandler class.",
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution."
|
||||
@@ -254,7 +256,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
@abstractmethod
|
||||
def get_output_converter(
|
||||
self, llm: Any, text: str, model: type[BaseModel] | None, instructions: str
|
||||
):
|
||||
) -> Converter:
|
||||
"""Get the converter class for the agent to create json/pydantic outputs."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import subprocess
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.utils import get_crew
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
@@ -30,30 +31,35 @@ def reset_memories_command(
|
||||
"""
|
||||
|
||||
try:
|
||||
crew = get_crew()
|
||||
if not crew:
|
||||
raise ValueError("No crew found.")
|
||||
if all:
|
||||
ShortTermMemory().reset()
|
||||
EntityMemory().reset()
|
||||
LongTermMemory().reset()
|
||||
TaskOutputStorageHandler().reset()
|
||||
KnowledgeStorage().reset()
|
||||
crew.reset_memories(command_type="all")
|
||||
click.echo("All memories have been reset.")
|
||||
else:
|
||||
if long:
|
||||
LongTermMemory().reset()
|
||||
click.echo("Long term memory has been reset.")
|
||||
return
|
||||
|
||||
if short:
|
||||
ShortTermMemory().reset()
|
||||
click.echo("Short term memory has been reset.")
|
||||
if entity:
|
||||
EntityMemory().reset()
|
||||
click.echo("Entity memory has been reset.")
|
||||
if kickoff_outputs:
|
||||
TaskOutputStorageHandler().reset()
|
||||
click.echo("Latest Kickoff outputs stored has been reset.")
|
||||
if knowledge:
|
||||
KnowledgeStorage().reset()
|
||||
click.echo("Knowledge has been reset.")
|
||||
if not any([long, short, entity, kickoff_outputs, knowledge]):
|
||||
click.echo(
|
||||
"No memory type specified. Please specify at least one type to reset."
|
||||
)
|
||||
return
|
||||
|
||||
if long:
|
||||
crew.reset_memories(command_type="long")
|
||||
click.echo("Long term memory has been reset.")
|
||||
if short:
|
||||
crew.reset_memories(command_type="short")
|
||||
click.echo("Short term memory has been reset.")
|
||||
if entity:
|
||||
crew.reset_memories(command_type="entity")
|
||||
click.echo("Entity memory has been reset.")
|
||||
if kickoff_outputs:
|
||||
crew.reset_memories(command_type="kickoff_outputs")
|
||||
click.echo("Latest Kickoff outputs stored has been reset.")
|
||||
if knowledge:
|
||||
crew.reset_memories(command_type="knowledge")
|
||||
click.echo("Knowledge has been reset.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
|
||||
@@ -9,6 +9,7 @@ import tomli
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.crew import Crew
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
@@ -247,3 +248,64 @@ def write_env_file(folder_path, env_vars):
|
||||
with open(env_file_path, "w") as file:
|
||||
for key, value in env_vars.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
|
||||
def get_crew(crew_path: str = "crew.py", require: bool = False) -> Crew | None:
|
||||
"""Get the crew instance from the crew.py file."""
|
||||
try:
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
if "crew.py" in files:
|
||||
crew_path = os.path.join(root, "crew.py")
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"crew_module", crew_path
|
||||
)
|
||||
if not spec or not spec.loader:
|
||||
continue
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
try:
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
try:
|
||||
if callable(attr) and hasattr(attr, "crew"):
|
||||
crew_instance = attr().crew()
|
||||
return crew_instance
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing attribute {attr_name}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as exec_error:
|
||||
print(f"Error executing module: {exec_error}")
|
||||
import traceback
|
||||
|
||||
print(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Error importing crew from {crew_path}: {str(e)}",
|
||||
style="bold red",
|
||||
)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
if require:
|
||||
console.print("No valid Crew instance found in crew.py", style="bold red")
|
||||
raise SystemExit
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
@@ -183,9 +184,9 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="Path to the prompt json file to be used for the crew.",
|
||||
)
|
||||
output_log_file: Optional[str] = Field(
|
||||
output_log_file: Optional[Union[bool, str]] = Field(
|
||||
default=None,
|
||||
description="output_log_file",
|
||||
description="Path to the log file to be saved",
|
||||
)
|
||||
planning: Optional[bool] = Field(
|
||||
default=False,
|
||||
@@ -439,6 +440,7 @@ class Crew(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
source = [agent.key for agent in self.agents] + [
|
||||
@@ -1147,3 +1149,80 @@ class Crew(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"
|
||||
|
||||
def reset_memories(self, command_type: str) -> None:
|
||||
"""Reset specific or all memories for the crew.
|
||||
|
||||
Args:
|
||||
command_type: Type of memory to reset.
|
||||
Valid options: 'long', 'short', 'entity', 'knowledge',
|
||||
'kickoff_outputs', or 'all'
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid command type is provided.
|
||||
RuntimeError: If memory reset operation fails.
|
||||
"""
|
||||
VALID_TYPES = frozenset(
|
||||
["long", "short", "entity", "knowledge", "kickoff_outputs", "all"]
|
||||
)
|
||||
|
||||
if command_type not in VALID_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid command type. Must be one of: {', '.join(sorted(VALID_TYPES))}"
|
||||
)
|
||||
|
||||
try:
|
||||
if command_type == "all":
|
||||
self._reset_all_memories()
|
||||
else:
|
||||
self._reset_specific_memory(command_type)
|
||||
|
||||
self._logger.log("info", f"{command_type} memory has been reset")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to reset {command_type} memory: {str(e)}"
|
||||
self._logger.log("error", error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _reset_all_memories(self) -> None:
|
||||
"""Reset all available memory systems."""
|
||||
memory_systems = [
|
||||
("short term", self._short_term_memory),
|
||||
("entity", self._entity_memory),
|
||||
("long term", self._long_term_memory),
|
||||
("task output", self._task_output_handler),
|
||||
("knowledge", self.knowledge),
|
||||
]
|
||||
|
||||
for name, system in memory_systems:
|
||||
if system is not None:
|
||||
try:
|
||||
system.reset()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
"""Reset a specific memory system.
|
||||
|
||||
Args:
|
||||
memory_type: Type of memory to reset
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the specified memory system fails to reset
|
||||
"""
|
||||
reset_functions = {
|
||||
"long": (self._long_term_memory, "long term"),
|
||||
"short": (self._short_term_memory, "short term"),
|
||||
"entity": (self._entity_memory, "entity"),
|
||||
"knowledge": (self.knowledge, "knowledge"),
|
||||
"kickoff_outputs": (self._task_output_handler, "task output"),
|
||||
}
|
||||
|
||||
memory_system, name = reset_functions[memory_type]
|
||||
if memory_system is None:
|
||||
raise RuntimeError(f"{name} memory system is not initialized")
|
||||
|
||||
try:
|
||||
memory_system.reset()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to reset {name} memory") from e
|
||||
|
||||
@@ -600,7 +600,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
```
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, '_state'):
|
||||
if not hasattr(self, "_state"):
|
||||
return ""
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
@@ -706,26 +706,31 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
inputs: Optional dictionary containing input values and potentially a state ID to restore
|
||||
"""
|
||||
# Handle state restoration if ID is provided in inputs
|
||||
if inputs and 'id' in inputs and self._persistence is not None:
|
||||
restore_uuid = inputs['id']
|
||||
if inputs and "id" in inputs and self._persistence is not None:
|
||||
restore_uuid = inputs["id"]
|
||||
stored_state = self._persistence.load_state(restore_uuid)
|
||||
|
||||
# Override the id in the state if it exists in inputs
|
||||
if 'id' in inputs:
|
||||
if "id" in inputs:
|
||||
if isinstance(self._state, dict):
|
||||
self._state['id'] = inputs['id']
|
||||
self._state["id"] = inputs["id"]
|
||||
elif isinstance(self._state, BaseModel):
|
||||
setattr(self._state, 'id', inputs['id'])
|
||||
setattr(self._state, "id", inputs["id"])
|
||||
|
||||
if stored_state:
|
||||
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow")
|
||||
self._log_flow_event(
|
||||
f"Loading flow state from memory for UUID: {restore_uuid}",
|
||||
color="yellow",
|
||||
)
|
||||
# Restore the state
|
||||
self._restore_state(stored_state)
|
||||
else:
|
||||
self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red")
|
||||
self._log_flow_event(
|
||||
f"No flow state found for UUID: {restore_uuid}", color="red"
|
||||
)
|
||||
|
||||
# Apply any additional inputs after restoration
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'}
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k != "id"}
|
||||
if filtered_inputs:
|
||||
self._initialize_state(filtered_inputs)
|
||||
|
||||
@@ -737,9 +742,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
flow_name=self.__class__.__name__,
|
||||
),
|
||||
)
|
||||
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="bold_magenta")
|
||||
self._log_flow_event(
|
||||
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
|
||||
)
|
||||
|
||||
if inputs is not None and 'id' not in inputs:
|
||||
if inputs is not None and "id" not in inputs:
|
||||
self._initialize_state(inputs)
|
||||
|
||||
return asyncio.run(self.kickoff_async())
|
||||
@@ -984,7 +991,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
|
||||
def _log_flow_event(
|
||||
self, message: str, color: str = "yellow", level: str = "info"
|
||||
) -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
This method provides a consistent interface for logging flow-related events,
|
||||
|
||||
@@ -67,3 +67,9 @@ class Knowledge(BaseModel):
|
||||
source.add()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.reset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
@@ -221,6 +221,13 @@ class LLM:
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# For O1 models, system messages are not supported.
|
||||
# Convert any system messages into assistant messages.
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
message["role"] = "assistant"
|
||||
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
@@ -10,13 +14,15 @@ class EntityMemory(Memory):
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
|
||||
if self.memory_provider == "mem0":
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
memory_provider = None
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
@@ -36,11 +42,13 @@ class EntityMemory(Memory):
|
||||
path=path,
|
||||
)
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
if self.memory_provider == "mem0":
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
|
||||
@@ -17,7 +17,7 @@ class LongTermMemory(Memory):
|
||||
def __init__(self, storage=None, path=None):
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage)
|
||||
super().__init__(storage=storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
metadata = item.metadata
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class Memory:
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: RAGStorage):
|
||||
self.storage = storage
|
||||
embedder_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
storage: Any
|
||||
|
||||
def __init__(self, storage: Any, **data: Any):
|
||||
super().__init__(storage=storage, **data)
|
||||
|
||||
def save(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
@@ -14,13 +16,15 @@ class ShortTermMemory(Memory):
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
|
||||
if self.memory_provider == "mem0":
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
memory_provider = None
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
except ImportError:
|
||||
@@ -39,7 +43,8 @@ class ShortTermMemory(Memory):
|
||||
path=path,
|
||||
)
|
||||
)
|
||||
super().__init__(storage)
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -48,7 +53,7 @@ class ShortTermMemory(Memory):
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self.memory_provider == "mem0":
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||
|
||||
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Any] = None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
|
||||
@@ -7,11 +7,11 @@ from crewai.utilities import I18N
|
||||
|
||||
i18n = I18N()
|
||||
|
||||
|
||||
class AddImageToolSchema(BaseModel):
|
||||
image_url: str = Field(..., description="The URL or path of the image to add")
|
||||
action: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional context or question about the image"
|
||||
default=None, description="Optional context or question about the image"
|
||||
)
|
||||
|
||||
|
||||
@@ -36,10 +36,7 @@ class AddImageTool(BaseTool):
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"final_answer_format": "If you don't need to use any more tools, you must give your best complete final answer, make sure it satisfies the expected criteria, use the EXACT format below:\n\n```\nThought: I now can give a great answer\nFinal Answer: my best complete final answer to the task.\n\n```",
|
||||
"format_without_tools": "\nSorry, I didn't use the right format. I MUST either use a tool (among the available ones), OR give my best final answer.\nHere is the expected format I must follow:\n\n```\nQuestion: the input question you must answer\nThought: you should always think about what to do\nAction: the action to take, should be one of [{tool_names}]\nAction Input: the input to the action\nObservation: the result of the action\n```\n This Thought/Action/Action Input/Result process can repeat N times. Once I know the final answer, I must return the following format:\n\n```\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described\n\n```",
|
||||
"task_with_context": "{task}\n\nThis is the context you're working with:\n{context}",
|
||||
"expected_output": "\nThis is the expect criteria for your final answer: {expected_output}\nyou MUST return the actual complete content as the final answer, not a summary.",
|
||||
"expected_output": "\nThis is the expected criteria for your final answer: {expected_output}\nyou MUST return the actual complete content as the final answer, not a summary.",
|
||||
"human_feedback": "You got human feedback on your work, re-evaluate it and give a new Final Answer when ready.\n {human_feedback}",
|
||||
"getting_input": "This is the agent's final answer: {final_answer}\n\n",
|
||||
"summarizer_system_message": "You are a helpful assistant that summarizes text.",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, cast
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
@@ -18,11 +18,12 @@ class EmbeddingConfigurator:
|
||||
"bedrock": self._configure_bedrock,
|
||||
"huggingface": self._configure_huggingface,
|
||||
"watson": self._configure_watson,
|
||||
"custom": self._configure_custom,
|
||||
}
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: Dict[str, Any] | None = None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
if embedder_config is None:
|
||||
@@ -30,20 +31,19 @@ class EmbeddingConfigurator:
|
||||
|
||||
provider = embedder_config.get("provider")
|
||||
config = embedder_config.get("config", {})
|
||||
model_name = config.get("model")
|
||||
|
||||
if isinstance(provider, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(provider)
|
||||
return provider
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
model_name = config.get("model") if provider != "custom" else None
|
||||
|
||||
if provider not in self.embedding_functions:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||
)
|
||||
return self.embedding_functions[provider](config, model_name)
|
||||
|
||||
embedding_function = self.embedding_functions[provider]
|
||||
return (
|
||||
embedding_function(config)
|
||||
if provider == "custom"
|
||||
else embedding_function(config, model_name)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
@@ -64,6 +64,13 @@ class EmbeddingConfigurator:
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model_name,
|
||||
api_base=config.get("api_base", None),
|
||||
api_type=config.get("api_type", None),
|
||||
api_version=config.get("api_version", None),
|
||||
default_headers=config.get("default_headers", None),
|
||||
dimensions=config.get("dimensions", None),
|
||||
deployment_id=config.get("deployment_id", None),
|
||||
organization_id=config.get("organization_id", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -78,6 +85,10 @@ class EmbeddingConfigurator:
|
||||
api_type=config.get("api_type", "azure"),
|
||||
api_version=config.get("api_version"),
|
||||
model_name=model_name,
|
||||
default_headers=config.get("default_headers"),
|
||||
dimensions=config.get("dimensions"),
|
||||
deployment_id=config.get("deployment_id"),
|
||||
organization_id=config.get("organization_id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -100,6 +111,8 @@ class EmbeddingConfigurator:
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -111,6 +124,7 @@ class EmbeddingConfigurator:
|
||||
return GoogleGenerativeAiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
task_type=config.get("task_type"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -195,3 +209,28 @@ class EmbeddingConfigurator:
|
||||
raise e
|
||||
|
||||
return WatsonEmbeddingFunction()
|
||||
|
||||
@staticmethod
|
||||
def _configure_custom(config):
|
||||
custom_embedder = config.get("embedder")
|
||||
if isinstance(custom_embedder, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(custom_embedder)
|
||||
return custom_embedder
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
elif callable(custom_embedder):
|
||||
try:
|
||||
instance = custom_embedder()
|
||||
if isinstance(instance, EmbeddingFunction):
|
||||
validate_embedding_function(instance)
|
||||
return instance
|
||||
raise ValueError(
|
||||
"Custom embedder does not create an EmbeddingFunction instance"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||
)
|
||||
|
||||
@@ -1,30 +1,64 @@
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
|
||||
|
||||
class FileHandler:
|
||||
"""take care of file operations, currently it only logs messages to a file"""
|
||||
"""Handler for file operations supporting both JSON and text-based logging.
|
||||
|
||||
Args:
|
||||
file_path (Union[bool, str]): Path to the log file or boolean flag
|
||||
"""
|
||||
|
||||
def __init__(self, file_path):
|
||||
if isinstance(file_path, bool):
|
||||
def __init__(self, file_path: Union[bool, str]):
|
||||
self._initialize_path(file_path)
|
||||
|
||||
def _initialize_path(self, file_path: Union[bool, str]):
|
||||
if file_path is True: # File path is boolean True
|
||||
self._path = os.path.join(os.curdir, "logs.txt")
|
||||
elif isinstance(file_path, str):
|
||||
self._path = file_path
|
||||
|
||||
elif isinstance(file_path, str): # File path is a string
|
||||
if file_path.endswith((".json", ".txt")):
|
||||
self._path = file_path # No modification if the file ends with .json or .txt
|
||||
else:
|
||||
self._path = file_path + ".txt" # Append .txt if the file doesn't end with .json or .txt
|
||||
|
||||
else:
|
||||
raise ValueError("file_path must be either a boolean or a string.")
|
||||
|
||||
raise ValueError("file_path must be a string or boolean.") # Handle the case where file_path isn't valid
|
||||
|
||||
def log(self, **kwargs):
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
message = (
|
||||
f"{now}: "
|
||||
+ ", ".join([f'{key}="{value}"' for key, value in kwargs.items()])
|
||||
+ "\n"
|
||||
)
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message + "\n")
|
||||
try:
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
log_entry = {"timestamp": now, **kwargs}
|
||||
|
||||
if self._path.endswith(".json"):
|
||||
# Append log in JSON format
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
# If the file is empty, start with a list; else, append to it
|
||||
try:
|
||||
# Try reading existing content to avoid overwriting
|
||||
with open(self._path, "r", encoding="utf-8") as read_file:
|
||||
existing_data = json.load(read_file)
|
||||
existing_data.append(log_entry)
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# If no valid JSON or file doesn't exist, start with an empty list
|
||||
existing_data = [log_entry]
|
||||
|
||||
with open(self._path, "w", encoding="utf-8") as write_file:
|
||||
json.dump(existing_data, write_file, indent=4)
|
||||
write_file.write("\n")
|
||||
|
||||
else:
|
||||
# Append log in plain text format
|
||||
message = f"{now}: " + ", ".join([f"{key}=\"{value}\"" for key, value in kwargs.items()]) + "\n"
|
||||
with open(self._path, "a", encoding="utf-8") as file:
|
||||
file.write(message)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to log message: {str(e)}")
|
||||
|
||||
class PickleHandler:
|
||||
def __init__(self, file_name: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -1183,7 +1183,7 @@ def test_agent_max_retry_limit():
|
||||
[
|
||||
mock.call(
|
||||
{
|
||||
"input": "Say the word: Hi\n\nThis is the expect criteria for your final answer: The word: Hi\nyou MUST return the actual complete content as the final answer, not a summary.",
|
||||
"input": "Say the word: Hi\n\nThis is the expected criteria for your final answer: The word: Hi\nyou MUST return the actual complete content as the final answer, not a summary.",
|
||||
"tool_names": "",
|
||||
"tools": "",
|
||||
"ask_for_human_input": True,
|
||||
@@ -1191,7 +1191,7 @@ def test_agent_max_retry_limit():
|
||||
),
|
||||
mock.call(
|
||||
{
|
||||
"input": "Say the word: Hi\n\nThis is the expect criteria for your final answer: The word: Hi\nyou MUST return the actual complete content as the final answer, not a summary.",
|
||||
"input": "Say the word: Hi\n\nThis is the expected criteria for your final answer: The word: Hi\nyou MUST return the actual complete content as the final answer, not a summary.",
|
||||
"tool_names": "",
|
||||
"tools": "",
|
||||
"ask_for_human_input": True,
|
||||
|
||||
@@ -55,72 +55,83 @@ def test_train_invalid_string_iterations(train_crew, runner):
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
|
||||
@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
|
||||
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
|
||||
@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
|
||||
def test_reset_all_memories(
|
||||
MockTaskOutputStorageHandler,
|
||||
MockLongTermMemory,
|
||||
MockEntityMemory,
|
||||
MockShortTermMemory,
|
||||
runner,
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["--all"])
|
||||
MockShortTermMemory().reset.assert_called_once()
|
||||
MockEntityMemory().reset.assert_called_once()
|
||||
MockLongTermMemory().reset.assert_called_once()
|
||||
MockTaskOutputStorageHandler().reset.assert_called_once()
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_all_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-a"])
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="all")
|
||||
assert result.output == "All memories have been reset.\n"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
|
||||
def test_reset_short_term_memories(MockShortTermMemory, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_short_term_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-s"])
|
||||
MockShortTermMemory().reset.assert_called_once()
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="short")
|
||||
assert result.output == "Short term memory has been reset.\n"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.EntityMemory")
|
||||
def test_reset_entity_memories(MockEntityMemory, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_entity_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-e"])
|
||||
MockEntityMemory().reset.assert_called_once()
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="entity")
|
||||
assert result.output == "Entity memory has been reset.\n"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
|
||||
def test_reset_long_term_memories(MockLongTermMemory, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_long_term_memories(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-l"])
|
||||
MockLongTermMemory().reset.assert_called_once()
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="long")
|
||||
assert result.output == "Long term memory has been reset.\n"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.TaskOutputStorageHandler")
|
||||
def test_reset_kickoff_outputs(MockTaskOutputStorageHandler, runner):
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_kickoff_outputs(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-k"])
|
||||
MockTaskOutputStorageHandler().reset.assert_called_once()
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="kickoff_outputs")
|
||||
assert result.output == "Latest Kickoff outputs stored has been reset.\n"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.ShortTermMemory")
|
||||
@mock.patch("crewai.cli.reset_memories_command.LongTermMemory")
|
||||
def test_reset_multiple_memory_flags(MockShortTermMemory, MockLongTermMemory, runner):
|
||||
result = runner.invoke(
|
||||
reset_memories,
|
||||
[
|
||||
"-s",
|
||||
"-l",
|
||||
],
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_multiple_memory_flags(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-s", "-l"])
|
||||
|
||||
# Check that reset_memories was called twice with the correct arguments
|
||||
assert mock_crew.reset_memories.call_count == 2
|
||||
mock_crew.reset_memories.assert_has_calls(
|
||||
[mock.call(command_type="long"), mock.call(command_type="short")]
|
||||
)
|
||||
MockShortTermMemory().reset.assert_called_once()
|
||||
MockLongTermMemory().reset.assert_called_once()
|
||||
assert (
|
||||
result.output
|
||||
== "Long term memory has been reset.\nShort term memory has been reset.\n"
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_knowledge(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert result.output == "Knowledge has been reset.\n"
|
||||
|
||||
|
||||
def test_reset_no_memory_flags(runner):
|
||||
result = runner.invoke(
|
||||
reset_memories,
|
||||
|
||||
Reference in New Issue
Block a user