Compare commits

...

32 Commits

Author SHA1 Message Date
João Moura
20fc2f9878 Merge branch 'main' into devin/1738752192-fix-memory-reset-openai-dependency 2025-02-09 20:10:50 -03:00
Devin AI
c149b75874 fix: Update metadata type handling in KnowledgeStorage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:44:22 +00:00
Devin AI
86844ff3df fix: Update _generate_embedding signature to match BaseRAGStorage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:43:24 +00:00
Devin AI
b442fe20a2 fix: Add to_structured_tool method to BaseTool
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:42:17 +00:00
Devin AI
9b1b1d33ba fix: Handle exclude parameter type conversion in Task.copy
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:40:55 +00:00
Devin AI
3c350e8933 fix: Update copy method in Task to match BaseModel signature
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:39:59 +00:00
Devin AI
a3a5507f9a fix: Update default_factory in BaseAgent to use lambda functions
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:38:14 +00:00
Devin AI
a175167aaf fix: Update default_factory in Task to use lambda functions
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:38:13 +00:00
Devin AI
1dc62b0d0a fix: Update args_schema type in BaseTool
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:36:55 +00:00
Devin AI
75b376ebac fix: Use UsageMetrics as default_factory for token_usage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:35:50 +00:00
Devin AI
29106068b7 fix: Use pydantic.main.IncEx and fix default_factory types
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:34:53 +00:00
Devin AI
3bf531189f fix: Update type hints in TaskOutput and CrewOutput to match BaseModel
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:33:54 +00:00
Devin AI
47919a60a0 fix: Replace json with model_json to avoid overriding BaseModel methods
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:32:42 +00:00
Devin AI
6b9ed90510 fix: Update base_agent_tools and base_tool to fix type errors
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:30:44 +00:00
Devin AI
f6a65486f1 fix: Update json and model_dump_json signatures to match BaseModel
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:27:57 +00:00
Devin AI
bf6db93bdf fix: Add proper type hints to EmbeddingConfigurator
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:23:04 +00:00
Devin AI
25e68bc459 fix: Restore agent.py and fix merge conflicts
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:21:31 +00:00
Devin AI
6f6010db1c fix: Resolve merge conflicts properly
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:19:18 +00:00
Devin AI
a95227deef fix: Resolve merge conflicts, keeping our changes
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:18:12 +00:00
Devin AI
636dac6efb fix: Update embedding configuration and fix type errors
- Add configurable embedding providers (OpenAI, Ollama)
- Fix type hints in base_tool and structured_tool
- Add proper json property implementations
- Update documentation for memory configuration
- Add environment variables for embedding configuration
- Fix type errors in task and crew output classes

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-09 19:16:51 +00:00
João Moura
a4e2b17bae Merge branch 'main' into devin/1738752192-fix-memory-reset-openai-dependency 2025-02-09 15:43:20 -03:00
Devin AI
823f22a601 fix: Remove OpenAI dependency for memory reset when using alternative LLMs
- Add environment variables for default embedding provider
- Support Ollama as default embedding provider
- Add tests for memory reset with different providers
- Update documentation

Fixes #2023

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 11:09:36 +00:00
Devin AI
649414805d fix: Remove OpenAI dependency for memory reset when using alternative LLMs
- Add environment variables for default embedding provider
- Support Ollama as default embedding provider
- Add tests for memory reset with different providers
- Update documentation

Fixes #2023

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:56:01 +00:00
Devin AI
8017ab2dfd docs: Add memory configuration documentation
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:46:43 +00:00
Devin AI
6445cda35a fix: Use custom path for memory reset
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:46:21 +00:00
Devin AI
6116c73721 test: Use temporary directory for memory reset tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:45:53 +00:00
Devin AI
a038b751ef test: Add memory reset tests for different embedding providers
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:45:21 +00:00
Devin AI
5006161d31 fix: Add type safety to RAGStorage initialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:45:07 +00:00
Devin AI
85a13751ba fix: Add type safety to embedding configurator
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:44:59 +00:00
Devin AI
1c7c4cb828 refactor: Remove duplicate embedding function from RAGStorage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:44:40 +00:00
Devin AI
509fb375ca feat: Update embedding configurator to support configurable default providers
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:43:58 +00:00
Devin AI
d01d44b29c feat: Add default embedding provider and model constants
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-05 10:43:38 +00:00
18 changed files with 645 additions and 410 deletions

45
docs/memory.md Normal file
View File

@@ -0,0 +1,45 @@
# Memory in CrewAI
CrewAI provides a robust memory system that allows agents to retain and recall information from previous interactions.
## Configuring Embedding Providers
CrewAI supports multiple embedding providers for memory functionality:
- OpenAI (default) - Requires `OPENAI_API_KEY`
- Ollama - Requires `CREWAI_OLLAMA_URL` (defaults to "http://localhost:11434/api/embeddings")
### Environment Variables
Configure the embedding provider using these environment variables:
- `CREWAI_EMBEDDING_PROVIDER`: Provider name (default: "openai")
- `CREWAI_EMBEDDING_MODEL`: Model name (default: "text-embedding-3-small")
- `CREWAI_OLLAMA_URL`: URL for Ollama API (when using Ollama provider)
### Example Configuration
```python
# Using OpenAI (default)
os.environ["OPENAI_API_KEY"] = "your-api-key"
# Using Ollama
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2" # or any other model supported by your Ollama instance
os.environ["CREWAI_OLLAMA_URL"] = "http://localhost:11434/api/embeddings" # optional, this is the default
```
## Memory Usage
When an agent has memory enabled, it can access and store information from previous interactions:
```python
agent = Agent(
role="Researcher",
goal="Research AI topics",
backstory="You're an AI researcher",
memory=True # Enable memory for this agent
)
```
The memory system uses embeddings to store and retrieve relevant information, allowing agents to maintain context across multiple interactions and tasks.

View File

@@ -1,13 +1,14 @@
import re
import os
import shutil
import subprocess
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
@@ -16,10 +17,10 @@ from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.task import Task
from crewai.tools import BaseTool
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import Tool
from crewai.utilities import Converter, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -54,13 +55,13 @@ class Agent(BaseAgent):
llm: The language model that will run the agent.
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
max_iter: Maximum number of iterations for an agent to execute a task.
memory: Whether the agent should have memory or not.
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
verbose: Whether the agent execution should be in verbose mode.
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
tools: Tools at agents disposal
step_callback: Callback to be executed after each step of the agent execution.
knowledge_sources: Knowledge sources for the agent.
embedder: Embedder configuration for the agent.
"""
_times_executed: int = PrivateAttr(default=0)
@@ -70,6 +71,9 @@ class Agent(BaseAgent):
)
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
cache_handler: InstanceOf[CacheHandler] = Field(
default=None, description="An instance of the CacheHandler class."
)
step_callback: Optional[Any] = Field(
default=None,
description="Callback to be executed after each step of the agent execution.",
@@ -81,7 +85,7 @@ class Agent(BaseAgent):
llm: Union[str, InstanceOf[LLM], Any] = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: Optional[Any] = Field(
description="Language model that will run the agent.", default=None
)
system_template: Optional[str] = Field(
@@ -103,6 +107,10 @@ class Agent(BaseAgent):
default=True,
description="Keep messages under the context window size by summarizing content.",
)
max_iter: int = Field(
default=20,
description="Maximum number of iterations for an agent to execute a task before giving it's best answer",
)
max_retry_limit: int = Field(
default=2,
description="Maximum number of retries for an agent to execute a task when an error occurs.",
@@ -115,19 +123,105 @@ class Agent(BaseAgent):
default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
)
embedder: Optional[Dict[str, Any]] = Field(
embedder_config: Optional[Dict[str, Any]] = Field(
default=None,
description="Embedder configuration for the agent.",
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the agent.",
)
_knowledge: Optional[Knowledge] = PrivateAttr(
default=None,
)
@model_validator(mode="after")
def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role
unaccepted_attributes = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]
self.llm = create_llm(self.llm)
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
self.function_calling_llm = create_llm(self.function_calling_llm)
# Handle different cases for self.llm
if isinstance(self.llm, str):
# If it's a string, create an LLM instance
self.llm = LLM(model=self.llm)
elif isinstance(self.llm, LLM):
# If it's already an LLM instance, keep it as is
pass
elif self.llm is None:
# Determine the model name from environment variables or use default
model_name = (
os.environ.get("OPENAI_MODEL_NAME")
or os.environ.get("MODEL")
or "gpt-4o-mini"
)
llm_params = {"model": model_name}
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get(
"OPENAI_BASE_URL"
)
if api_base:
llm_params["base_url"] = api_base
set_provider = model_name.split("/")[0] if "/" in model_name else "openai"
# Iterate over all environment variables to find matching API keys or use defaults
for provider, env_vars in ENV_VARS.items():
if provider == set_provider:
for env_var in env_vars:
# Check if the environment variable is set
key_name = env_var.get("key_name")
if key_name and key_name not in unaccepted_attributes:
env_value = os.environ.get(key_name)
if env_value:
key_name = key_name.lower()
for pattern in LITELLM_PARAMS:
if pattern in key_name:
key_name = pattern
break
llm_params[key_name] = env_value
# Check for default values if the environment variable is not set
elif env_var.get("default", False):
for key, value in env_var.items():
if key not in ["prompt", "key_name", "default"]:
# Only add default if the key is already set in os.environ
if key in os.environ:
llm_params[key] = value
self.llm = LLM(**llm_params)
else:
# For any other type, attempt to extract relevant attributes
llm_params = {
"model": getattr(self.llm, "model_name", None)
or getattr(self.llm, "deployment_name", None)
or str(self.llm),
"temperature": getattr(self.llm, "temperature", None),
"max_tokens": getattr(self.llm, "max_tokens", None),
"logprobs": getattr(self.llm, "logprobs", None),
"timeout": getattr(self.llm, "timeout", None),
"max_retries": getattr(self.llm, "max_retries", None),
"api_key": getattr(self.llm, "api_key", None),
"base_url": getattr(self.llm, "base_url", None),
"organization": getattr(self.llm, "organization", None),
}
# Remove None values to avoid passing unnecessary parameters
llm_params = {k: v for k, v in llm_params.items() if v is not None}
self.llm = LLM(**llm_params)
# Similar handling for function_calling_llm
if self.function_calling_llm:
if isinstance(self.function_calling_llm, str):
self.function_calling_llm = LLM(model=self.function_calling_llm)
elif not isinstance(self.function_calling_llm, LLM):
self.function_calling_llm = LLM(
model=getattr(self.function_calling_llm, "model_name", None)
or getattr(self.function_calling_llm, "deployment_name", None)
or str(self.function_calling_llm)
)
if not self.agent_executor:
self._setup_agent_executor()
@@ -145,16 +239,23 @@ class Agent(BaseAgent):
def _set_knowledge(self):
try:
if self.knowledge_sources:
full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)")
knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}"
knowledge_agent_name = f"{self.role.replace(' ', '_')}"
if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
):
self.knowledge = Knowledge(
# Validate embedding configuration based on provider
from crewai.utilities.constants import DEFAULT_EMBEDDING_PROVIDER
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", DEFAULT_EMBEDDING_PROVIDER)
if provider == "openai" and not os.getenv("OPENAI_API_KEY"):
raise ValueError("Please provide an OpenAI API key via OPENAI_API_KEY environment variable")
elif provider == "ollama" and not os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings"):
raise ValueError("Please provide Ollama URL via CREWAI_OLLAMA_URL environment variable")
self._knowledge = Knowledge(
sources=self.knowledge_sources,
embedder=self.embedder,
embedder_config=self.embedder_config,
collection_name=knowledge_agent_name,
storage=self.knowledge_storage or None,
)
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
@@ -188,15 +289,13 @@ class Agent(BaseAgent):
if task.output_json:
# schema = json.dumps(task.output_json, indent=2)
schema = generate_model_description(task.output_json)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
elif task.output_pydantic:
schema = generate_model_description(task.output_pydantic)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
task_prompt += "\n" + self.i18n.slice("formatted_task_instructions").format(
output_format=schema
)
if context:
task_prompt = self.i18n.slice("task_with_context").format(
@@ -215,8 +314,8 @@ class Agent(BaseAgent):
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
if self.knowledge:
agent_knowledge_snippets = self.knowledge.query([task.prompt()])
if self._knowledge:
agent_knowledge_snippets = self._knowledge.query([task.prompt()])
if agent_knowledge_snippets:
agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
@@ -249,9 +348,6 @@ class Agent(BaseAgent):
}
)["output"]
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
raise e
self._times_executed += 1
if self._times_executed > self.max_retry_limit:
raise e
@@ -324,14 +420,13 @@ class Agent(BaseAgent):
tools = agent_tools.tools()
return tools
def get_multimodal_tools(self) -> Sequence[BaseTool]:
def get_multimodal_tools(self) -> List[Tool]:
from crewai.tools.agent_tools.add_image_tool import AddImageTool
return [AddImageTool()]
def get_code_execution_tools(self):
try:
from crewai_tools import CodeInterpreterTool # type: ignore
from crewai_tools import CodeInterpreterTool
# Set the unsafe_mode based on the code_execution_mode attribute
unsafe_mode = self.code_execution_mode == "unsafe"

View File

@@ -113,7 +113,7 @@ class BaseAgent(ABC, BaseModel):
description="Enable agent to delegate and ask questions among each other.",
)
tools: Optional[List[Any]] = Field(
default_factory=list, description="Tools at agents' disposal"
default_factory=lambda: [], description="Tools at agents' disposal"
)
max_iter: int = Field(
default=25, description="Maximum iterations for an agent to execute a task"

View File

@@ -1,12 +1,10 @@
import asyncio
import json
import re
import uuid
import warnings
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import (
UUID4,
@@ -18,7 +16,6 @@ from pydantic import (
field_validator,
model_validator,
)
from pydantic_core import PydanticCustomError
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent

View File

@@ -1,7 +1,9 @@
import json
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional
from pydantic import BaseModel, Field
from pydantic.main import IncEx
from typing_extensions import Literal
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
@@ -21,16 +23,45 @@ class CrewOutput(BaseModel):
tasks_output: list[TaskOutput] = Field(
description="Output of each task", default=[]
)
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
token_usage: UsageMetrics = Field(description="Processed token summary", default_factory=UsageMetrics)
@property
def json(self) -> Optional[str]:
if self.tasks_output[-1].output_format != OutputFormat.JSON:
def model_json(self) -> str:
"""Get the JSON representation of the output."""
if self.tasks_output and self.tasks_output[-1].output_format != OutputFormat.JSON:
raise ValueError(
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
)
return json.dumps(self.json_dict) if self.json_dict else "{}"
return json.dumps(self.json_dict)
def model_dump_json(
self,
*,
indent: Optional[int] = None,
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
context: Optional[Any] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = False,
serialize_as_any: bool = False,
) -> str:
"""Override model_dump_json to handle custom JSON output."""
return super().model_dump_json(
indent=indent,
include=include,
exclude=exclude,
context=context,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
serialize_as_any=serialize_as_any,
)
def to_dict(self) -> Dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary."""

View File

@@ -47,7 +47,7 @@ class FastEmbed(BaseEmbedder):
cache_dir=str(cache_dir) if cache_dir else None,
)
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
def embed_chunks(self, chunks: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of text chunks
@@ -55,12 +55,12 @@ class FastEmbed(BaseEmbedder):
chunks: List of text chunks to embed
Returns:
List of embeddings
Array of embeddings
"""
embeddings = list(self.model.embed(chunks))
return embeddings
return np.stack(embeddings)
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
def embed_texts(self, texts: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of texts
@@ -68,10 +68,10 @@ class FastEmbed(BaseEmbedder):
texts: List of texts to embed
Returns:
List of embeddings
Array of embeddings
"""
embeddings = list(self.model.embed(texts))
return embeddings
return np.stack(embeddings)
def embed_text(self, text: str) -> np.ndarray:
"""

View File

@@ -154,9 +154,15 @@ class KnowledgeStorage(BaseKnowledgeStorage):
filtered_ids.append(doc_id)
# If we have no metadata at all, set it to None
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
None if all(m is None for m in filtered_metadata) else filtered_metadata
)
final_metadata: Optional[List[Dict[str, Union[str, int, float, bool]]]] = None
if not all(m is None for m in filtered_metadata):
final_metadata = []
for m in filtered_metadata:
if m is not None:
filtered_m = {k: v for k, v in m.items() if isinstance(v, (str, int, float, bool))}
final_metadata.append(filtered_m)
else:
final_metadata.append({"empty": True})
self.collection.upsert(
documents=filtered_docs,

View File

@@ -6,12 +6,17 @@ import shutil
import uuid
from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI
from chromadb.api import ClientAPI, Collection
from chromadb.api.types import Documents, Embeddings, Metadatas
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
from crewai.utilities.exceptions.embedding_exceptions import (
EmbeddingConfigurationError,
EmbeddingInitializationError
)
@contextlib.contextmanager
@@ -32,15 +37,24 @@ def suppress_logging(
class RAGStorage(BaseRAGStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
search efficiency.
"""RAG-based Storage implementation using ChromaDB for vector storage and retrieval.
This class extends BaseRAGStorage to handle embeddings for memory entries,
improving search efficiency through vector similarity.
Attributes:
app: ChromaDB client instance
collection: ChromaDB collection for storing embeddings
type: Type of memory storage
allow_reset: Whether memory reset is allowed
path: Custom storage path for the database
"""
app: ClientAPI | None = None
collection: Any = None
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
self, type: str, allow_reset: bool = True, embedder_config: Dict[str, Any] | None = None, crew: Any = None, path: str | None = None
):
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
@@ -50,7 +64,6 @@ class RAGStorage(BaseRAGStorage):
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type
self.allow_reset = allow_reset
self.path = path
self._initialize_app()
@@ -59,26 +72,36 @@ class RAGStorage(BaseRAGStorage):
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
def _initialize_app(self) -> None:
"""Initialize the ChromaDB client and collection.
Raises:
RuntimeError: If ChromaDB client initialization fails
EmbeddingConfigurationError: If embedding configuration is invalid
EmbeddingInitializationError: If embedding function fails to initialize
"""
import chromadb
from chromadb.config import Settings
self._set_embedder_config()
chroma_client = chromadb.PersistentClient(
path=self.path if self.path else self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset),
)
self.app = chroma_client
try:
self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception:
self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config
self.app = chromadb.PersistentClient(
path=self.path if self.path else self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset),
)
if not self.app:
raise RuntimeError("Failed to initialize ChromaDB client")
try:
self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception:
self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception as e:
raise RuntimeError(f"Failed to initialize ChromaDB: {str(e)}")
def _sanitize_role(self, role: str) -> str:
"""
@@ -101,12 +124,21 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
"""Save a value with metadata to the memory storage.
Args:
value: The text content to store
metadata: Additional metadata for the stored content
Raises:
EmbeddingInitializationError: If embedding generation fails
"""
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
try:
self._generate_embedding(value, metadata)
except Exception as e:
logging.error(f"Error during {self.type} save: {str(e)}")
raise EmbeddingInitializationError(self.type, str(e))
def search(
self,
@@ -114,7 +146,18 @@ class RAGStorage(BaseRAGStorage):
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
) -> List[Dict[str, Any]]:
"""Search for similar content in memory.
Args:
query: The search query text
limit: Maximum number of results to return
filter: Optional filter criteria
score_threshold: Minimum similarity score threshold
Returns:
List of matching results with metadata and scores
"""
if not hasattr(self, "app"):
self._initialize_app()
@@ -138,37 +181,50 @@ class RAGStorage(BaseRAGStorage):
logging.error(f"Error during {self.type} search: {str(e)}")
return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
"""Generate and store embeddings for the given text.
Args:
text: The text to generate embeddings for
metadata: Optional additional metadata to store with the embeddings
Returns:
Any: The generated embedding or None if only storing
"""
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
self.collection.add(
documents=[text],
metadatas=[metadata or {}],
ids=[str(uuid.uuid4())],
)
try:
self.collection.add(
documents=[text],
metadatas=[metadata or {}],
ids=[str(uuid.uuid4())],
)
return None
except Exception as e:
raise EmbeddingInitializationError(self.type, f"Failed to generate embedding: {str(e)}")
def reset(self) -> None:
"""Reset the memory storage by clearing the database and removing files.
Raises:
RuntimeError: If memory reset fails and allow_reset is False
EmbeddingConfigurationError: If embedding configuration is invalid during reinitialization
"""
try:
if self.app:
self.app.reset()
shutil.rmtree(f"{db_storage_path()}/{self.type}")
storage_path = self.path if self.path else db_storage_path()
db_dir = os.path.join(storage_path, self.type)
if os.path.exists(db_dir):
shutil.rmtree(db_dir)
self.app = None
self.collection = None
except Exception as e:
if "attempt to write a readonly database" in str(e):
# Ignore this specific error
# Ignore this specific error as it's expected in some environments
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
if not self.allow_reset:
raise RuntimeError(f"Failed to reset {self.type} memory: {str(e)}")
logging.error(f"Error during {self.type} memory reset: {str(e)}")

View File

@@ -9,11 +9,13 @@ from copy import copy
from hashlib import md5
from pathlib import Path
from typing import (
AbstractSet,
Any,
Callable,
ClassVar,
Dict,
List,
Mapping,
Optional,
Set,
Tuple,
@@ -109,7 +111,7 @@ class Task(BaseModel):
description="Task output, it's final result after being executed", default=None
)
tools: Optional[List[BaseTool]] = Field(
default_factory=list,
default_factory=lambda: [],
description="Tools the agent is limited to use for this task.",
)
id: UUID4 = Field(
@@ -125,7 +127,7 @@ class Task(BaseModel):
description="A converter class used to export structured output",
default=None,
)
processed_by_agents: Set[str] = Field(default_factory=set)
processed_by_agents: Set[str] = Field(default_factory=lambda: set())
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
default=None,
description="Function to validate task output before proceeding to next task",
@@ -606,37 +608,56 @@ class Task(BaseModel):
self.delegations += 1
def copy(
self,
*,
include: Optional[AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any]] = None,
exclude: Optional[AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any]] = None,
update: Optional[Dict[str, Any]] = None,
deep: bool = False,
) -> "Task":
"""Create a copy of the Task."""
exclude_set = {"id", "agent", "context", "tools"}
if exclude:
if isinstance(exclude, (AbstractSet, set)):
exclude_set.update(str(x) for x in exclude)
elif isinstance(exclude, Mapping):
exclude_set.update(str(x) for x in exclude.keys())
copied_task = super().copy(
include=include,
exclude=exclude_set,
update=update,
deep=deep,
)
copied_task.id = uuid.uuid4()
copied_task.agent = None
copied_task.context = None
copied_task.tools = []
return copied_task
def copy_with_agents(
self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"]
) -> "Task":
"""Create a deep copy of the Task."""
exclude = {
"id",
"agent",
"context",
"tools",
}
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
cloned_context = (
[task_mapping[context_task.key] for context_task in self.context]
if self.context
else None
)
"""Create a copy of the Task with agent references."""
copied_task = self.copy()
def get_agent_by_role(role: str) -> Union["BaseAgent", None]:
return next((agent for agent in agents if agent.role == role), None)
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
cloned_tools = copy(self.tools) if self.tools else []
if self.agent:
copied_task.agent = get_agent_by_role(self.agent.role)
copied_task = Task(
**copied_data,
context=cloned_context,
agent=cloned_agent,
tools=cloned_tools,
)
if self.context:
copied_task.context = [
task_mapping[context_task.key]
for context_task in self.context
if context_task.key in task_mapping
]
if self.tools:
copied_task.tools = copy(self.tools)
return copied_task

View File

@@ -1,7 +1,9 @@
import json
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional
from pydantic import BaseModel, Field, model_validator
from pydantic.main import IncEx
from typing_extensions import Literal
from crewai.tasks.output_format import OutputFormat
@@ -34,8 +36,8 @@ class TaskOutput(BaseModel):
self.summary = f"{excerpt}..."
return self
@property
def json(self) -> Optional[str]:
def model_json(self) -> str:
"""Get the JSON representation of the output."""
if self.output_format != OutputFormat.JSON:
raise ValueError(
"""
@@ -44,8 +46,37 @@ class TaskOutput(BaseModel):
please make sure to set the output_json property for the task
"""
)
return json.dumps(self.json_dict) if self.json_dict else "{}"
return json.dumps(self.json_dict)
def model_dump_json(
self,
*,
indent: Optional[int] = None,
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
context: Optional[Any] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = False,
serialize_as_any: bool = False,
) -> str:
"""Override model_dump_json to handle custom JSON output."""
return super().model_dump_json(
indent=indent,
include=include,
exclude=exclude,
context=context,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
serialize_as_any=serialize_as_any,
)
def to_dict(self) -> Dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary."""

View File

@@ -82,12 +82,12 @@ class BaseAgentTool(BaseTool):
available_agents = [agent.role for agent in self.agents]
logger.debug(f"Available agents: {available_agents}")
agent = [ # type: ignore # Incompatible types in assignment (expression has type "list[BaseAgent]", variable has type "str | None")
matching_agents = [
available_agent
for available_agent in self.agents
if self.sanitize_agent_name(available_agent.role) == sanitized_name
]
logger.debug(f"Found {len(agent)} matching agents for role '{sanitized_name}'")
logger.debug(f"Found {len(matching_agents)} matching agents for role '{sanitized_name}'")
except (AttributeError, ValueError) as e:
# Handle specific exceptions that might occur during role name processing
return self.i18n.errors("agent_tool_unexisting_coworker").format(
@@ -97,7 +97,7 @@ class BaseAgentTool(BaseTool):
error=str(e)
)
if not agent:
if not matching_agents:
# No matching agent found after sanitization
return self.i18n.errors("agent_tool_unexisting_coworker").format(
coworkers="\n".join(
@@ -106,19 +106,19 @@ class BaseAgentTool(BaseTool):
error=f"No agent found with role '{sanitized_name}'"
)
agent = agent[0]
selected_agent = matching_agents[0]
try:
task_with_assigned_agent = Task(
description=task,
agent=agent,
expected_output=agent.i18n.slice("manager_request"),
i18n=agent.i18n,
agent=selected_agent,
expected_output=selected_agent.i18n.slice("manager_request"),
i18n=selected_agent.i18n,
)
logger.debug(f"Created task for agent '{self.sanitize_agent_name(agent.role)}': {task}")
return agent.execute_task(task_with_assigned_agent, context)
logger.debug(f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}")
return selected_agent.execute_task(task_with_assigned_agent, context)
except Exception as e:
# Handle task creation or execution errors
return self.i18n.errors("agent_tool_execution_error").format(
agent_role=self.sanitize_agent_name(agent.role),
agent_role=self.sanitize_agent_name(selected_agent.role),
error=str(e)
)

View File

@@ -1,40 +1,36 @@
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Type, get_args, get_origin
from typing import Any, Callable, Dict, Optional, Type, Tuple, get_args, get_origin
from pydantic import (
BaseModel,
ConfigDict,
Field,
PydanticDeprecatedSince20,
create_model,
validator,
)
from pydantic import BaseModel, ConfigDict, Field, create_model, validator
from pydantic.fields import FieldInfo
from pydantic import BaseModel as PydanticBaseModel
from crewai.tools.structured_tool import CrewStructuredTool
# Ignore all "PydanticDeprecatedSince20" warnings globally
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str, Any]:
"""Helper function to create model fields with proper type hints."""
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
class BaseTool(BaseModel, ABC):
"""Base class for all tools."""
class _ArgsSchemaPlaceholder(PydanticBaseModel):
pass
model_config = ConfigDict()
model_config = ConfigDict(arbitrary_types_allowed=True)
func: Optional[Callable] = None
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool."""
args_schema: Type[PydanticBaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
args_schema: Type[PydanticBaseModel] = Field(default=_ArgsSchemaPlaceholder)
"""The schema for the arguments that the tool accepts."""
description_updated: bool = False
"""Flag to check if the description has been updated."""
cache_function: Callable = lambda _args=None, _result=None: True
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
"""Function that will be used to determine if the tool should be cached."""
result_as_answer: bool = False
"""Flag to check if the tool should be the final agent answer."""
@@ -57,7 +53,6 @@ class BaseTool(BaseModel, ABC):
def model_post_init(self, __context: Any) -> None:
self._generate_description()
super().model_post_init(__context)
def run(
@@ -87,50 +82,7 @@ class BaseTool(BaseModel, ABC):
result_as_answer=self.result_as_answer,
)
@classmethod
def from_langchain(cls, tool: Any) -> "BaseTool":
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
Tool instance. It ensures that the provided tool has a callable 'func'
attribute and infers the argument schema if not explicitly provided.
"""
if not hasattr(tool, "func") or not callable(tool.func):
raise ValueError("The provided tool must have a callable 'func' attribute.")
args_schema = getattr(tool, "args_schema", None)
if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
return cls(
name=getattr(tool, "name", "Unnamed Tool"),
description=getattr(tool, "description", ""),
func=tool.func,
args_schema=args_schema,
)
def _set_args_schema(self):
def _set_args_schema(self) -> None:
if self.args_schema is None:
class_name = f"{self.__class__.__name__}Schema"
self.args_schema = type(
@@ -145,7 +97,7 @@ class BaseTool(BaseModel, ABC):
},
)
def _generate_description(self):
def _generate_description(self) -> None:
args_schema = {
name: {
"description": field.description,
@@ -179,79 +131,25 @@ class BaseTool(BaseModel, ABC):
class Tool(BaseTool):
"""The function that will be executed when the tool is called."""
"""Tool class that wraps a function."""
func: Callable
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **kwargs):
if "func" not in kwargs:
raise ValueError("Tool requires a 'func' argument")
super().__init__(**kwargs)
def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
@classmethod
def from_langchain(cls, tool: Any) -> "Tool":
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
Tool instance. It ensures that the provided tool has a callable 'func'
attribute and infers the argument schema if not explicitly provided.
Args:
tool (Any): The CrewStructuredTool object to be converted.
Returns:
Tool: A new Tool instance created from the provided CrewStructuredTool.
Raises:
ValueError: If the provided tool does not have a callable 'func' attribute.
"""
if not hasattr(tool, "func") or not callable(tool.func):
raise ValueError("The provided tool must have a callable 'func' attribute.")
args_schema = getattr(tool, "args_schema", None)
if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
return cls(
name=getattr(tool, "name", "Unnamed Tool"),
description=getattr(tool, "description", ""),
func=tool.func,
args_schema=args_schema,
)
def to_langchain(
tools: list[BaseTool | CrewStructuredTool],
) -> list[CrewStructuredTool]:
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
def tool(*args):
"""
Decorator to create a tool from a function.
"""
def tool(*args: Any) -> Any:
"""Decorator to create a tool from a function."""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(f: Callable) -> BaseTool:
def _make_tool(f: Callable) -> Tool:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
if f.__annotations__ is None:

View File

@@ -2,9 +2,14 @@ from __future__ import annotations
import inspect
import textwrap
from typing import Any, Callable, Optional, Union, get_type_hints
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_type_hints
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic.fields import FieldInfo
def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str, Any]:
"""Helper function to create model fields with proper type hints."""
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
from crewai.utilities.logger import Logger
@@ -142,7 +147,8 @@ class CrewStructuredTool:
# Create model
schema_name = f"{name.title()}Schema"
return create_model(schema_name, **fields)
model_fields = _create_model_fields(fields)
return create_model(schema_name, __base__=BaseModel, **model_fields)
def _validate_function_signature(self) -> None:
"""Validate that the function signature matches the args schema."""

View File

@@ -4,3 +4,7 @@ DEFAULT_SCORE_THRESHOLD = 0.35
KNOWLEDGE_DIRECTORY = "knowledge"
MAX_LLM_RETRY = 3
MAX_FILE_NAME_LENGTH = 255
# Default embedding configuration
DEFAULT_EMBEDDING_PROVIDER = "openai"
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"

View File

@@ -1,9 +1,15 @@
import os
from typing import Any, Dict, Optional, cast
from typing import Any, Dict, List, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
from crewai.utilities.exceptions.embedding_exceptions import (
EmbeddingConfigurationError,
EmbeddingProviderError,
EmbeddingInitializationError
)
class EmbeddingConfigurator:
def __init__(self):
@@ -14,11 +20,9 @@ class EmbeddingConfigurator:
"vertexai": self._configure_vertexai,
"google": self._configure_google,
"cohere": self._configure_cohere,
"voyageai": self._configure_voyageai,
"bedrock": self._configure_bedrock,
"huggingface": self._configure_huggingface,
"watson": self._configure_watson,
"custom": self._configure_custom,
}
def configure_embedder(
@@ -31,156 +35,119 @@ class EmbeddingConfigurator:
provider = embedder_config.get("provider")
config = embedder_config.get("config", {})
model_name = config.get("model") if provider != "custom" else None
model_name = config.get("model")
if provider not in self.embedding_functions:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
)
if isinstance(provider, EmbeddingFunction):
try:
validate_embedding_function(provider)
return provider
except Exception as e:
raise EmbeddingConfigurationError(f"Invalid custom embedding function: {str(e)}")
embedding_function = self.embedding_functions[provider]
return (
embedding_function(config)
if provider == "custom"
else embedding_function(config, model_name)
)
if not provider or provider not in self.embedding_functions:
raise EmbeddingProviderError(str(provider), list(self.embedding_functions.keys()))
try:
return self.embedding_functions[str(provider)](config, model_name)
except Exception as e:
raise EmbeddingInitializationError(str(provider), str(e))
@staticmethod
def _create_default_embedding_function():
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def _create_default_embedding_function() -> EmbeddingFunction:
from crewai.utilities.constants import DEFAULT_EMBEDDING_PROVIDER, DEFAULT_EMBEDDING_MODEL
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", DEFAULT_EMBEDDING_PROVIDER)
model = os.getenv("CREWAI_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
if provider == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise EmbeddingConfigurationError("OpenAI API key is required but not provided")
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
return OpenAIEmbeddingFunction(api_key=api_key, model_name=model)
elif provider == "ollama":
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
url = os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings")
return OllamaEmbeddingFunction(url=url, model_name=model)
else:
raise EmbeddingProviderError(provider, ["openai", "ollama"])
@staticmethod
def _configure_openai(config, model_name):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
def _configure_openai(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
api_base=config.get("api_base", None),
api_type=config.get("api_type", None),
api_version=config.get("api_version", None),
default_headers=config.get("default_headers", None),
dimensions=config.get("dimensions", None),
deployment_id=config.get("deployment_id", None),
organization_id=config.get("organization_id", None),
)
@staticmethod
def _configure_azure(config, model_name):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
def _configure_azure(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
return OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
default_headers=config.get("default_headers"),
dimensions=config.get("dimensions"),
deployment_id=config.get("deployment_id"),
organization_id=config.get("organization_id"),
)
@staticmethod
def _configure_ollama(config, model_name):
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
def _configure_ollama(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
return OllamaEmbeddingFunction(
url=config.get("url", "http://localhost:11434/api/embeddings"),
model_name=model_name,
)
@staticmethod
def _configure_vertexai(config, model_name):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
def _configure_vertexai(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
from chromadb.utils.embedding_functions.google_embedding_function import GoogleVertexEmbeddingFunction
return GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
)
@staticmethod
def _configure_google(config, model_name):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
def _configure_google(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
from chromadb.utils.embedding_functions.google_embedding_function import GoogleGenerativeAiEmbeddingFunction
return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
task_type=config.get("task_type"),
)
@staticmethod
def _configure_cohere(config, model_name):
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
def _configure_cohere(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
from chromadb.utils.embedding_functions.cohere_embedding_function import CohereEmbeddingFunction
return CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
@staticmethod
def _configure_voyageai(config, model_name):
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
VoyageAIEmbeddingFunction,
)
return VoyageAIEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
def _configure_bedrock(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import AmazonBedrockEmbeddingFunction
return AmazonBedrockEmbeddingFunction(
session=config.get("session"),
)
@staticmethod
def _configure_bedrock(config, model_name):
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
# Allow custom model_name override with backwards compatibility
kwargs = {"session": config.get("session")}
if model_name is not None:
kwargs["model_name"] = model_name
return AmazonBedrockEmbeddingFunction(**kwargs)
@staticmethod
def _configure_huggingface(config, model_name):
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
def _configure_huggingface(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
from chromadb.utils.embedding_functions.huggingface_embedding_function import HuggingFaceEmbeddingServer
return HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
@staticmethod
def _configure_watson(config, model_name):
def _configure_watson(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
raise EmbeddingConfigurationError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding.",
provider="watson"
)
class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
@@ -205,32 +172,6 @@ class EmbeddingConfigurator:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print("Error during Watson embedding:", e)
raise e
raise EmbeddingInitializationError("watson", str(e))
return WatsonEmbeddingFunction()
@staticmethod
def _configure_custom(config):
custom_embedder = config.get("embedder")
if isinstance(custom_embedder, EmbeddingFunction):
try:
validate_embedding_function(custom_embedder)
return custom_embedder
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
elif callable(custom_embedder):
try:
instance = custom_embedder()
if isinstance(instance, EmbeddingFunction):
validate_embedding_function(instance)
return instance
raise ValueError(
"Custom embedder does not create an EmbeddingFunction instance"
)
except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
else:
raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
)

View File

@@ -0,0 +1,20 @@
from typing import List, Optional
class EmbeddingConfigurationError(Exception):
def __init__(self, message: str, provider: Optional[str] = None):
self.message = message
self.provider = provider
super().__init__(self.message)
class EmbeddingProviderError(EmbeddingConfigurationError):
def __init__(self, provider: str, supported_providers: List[str]):
message = f"Unsupported embedding provider: {provider}, supported providers: {supported_providers}"
super().__init__(message, provider)
class EmbeddingInitializationError(EmbeddingConfigurationError):
def __init__(self, provider: str, error: str):
message = f"Failed to initialize embedding function for provider {provider}: {error}"
super().__init__(message, provider)

View File

@@ -1,37 +1,30 @@
# conftest.py
import os
import tempfile
from pathlib import Path
import pytest
from dotenv import load_dotenv
load_result = load_dotenv(override=True)
@pytest.fixture(autouse=True)
def setup_test_environment():
"""Set up test environment with a temporary directory for SQLite storage."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create the directory with proper permissions
storage_dir = Path(temp_dir) / "crewai_test_storage"
storage_dir.mkdir(parents=True, exist_ok=True)
# Validate that the directory was created successfully
if not storage_dir.exists() or not storage_dir.is_dir():
raise RuntimeError(f"Failed to create test storage directory: {storage_dir}")
# Verify directory permissions
try:
# Try to create a test file to verify write permissions
test_file = storage_dir / ".permissions_test"
test_file.touch()
test_file.unlink()
except (OSError, IOError) as e:
raise RuntimeError(f"Test storage directory {storage_dir} is not writable: {e}")
# Set environment variable to point to the test storage directory
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
yield
# Cleanup is handled automatically when tempfile context exits
def setup_test_env():
"""Configure test environment to use Ollama as the default embedding provider."""
# Store original environment variables
original_env = {
"CREWAI_EMBEDDING_PROVIDER": os.environ.get("CREWAI_EMBEDDING_PROVIDER"),
"CREWAI_EMBEDDING_MODEL": os.environ.get("CREWAI_EMBEDDING_MODEL"),
"CREWAI_OLLAMA_URL": os.environ.get("CREWAI_OLLAMA_URL"),
}
# Set test environment
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2"
os.environ["CREWAI_OLLAMA_URL"] = "http://localhost:11434/api/embeddings"
yield
# Restore original environment
for key, value in original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

View File

@@ -0,0 +1,91 @@
import os
import tempfile
import pytest
from crewai.memory import ShortTermMemory, LongTermMemory, EntityMemory
from crewai.utilities.exceptions.embedding_exceptions import (
EmbeddingConfigurationError,
EmbeddingProviderError
)
from crewai.utilities import EmbeddingConfigurator
@pytest.fixture
def temp_db_dir():
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir
def test_memory_reset_with_ollama(temp_db_dir):
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2"
memories = [
ShortTermMemory(path=temp_db_dir),
LongTermMemory(path=temp_db_dir),
EntityMemory(path=temp_db_dir)
]
for memory in memories:
memory.reset()
def test_memory_reset_with_openai(temp_db_dir):
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai"
os.environ["CREWAI_EMBEDDING_MODEL"] = "text-embedding-3-small"
memories = [
ShortTermMemory(path=temp_db_dir),
LongTermMemory(path=temp_db_dir),
EntityMemory(path=temp_db_dir)
]
for memory in memories:
memory.reset()
def test_memory_reset_with_invalid_provider(temp_db_dir):
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider"
with pytest.raises(EmbeddingProviderError):
memories = [
ShortTermMemory(path=temp_db_dir),
LongTermMemory(path=temp_db_dir),
EntityMemory(path=temp_db_dir)
]
for memory in memories:
memory.reset()
def test_memory_reset_with_invalid_configuration(temp_db_dir):
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai"
os.environ.pop("OPENAI_API_KEY", None)
with pytest.raises(EmbeddingConfigurationError):
memories = [
ShortTermMemory(path=temp_db_dir),
LongTermMemory(path=temp_db_dir),
EntityMemory(path=temp_db_dir)
]
for memory in memories:
memory.reset()
def test_memory_reset_with_missing_ollama_url(temp_db_dir):
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
os.environ.pop("CREWAI_OLLAMA_URL", None)
# Should use default URL when CREWAI_OLLAMA_URL is not set
memories = [
ShortTermMemory(path=temp_db_dir),
LongTermMemory(path=temp_db_dir),
EntityMemory(path=temp_db_dir)
]
for memory in memories:
memory.reset()
def test_memory_reset_with_custom_path(temp_db_dir):
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
custom_path = os.path.join(temp_db_dir, "custom")
os.makedirs(custom_path, exist_ok=True)
memories = [
ShortTermMemory(path=custom_path),
LongTermMemory(path=custom_path),
EntityMemory(path=custom_path)
]
for memory in memories:
memory.reset()
assert not os.path.exists(os.path.join(custom_path, "short_term"))
assert not os.path.exists(os.path.join(custom_path, "long_term"))
assert not os.path.exists(os.path.join(custom_path, "entity"))