mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 18:18:13 +00:00
Compare commits
5 Commits
devin/1768
...
devin/1745
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b7b2993ca | ||
|
|
86155b4590 | ||
|
|
6435a419d7 | ||
|
|
9f57e266b8 | ||
|
|
8dd779371d |
@@ -21,7 +21,6 @@ dependencies = [
|
|||||||
"opentelemetry-sdk>=1.30.0",
|
"opentelemetry-sdk>=1.30.0",
|
||||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
||||||
# Data Handling
|
# Data Handling
|
||||||
"chromadb>=0.5.23",
|
|
||||||
"openpyxl>=3.1.5",
|
"openpyxl>=3.1.5",
|
||||||
"pyvis>=0.3.2",
|
"pyvis>=0.3.2",
|
||||||
# Authentication and Security
|
# Authentication and Security
|
||||||
@@ -67,6 +66,9 @@ docling = [
|
|||||||
aisuite = [
|
aisuite = [
|
||||||
"aisuite>=0.1.10",
|
"aisuite>=0.1.10",
|
||||||
]
|
]
|
||||||
|
chromadb = [
|
||||||
|
"chromadb>=0.5.23",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
|
|||||||
@@ -4,13 +4,29 @@ import io
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import chromadb
|
if TYPE_CHECKING:
|
||||||
import chromadb.errors
|
import chromadb
|
||||||
from chromadb.api import ClientAPI
|
import chromadb.errors
|
||||||
from chromadb.api.types import OneOrMany
|
from chromadb.api import ClientAPI
|
||||||
from chromadb.config import Settings
|
from chromadb.api.types import OneOrMany
|
||||||
|
from chromadb.config import Settings
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
import chromadb.errors
|
||||||
|
from chromadb.api import ClientAPI
|
||||||
|
from chromadb.api.types import OneOrMany
|
||||||
|
from chromadb.config import Settings
|
||||||
|
except ImportError:
|
||||||
|
chromadb = None
|
||||||
|
ClientAPI = None
|
||||||
|
OneOrMany = None
|
||||||
|
Settings = None
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
@@ -43,9 +59,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
search efficiency.
|
search efficiency.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
collection: Optional[chromadb.Collection] = None
|
collection: Optional[Any] = None
|
||||||
collection_name: Optional[str] = "knowledge"
|
collection_name: Optional[str] = "knowledge"
|
||||||
app: Optional[ClientAPI] = None
|
app: Optional[Any] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -62,6 +78,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
filter: Optional[dict] = None,
|
filter: Optional[dict] = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
|
if not chromadb:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
with suppress_logging():
|
with suppress_logging():
|
||||||
if self.collection:
|
if self.collection:
|
||||||
fetched = self.collection.query(
|
fetched = self.collection.query(
|
||||||
@@ -84,6 +105,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
raise Exception("Collection not initialized")
|
raise Exception("Collection not initialized")
|
||||||
|
|
||||||
def initialize_knowledge_storage(self):
|
def initialize_knowledge_storage(self):
|
||||||
|
if not chromadb:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
base_path = os.path.join(db_storage_path(), "knowledge")
|
base_path = os.path.join(db_storage_path(), "knowledge")
|
||||||
chroma_client = chromadb.PersistentClient(
|
chroma_client = chromadb.PersistentClient(
|
||||||
path=base_path,
|
path=base_path,
|
||||||
@@ -109,6 +135,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
raise Exception("Failed to create or get collection")
|
raise Exception("Failed to create or get collection")
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
if not chromadb:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||||
if not self.app:
|
if not self.app:
|
||||||
self.app = chromadb.PersistentClient(
|
self.app = chromadb.PersistentClient(
|
||||||
@@ -126,6 +157,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
documents: List[str],
|
documents: List[str],
|
||||||
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
||||||
):
|
):
|
||||||
|
if not chromadb:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
if not self.collection:
|
if not self.collection:
|
||||||
raise Exception("Collection not initialized")
|
raise Exception("Collection not initialized")
|
||||||
|
|
||||||
@@ -181,13 +217,23 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_default_embedding_function(self):
|
def _create_default_embedding_function(self):
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
if not chromadb:
|
||||||
OpenAIEmbeddingFunction,
|
raise ImportError(
|
||||||
)
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
try:
|
||||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
)
|
OpenAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
||||||
"""Set the embedding configuration for the knowledge storage.
|
"""Set the embedding configuration for the knowledge storage.
|
||||||
|
|||||||
@@ -4,9 +4,24 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from chromadb.api import ClientAPI
|
if TYPE_CHECKING:
|
||||||
|
import chromadb
|
||||||
|
from chromadb.api import ClientAPI
|
||||||
|
from chromadb.config import Settings
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
from chromadb.api import ClientAPI
|
||||||
|
from chromadb.config import Settings
|
||||||
|
except ImportError:
|
||||||
|
chromadb = None
|
||||||
|
ClientAPI = None
|
||||||
|
Settings = None
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
@@ -37,7 +52,8 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
search efficiency.
|
search efficiency.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
app: ClientAPI | None = None
|
app: Optional[Any] = None
|
||||||
|
collection: Optional[Any] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||||
@@ -60,8 +76,13 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||||
|
|
||||||
def _initialize_app(self):
|
def _initialize_app(self):
|
||||||
import chromadb
|
try:
|
||||||
from chromadb.config import Settings
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
self._set_embedder_config()
|
self._set_embedder_config()
|
||||||
chroma_client = chromadb.PersistentClient(
|
chroma_client = chromadb.PersistentClient(
|
||||||
@@ -118,6 +139,9 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
if not hasattr(self, "app"):
|
if not hasattr(self, "app"):
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
|
if not self.collection:
|
||||||
|
raise ValueError("Collection not initialized")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with suppress_logging():
|
with suppress_logging():
|
||||||
response = self.collection.query(query_texts=query, n_results=limit)
|
response = self.collection.query(query_texts=query, n_results=limit)
|
||||||
@@ -142,6 +166,9 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
|
if not self.collection:
|
||||||
|
raise ValueError("Collection not initialized")
|
||||||
|
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
documents=[text],
|
documents=[text],
|
||||||
metadatas=[metadata or {}],
|
metadatas=[metadata or {}],
|
||||||
|
|||||||
@@ -1,8 +1,30 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional, cast
|
from typing import Any, Dict, List, Optional, Protocol, TypeVar, Union, cast
|
||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
Documents = Union[str, List[str]]
|
||||||
from chromadb.api.types import validate_embedding_function
|
Embeddings = List[List[float]]
|
||||||
|
|
||||||
|
class EmbeddingFunctionProtocol(Protocol):
|
||||||
|
"""Protocol for EmbeddingFunction when chromadb is not installed."""
|
||||||
|
def __call__(self, input: Documents) -> Embeddings: ...
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from chromadb import EmbeddingFunction
|
||||||
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from chromadb import EmbeddingFunction
|
||||||
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
except ImportError:
|
||||||
|
class EmbeddingFunction(Protocol):
|
||||||
|
"""Protocol for EmbeddingFunction when chromadb is not installed."""
|
||||||
|
def __call__(self, input: Any) -> Any: ...
|
||||||
|
|
||||||
|
def validate_embedding_function(func: Any) -> None:
|
||||||
|
"""Stub for validate_embedding_function when chromadb is not installed."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfigurator:
|
class EmbeddingConfigurator:
|
||||||
@@ -47,190 +69,252 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_default_embedding_function():
|
def _create_default_embedding_function():
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
try:
|
||||||
OpenAIEmbeddingFunction,
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
)
|
OpenAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_openai(config, model_name):
|
def _configure_openai(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
try:
|
||||||
OpenAIEmbeddingFunction,
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
)
|
OpenAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_base=config.get("api_base", None),
|
api_base=config.get("api_base", None),
|
||||||
api_type=config.get("api_type", None),
|
api_type=config.get("api_type", None),
|
||||||
api_version=config.get("api_version", None),
|
api_version=config.get("api_version", None),
|
||||||
default_headers=config.get("default_headers", None),
|
default_headers=config.get("default_headers", None),
|
||||||
dimensions=config.get("dimensions", None),
|
dimensions=config.get("dimensions", None),
|
||||||
deployment_id=config.get("deployment_id", None),
|
deployment_id=config.get("deployment_id", None),
|
||||||
organization_id=config.get("organization_id", None),
|
organization_id=config.get("organization_id", None),
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_azure(config, model_name):
|
def _configure_azure(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
try:
|
||||||
OpenAIEmbeddingFunction,
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
)
|
OpenAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
api_base=config.get("api_base"),
|
api_base=config.get("api_base"),
|
||||||
api_type=config.get("api_type", "azure"),
|
api_type=config.get("api_type", "azure"),
|
||||||
api_version=config.get("api_version"),
|
api_version=config.get("api_version"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
default_headers=config.get("default_headers"),
|
default_headers=config.get("default_headers"),
|
||||||
dimensions=config.get("dimensions"),
|
dimensions=config.get("dimensions"),
|
||||||
deployment_id=config.get("deployment_id"),
|
deployment_id=config.get("deployment_id"),
|
||||||
organization_id=config.get("organization_id"),
|
organization_id=config.get("organization_id"),
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_ollama(config, model_name):
|
def _configure_ollama(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
try:
|
||||||
OllamaEmbeddingFunction,
|
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||||
)
|
OllamaEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
return OllamaEmbeddingFunction(
|
return OllamaEmbeddingFunction(
|
||||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_vertexai(config, model_name):
|
def _configure_vertexai(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
try:
|
||||||
GoogleVertexEmbeddingFunction,
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
)
|
GoogleVertexEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
return GoogleVertexEmbeddingFunction(
|
return GoogleVertexEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
project_id=config.get("project_id"),
|
project_id=config.get("project_id"),
|
||||||
region=config.get("region"),
|
region=config.get("region"),
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_google(config, model_name):
|
def _configure_google(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
try:
|
||||||
GoogleGenerativeAiEmbeddingFunction,
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
)
|
GoogleGenerativeAiEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
return GoogleGenerativeAiEmbeddingFunction(
|
return GoogleGenerativeAiEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
task_type=config.get("task_type"),
|
task_type=config.get("task_type"),
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_cohere(config, model_name):
|
def _configure_cohere(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
try:
|
||||||
CohereEmbeddingFunction,
|
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||||
)
|
CohereEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
return CohereEmbeddingFunction(
|
return CohereEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_voyageai(config, model_name):
|
def _configure_voyageai(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
try:
|
||||||
VoyageAIEmbeddingFunction,
|
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||||
)
|
VoyageAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
return VoyageAIEmbeddingFunction(
|
return VoyageAIEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_bedrock(config, model_name):
|
def _configure_bedrock(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
try:
|
||||||
AmazonBedrockEmbeddingFunction,
|
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||||
)
|
AmazonBedrockEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
# Allow custom model_name override with backwards compatibility
|
# Allow custom model_name override with backwards compatibility
|
||||||
kwargs = {"session": config.get("session")}
|
kwargs = {"session": config.get("session")}
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
kwargs["model_name"] = model_name
|
kwargs["model_name"] = model_name
|
||||||
return AmazonBedrockEmbeddingFunction(**kwargs)
|
return AmazonBedrockEmbeddingFunction(**kwargs)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_huggingface(config, model_name):
|
def _configure_huggingface(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
try:
|
||||||
HuggingFaceEmbeddingServer,
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||||
)
|
HuggingFaceEmbeddingServer,
|
||||||
|
)
|
||||||
|
|
||||||
return HuggingFaceEmbeddingServer(
|
return HuggingFaceEmbeddingServer(
|
||||||
url=config.get("api_url"),
|
url=config.get("api_url"),
|
||||||
)
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_watson(config, model_name):
|
def _configure_watson(config, model_name):
|
||||||
try:
|
try:
|
||||||
import ibm_watsonx_ai.foundation_models as watson_models
|
try:
|
||||||
from ibm_watsonx_ai import Credentials
|
import ibm_watsonx_ai.foundation_models as watson_models
|
||||||
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
from ibm_watsonx_ai import Credentials
|
||||||
except ImportError as e:
|
from ibm_watsonx_ai.metanames import (
|
||||||
raise ImportError(
|
EmbedTextParamsMetaNames as EmbedParams,
|
||||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
|
||||||
def __call__(self, input: Documents) -> Embeddings:
|
|
||||||
if isinstance(input, str):
|
|
||||||
input = [input]
|
|
||||||
|
|
||||||
embed_params = {
|
|
||||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
|
||||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
|
||||||
}
|
|
||||||
|
|
||||||
embedding = watson_models.Embeddings(
|
|
||||||
model_id=config.get("model"),
|
|
||||||
params=embed_params,
|
|
||||||
credentials=Credentials(
|
|
||||||
api_key=config.get("api_key"), url=config.get("api_url")
|
|
||||||
),
|
|
||||||
project_id=config.get("project_id"),
|
|
||||||
)
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||||
|
) from e
|
||||||
|
|
||||||
try:
|
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||||
embeddings = embedding.embed_documents(input)
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
return cast(Embeddings, embeddings)
|
if isinstance(input, str):
|
||||||
except Exception as e:
|
input = [input]
|
||||||
print("Error during Watson embedding:", e)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return WatsonEmbeddingFunction()
|
embed_params = {
|
||||||
|
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||||
|
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding = watson_models.Embeddings(
|
||||||
|
model_id=config.get("model"),
|
||||||
|
params=embed_params,
|
||||||
|
credentials=Credentials(
|
||||||
|
api_key=config.get("api_key"), url=config.get("api_url")
|
||||||
|
),
|
||||||
|
project_id=config.get("project_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = embedding.embed_documents(input)
|
||||||
|
return cast(Embeddings, embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error during Watson embedding:", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return WatsonEmbeddingFunction()
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_custom(config):
|
def _configure_custom(config):
|
||||||
custom_embedder = config.get("embedder")
|
try:
|
||||||
if isinstance(custom_embedder, EmbeddingFunction):
|
custom_embedder = config.get("embedder")
|
||||||
try:
|
if isinstance(custom_embedder, EmbeddingFunction):
|
||||||
validate_embedding_function(custom_embedder)
|
try:
|
||||||
return custom_embedder
|
validate_embedding_function(custom_embedder)
|
||||||
except Exception as e:
|
return custom_embedder
|
||||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
except Exception as e:
|
||||||
elif callable(custom_embedder):
|
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||||
try:
|
elif callable(custom_embedder):
|
||||||
instance = custom_embedder()
|
try:
|
||||||
if isinstance(instance, EmbeddingFunction):
|
instance = custom_embedder()
|
||||||
validate_embedding_function(instance)
|
if isinstance(instance, EmbeddingFunction):
|
||||||
return instance
|
validate_embedding_function(instance)
|
||||||
|
return instance
|
||||||
|
raise ValueError(
|
||||||
|
"Custom embedder does not create an EmbeddingFunction instance"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
||||||
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Custom embedder does not create an EmbeddingFunction instance"
|
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except ImportError:
|
||||||
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
raise ImportError(
|
||||||
else:
|
"ChromaDB is not installed. Please install it with `pip install crewai[chromadb]`."
|
||||||
raise ValueError(
|
|
||||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
chromadb_not_installed = False
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
except ImportError:
|
||||||
|
chromadb_not_installed = True
|
||||||
|
|
||||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||||
|
|
||||||
@@ -10,6 +16,7 @@ def long_term_memory():
|
|||||||
return LongTermMemory()
|
return LongTermMemory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(chromadb_not_installed, reason="ChromaDB is not installed")
|
||||||
def test_save_and_search(long_term_memory):
|
def test_save_and_search(long_term_memory):
|
||||||
memory = LongTermMemoryItem(
|
memory = LongTermMemoryItem(
|
||||||
agent="test_agent",
|
agent="test_agent",
|
||||||
|
|||||||
@@ -2,6 +2,12 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
chromadb_not_installed = False
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
except ImportError:
|
||||||
|
chromadb_not_installed = True
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
@@ -28,6 +34,7 @@ def short_term_memory():
|
|||||||
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
|
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(chromadb_not_installed, reason="ChromaDB is not installed")
|
||||||
def test_save_and_search(short_term_memory):
|
def test_save_and_search(short_term_memory):
|
||||||
memory = ShortTermMemoryItem(
|
memory = ShortTermMemoryItem(
|
||||||
data="""test value test value test value test value test value test value
|
data="""test value test value test value test value test value test value
|
||||||
|
|||||||
24
tests/storage/test_optional_chromadb.py
Normal file
24
tests/storage/test_optional_chromadb.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptionalChromadb(unittest.TestCase):
|
||||||
|
def test_rag_storage_import_error(self):
|
||||||
|
"""Test that RAGStorage raises an ImportError when chromadb is not installed."""
|
||||||
|
with patch.dict(sys.modules, {"chromadb": None}):
|
||||||
|
with pytest.raises(ImportError) as excinfo:
|
||||||
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
|
|
||||||
|
assert "ChromaDB is not installed" in str(excinfo.value)
|
||||||
|
|
||||||
|
def test_knowledge_storage_import_error(self):
|
||||||
|
"""Test that KnowledgeStorage raises an ImportError when chromadb is not installed."""
|
||||||
|
with patch.dict(sys.modules, {"chromadb": None}):
|
||||||
|
with pytest.raises(ImportError) as excinfo:
|
||||||
|
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||||
|
|
||||||
|
assert "ChromaDB is not installed" in str(excinfo.value)
|
||||||
Reference in New Issue
Block a user