mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
32 Commits
devin/1756
...
devin/1738
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
20fc2f9878 | ||
|
|
c149b75874 | ||
|
|
86844ff3df | ||
|
|
b442fe20a2 | ||
|
|
9b1b1d33ba | ||
|
|
3c350e8933 | ||
|
|
a3a5507f9a | ||
|
|
a175167aaf | ||
|
|
1dc62b0d0a | ||
|
|
75b376ebac | ||
|
|
29106068b7 | ||
|
|
3bf531189f | ||
|
|
47919a60a0 | ||
|
|
6b9ed90510 | ||
|
|
f6a65486f1 | ||
|
|
bf6db93bdf | ||
|
|
25e68bc459 | ||
|
|
6f6010db1c | ||
|
|
a95227deef | ||
|
|
636dac6efb | ||
|
|
a4e2b17bae | ||
|
|
823f22a601 | ||
|
|
649414805d | ||
|
|
8017ab2dfd | ||
|
|
6445cda35a | ||
|
|
6116c73721 | ||
|
|
a038b751ef | ||
|
|
5006161d31 | ||
|
|
85a13751ba | ||
|
|
1c7c4cb828 | ||
|
|
509fb375ca | ||
|
|
d01d44b29c |
45
docs/memory.md
Normal file
45
docs/memory.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Memory in CrewAI
|
||||
|
||||
CrewAI provides a robust memory system that allows agents to retain and recall information from previous interactions.
|
||||
|
||||
## Configuring Embedding Providers
|
||||
|
||||
CrewAI supports multiple embedding providers for memory functionality:
|
||||
|
||||
- OpenAI (default) - Requires `OPENAI_API_KEY`
|
||||
- Ollama - Requires `CREWAI_OLLAMA_URL` (defaults to "http://localhost:11434/api/embeddings")
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Configure the embedding provider using these environment variables:
|
||||
|
||||
- `CREWAI_EMBEDDING_PROVIDER`: Provider name (default: "openai")
|
||||
- `CREWAI_EMBEDDING_MODEL`: Model name (default: "text-embedding-3-small")
|
||||
- `CREWAI_OLLAMA_URL`: URL for Ollama API (when using Ollama provider)
|
||||
|
||||
### Example Configuration
|
||||
|
||||
```python
|
||||
# Using OpenAI (default)
|
||||
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||
|
||||
# Using Ollama
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2" # or any other model supported by your Ollama instance
|
||||
os.environ["CREWAI_OLLAMA_URL"] = "http://localhost:11434/api/embeddings" # optional, this is the default
|
||||
```
|
||||
|
||||
## Memory Usage
|
||||
|
||||
When an agent has memory enabled, it can access and store information from previous interactions:
|
||||
|
||||
```python
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research AI topics",
|
||||
backstory="You're an AI researcher",
|
||||
memory=True # Enable memory for this agent
|
||||
)
|
||||
```
|
||||
|
||||
The memory system uses embeddings to store and retrieve relevant information, allowing agents to maintain context across multiple interactions and tasks.
|
||||
@@ -1,13 +1,14 @@
|
||||
import re
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
@@ -16,10 +17,10 @@ from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities import Converter, Prompts
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
@@ -54,13 +55,13 @@ class Agent(BaseAgent):
|
||||
llm: The language model that will run the agent.
|
||||
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
|
||||
max_iter: Maximum number of iterations for an agent to execute a task.
|
||||
memory: Whether the agent should have memory or not.
|
||||
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
|
||||
verbose: Whether the agent execution should be in verbose mode.
|
||||
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
|
||||
tools: Tools at agents disposal
|
||||
step_callback: Callback to be executed after each step of the agent execution.
|
||||
knowledge_sources: Knowledge sources for the agent.
|
||||
embedder: Embedder configuration for the agent.
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
@@ -70,6 +71,9 @@ class Agent(BaseAgent):
|
||||
)
|
||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
cache_handler: InstanceOf[CacheHandler] = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
)
|
||||
step_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
@@ -81,7 +85,7 @@ class Agent(BaseAgent):
|
||||
llm: Union[str, InstanceOf[LLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
function_calling_llm: Optional[Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
@@ -103,6 +107,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Keep messages under the context window size by summarizing content.",
|
||||
)
|
||||
max_iter: int = Field(
|
||||
default=20,
|
||||
description="Maximum number of iterations for an agent to execute a task before giving it's best answer",
|
||||
)
|
||||
max_retry_limit: int = Field(
|
||||
default=2,
|
||||
description="Maximum number of retries for an agent to execute a task when an error occurs.",
|
||||
@@ -115,19 +123,105 @@ class Agent(BaseAgent):
|
||||
default="safe",
|
||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||
)
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
embedder_config: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
_knowledge: Optional[Knowledge] = PrivateAttr(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
self._set_knowledge()
|
||||
self.agent_ops_agent_name = self.role
|
||||
unaccepted_attributes = [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
]
|
||||
|
||||
self.llm = create_llm(self.llm)
|
||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
# Handle different cases for self.llm
|
||||
if isinstance(self.llm, str):
|
||||
# If it's a string, create an LLM instance
|
||||
self.llm = LLM(model=self.llm)
|
||||
elif isinstance(self.llm, LLM):
|
||||
# If it's already an LLM instance, keep it as is
|
||||
pass
|
||||
elif self.llm is None:
|
||||
# Determine the model name from environment variables or use default
|
||||
model_name = (
|
||||
os.environ.get("OPENAI_MODEL_NAME")
|
||||
or os.environ.get("MODEL")
|
||||
or "gpt-4o-mini"
|
||||
)
|
||||
llm_params = {"model": model_name}
|
||||
|
||||
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get(
|
||||
"OPENAI_BASE_URL"
|
||||
)
|
||||
if api_base:
|
||||
llm_params["base_url"] = api_base
|
||||
|
||||
set_provider = model_name.split("/")[0] if "/" in model_name else "openai"
|
||||
|
||||
# Iterate over all environment variables to find matching API keys or use defaults
|
||||
for provider, env_vars in ENV_VARS.items():
|
||||
if provider == set_provider:
|
||||
for env_var in env_vars:
|
||||
# Check if the environment variable is set
|
||||
key_name = env_var.get("key_name")
|
||||
if key_name and key_name not in unaccepted_attributes:
|
||||
env_value = os.environ.get(key_name)
|
||||
if env_value:
|
||||
key_name = key_name.lower()
|
||||
for pattern in LITELLM_PARAMS:
|
||||
if pattern in key_name:
|
||||
key_name = pattern
|
||||
break
|
||||
llm_params[key_name] = env_value
|
||||
# Check for default values if the environment variable is not set
|
||||
elif env_var.get("default", False):
|
||||
for key, value in env_var.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
# Only add default if the key is already set in os.environ
|
||||
if key in os.environ:
|
||||
llm_params[key] = value
|
||||
|
||||
self.llm = LLM(**llm_params)
|
||||
else:
|
||||
# For any other type, attempt to extract relevant attributes
|
||||
llm_params = {
|
||||
"model": getattr(self.llm, "model_name", None)
|
||||
or getattr(self.llm, "deployment_name", None)
|
||||
or str(self.llm),
|
||||
"temperature": getattr(self.llm, "temperature", None),
|
||||
"max_tokens": getattr(self.llm, "max_tokens", None),
|
||||
"logprobs": getattr(self.llm, "logprobs", None),
|
||||
"timeout": getattr(self.llm, "timeout", None),
|
||||
"max_retries": getattr(self.llm, "max_retries", None),
|
||||
"api_key": getattr(self.llm, "api_key", None),
|
||||
"base_url": getattr(self.llm, "base_url", None),
|
||||
"organization": getattr(self.llm, "organization", None),
|
||||
}
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
self.llm = LLM(**llm_params)
|
||||
|
||||
# Similar handling for function_calling_llm
|
||||
if self.function_calling_llm:
|
||||
if isinstance(self.function_calling_llm, str):
|
||||
self.function_calling_llm = LLM(model=self.function_calling_llm)
|
||||
elif not isinstance(self.function_calling_llm, LLM):
|
||||
self.function_calling_llm = LLM(
|
||||
model=getattr(self.function_calling_llm, "model_name", None)
|
||||
or getattr(self.function_calling_llm, "deployment_name", None)
|
||||
or str(self.function_calling_llm)
|
||||
)
|
||||
|
||||
if not self.agent_executor:
|
||||
self._setup_agent_executor()
|
||||
@@ -145,16 +239,23 @@ class Agent(BaseAgent):
|
||||
def _set_knowledge(self):
|
||||
try:
|
||||
if self.knowledge_sources:
|
||||
full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)")
|
||||
knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}"
|
||||
knowledge_agent_name = f"{self.role.replace(' ', '_')}"
|
||||
if isinstance(self.knowledge_sources, list) and all(
|
||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||
):
|
||||
self.knowledge = Knowledge(
|
||||
# Validate embedding configuration based on provider
|
||||
from crewai.utilities.constants import DEFAULT_EMBEDDING_PROVIDER
|
||||
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", DEFAULT_EMBEDDING_PROVIDER)
|
||||
|
||||
if provider == "openai" and not os.getenv("OPENAI_API_KEY"):
|
||||
raise ValueError("Please provide an OpenAI API key via OPENAI_API_KEY environment variable")
|
||||
elif provider == "ollama" and not os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings"):
|
||||
raise ValueError("Please provide Ollama URL via CREWAI_OLLAMA_URL environment variable")
|
||||
|
||||
self._knowledge = Knowledge(
|
||||
sources=self.knowledge_sources,
|
||||
embedder=self.embedder,
|
||||
embedder_config=self.embedder_config,
|
||||
collection_name=knowledge_agent_name,
|
||||
storage=self.knowledge_storage or None,
|
||||
)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
||||
@@ -188,15 +289,13 @@ class Agent(BaseAgent):
|
||||
if task.output_json:
|
||||
# schema = json.dumps(task.output_json, indent=2)
|
||||
schema = generate_model_description(task.output_json)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema = generate_model_description(task.output_pydantic)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
task_prompt += "\n" + self.i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
@@ -215,8 +314,8 @@ class Agent(BaseAgent):
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
if self.knowledge:
|
||||
agent_knowledge_snippets = self.knowledge.query([task.prompt()])
|
||||
if self._knowledge:
|
||||
agent_knowledge_snippets = self._knowledge.query([task.prompt()])
|
||||
if agent_knowledge_snippets:
|
||||
agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
@@ -249,9 +348,6 @@ class Agent(BaseAgent):
|
||||
}
|
||||
)["output"]
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
raise e
|
||||
@@ -324,14 +420,13 @@ class Agent(BaseAgent):
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
def get_multimodal_tools(self) -> List[Tool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
|
||||
return [AddImageTool()]
|
||||
|
||||
def get_code_execution_tools(self):
|
||||
try:
|
||||
from crewai_tools import CodeInterpreterTool # type: ignore
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
# Set the unsafe_mode based on the code_execution_mode attribute
|
||||
unsafe_mode = self.code_execution_mode == "unsafe"
|
||||
|
||||
@@ -113,7 +113,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
description="Enable agent to delegate and ask questions among each other.",
|
||||
)
|
||||
tools: Optional[List[Any]] = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
default_factory=lambda: [], description="Tools at agents' disposal"
|
||||
)
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -18,7 +16,6 @@ from pydantic import (
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.main import IncEx
|
||||
from typing_extensions import Literal
|
||||
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
@@ -21,16 +23,45 @@ class CrewOutput(BaseModel):
|
||||
tasks_output: list[TaskOutput] = Field(
|
||||
description="Output of each task", default=[]
|
||||
)
|
||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||
token_usage: UsageMetrics = Field(description="Processed token summary", default_factory=UsageMetrics)
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
def model_json(self) -> str:
|
||||
"""Get the JSON representation of the output."""
|
||||
if self.tasks_output and self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
)
|
||||
return json.dumps(self.json_dict) if self.json_dict else "{}"
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
def model_dump_json(
|
||||
self,
|
||||
*,
|
||||
indent: Optional[int] = None,
|
||||
include: Optional[IncEx] = None,
|
||||
exclude: Optional[IncEx] = None,
|
||||
context: Optional[Any] = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal["none", "warn", "error"] = False,
|
||||
serialize_as_any: bool = False,
|
||||
) -> str:
|
||||
"""Override model_dump_json to handle custom JSON output."""
|
||||
return super().model_dump_json(
|
||||
indent=indent,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
context=context,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
serialize_as_any=serialize_as_any,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
|
||||
@@ -47,7 +47,7 @@ class FastEmbed(BaseEmbedder):
|
||||
cache_dir=str(cache_dir) if cache_dir else None,
|
||||
)
|
||||
|
||||
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
|
||||
def embed_chunks(self, chunks: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of text chunks
|
||||
|
||||
@@ -55,12 +55,12 @@ class FastEmbed(BaseEmbedder):
|
||||
chunks: List of text chunks to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings
|
||||
Array of embeddings
|
||||
"""
|
||||
embeddings = list(self.model.embed(chunks))
|
||||
return embeddings
|
||||
return np.stack(embeddings)
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
|
||||
def embed_texts(self, texts: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of texts
|
||||
|
||||
@@ -68,10 +68,10 @@ class FastEmbed(BaseEmbedder):
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings
|
||||
Array of embeddings
|
||||
"""
|
||||
embeddings = list(self.model.embed(texts))
|
||||
return embeddings
|
||||
return np.stack(embeddings)
|
||||
|
||||
def embed_text(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
|
||||
@@ -154,9 +154,15 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
filtered_ids.append(doc_id)
|
||||
|
||||
# If we have no metadata at all, set it to None
|
||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
||||
)
|
||||
final_metadata: Optional[List[Dict[str, Union[str, int, float, bool]]]] = None
|
||||
if not all(m is None for m in filtered_metadata):
|
||||
final_metadata = []
|
||||
for m in filtered_metadata:
|
||||
if m is not None:
|
||||
filtered_m = {k: v for k, v in m.items() if isinstance(v, (str, int, float, bool))}
|
||||
final_metadata.append(filtered_m)
|
||||
else:
|
||||
final_metadata.append({"empty": True})
|
||||
|
||||
self.collection.upsert(
|
||||
documents=filtered_docs,
|
||||
|
||||
@@ -6,12 +6,17 @@ import shutil
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api import ClientAPI, Collection
|
||||
from chromadb.api.types import Documents, Embeddings, Metadatas
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.utilities.exceptions.embedding_exceptions import (
|
||||
EmbeddingConfigurationError,
|
||||
EmbeddingInitializationError
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -32,15 +37,24 @@ def suppress_logging(
|
||||
|
||||
|
||||
class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""RAG-based Storage implementation using ChromaDB for vector storage and retrieval.
|
||||
|
||||
This class extends BaseRAGStorage to handle embeddings for memory entries,
|
||||
improving search efficiency through vector similarity.
|
||||
|
||||
Attributes:
|
||||
app: ChromaDB client instance
|
||||
collection: ChromaDB collection for storing embeddings
|
||||
type: Type of memory storage
|
||||
allow_reset: Whether memory reset is allowed
|
||||
path: Custom storage path for the database
|
||||
"""
|
||||
|
||||
app: ClientAPI | None = None
|
||||
collection: Any = None
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
self, type: str, allow_reset: bool = True, embedder_config: Dict[str, Any] | None = None, crew: Any = None, path: str | None = None
|
||||
):
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
@@ -50,7 +64,6 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self.path = path
|
||||
self._initialize_app()
|
||||
@@ -59,26 +72,36 @@ class RAGStorage(BaseRAGStorage):
|
||||
configurator = EmbeddingConfigurator()
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
def _initialize_app(self) -> None:
|
||||
"""Initialize the ChromaDB client and collection.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If ChromaDB client initialization fails
|
||||
EmbeddingConfigurationError: If embedding configuration is invalid
|
||||
EmbeddingInitializationError: If embedding function fails to initialize
|
||||
"""
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
self._set_embedder_config()
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=self.path if self.path else self.storage_file_name,
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
try:
|
||||
self.collection = self.app.get_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
except Exception:
|
||||
self.collection = self.app.create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
self.app = chromadb.PersistentClient(
|
||||
path=self.path if self.path else self.storage_file_name,
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
if not self.app:
|
||||
raise RuntimeError("Failed to initialize ChromaDB client")
|
||||
|
||||
try:
|
||||
self.collection = self.app.get_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
except Exception:
|
||||
self.collection = self.app.create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize ChromaDB: {str(e)}")
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
@@ -101,12 +124,21 @@ class RAGStorage(BaseRAGStorage):
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
"""Save a value with metadata to the memory storage.
|
||||
|
||||
Args:
|
||||
value: The text content to store
|
||||
metadata: Additional metadata for the stored content
|
||||
|
||||
Raises:
|
||||
EmbeddingInitializationError: If embedding generation fails
|
||||
"""
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
try:
|
||||
self._generate_embedding(value, metadata)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during {self.type} save: {str(e)}")
|
||||
raise EmbeddingInitializationError(self.type, str(e))
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -114,7 +146,18 @@ class RAGStorage(BaseRAGStorage):
|
||||
limit: int = 3,
|
||||
filter: Optional[dict] = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> List[Any]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for similar content in memory.
|
||||
|
||||
Args:
|
||||
query: The search query text
|
||||
limit: Maximum number of results to return
|
||||
filter: Optional filter criteria
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
List of matching results with metadata and scores
|
||||
"""
|
||||
if not hasattr(self, "app"):
|
||||
self._initialize_app()
|
||||
|
||||
@@ -138,37 +181,50 @@ class RAGStorage(BaseRAGStorage):
|
||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||
return []
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
|
||||
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Generate and store embeddings for the given text.
|
||||
|
||||
Args:
|
||||
text: The text to generate embeddings for
|
||||
metadata: Optional additional metadata to store with the embeddings
|
||||
|
||||
Returns:
|
||||
Any: The generated embedding or None if only storing
|
||||
"""
|
||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||
self._initialize_app()
|
||||
|
||||
self.collection.add(
|
||||
documents=[text],
|
||||
metadatas=[metadata or {}],
|
||||
ids=[str(uuid.uuid4())],
|
||||
)
|
||||
try:
|
||||
self.collection.add(
|
||||
documents=[text],
|
||||
metadatas=[metadata or {}],
|
||||
ids=[str(uuid.uuid4())],
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
raise EmbeddingInitializationError(self.type, f"Failed to generate embedding: {str(e)}")
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the memory storage by clearing the database and removing files.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If memory reset fails and allow_reset is False
|
||||
EmbeddingConfigurationError: If embedding configuration is invalid during reinitialization
|
||||
"""
|
||||
try:
|
||||
if self.app:
|
||||
self.app.reset()
|
||||
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
||||
storage_path = self.path if self.path else db_storage_path()
|
||||
db_dir = os.path.join(storage_path, self.type)
|
||||
if os.path.exists(db_dir):
|
||||
shutil.rmtree(db_dir)
|
||||
self.app = None
|
||||
self.collection = None
|
||||
except Exception as e:
|
||||
if "attempt to write a readonly database" in str(e):
|
||||
# Ignore this specific error
|
||||
# Ignore this specific error as it's expected in some environments
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
|
||||
def _create_default_embedding_function(self):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
if not self.allow_reset:
|
||||
raise RuntimeError(f"Failed to reset {self.type} memory: {str(e)}")
|
||||
logging.error(f"Error during {self.type} memory reset: {str(e)}")
|
||||
|
||||
@@ -9,11 +9,13 @@ from copy import copy
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
@@ -109,7 +111,7 @@ class Task(BaseModel):
|
||||
description="Task output, it's final result after being executed", default=None
|
||||
)
|
||||
tools: Optional[List[BaseTool]] = Field(
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
description="Tools the agent is limited to use for this task.",
|
||||
)
|
||||
id: UUID4 = Field(
|
||||
@@ -125,7 +127,7 @@ class Task(BaseModel):
|
||||
description="A converter class used to export structured output",
|
||||
default=None,
|
||||
)
|
||||
processed_by_agents: Set[str] = Field(default_factory=set)
|
||||
processed_by_agents: Set[str] = Field(default_factory=lambda: set())
|
||||
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
|
||||
default=None,
|
||||
description="Function to validate task output before proceeding to next task",
|
||||
@@ -606,37 +608,56 @@ class Task(BaseModel):
|
||||
self.delegations += 1
|
||||
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
include: Optional[AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any]] = None,
|
||||
exclude: Optional[AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any]] = None,
|
||||
update: Optional[Dict[str, Any]] = None,
|
||||
deep: bool = False,
|
||||
) -> "Task":
|
||||
"""Create a copy of the Task."""
|
||||
exclude_set = {"id", "agent", "context", "tools"}
|
||||
if exclude:
|
||||
if isinstance(exclude, (AbstractSet, set)):
|
||||
exclude_set.update(str(x) for x in exclude)
|
||||
elif isinstance(exclude, Mapping):
|
||||
exclude_set.update(str(x) for x in exclude.keys())
|
||||
|
||||
copied_task = super().copy(
|
||||
include=include,
|
||||
exclude=exclude_set,
|
||||
update=update,
|
||||
deep=deep,
|
||||
)
|
||||
|
||||
copied_task.id = uuid.uuid4()
|
||||
copied_task.agent = None
|
||||
copied_task.context = None
|
||||
copied_task.tools = []
|
||||
|
||||
return copied_task
|
||||
|
||||
def copy_with_agents(
|
||||
self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"]
|
||||
) -> "Task":
|
||||
"""Create a deep copy of the Task."""
|
||||
exclude = {
|
||||
"id",
|
||||
"agent",
|
||||
"context",
|
||||
"tools",
|
||||
}
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
|
||||
cloned_context = (
|
||||
[task_mapping[context_task.key] for context_task in self.context]
|
||||
if self.context
|
||||
else None
|
||||
)
|
||||
"""Create a copy of the Task with agent references."""
|
||||
copied_task = self.copy()
|
||||
|
||||
def get_agent_by_role(role: str) -> Union["BaseAgent", None]:
|
||||
return next((agent for agent in agents if agent.role == role), None)
|
||||
|
||||
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
|
||||
cloned_tools = copy(self.tools) if self.tools else []
|
||||
if self.agent:
|
||||
copied_task.agent = get_agent_by_role(self.agent.role)
|
||||
|
||||
copied_task = Task(
|
||||
**copied_data,
|
||||
context=cloned_context,
|
||||
agent=cloned_agent,
|
||||
tools=cloned_tools,
|
||||
)
|
||||
if self.context:
|
||||
copied_task.context = [
|
||||
task_mapping[context_task.key]
|
||||
for context_task in self.context
|
||||
if context_task.key in task_mapping
|
||||
]
|
||||
|
||||
if self.tools:
|
||||
copied_task.tools = copy(self.tools)
|
||||
|
||||
return copied_task
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic.main import IncEx
|
||||
from typing_extensions import Literal
|
||||
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
|
||||
@@ -34,8 +36,8 @@ class TaskOutput(BaseModel):
|
||||
self.summary = f"{excerpt}..."
|
||||
return self
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def model_json(self) -> str:
|
||||
"""Get the JSON representation of the output."""
|
||||
if self.output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"""
|
||||
@@ -44,8 +46,37 @@ class TaskOutput(BaseModel):
|
||||
please make sure to set the output_json property for the task
|
||||
"""
|
||||
)
|
||||
return json.dumps(self.json_dict) if self.json_dict else "{}"
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
def model_dump_json(
|
||||
self,
|
||||
*,
|
||||
indent: Optional[int] = None,
|
||||
include: Optional[IncEx] = None,
|
||||
exclude: Optional[IncEx] = None,
|
||||
context: Optional[Any] = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal["none", "warn", "error"] = False,
|
||||
serialize_as_any: bool = False,
|
||||
) -> str:
|
||||
"""Override model_dump_json to handle custom JSON output."""
|
||||
return super().model_dump_json(
|
||||
indent=indent,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
context=context,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
serialize_as_any=serialize_as_any,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
|
||||
@@ -82,12 +82,12 @@ class BaseAgentTool(BaseTool):
|
||||
available_agents = [agent.role for agent in self.agents]
|
||||
logger.debug(f"Available agents: {available_agents}")
|
||||
|
||||
agent = [ # type: ignore # Incompatible types in assignment (expression has type "list[BaseAgent]", variable has type "str | None")
|
||||
matching_agents = [
|
||||
available_agent
|
||||
for available_agent in self.agents
|
||||
if self.sanitize_agent_name(available_agent.role) == sanitized_name
|
||||
]
|
||||
logger.debug(f"Found {len(agent)} matching agents for role '{sanitized_name}'")
|
||||
logger.debug(f"Found {len(matching_agents)} matching agents for role '{sanitized_name}'")
|
||||
except (AttributeError, ValueError) as e:
|
||||
# Handle specific exceptions that might occur during role name processing
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
@@ -97,7 +97,7 @@ class BaseAgentTool(BaseTool):
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if not agent:
|
||||
if not matching_agents:
|
||||
# No matching agent found after sanitization
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
@@ -106,19 +106,19 @@ class BaseAgentTool(BaseTool):
|
||||
error=f"No agent found with role '{sanitized_name}'"
|
||||
)
|
||||
|
||||
agent = agent[0]
|
||||
selected_agent = matching_agents[0]
|
||||
try:
|
||||
task_with_assigned_agent = Task(
|
||||
description=task,
|
||||
agent=agent,
|
||||
expected_output=agent.i18n.slice("manager_request"),
|
||||
i18n=agent.i18n,
|
||||
agent=selected_agent,
|
||||
expected_output=selected_agent.i18n.slice("manager_request"),
|
||||
i18n=selected_agent.i18n,
|
||||
)
|
||||
logger.debug(f"Created task for agent '{self.sanitize_agent_name(agent.role)}': {task}")
|
||||
return agent.execute_task(task_with_assigned_agent, context)
|
||||
logger.debug(f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}")
|
||||
return selected_agent.execute_task(task_with_assigned_agent, context)
|
||||
except Exception as e:
|
||||
# Handle task creation or execution errors
|
||||
return self.i18n.errors("agent_tool_execution_error").format(
|
||||
agent_role=self.sanitize_agent_name(agent.role),
|
||||
agent_role=self.sanitize_agent_name(selected_agent.role),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@@ -1,40 +1,36 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Type, get_args, get_origin
|
||||
from typing import Any, Callable, Dict, Optional, Type, Tuple, get_args, get_origin
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PydanticDeprecatedSince20,
|
||||
create_model,
|
||||
validator,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, validator
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
# Ignore all "PydanticDeprecatedSince20" warnings globally
|
||||
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
||||
|
||||
def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str, Any]:
|
||||
"""Helper function to create model fields with proper type hints."""
|
||||
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
"""Base class for all tools."""
|
||||
|
||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||
pass
|
||||
|
||||
model_config = ConfigDict()
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
func: Optional[Callable] = None
|
||||
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
args_schema: Type[PydanticBaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
|
||||
args_schema: Type[PydanticBaseModel] = Field(default=_ArgsSchemaPlaceholder)
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
description_updated: bool = False
|
||||
"""Flag to check if the description has been updated."""
|
||||
cache_function: Callable = lambda _args=None, _result=None: True
|
||||
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
|
||||
"""Function that will be used to determine if the tool should be cached."""
|
||||
result_as_answer: bool = False
|
||||
"""Flag to check if the tool should be the final agent answer."""
|
||||
|
||||
@@ -57,7 +53,6 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._generate_description()
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
def run(
|
||||
@@ -87,50 +82,7 @@ class BaseTool(BaseModel, ABC):
|
||||
result_as_answer=self.result_as_answer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, tool: Any) -> "BaseTool":
|
||||
"""Create a Tool instance from a CrewStructuredTool.
|
||||
|
||||
This method takes a CrewStructuredTool object and converts it into a
|
||||
Tool instance. It ensures that the provided tool has a callable 'func'
|
||||
attribute and infers the argument schema if not explicitly provided.
|
||||
"""
|
||||
if not hasattr(tool, "func") or not callable(tool.func):
|
||||
raise ValueError("The provided tool must have a callable 'func' attribute.")
|
||||
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
|
||||
if args_schema is None:
|
||||
# Infer args_schema from the function signature if not provided
|
||||
func_signature = signature(tool.func)
|
||||
annotations = func_signature.parameters
|
||||
args_fields = {}
|
||||
for name, param in annotations.items():
|
||||
if name != "self":
|
||||
param_annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
field_info = Field(
|
||||
default=...,
|
||||
description="",
|
||||
)
|
||||
args_fields[name] = (param_annotation, field_info)
|
||||
if args_fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
||||
else:
|
||||
# Create a default schema with no fields if no parameters are found
|
||||
args_schema = create_model(
|
||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=getattr(tool, "name", "Unnamed Tool"),
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
)
|
||||
|
||||
def _set_args_schema(self):
|
||||
def _set_args_schema(self) -> None:
|
||||
if self.args_schema is None:
|
||||
class_name = f"{self.__class__.__name__}Schema"
|
||||
self.args_schema = type(
|
||||
@@ -145,7 +97,7 @@ class BaseTool(BaseModel, ABC):
|
||||
},
|
||||
)
|
||||
|
||||
def _generate_description(self):
|
||||
def _generate_description(self) -> None:
|
||||
args_schema = {
|
||||
name: {
|
||||
"description": field.description,
|
||||
@@ -179,79 +131,25 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""The function that will be executed when the tool is called."""
|
||||
"""Tool class that wraps a function."""
|
||||
|
||||
func: Callable
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if "func" not in kwargs:
|
||||
raise ValueError("Tool requires a 'func' argument")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, tool: Any) -> "Tool":
|
||||
"""Create a Tool instance from a CrewStructuredTool.
|
||||
|
||||
This method takes a CrewStructuredTool object and converts it into a
|
||||
Tool instance. It ensures that the provided tool has a callable 'func'
|
||||
attribute and infers the argument schema if not explicitly provided.
|
||||
|
||||
Args:
|
||||
tool (Any): The CrewStructuredTool object to be converted.
|
||||
|
||||
Returns:
|
||||
Tool: A new Tool instance created from the provided CrewStructuredTool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided tool does not have a callable 'func' attribute.
|
||||
"""
|
||||
if not hasattr(tool, "func") or not callable(tool.func):
|
||||
raise ValueError("The provided tool must have a callable 'func' attribute.")
|
||||
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
|
||||
if args_schema is None:
|
||||
# Infer args_schema from the function signature if not provided
|
||||
func_signature = signature(tool.func)
|
||||
annotations = func_signature.parameters
|
||||
args_fields = {}
|
||||
for name, param in annotations.items():
|
||||
if name != "self":
|
||||
param_annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
field_info = Field(
|
||||
default=...,
|
||||
description="",
|
||||
)
|
||||
args_fields[name] = (param_annotation, field_info)
|
||||
if args_fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
||||
else:
|
||||
# Create a default schema with no fields if no parameters are found
|
||||
args_schema = create_model(
|
||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=getattr(tool, "name", "Unnamed Tool"),
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
)
|
||||
|
||||
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | CrewStructuredTool],
|
||||
) -> list[CrewStructuredTool]:
|
||||
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
def tool(*args):
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
"""
|
||||
def tool(*args: Any) -> Any:
|
||||
"""Decorator to create a tool from a function."""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(f: Callable) -> BaseTool:
|
||||
def _make_tool(f: Callable) -> Tool:
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
if f.__annotations__ is None:
|
||||
|
||||
@@ -2,9 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any, Callable, Optional, Union, get_type_hints
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str, Any]:
|
||||
"""Helper function to create model fields with proper type hints."""
|
||||
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
@@ -142,7 +147,8 @@ class CrewStructuredTool:
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields)
|
||||
model_fields = _create_model_fields(fields)
|
||||
return create_model(schema_name, __base__=BaseModel, **model_fields)
|
||||
|
||||
def _validate_function_signature(self) -> None:
|
||||
"""Validate that the function signature matches the args schema."""
|
||||
|
||||
@@ -4,3 +4,7 @@ DEFAULT_SCORE_THRESHOLD = 0.35
|
||||
KNOWLEDGE_DIRECTORY = "knowledge"
|
||||
MAX_LLM_RETRY = 3
|
||||
MAX_FILE_NAME_LENGTH = 255
|
||||
|
||||
# Default embedding configuration
|
||||
DEFAULT_EMBEDDING_PROVIDER = "openai"
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
|
||||
from crewai.utilities.exceptions.embedding_exceptions import (
|
||||
EmbeddingConfigurationError,
|
||||
EmbeddingProviderError,
|
||||
EmbeddingInitializationError
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingConfigurator:
|
||||
def __init__(self):
|
||||
@@ -14,11 +20,9 @@ class EmbeddingConfigurator:
|
||||
"vertexai": self._configure_vertexai,
|
||||
"google": self._configure_google,
|
||||
"cohere": self._configure_cohere,
|
||||
"voyageai": self._configure_voyageai,
|
||||
"bedrock": self._configure_bedrock,
|
||||
"huggingface": self._configure_huggingface,
|
||||
"watson": self._configure_watson,
|
||||
"custom": self._configure_custom,
|
||||
}
|
||||
|
||||
def configure_embedder(
|
||||
@@ -31,156 +35,119 @@ class EmbeddingConfigurator:
|
||||
|
||||
provider = embedder_config.get("provider")
|
||||
config = embedder_config.get("config", {})
|
||||
model_name = config.get("model") if provider != "custom" else None
|
||||
model_name = config.get("model")
|
||||
|
||||
if provider not in self.embedding_functions:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||
)
|
||||
if isinstance(provider, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(provider)
|
||||
return provider
|
||||
except Exception as e:
|
||||
raise EmbeddingConfigurationError(f"Invalid custom embedding function: {str(e)}")
|
||||
|
||||
embedding_function = self.embedding_functions[provider]
|
||||
return (
|
||||
embedding_function(config)
|
||||
if provider == "custom"
|
||||
else embedding_function(config, model_name)
|
||||
)
|
||||
if not provider or provider not in self.embedding_functions:
|
||||
raise EmbeddingProviderError(str(provider), list(self.embedding_functions.keys()))
|
||||
|
||||
try:
|
||||
return self.embedding_functions[str(provider)](config, model_name)
|
||||
except Exception as e:
|
||||
raise EmbeddingInitializationError(str(provider), str(e))
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
def _create_default_embedding_function() -> EmbeddingFunction:
|
||||
from crewai.utilities.constants import DEFAULT_EMBEDDING_PROVIDER, DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", DEFAULT_EMBEDDING_PROVIDER)
|
||||
model = os.getenv("CREWAI_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
if provider == "openai":
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise EmbeddingConfigurationError("OpenAI API key is required but not provided")
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
||||
return OpenAIEmbeddingFunction(api_key=api_key, model_name=model)
|
||||
elif provider == "ollama":
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
||||
url = os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings")
|
||||
return OllamaEmbeddingFunction(url=url, model_name=model)
|
||||
else:
|
||||
raise EmbeddingProviderError(provider, ["openai", "ollama"])
|
||||
|
||||
@staticmethod
|
||||
def _configure_openai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
def _configure_openai(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model_name,
|
||||
api_base=config.get("api_base", None),
|
||||
api_type=config.get("api_type", None),
|
||||
api_version=config.get("api_version", None),
|
||||
default_headers=config.get("default_headers", None),
|
||||
dimensions=config.get("dimensions", None),
|
||||
deployment_id=config.get("deployment_id", None),
|
||||
organization_id=config.get("organization_id", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_azure(config, model_name):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
def _configure_azure(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key"),
|
||||
api_base=config.get("api_base"),
|
||||
api_type=config.get("api_type", "azure"),
|
||||
api_version=config.get("api_version"),
|
||||
model_name=model_name,
|
||||
default_headers=config.get("default_headers"),
|
||||
dimensions=config.get("dimensions"),
|
||||
deployment_id=config.get("deployment_id"),
|
||||
organization_id=config.get("organization_id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_ollama(config, model_name):
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
|
||||
def _configure_ollama(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
||||
return OllamaEmbeddingFunction(
|
||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_vertexai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
def _configure_vertexai(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import GoogleVertexEmbeddingFunction
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_google(config, model_name):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
|
||||
def _configure_google(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import GoogleGenerativeAiEmbeddingFunction
|
||||
return GoogleGenerativeAiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
task_type=config.get("task_type"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_cohere(config, model_name):
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
|
||||
def _configure_cohere(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction
|
||||
return CohereEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_voyageai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return VoyageAIEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
def _configure_bedrock(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction
|
||||
return AmazonBedrockEmbeddingFunction(
|
||||
session=config.get("session"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_bedrock(config, model_name):
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
|
||||
# Allow custom model_name override with backwards compatibility
|
||||
kwargs = {"session": config.get("session")}
|
||||
if model_name is not None:
|
||||
kwargs["model_name"] = model_name
|
||||
return AmazonBedrockEmbeddingFunction(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _configure_huggingface(config, model_name):
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
|
||||
def _configure_huggingface(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import HuggingFaceEmbeddingServer
|
||||
return HuggingFaceEmbeddingServer(
|
||||
url=config.get("api_url"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_watson(config, model_name):
|
||||
def _configure_watson(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models
|
||||
from ibm_watsonx_ai import Credentials
|
||||
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
) from e
|
||||
raise EmbeddingConfigurationError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding.",
|
||||
provider="watson"
|
||||
)
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
@@ -205,32 +172,6 @@ class EmbeddingConfigurator:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print("Error during Watson embedding:", e)
|
||||
raise e
|
||||
raise EmbeddingInitializationError("watson", str(e))
|
||||
|
||||
return WatsonEmbeddingFunction()
|
||||
|
||||
@staticmethod
|
||||
def _configure_custom(config):
|
||||
custom_embedder = config.get("embedder")
|
||||
if isinstance(custom_embedder, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(custom_embedder)
|
||||
return custom_embedder
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
elif callable(custom_embedder):
|
||||
try:
|
||||
instance = custom_embedder()
|
||||
if isinstance(instance, EmbeddingFunction):
|
||||
validate_embedding_function(instance)
|
||||
return instance
|
||||
raise ValueError(
|
||||
"Custom embedder does not create an EmbeddingFunction instance"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||
)
|
||||
|
||||
20
src/crewai/utilities/exceptions/embedding_exceptions.py
Normal file
20
src/crewai/utilities/exceptions/embedding_exceptions.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingConfigurationError(Exception):
|
||||
def __init__(self, message: str, provider: Optional[str] = None):
|
||||
self.message = message
|
||||
self.provider = provider
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class EmbeddingProviderError(EmbeddingConfigurationError):
|
||||
def __init__(self, provider: str, supported_providers: List[str]):
|
||||
message = f"Unsupported embedding provider: {provider}, supported providers: {supported_providers}"
|
||||
super().__init__(message, provider)
|
||||
|
||||
|
||||
class EmbeddingInitializationError(EmbeddingConfigurationError):
|
||||
def __init__(self, provider: str, error: str):
|
||||
message = f"Failed to initialize embedding function for provider {provider}: {error}"
|
||||
super().__init__(message, provider)
|
||||
@@ -1,37 +1,30 @@
|
||||
# conftest.py
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_result = load_dotenv(override=True)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_environment():
|
||||
"""Set up test environment with a temporary directory for SQLite storage."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create the directory with proper permissions
|
||||
storage_dir = Path(temp_dir) / "crewai_test_storage"
|
||||
storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Validate that the directory was created successfully
|
||||
if not storage_dir.exists() or not storage_dir.is_dir():
|
||||
raise RuntimeError(f"Failed to create test storage directory: {storage_dir}")
|
||||
|
||||
# Verify directory permissions
|
||||
try:
|
||||
# Try to create a test file to verify write permissions
|
||||
test_file = storage_dir / ".permissions_test"
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(f"Test storage directory {storage_dir} is not writable: {e}")
|
||||
|
||||
# Set environment variable to point to the test storage directory
|
||||
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup is handled automatically when tempfile context exits
|
||||
def setup_test_env():
|
||||
"""Configure test environment to use Ollama as the default embedding provider."""
|
||||
# Store original environment variables
|
||||
original_env = {
|
||||
"CREWAI_EMBEDDING_PROVIDER": os.environ.get("CREWAI_EMBEDDING_PROVIDER"),
|
||||
"CREWAI_EMBEDDING_MODEL": os.environ.get("CREWAI_EMBEDDING_MODEL"),
|
||||
"CREWAI_OLLAMA_URL": os.environ.get("CREWAI_OLLAMA_URL"),
|
||||
}
|
||||
|
||||
# Set test environment
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2"
|
||||
os.environ["CREWAI_OLLAMA_URL"] = "http://localhost:11434/api/embeddings"
|
||||
|
||||
yield
|
||||
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
91
tests/memory/test_memory_reset.py
Normal file
91
tests/memory/test_memory_reset.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from crewai.memory import ShortTermMemory, LongTermMemory, EntityMemory
|
||||
from crewai.utilities.exceptions.embedding_exceptions import (
|
||||
EmbeddingConfigurationError,
|
||||
EmbeddingProviderError
|
||||
)
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield tmpdir
|
||||
|
||||
def test_memory_reset_with_ollama(temp_db_dir):
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2"
|
||||
|
||||
memories = [
|
||||
ShortTermMemory(path=temp_db_dir),
|
||||
LongTermMemory(path=temp_db_dir),
|
||||
EntityMemory(path=temp_db_dir)
|
||||
]
|
||||
for memory in memories:
|
||||
memory.reset()
|
||||
|
||||
def test_memory_reset_with_openai(temp_db_dir):
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai"
|
||||
os.environ["CREWAI_EMBEDDING_MODEL"] = "text-embedding-3-small"
|
||||
|
||||
memories = [
|
||||
ShortTermMemory(path=temp_db_dir),
|
||||
LongTermMemory(path=temp_db_dir),
|
||||
EntityMemory(path=temp_db_dir)
|
||||
]
|
||||
for memory in memories:
|
||||
memory.reset()
|
||||
|
||||
def test_memory_reset_with_invalid_provider(temp_db_dir):
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider"
|
||||
with pytest.raises(EmbeddingProviderError):
|
||||
memories = [
|
||||
ShortTermMemory(path=temp_db_dir),
|
||||
LongTermMemory(path=temp_db_dir),
|
||||
EntityMemory(path=temp_db_dir)
|
||||
]
|
||||
for memory in memories:
|
||||
memory.reset()
|
||||
|
||||
def test_memory_reset_with_invalid_configuration(temp_db_dir):
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai"
|
||||
os.environ.pop("OPENAI_API_KEY", None)
|
||||
|
||||
with pytest.raises(EmbeddingConfigurationError):
|
||||
memories = [
|
||||
ShortTermMemory(path=temp_db_dir),
|
||||
LongTermMemory(path=temp_db_dir),
|
||||
EntityMemory(path=temp_db_dir)
|
||||
]
|
||||
for memory in memories:
|
||||
memory.reset()
|
||||
|
||||
def test_memory_reset_with_missing_ollama_url(temp_db_dir):
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||
os.environ.pop("CREWAI_OLLAMA_URL", None)
|
||||
# Should use default URL when CREWAI_OLLAMA_URL is not set
|
||||
memories = [
|
||||
ShortTermMemory(path=temp_db_dir),
|
||||
LongTermMemory(path=temp_db_dir),
|
||||
EntityMemory(path=temp_db_dir)
|
||||
]
|
||||
for memory in memories:
|
||||
memory.reset()
|
||||
|
||||
def test_memory_reset_with_custom_path(temp_db_dir):
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||
custom_path = os.path.join(temp_db_dir, "custom")
|
||||
os.makedirs(custom_path, exist_ok=True)
|
||||
|
||||
memories = [
|
||||
ShortTermMemory(path=custom_path),
|
||||
LongTermMemory(path=custom_path),
|
||||
EntityMemory(path=custom_path)
|
||||
]
|
||||
for memory in memories:
|
||||
memory.reset()
|
||||
|
||||
assert not os.path.exists(os.path.join(custom_path, "short_term"))
|
||||
assert not os.path.exists(os.path.join(custom_path, "long_term"))
|
||||
assert not os.path.exists(os.path.join(custom_path, "entity"))
|
||||
Reference in New Issue
Block a user