mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-28 08:48:15 +00:00
Compare commits
9 Commits
bugfix-240
...
fix/cli-cr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b7c5d1821 | ||
|
|
63ef3918dd | ||
|
|
3c24350306 | ||
|
|
fcaf0d264f | ||
|
|
356d4d9729 | ||
|
|
e290064ecc | ||
|
|
77fa1b18c7 | ||
|
|
08a6a82071 | ||
|
|
625748e462 |
@@ -164,7 +164,10 @@ crew = Crew(
|
||||
|
||||
[Mem0](https://mem0.ai/) is a self-improving memory layer for LLM applications, enabling personalized AI experiences.
|
||||
|
||||
To include user-specific memory you can get your API key [here](https://app.mem0.ai/dashboard/api-keys) and refer the [docs](https://docs.mem0.ai/platform/quickstart#4-1-create-memories) for adding user preferences.
|
||||
|
||||
### Using Mem0 API platform
|
||||
|
||||
To include user-specific memory you can get your API key [here](https://app.mem0.ai/dashboard/api-keys) and refer the [docs](https://docs.mem0.ai/platform/quickstart#4-1-create-memories) for adding user preferences. In this case `user_memory` is set to `MemoryClient` from mem0.
|
||||
|
||||
|
||||
```python Code
|
||||
@@ -175,18 +178,7 @@ from mem0 import MemoryClient
|
||||
# Set environment variables for Mem0
|
||||
os.environ["MEM0_API_KEY"] = "m0-xx"
|
||||
|
||||
# Step 1: Record preferences based on past conversation or user input
|
||||
client = MemoryClient()
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi there! I'm planning a vacation and could use some advice."},
|
||||
{"role": "assistant", "content": "Hello! I'd be happy to help with your vacation planning. What kind of destination do you prefer?"},
|
||||
{"role": "user", "content": "I am more of a beach person than a mountain person."},
|
||||
{"role": "assistant", "content": "That's interesting. Do you like hotels or Airbnb?"},
|
||||
{"role": "user", "content": "I like Airbnb more."},
|
||||
]
|
||||
client.add(messages, user_id="john")
|
||||
|
||||
# Step 2: Create a Crew with User Memory
|
||||
# Step 1: Create a Crew with User Memory
|
||||
|
||||
crew = Crew(
|
||||
agents=[...],
|
||||
@@ -197,11 +189,12 @@ crew = Crew(
|
||||
memory_config={
|
||||
"provider": "mem0",
|
||||
"config": {"user_id": "john"},
|
||||
"user_memory" : {} #Set user_memory explicitly to a dictionary, we are working on this issue.
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
## Memory Configuration Options
|
||||
#### Additional Memory Configuration Options
|
||||
If you want to access a specific organization and project, you can set the `org_id` and `project_id` parameters in the memory configuration.
|
||||
|
||||
```python Code
|
||||
@@ -215,10 +208,74 @@ crew = Crew(
|
||||
memory_config={
|
||||
"provider": "mem0",
|
||||
"config": {"user_id": "john", "org_id": "my_org_id", "project_id": "my_project_id"},
|
||||
"user_memory" : {} #Set user_memory explicitly to a dictionary, we are working on this issue.
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
### Using Local Mem0 memory
|
||||
If you want to use local mem0 memory, with a custom configuration, you can set a parameter `local_mem0_config` in the config itself.
|
||||
If both os environment key is set and local_mem0_config is given, the API platform takes higher priority over the local configuration.
|
||||
Check [this](https://docs.mem0.ai/open-source/python-quickstart#run-mem0-locally) mem0 local configuration docs for more understanding.
|
||||
In this case `user_memory` is set to `Memory` from mem0.
|
||||
|
||||
|
||||
```python Code
|
||||
from crewai import Crew
|
||||
|
||||
|
||||
#local mem0 config
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"host": "localhost",
|
||||
"port": 6333
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model": "gpt-4"
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model": "text-embedding-3-small"
|
||||
}
|
||||
},
|
||||
"graph_store": {
|
||||
"provider": "neo4j",
|
||||
"config": {
|
||||
"url": "neo4j+s://your-instance",
|
||||
"username": "neo4j",
|
||||
"password": "password"
|
||||
}
|
||||
},
|
||||
"history_db_path": "/path/to/history.db",
|
||||
"version": "v1.1",
|
||||
"custom_fact_extraction_prompt": "Optional custom prompt for fact extraction for memory",
|
||||
"custom_update_memory_prompt": "Optional custom prompt for update memory"
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
verbose=True,
|
||||
memory=True,
|
||||
memory_config={
|
||||
"provider": "mem0",
|
||||
"config": {"user_id": "john", 'local_mem0_config': config},
|
||||
"user_memory" : {} #Set user_memory explicitly to a dictionary, we are working on this issue.
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Additional Embedding Providers
|
||||
|
||||
### Using OpenAI embeddings (already default)
|
||||
|
||||
@@ -93,50 +93,66 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
|
||||
env_vars = load_env_vars(folder_path)
|
||||
if not skip_provider:
|
||||
if not provider:
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
|
||||
existing_provider = None
|
||||
for provider, env_keys in ENV_VARS.items():
|
||||
if any(
|
||||
"key_name" in details and details["key_name"] in env_vars
|
||||
for details in env_keys
|
||||
):
|
||||
existing_provider = provider
|
||||
break
|
||||
|
||||
if existing_provider:
|
||||
if not click.confirm(
|
||||
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
|
||||
):
|
||||
click.secho("Keeping existing provider configuration.", fg="yellow")
|
||||
return
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
click.secho("Could not retrieve provider data.", fg="red")
|
||||
return
|
||||
|
||||
while True:
|
||||
selected_provider = select_provider(provider_models)
|
||||
if selected_provider is None: # User typed 'q'
|
||||
click.secho("Exiting...", fg="yellow")
|
||||
sys.exit(0)
|
||||
if selected_provider: # Valid selection
|
||||
break
|
||||
click.secho(
|
||||
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
||||
)
|
||||
selected_provider = None
|
||||
|
||||
if provider:
|
||||
provider = provider.lower()
|
||||
if provider in provider_models:
|
||||
selected_provider = provider
|
||||
click.secho(f"Using specified provider: {selected_provider.capitalize()}", fg="green")
|
||||
else:
|
||||
click.secho(f"Warning: Specified provider '{provider}' is not recognized. Please select one.", fg="yellow")
|
||||
|
||||
if not selected_provider:
|
||||
existing_provider = None
|
||||
for p, env_keys in ENV_VARS.items():
|
||||
if any(
|
||||
"key_name" in details and details["key_name"] in env_vars
|
||||
for details in env_keys
|
||||
):
|
||||
existing_provider = p
|
||||
break
|
||||
|
||||
if existing_provider:
|
||||
if not click.confirm(
|
||||
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
|
||||
):
|
||||
click.secho("Keeping existing provider configuration. Exiting provider setup.", fg="yellow")
|
||||
copy_template_files(folder_path, name, class_name, parent_folder)
|
||||
click.secho(f"Crew '{name}' created successfully!", fg="green")
|
||||
click.secho(f"To run your crew, cd into '{folder_name}' and run 'crewai run'", fg="cyan")
|
||||
return
|
||||
else:
|
||||
pass
|
||||
|
||||
while True:
|
||||
selected_provider = select_provider(provider_models)
|
||||
if selected_provider is None:
|
||||
click.secho("Exiting...", fg="yellow")
|
||||
sys.exit(0)
|
||||
if selected_provider:
|
||||
break
|
||||
click.secho(
|
||||
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
||||
)
|
||||
|
||||
if not selected_provider:
|
||||
click.secho("Provider selection failed. Exiting.", fg="red")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Check if the selected provider has predefined models
|
||||
if selected_provider in MODELS and MODELS[selected_provider]:
|
||||
while True:
|
||||
selected_model = select_model(selected_provider, provider_models)
|
||||
if selected_model is None: # User typed 'q'
|
||||
if selected_model is None:
|
||||
click.secho("Exiting...", fg="yellow")
|
||||
sys.exit(0)
|
||||
if selected_model: # Valid selection
|
||||
if selected_model:
|
||||
break
|
||||
click.secho(
|
||||
"No model selected. Please try again or press 'q' to exit.",
|
||||
@@ -144,17 +160,14 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
)
|
||||
env_vars["MODEL"] = selected_model
|
||||
|
||||
# Check if the selected provider requires API keys
|
||||
if selected_provider in ENV_VARS:
|
||||
provider_env_vars = ENV_VARS[selected_provider]
|
||||
for details in provider_env_vars:
|
||||
if details.get("default", False):
|
||||
# Automatically add default key-value pairs
|
||||
for key, value in details.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
env_vars[key] = value
|
||||
elif "key_name" in details:
|
||||
# Prompt for non-default key-value pairs
|
||||
prompt = details["prompt"]
|
||||
key_name = details["key_name"]
|
||||
api_key_value = click.prompt(prompt, default="", show_default=False)
|
||||
@@ -167,41 +180,12 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
click.secho("API keys and model saved to .env file", fg="green")
|
||||
else:
|
||||
click.secho(
|
||||
"No API keys provided. Skipping .env file creation.", fg="yellow"
|
||||
"No API keys provided or required by provider. Skipping .env file creation.", fg="yellow"
|
||||
)
|
||||
|
||||
click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
copy_template_files(folder_path, name, class_name, parent_folder)
|
||||
|
||||
root_template_files = (
|
||||
[".gitignore", "pyproject.toml", "README.md", "knowledge/user_preference.txt"]
|
||||
if not parent_folder
|
||||
else []
|
||||
)
|
||||
tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"]
|
||||
config_template_files = ["config/agents.yaml", "config/tasks.yaml"]
|
||||
src_template_files = (
|
||||
["__init__.py", "main.py", "crew.py"] if not parent_folder else ["crew.py"]
|
||||
)
|
||||
|
||||
for file_name in root_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = folder_path / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_name)
|
||||
|
||||
src_folder = folder_path / "src" / folder_name if not parent_folder else folder_path
|
||||
|
||||
for file_name in src_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = src_folder / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_name)
|
||||
|
||||
if not parent_folder:
|
||||
for file_name in tools_template_files + config_template_files:
|
||||
src_file = templates_dir / file_name
|
||||
dst_file = src_folder / file_name
|
||||
copy_template(src_file, dst_file, name, class_name, folder_name)
|
||||
|
||||
click.secho(f"Crew {name} created successfully!", fg="green", bold=True)
|
||||
click.secho(f"Crew '{name}' created successfully!", fg="green")
|
||||
click.secho(f"To run your crew, cd into '{folder_name}' and run 'crewai run'", fg="cyan")
|
||||
|
||||
@@ -290,23 +290,17 @@ class Crew(BaseModel):
|
||||
else EntityMemory(crew=self, embedder_config=self.embedder)
|
||||
)
|
||||
if (
|
||||
self.memory_config and "user_memory" in self.memory_config
|
||||
self.memory_config
|
||||
and "user_memory" in self.memory_config
|
||||
and self.memory_config.get("provider") == "mem0"
|
||||
): # Check for user_memory in config
|
||||
user_memory_config = self.memory_config["user_memory"]
|
||||
if isinstance(
|
||||
user_memory_config, UserMemory
|
||||
): # Check if it is already an instance
|
||||
self._user_memory = user_memory_config
|
||||
elif isinstance(
|
||||
user_memory_config, dict
|
||||
): # Check if it's a configuration dict
|
||||
self._user_memory = UserMemory(
|
||||
crew=self, **user_memory_config
|
||||
) # Initialize with config
|
||||
self._user_memory = UserMemory(crew=self)
|
||||
else:
|
||||
raise TypeError(
|
||||
"user_memory must be a UserMemory instance or a configuration dictionary"
|
||||
)
|
||||
raise TypeError("user_memory must be a configuration dictionary")
|
||||
else:
|
||||
self._user_memory = None # No user memory if not in config
|
||||
return self
|
||||
@@ -1158,7 +1152,7 @@ class Crew(BaseModel):
|
||||
def copy(self):
|
||||
"""
|
||||
Creates a deep copy of the Crew instance.
|
||||
|
||||
|
||||
Returns:
|
||||
Crew: A new instance with copied components
|
||||
"""
|
||||
@@ -1180,7 +1174,6 @@ class Crew(BaseModel):
|
||||
"knowledge",
|
||||
"manager_agent",
|
||||
"manager_llm",
|
||||
|
||||
}
|
||||
|
||||
cloned_agents = [agent.copy() for agent in self.agents]
|
||||
|
||||
@@ -94,6 +94,10 @@ class ContextualMemory:
|
||||
Returns:
|
||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||
"""
|
||||
|
||||
if self.um is None:
|
||||
return ""
|
||||
|
||||
user_memories = self.um.search(query)
|
||||
if not user_memories:
|
||||
return ""
|
||||
|
||||
@@ -31,6 +31,7 @@ class Mem0Storage(Storage):
|
||||
mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY")
|
||||
mem0_org_id = config.get("org_id")
|
||||
mem0_project_id = config.get("project_id")
|
||||
mem0_local_config = config.get("local_mem0_config")
|
||||
|
||||
# Initialize MemoryClient or Memory based on the presence of the mem0_api_key
|
||||
if mem0_api_key:
|
||||
@@ -41,7 +42,10 @@ class Mem0Storage(Storage):
|
||||
else:
|
||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||
else:
|
||||
self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided
|
||||
if mem0_local_config and len(mem0_local_config):
|
||||
self.memory = Memory.from_config(config)
|
||||
else:
|
||||
self.memory = Memory()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
@@ -114,3 +118,7 @@ class Mem0Storage(Storage):
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
return agents
|
||||
|
||||
def reset(self):
|
||||
if self.memory:
|
||||
self.memory.reset()
|
||||
|
||||
@@ -43,3 +43,11 @@ class UserMemory(Memory):
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the user memory: {e}"
|
||||
)
|
||||
|
||||
@@ -7,29 +7,27 @@ from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PydanticDeprecatedSince20,
|
||||
create_model,
|
||||
validator,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
# Ignore all "PydanticDeprecatedSince20" warnings globally
|
||||
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||
pass
|
||||
|
||||
model_config = ConfigDict()
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
args_schema: Type[PydanticBaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
|
||||
args_schema: Type[PydanticBaseModel] = Field(
|
||||
default_factory=_ArgsSchemaPlaceholder, validate_default=True
|
||||
)
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
description_updated: bool = False
|
||||
"""Flag to check if the description has been updated."""
|
||||
@@ -38,7 +36,8 @@ class BaseTool(BaseModel, ABC):
|
||||
result_as_answer: bool = False
|
||||
"""Flag to check if the tool should be the final agent answer."""
|
||||
|
||||
@validator("args_schema", always=True, pre=True)
|
||||
@field_validator("args_schema", mode="before")
|
||||
@classmethod
|
||||
def _default_args_schema(
|
||||
cls, v: Type[PydanticBaseModel]
|
||||
) -> Type[PydanticBaseModel]:
|
||||
|
||||
@@ -287,8 +287,9 @@ def generate_model_description(model: Type[BaseModel]) -> str:
|
||||
else:
|
||||
return str(field_type)
|
||||
|
||||
fields = model.__annotations__
|
||||
fields = model.model_fields
|
||||
field_descriptions = [
|
||||
f'"{name}": {describe_field(type_)}' for name, type_ in fields.items()
|
||||
f'"{name}": {describe_field(field.annotation)}'
|
||||
for name, field in fields.items()
|
||||
]
|
||||
return "{\n " + ",\n ".join(field_descriptions) + "\n}"
|
||||
|
||||
142
tests/cli/test_create_crew.py
Normal file
142
tests/cli/test_create_crew.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from unittest.mock import patch, MagicMock
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Ensure the src directory is in the Python path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'src'))
|
||||
|
||||
from crewai.cli.cli import crewai
|
||||
from crewai.cli import create_crew
|
||||
from crewai.cli.constants import MODELS, ENV_VARS
|
||||
|
||||
# Mock provider data for testing
|
||||
MOCK_PROVIDER_DATA = {
|
||||
'openai': {'models': ['gpt-4', 'gpt-3.5-turbo']},
|
||||
'google': {'models': ['gemini-pro']},
|
||||
'anthropic': {'models': ['claude-3-opus']}
|
||||
}
|
||||
|
||||
MOCK_VALID_PROVIDERS = list(MOCK_PROVIDER_DATA.keys())
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolate_fs(monkeypatch):
|
||||
# Prevent tests from interacting with the actual filesystem or real env vars
|
||||
monkeypatch.setattr(Path, 'mkdir', lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(Path, 'exists', lambda *args: False) # Assume folders don't exist initially
|
||||
monkeypatch.setattr(create_crew, 'load_env_vars', lambda *args: {}) # Start with empty env vars
|
||||
monkeypatch.setattr(create_crew, 'write_env_file', lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(create_crew, 'copy_template_files', lambda *args, **kwargs: None)
|
||||
|
||||
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
|
||||
@patch('crewai.cli.create_crew.select_provider')
|
||||
@patch('crewai.cli.create_crew.select_model')
|
||||
@patch('click.prompt')
|
||||
@patch('click.confirm', return_value=True) # Default to confirming prompts
|
||||
def test_create_crew_with_valid_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
|
||||
"""Test `crewai create crew <name> --provider <valid_provider>`"""
|
||||
result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--provider', 'openai'])
|
||||
|
||||
print(f"CLI Output:\n{result.output}") # Debug output
|
||||
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
|
||||
assert "Using specified provider: Openai" in result.output
|
||||
mock_select_provider.assert_not_called() # Should not ask interactively
|
||||
# Depending on whether openai needs models/keys, check select_model/prompt calls
|
||||
assert "Crew 'testcrew' created successfully!" in result.output
|
||||
|
||||
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
|
||||
@patch('crewai.cli.create_crew.select_provider', return_value='google') # Simulate user selecting google
|
||||
@patch('crewai.cli.create_crew.select_model', return_value='gemini-pro')
|
||||
@patch('click.prompt')
|
||||
@patch('click.confirm', return_value=True)
|
||||
def test_create_crew_with_invalid_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
|
||||
"""Test `crewai create crew <name> --provider <invalid_provider>`"""
|
||||
result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--provider', 'invalidprovider'])
|
||||
|
||||
print(f"CLI Output:\n{result.output}") # Debug output
|
||||
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
|
||||
assert "Warning: Specified provider 'invalidprovider' is not recognized." in result.output
|
||||
mock_select_provider.assert_called_once() # Should ask interactively
|
||||
# Check if subsequent steps for the selected provider (google) ran
|
||||
mock_select_model.assert_called_once()
|
||||
assert "Crew 'testcrew' created successfully!" in result.output
|
||||
|
||||
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
|
||||
@patch('crewai.cli.create_crew.select_provider', return_value='anthropic') # Simulate user selecting anthropic
|
||||
@patch('crewai.cli.create_crew.select_model', return_value='claude-3-opus')
|
||||
@patch('click.prompt', return_value='sk-abc') # Simulate API key entry
|
||||
@patch('click.confirm', return_value=True)
|
||||
def test_create_crew_no_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
|
||||
"""Test `crewai create crew <name>`"""
|
||||
result = runner.invoke(crewai, ['create', 'crew', 'testcrew'])
|
||||
|
||||
print(f"CLI Output:\n{result.output}") # Debug output
|
||||
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
|
||||
assert "Using specified provider:" not in result.output # Should not mention specified provider
|
||||
mock_select_provider.assert_called_once() # Should ask interactively
|
||||
mock_select_model.assert_called_once()
|
||||
# Check if prompt for API key was called (assuming anthropic needs one)
|
||||
if 'anthropic' in ENV_VARS and any('key_name' in d for d in ENV_VARS['anthropic']):
|
||||
mock_prompt.assert_called()
|
||||
assert "Crew 'testcrew' created successfully!" in result.output
|
||||
|
||||
@patch('crewai.cli.create_crew.get_provider_data')
|
||||
@patch('crewai.cli.create_crew.select_provider')
|
||||
@patch('crewai.cli.create_crew.select_model')
|
||||
@patch('click.prompt')
|
||||
@patch('click.confirm')
|
||||
def test_create_crew_skip_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
|
||||
"""Test `crewai create crew <name> --skip_provider`"""
|
||||
result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--skip_provider'])
|
||||
|
||||
print(f"CLI Output:\n{result.output}") # Debug output
|
||||
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
|
||||
mock_get_data.assert_not_called()
|
||||
mock_select_provider.assert_not_called()
|
||||
mock_select_model.assert_not_called()
|
||||
mock_prompt.assert_not_called()
|
||||
mock_confirm.assert_not_called()
|
||||
assert "Crew 'testcrew' created successfully!" in result.output
|
||||
|
||||
@patch('crewai.cli.create_crew.load_env_vars', return_value={'OPENAI_API_KEY': 'existing_key'}) # Simulate existing env
|
||||
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
|
||||
@patch('crewai.cli.create_crew.select_provider', return_value='google') # Simulate selecting new provider
|
||||
@patch('crewai.cli.create_crew.select_model', return_value='gemini-pro')
|
||||
@patch('click.prompt')
|
||||
@patch('click.confirm', return_value=True) # User confirms override
|
||||
def test_create_crew_existing_override(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, mock_load_env, runner):
|
||||
"""Test `crewai create crew <name>` with existing config and user overrides."""
|
||||
result = runner.invoke(crewai, ['create', 'crew', 'testcrew'])
|
||||
|
||||
print(f"CLI Output:\n{result.output}") # Debug output
|
||||
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
|
||||
mock_confirm.assert_called_once_with(
|
||||
'Found existing environment variable configuration for Openai. Do you want to override it?'
|
||||
)
|
||||
mock_select_provider.assert_called_once() # Should ask for new provider after confirming override
|
||||
assert "Crew 'testcrew' created successfully!" in result.output
|
||||
|
||||
@patch('crewai.cli.create_crew.load_env_vars', return_value={'OPENAI_API_KEY': 'existing_key'}) # Simulate existing env
|
||||
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
|
||||
@patch('crewai.cli.create_crew.select_provider')
|
||||
@patch('crewai.cli.create_crew.select_model')
|
||||
@patch('click.prompt')
|
||||
@patch('click.confirm', return_value=False) # User denies override
|
||||
def test_create_crew_existing_keep(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, mock_load_env, runner):
|
||||
"""Test `crewai create crew <name>` with existing config and user keeps it."""
|
||||
result = runner.invoke(crewai, ['create', 'crew', 'testcrew'])
|
||||
|
||||
print(f"CLI Output:\n{result.output}") # Debug output
|
||||
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
|
||||
mock_confirm.assert_called_once_with(
|
||||
'Found existing environment variable configuration for Openai. Do you want to override it?'
|
||||
)
|
||||
assert "Keeping existing provider configuration. Exiting provider setup." in result.output
|
||||
mock_select_provider.assert_not_called() # Should NOT ask for new provider
|
||||
assert "Crew 'testcrew' created successfully!" in result.output
|
||||
|
||||
68
tests/memory/user_memory_test.py
Normal file
68
tests/memory/user_memory_test.py
Normal file
@@ -0,0 +1,68 @@
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mem0.memory.main import Memory
|
||||
|
||||
from crewai.memory.user.user_memory import UserMemory
|
||||
from crewai.memory.user.user_memory_item import UserMemoryItem
|
||||
|
||||
|
||||
class MockCrew:
|
||||
def __init__(self, memory_config):
|
||||
self.memory_config = memory_config
|
||||
|
||||
@pytest.fixture
|
||||
def user_memory():
|
||||
"""Fixture to create a UserMemory instance"""
|
||||
crew = MockCrew(
|
||||
memory_config={
|
||||
"provider": "mem0",
|
||||
"config": {"user_id": "john"},
|
||||
"user_memory" : {}
|
||||
}
|
||||
)
|
||||
|
||||
user_memory = MagicMock(spec=UserMemory)
|
||||
|
||||
with patch.object(Memory,'__new__',return_value=user_memory):
|
||||
user_memory_instance = UserMemory(crew=crew)
|
||||
|
||||
return user_memory_instance
|
||||
|
||||
def test_save_and_search(user_memory):
|
||||
memory = UserMemoryItem(
|
||||
data="""test value test value test value test value test value test value
|
||||
test value test value test value test value test value test value
|
||||
test value test value test value test value test value test value""",
|
||||
user="test_user",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
|
||||
with patch.object(UserMemory, "save") as mock_save:
|
||||
user_memory.save(
|
||||
value=memory.data,
|
||||
metadata=memory.metadata,
|
||||
user=memory.user
|
||||
)
|
||||
|
||||
mock_save.assert_called_once_with(
|
||||
value=memory.data,
|
||||
metadata=memory.metadata,
|
||||
user=memory.user
|
||||
)
|
||||
|
||||
expected_result = [
|
||||
{
|
||||
"context": memory.data,
|
||||
"metadata": {"agent": "test_agent"},
|
||||
"score": 0.95,
|
||||
}
|
||||
]
|
||||
expected_result = ["mocked_result"]
|
||||
|
||||
# Use patch.object to mock UserMemory's search method
|
||||
with patch.object(UserMemory, 'search', return_value=expected_result) as mock_search:
|
||||
find = UserMemory.search("test value", score_threshold=0.01)[0]
|
||||
mock_search.assert_called_once_with("test value", score_threshold=0.01)
|
||||
assert find == expected_result[0]
|
||||
114
tests/storage/test_mem0_storage.py
Normal file
114
tests/storage/test_mem0_storage.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mem0.client.main import MemoryClient
|
||||
from mem0.memory.main import Memory
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
# Define the class (if not already defined)
|
||||
class MockCrew:
|
||||
def __init__(self, memory_config):
|
||||
self.memory_config = memory_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory():
|
||||
"""Fixture to create a mock Memory instance"""
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
return mock_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem0_storage_with_mocked_config(mock_mem0_memory):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# Patch the Memory class to return our mock
|
||||
with patch('mem0.memory.main.Memory.from_config', return_value=mock_mem0_memory):
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "mock_vector_store",
|
||||
"config": {
|
||||
"host": "localhost",
|
||||
"port": 6333
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"provider": "mock_llm",
|
||||
"config": {
|
||||
"api_key": "mock-api-key",
|
||||
"model": "mock-model"
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "mock_embedder",
|
||||
"config": {
|
||||
"api_key": "mock-api-key",
|
||||
"model": "mock-model"
|
||||
}
|
||||
},
|
||||
"graph_store": {
|
||||
"provider": "mock_graph_store",
|
||||
"config": {
|
||||
"url": "mock-url",
|
||||
"username": "mock-user",
|
||||
"password": "mock-password"
|
||||
}
|
||||
},
|
||||
"history_db_path": "/mock/path",
|
||||
"version": "test-version",
|
||||
"custom_fact_extraction_prompt": "mock prompt 1",
|
||||
"custom_update_memory_prompt": "mock prompt 2"
|
||||
}
|
||||
|
||||
# Instantiate the class with memory_config
|
||||
crew = MockCrew(
|
||||
memory_config={
|
||||
"provider": "mem0",
|
||||
"config": {"user_id": "test_user", "local_mem0_config": config},
|
||||
}
|
||||
)
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew)
|
||||
return mem0_storage
|
||||
|
||||
|
||||
def test_mem0_storage_initialization(mem0_storage_with_mocked_config, mock_mem0_memory):
|
||||
"""Test that Mem0Storage initializes correctly with the mocked config"""
|
||||
assert mem0_storage_with_mocked_config.memory_type == "short_term"
|
||||
assert mem0_storage_with_mocked_config.memory is mock_mem0_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory_client():
|
||||
"""Fixture to create a mock MemoryClient instance"""
|
||||
mock_memory = MagicMock(spec=MemoryClient)
|
||||
return mock_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem0_storage_with_memory_client(mock_mem0_memory_client):
|
||||
"""Fixture to create a Mem0Storage instance with mocked dependencies"""
|
||||
|
||||
# We need to patch the MemoryClient before it's instantiated
|
||||
with patch.object(MemoryClient, '__new__', return_value=mock_mem0_memory_client):
|
||||
crew = MockCrew(
|
||||
memory_config={
|
||||
"provider": "mem0",
|
||||
"config": {"user_id": "test_user", "api_key": "ABCDEFGH", "org_id": "my_org_id", "project_id": "my_project_id"},
|
||||
}
|
||||
)
|
||||
|
||||
mem0_storage = Mem0Storage(type="short_term", crew=crew)
|
||||
return mem0_storage
|
||||
|
||||
|
||||
def test_mem0_storage_with_memory_client_initialization(mem0_storage_with_memory_client, mock_mem0_memory_client):
|
||||
"""Test Mem0Storage initialization with MemoryClient"""
|
||||
assert mem0_storage_with_memory_client.memory_type == "short_term"
|
||||
assert mem0_storage_with_memory_client.memory is mock_mem0_memory_client
|
||||
Reference in New Issue
Block a user