From cbfcde73ec4156b385dead3180109d885046cbb1 Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Mon, 18 Nov 2024 13:52:33 -0800 Subject: [PATCH] consolodation and improvements --- src/crewai/agent.py | 1 - src/crewai/knowledge/embedder/ollama.py | 82 -------- src/crewai/knowledge/embedder/openai.py | 85 -------- src/crewai/knowledge/knowledge.py | 7 +- .../source/string_knowledge_source.py | 9 +- .../source/text_file_knowledge_source.py | 6 +- .../storage/base_knowledge_storage.py | 29 +++ .../knowledge/storage/knowledge_storage.py | 29 ++- src/crewai/memory/storage/rag_storage.py | 134 +------------ src/crewai/utilities/__init__.py | 2 + .../utilities/embedding_configurator.py | 184 ++++++++++++++++++ 11 files changed, 248 insertions(+), 320 deletions(-) delete mode 100644 src/crewai/knowledge/embedder/ollama.py delete mode 100644 src/crewai/knowledge/embedder/openai.py create mode 100644 src/crewai/knowledge/storage/base_knowledge_storage.py create mode 100644 src/crewai/utilities/embedding_configurator.py diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 8937ada84..ce3f0b01e 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -121,7 +121,6 @@ class Agent(BaseAgent): default="safe", description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).", ) - # TODO: Lorenze add knowledge_embedder. Support direct class or config dict. _knowledge: Optional[Knowledge] = PrivateAttr(default=None) @model_validator(mode="after") diff --git a/src/crewai/knowledge/embedder/ollama.py b/src/crewai/knowledge/embedder/ollama.py deleted file mode 100644 index 3f7521cab..000000000 --- a/src/crewai/knowledge/embedder/ollama.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -from typing import List, Optional - -import numpy as np -from openai import OpenAI - -from .base_embedder import BaseEmbedder - - -class OllamaEmbedder(BaseEmbedder): - """ - A wrapper class for text embedding models using Ollama's API - """ - - def __init__( - self, - model_name: str, - api_key: Optional[str] = None, - base_url: str = "http://localhost:11434/v1", - ): - """ - Initialize the embedding model - - Args: - model_name: Name of the model to use - api_key: API key (defaults to 'ollama' or environment variable 'OLLAMA_API_KEY') - base_url: Base URL for the Ollama API (default is 'http://localhost:11434/v1') - """ - self.model_name = model_name - self.api_key = api_key or os.getenv("OLLAMA_API_KEY") or "ollama" - self.base_url = base_url - self.client = OpenAI(base_url=self.base_url, api_key=self.api_key) - - def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]: - """ - Generate embeddings for a list of text chunks - - Args: - chunks: List of text chunks to embed - - Returns: - List of embeddings - """ - return self.embed_texts(chunks) - - def embed_texts(self, texts: List[str]) -> List[np.ndarray]: - """ - Generate embeddings for a list of texts - - Args: - texts: List of texts to embed - - Returns: - List of embeddings - """ - embeddings = [] - max_batch_size = 2048 # Adjust batch size if necessary - for i in range(0, len(texts), max_batch_size): - batch = texts[i : i + max_batch_size] - response = self.client.embeddings.create(input=batch, model=self.model_name) - batch_embeddings = [np.array(item.embedding) for item in response.data] - embeddings.extend(batch_embeddings) - return embeddings - - def embed_text(self, text: str) -> np.ndarray: - """ - Generate embedding for a single text - - Args: - text: Text to embed - - Returns: - Embedding array - """ - return self.embed_texts([text])[0] - - @property - def dimension(self) -> int: - """Get the dimension of the embeddings""" - # Embedding dimensions may vary; we'll determine it dynamically - test_embed = self.embed_text("test") - return len(test_embed) diff --git a/src/crewai/knowledge/embedder/openai.py b/src/crewai/knowledge/embedder/openai.py deleted file mode 100644 index d38376bdc..000000000 --- a/src/crewai/knowledge/embedder/openai.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -from typing import List, Optional - -import numpy as np -from openai import OpenAI - -from .base_embedder import BaseEmbedder - - -class OpenAIEmbedder(BaseEmbedder): - """ - A wrapper class for text embedding models using OpenAI's Embedding API - """ - - def __init__( - self, - model_name: str = "text-embedding-ada-002", - api_key: Optional[str] = None, - ): - """ - Initialize the embedding model - - Args: - model_name: Name of the model to use - api_key: OpenAI API key - """ - self.model_name = model_name - self.api_key = api_key or os.getenv("OPENAI_API_KEY") - if not self.api_key: - raise ValueError( - "OpenAI API key must be provided or set in the environment variable 'OPENAI_API_KEY'" - ) - self.client = OpenAI( - api_key=self.api_key, - base_url="http://localhost:11434/v1", - ) - - def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]: - """ - Generate embeddings for a list of text chunks - - Args: - chunks: List of text chunks to embed - - Returns: - List of embeddings - """ - return self.embed_texts(chunks) - - def embed_texts(self, texts: List[str]) -> List[np.ndarray]: - """ - Generate embeddings for a list of texts - - Args: - texts: List of texts to embed - - Returns: - List of embeddings - """ - embeddings = [] - max_batch_size = 2048 # OpenAI recommends smaller batch sizes - for i in range(0, len(texts), max_batch_size): - batch = texts[i : i + max_batch_size] - response = self.client.embeddings.create(input=batch, model=self.model_name) - batch_embeddings = [np.array(data.embedding) for data in response.data] - embeddings.extend(batch_embeddings) - return embeddings - - def embed_text(self, text: str) -> np.ndarray: - """ - Generate embedding fors a single text - - Args: - text: Text to embed - - Returns: - Embedding array - """ - return self.embed_texts([text])[0] - - @property - def dimension(self) -> int: - """Get the dimension of the embeddings""" - # For OpenAI's text-embedding-ada-002, the dimension is 1536 - return 1536 diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index 79ff009f7..3399956bf 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -15,12 +15,13 @@ class Knowledge(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) agents: List[str] = Field(default_factory=list) storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage) + embedder_config: Optional[Dict[str, Any]] = Field(default_factory=None) def __init__(self, **data): super().__init__(**data) - # Call add on all sources during initialization - for source in self.sources: - source.add(self.embedder) + embedder_config = data.get("embedder_config", None) + if embedder_config: + self.storage = KnowledgeStorage(embedder_config=embedder_config) def query( self, query: List[str], limit: int = 3, preference: Optional[str] = None diff --git a/src/crewai/knowledge/source/string_knowledge_source.py b/src/crewai/knowledge/source/string_knowledge_source.py index e9e72334f..7d0b8e933 100644 --- a/src/crewai/knowledge/source/string_knowledge_source.py +++ b/src/crewai/knowledge/source/string_knowledge_source.py @@ -1,8 +1,7 @@ -from typing import List, Dict, Any +from typing import List from pydantic import Field -from crewai.knowledge.embedder.base_embedder import BaseEmbedder from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource @@ -20,14 +19,10 @@ class StringKnowledgeSource(BaseKnowledgeSource): if not isinstance(self.content, str): raise ValueError("StringKnowledgeSource only accepts string content") - def add(self, embedder: BaseEmbedder) -> None: + def add(self) -> None: """Add string content to the knowledge source, chunk it, compute embeddings, and save them.""" new_chunks = self._chunk_text(self.content) self.chunks.extend(new_chunks) - # Compute embeddings for the new chunks - new_embeddings = embedder.embed_chunks(new_chunks) - # Save the embeddings - self.chunk_embeddings.extend(new_embeddings) self.save_documents(metadata=self.metadata) def _chunk_text(self, text: str) -> List[str]: diff --git a/src/crewai/knowledge/source/text_file_knowledge_source.py b/src/crewai/knowledge/source/text_file_knowledge_source.py index b8195da5c..c4821832e 100644 --- a/src/crewai/knowledge/source/text_file_knowledge_source.py +++ b/src/crewai/knowledge/source/text_file_knowledge_source.py @@ -13,17 +13,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource): with self.file_path.open("r", encoding="utf-8") as f: return f.read() - def add(self, embedder: BaseEmbedder) -> None: + def add(self) -> None: """ Add text file content to the knowledge source, chunk it, compute embeddings, and save the embeddings. """ new_chunks = self._chunk_text(self.content) self.chunks.extend(new_chunks) - # Compute embeddings for the new chunks - new_embeddings = embedder.embed_chunks(new_chunks) - # Save the embeddings - self.chunk_embeddings.extend(new_embeddings) self.save_documents(metadata=self.metadata) def _chunk_text(self, text: str) -> List[str]: diff --git a/src/crewai/knowledge/storage/base_knowledge_storage.py b/src/crewai/knowledge/storage/base_knowledge_storage.py new file mode 100644 index 000000000..78d370e04 --- /dev/null +++ b/src/crewai/knowledge/storage/base_knowledge_storage.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional + + +class BaseKnowledgeStorage(ABC): + """Abstract base class for knowledge storage implementations.""" + + @abstractmethod + def search( + self, + query: List[str], + limit: int = 3, + filter: Optional[dict] = None, + score_threshold: float = 0.35, + ) -> List[Dict[str, Any]]: + """Search for documents in the knowledge base.""" + pass + + @abstractmethod + def save( + self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]] + ) -> None: + """Save documents to the knowledge base.""" + pass + + @abstractmethod + def reset(self) -> None: + """Reset the knowledge base.""" + pass diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index d79a192a0..7b7fa8c69 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -7,6 +7,8 @@ import chromadb from crewai.utilities.paths import db_storage_path from typing import Optional, List from typing import Dict, Any +from crewai.utilities import EmbeddingConfigurator +from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage @contextlib.contextmanager @@ -26,7 +28,7 @@ def suppress_logging( logger.setLevel(original_level) -class KnowledgeStorage: +class KnowledgeStorage(BaseKnowledgeStorage): """ Extends Storage to handle embeddings for memory entries, improving search efficiency. @@ -35,10 +37,7 @@ class KnowledgeStorage: collection: Optional[chromadb.Collection] = None def __init__(self, embedder_config=None): - self.embedder_config = ( - embedder_config or self._create_default_embedding_function() - ) - self._initialize_app() + self._initialize_app(embedder_config or {}) def search( self, @@ -54,7 +53,6 @@ class KnowledgeStorage: n_results=limit, where=filter, ) - print("Fetched", fetched) results = [] for i in range(len(fetched["ids"][0])): result = { @@ -69,10 +67,12 @@ class KnowledgeStorage: else: raise Exception("Collection not initialized") - def _initialize_app(self): + def _initialize_app(self, embedder_config: Optional[Dict[str, Any]] = None): import chromadb from chromadb.config import Settings + self._set_embedder_config(embedder_config) + chroma_client = chromadb.PersistentClient( path=f"{db_storage_path()}/knowledge", settings=Settings(allow_reset=True), @@ -107,3 +107,18 @@ class KnowledgeStorage: from crewai.knowledge.embedder.fastembed import FastEmbed return FastEmbed().embed_texts + + def _set_embedder_config( + self, embedder_config: Optional[Dict[str, Any]] = None + ) -> None: + """Set the embedding configuration for the knowledge storage. + + Args: + embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder. + If None or empty, defaults to the default embedding function. + """ + self.embedder_config = ( + EmbeddingConfigurator().configure_embedder(embedder_config) + if embedder_config + else self._create_default_embedding_function() + ) diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 7af5fb554..4023cf558 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -4,13 +4,12 @@ import logging import os import shutil import uuid -from typing import Any, Dict, List, Optional, cast -from chromadb import Documents, EmbeddingFunction, Embeddings +from typing import Any, Dict, List, Optional from chromadb.api import ClientAPI -from chromadb.api.types import validate_embedding_function from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities.paths import db_storage_path +from crewai.utilities import EmbeddingConfigurator @contextlib.contextmanager @@ -51,133 +50,8 @@ class RAGStorage(BaseRAGStorage): self._initialize_app() def _set_embedder_config(self): - if self.embedder_config is None: - self.embedder_config = self._create_default_embedding_function() - - if isinstance(self.embedder_config, dict): - provider = self.embedder_config.get("provider") - config = self.embedder_config.get("config", {}) - model_name = config.get("model") - if provider == "openai": - from chromadb.utils.embedding_functions.openai_embedding_function import ( - OpenAIEmbeddingFunction, - ) - - self.embedder_config = OpenAIEmbeddingFunction( - api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), - model_name=model_name, - ) - elif provider == "azure": - from chromadb.utils.embedding_functions.openai_embedding_function import ( - OpenAIEmbeddingFunction, - ) - - self.embedder_config = OpenAIEmbeddingFunction( - api_key=config.get("api_key"), - api_base=config.get("api_base"), - api_type=config.get("api_type", "azure"), - api_version=config.get("api_version"), - model_name=model_name, - ) - elif provider == "ollama": - from chromadb.utils.embedding_functions.ollama_embedding_function import ( - OllamaEmbeddingFunction, - ) - - self.embedder_config = OllamaEmbeddingFunction( - url=config.get("url", "http://localhost:11434/api/embeddings"), - model_name=model_name, - ) - elif provider == "vertexai": - from chromadb.utils.embedding_functions.google_embedding_function import ( - GoogleVertexEmbeddingFunction, - ) - - self.embedder_config = GoogleVertexEmbeddingFunction( - model_name=model_name, - api_key=config.get("api_key"), - ) - elif provider == "google": - from chromadb.utils.embedding_functions.google_embedding_function import ( - GoogleGenerativeAiEmbeddingFunction, - ) - - self.embedder_config = GoogleGenerativeAiEmbeddingFunction( - model_name=model_name, - api_key=config.get("api_key"), - ) - elif provider == "cohere": - from chromadb.utils.embedding_functions.cohere_embedding_function import ( - CohereEmbeddingFunction, - ) - - self.embedder_config = CohereEmbeddingFunction( - model_name=model_name, - api_key=config.get("api_key"), - ) - elif provider == "bedrock": - from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( - AmazonBedrockEmbeddingFunction, - ) - - self.embedder_config = AmazonBedrockEmbeddingFunction( - session=config.get("session"), - ) - elif provider == "huggingface": - from chromadb.utils.embedding_functions.huggingface_embedding_function import ( - HuggingFaceEmbeddingServer, - ) - - self.embedder_config = HuggingFaceEmbeddingServer( - url=config.get("api_url"), - ) - elif provider == "watson": - try: - import ibm_watsonx_ai.foundation_models as watson_models - from ibm_watsonx_ai import Credentials - from ibm_watsonx_ai.metanames import ( - EmbedTextParamsMetaNames as EmbedParams, - ) - except ImportError as e: - raise ImportError( - "IBM Watson dependencies are not installed. Please install them to use Watson embedding." - ) from e - - class WatsonEmbeddingFunction(EmbeddingFunction): - def __call__(self, input: Documents) -> Embeddings: - if isinstance(input, str): - input = [input] - - embed_params = { - EmbedParams.TRUNCATE_INPUT_TOKENS: 3, - EmbedParams.RETURN_OPTIONS: {"input_text": True}, - } - - embedding = watson_models.Embeddings( - model_id=config.get("model"), - params=embed_params, - credentials=Credentials( - api_key=config.get("api_key"), url=config.get("api_url") - ), - project_id=config.get("project_id"), - ) - - try: - embeddings = embedding.embed_documents(input) - return cast(Embeddings, embeddings) - - except Exception as e: - print("Error during Watson embedding:", e) - raise e - - self.embedder_config = WatsonEmbeddingFunction() - else: - raise Exception( - f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]" - ) - else: - validate_embedding_function(self.embedder_config) - self.embedder_config = self.embedder_config + configurator = EmbeddingConfigurator() + self.embedder_config = configurator.configure_embedder(self.embedder_config) def _initialize_app(self): import chromadb diff --git a/src/crewai/utilities/__init__.py b/src/crewai/utilities/__init__.py index 26d35a6cc..dd6d9fa44 100644 --- a/src/crewai/utilities/__init__.py +++ b/src/crewai/utilities/__init__.py @@ -10,6 +10,7 @@ from .rpm_controller import RPMController from .exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, ) +from .embedding_configurator import EmbeddingConfigurator __all__ = [ "Converter", @@ -23,4 +24,5 @@ __all__ = [ "RPMController", "YamlParser", "LLMContextLengthExceededException", + "EmbeddingConfigurator", ] diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py new file mode 100644 index 000000000..591071364 --- /dev/null +++ b/src/crewai/utilities/embedding_configurator.py @@ -0,0 +1,184 @@ +import os +from typing import Any, Dict, cast +from chromadb import EmbeddingFunction, Documents, Embeddings +from chromadb.api.types import validate_embedding_function + + +class EmbeddingConfigurator: + def __init__(self): + self.embedding_functions = { + "openai": self._configure_openai, + "azure": self._configure_azure, + "ollama": self._configure_ollama, + "vertexai": self._configure_vertexai, + "google": self._configure_google, + "cohere": self._configure_cohere, + "bedrock": self._configure_bedrock, + "huggingface": self._configure_huggingface, + "watson": self._configure_watson, + } + + def configure_embedder( + self, + embedder_config: Dict[str, Any] | None = None, + ) -> EmbeddingFunction: + """Configures and returns an embedding function based on the provided config.""" + if embedder_config is None: + return self._create_default_embedding_function() + + provider = embedder_config.get("provider") + config = embedder_config.get("config", {}) + model_name = config.get("model") + + if isinstance(provider, EmbeddingFunction): + try: + validate_embedding_function(provider) + print("Valid custom embedding function", provider, config) + return provider + except Exception as e: + raise ValueError(f"Invalid custom embedding function: {str(e)}") + + if provider not in self.embedding_functions: + raise Exception( + f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" + ) + + return self.embedding_functions[provider](config, model_name) + + @staticmethod + def _create_default_embedding_function(): + from chromadb.utils.embedding_functions.openai_embedding_function import ( + OpenAIEmbeddingFunction, + ) + + return OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" + ) + + @staticmethod + def _configure_openai(config, model_name): + from chromadb.utils.embedding_functions.openai_embedding_function import ( + OpenAIEmbeddingFunction, + ) + + return OpenAIEmbeddingFunction( + api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), + model_name=model_name, + ) + + @staticmethod + def _configure_azure(config, model_name): + from chromadb.utils.embedding_functions.openai_embedding_function import ( + OpenAIEmbeddingFunction, + ) + + return OpenAIEmbeddingFunction( + api_key=config.get("api_key"), + api_base=config.get("api_base"), + api_type=config.get("api_type", "azure"), + api_version=config.get("api_version"), + model_name=model_name, + ) + + @staticmethod + def _configure_ollama(config, model_name): + from chromadb.utils.embedding_functions.ollama_embedding_function import ( + OllamaEmbeddingFunction, + ) + + return OllamaEmbeddingFunction( + url=config.get("url", "http://localhost:11434/api/embeddings"), + model_name=model_name, + ) + + @staticmethod + def _configure_vertexai(config, model_name): + from chromadb.utils.embedding_functions.google_embedding_function import ( + GoogleVertexEmbeddingFunction, + ) + + return GoogleVertexEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) + + @staticmethod + def _configure_google(config, model_name): + from chromadb.utils.embedding_functions.google_embedding_function import ( + GoogleGenerativeAiEmbeddingFunction, + ) + + return GoogleGenerativeAiEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) + + @staticmethod + def _configure_cohere(config, model_name): + from chromadb.utils.embedding_functions.cohere_embedding_function import ( + CohereEmbeddingFunction, + ) + + return CohereEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) + + @staticmethod + def _configure_bedrock(config, model_name): + from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( + AmazonBedrockEmbeddingFunction, + ) + + return AmazonBedrockEmbeddingFunction( + session=config.get("session"), + ) + + @staticmethod + def _configure_huggingface(config, model_name): + from chromadb.utils.embedding_functions.huggingface_embedding_function import ( + HuggingFaceEmbeddingServer, + ) + + return HuggingFaceEmbeddingServer( + url=config.get("api_url"), + ) + + @staticmethod + def _configure_watson(config, model_name): + try: + import ibm_watsonx_ai.foundation_models as watson_models + from ibm_watsonx_ai import Credentials + from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams + except ImportError as e: + raise ImportError( + "IBM Watson dependencies are not installed. Please install them to use Watson embedding." + ) from e + + class WatsonEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + if isinstance(input, str): + input = [input] + + embed_params = { + EmbedParams.TRUNCATE_INPUT_TOKENS: 3, + EmbedParams.RETURN_OPTIONS: {"input_text": True}, + } + + embedding = watson_models.Embeddings( + model_id=config.get("model"), + params=embed_params, + credentials=Credentials( + api_key=config.get("api_key"), url=config.get("api_url") + ), + project_id=config.get("project_id"), + ) + + try: + embeddings = embedding.embed_documents(input) + return cast(Embeddings, embeddings) + except Exception as e: + print("Error during Watson embedding:", e) + raise e + + return WatsonEmbeddingFunction()