diff --git a/src/crewai/rag/core/__init__.py b/src/crewai/rag/core/__init__.py new file mode 100644 index 000000000..34c997dfb --- /dev/null +++ b/src/crewai/rag/core/__init__.py @@ -0,0 +1 @@ +"""Core abstract base classes and protocols for RAG systems.""" diff --git a/src/crewai/rag/core/base_client.py b/src/crewai/rag/core/base_client.py new file mode 100644 index 000000000..c3bdcd3b0 --- /dev/null +++ b/src/crewai/rag/core/base_client.py @@ -0,0 +1,433 @@ +"""Protocol for vector database client implementations.""" + +from abc import abstractmethod +from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated +from typing_extensions import Unpack, Required + + +from crewai.rag.types import ( + EmbeddingFunction, + BaseRecord, + SearchResult, +) + + +class BaseCollectionParams(TypedDict): + """Base parameters for collection operations. + + Attributes: + collection_name: The name of the collection/index to operate on. + """ + + collection_name: Required[ + Annotated[ + str, + "Name of the collection/index. Implementations may have specific constraints (e.g., character limits, allowed characters, case sensitivity).", + ] + ] + + +class BaseCollectionAddParams(BaseCollectionParams): + """Parameters for adding documents to a collection. + + Extends BaseCollectionParams with document-specific fields. + + Attributes: + collection_name: The name of the collection to add documents to. + documents: List of BaseRecord dictionaries containing document data. + """ + + documents: list[BaseRecord] + + +class BaseCollectionSearchParams(BaseCollectionParams, total=False): + """Parameters for searching within a collection. + + Extends BaseCollectionParams with search-specific optional fields. + All fields except collection_name and query are optional. + + Attributes: + query: The text query to search for (required). + limit: Maximum number of results to return. + metadata_filter: Filter results by metadata fields. + score_threshold: Minimum similarity score for results (0-1). + """ + + query: Required[str] + limit: int + metadata_filter: dict[str, Any] + score_threshold: float + + +@runtime_checkable +class BaseClient(Protocol): + """Protocol for vector store client implementations. + + This protocol defines the interface that all vector store client implementations + must follow. It provides a consistent API for storing and retrieving + documents with their vector embeddings across different vector database + backends (e.g., Qdrant, ChromaDB, Weaviate). Implementing classes should + handle connection management, data persistence, and vector similarity + search operations specific to their backend. + + Implementation Guidelines: + Implementations should accept BaseClientParams in their constructor to allow + passing pre-configured client instances: + + class MyVectorClient: + def __init__(self, client: Any | None = None, **kwargs): + if client: + self.client = client + else: + self.client = self._create_default_client(**kwargs) + + Notes: + This protocol replaces the former BaseRAGStorage abstraction, + providing a cleaner interface for vector store operations. + + Attributes: + embedding_function: Callable that takes a list of text strings + and returns a list of embedding vectors. Implementations + should always provide a default embedding function. + client: The underlying vector database client instance. This could be + passed via BaseClientParams during initialization or created internally. + """ + + client: Any + embedding_function: EmbeddingFunction + + @abstractmethod + def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: + """Create a new collection/index in the vector database. + + Keyword Args: + collection_name: The name of the collection to create. Must be unique within + the vector database instance. + + Raises: + ValueError: If collection name already exists. + ConnectionError: If unable to connect to the vector database backend. + """ + ... + + @abstractmethod + async def acreate_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: + """Create a new collection/index in the vector database asynchronously. + + Keyword Args: + collection_name: The name of the collection to create. Must be unique within + the vector database instance. + + Raises: + ValueError: If collection name already exists. + ConnectionError: If unable to connect to the vector database backend. + """ + ... + + @abstractmethod + def get_or_create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> Any: + """Get an existing collection or create it if it doesn't exist. + + This method provides a convenient way to ensure a collection exists + without having to check for its existence first. + + Keyword Args: + collection_name: The name of the collection to get or create. + + Returns: + A collection object whose type depends on the backend implementation. + This could be a collection reference, ID, or client object. + + Raises: + ValueError: If unable to create the collection. + ConnectionError: If unable to connect to the vector database backend. + """ + ... + + @abstractmethod + async def aget_or_create_collection( + self, **kwargs: Unpack[BaseCollectionParams] + ) -> Any: + """Get an existing collection or create it if it doesn't exist asynchronously. + + Keyword Args: + collection_name: The name of the collection to get or create. + + Returns: + A collection object whose type depends on the backend implementation. + + Raises: + ValueError: If unable to create the collection. + ConnectionError: If unable to connect to the vector database backend. + """ + ... + + @abstractmethod + def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: + """Add documents with their embeddings to a collection. + + This method performs an upsert operation - if a document with the same ID + already exists, it will be updated with the new content and metadata. + + Implementations should handle embedding generation internally based on + the configured embedding function. + + Keyword Args: + collection_name: The name of the collection to add documents to. + documents: List of BaseRecord dicts containing: + - content: The text content (required) + - doc_id: Optional unique identifier (auto-generated from content hash if missing) + - metadata: Optional metadata dictionary + Embeddings will be generated automatically. + + Raises: + ValueError: If collection doesn't exist or documents list is empty. + TypeError: If documents are not BaseRecord dict instances. + ConnectionError: If unable to connect to the vector database backend. + + Example: + >>> from crewai.rag.chromadb.client import ChromaDBClient + >>> from crewai.rag.types import BaseRecord + >>> client = ChromaDBClient() + >>> + >>> records: list[BaseRecord] = [ + ... { + ... "content": "Machine learning basics", + ... "metadata": {"source": "file3", "topic": "ML"} + ... }, + ... { + ... "doc_id": "custom_id", + ... "content": "Deep learning fundamentals", + ... "metadata": {"source": "file4", "topic": "DL"} + ... } + ... ] + >>> client.add_documents(collection_name="my_docs", documents=records) + >>> + >>> records_with_id: list[BaseRecord] = [ + ... { + ... "doc_id": "nlp_001", + ... "content": "Advanced NLP techniques", + ... "metadata": {"source": "file5", "topic": "NLP"} + ... } + ... ] + >>> client.add_documents(collection_name="my_docs", documents=records_with_id) + """ + ... + + @abstractmethod + async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: + """Add documents with their embeddings to a collection asynchronously. + + Implementations should handle embedding generation internally based on + the configured embedding function. + + Keyword Args: + collection_name: The name of the collection to add documents to. + documents: List of BaseRecord dicts containing: + - content: The text content (required) + - doc_id: Optional unique identifier (auto-generated from content hash if missing) + - metadata: Optional metadata dictionary + Embeddings will be generated automatically. + + Raises: + ValueError: If collection doesn't exist or documents list is empty. + TypeError: If documents are not BaseRecord dict instances. + ConnectionError: If unable to connect to the vector database backend. + + Example: + >>> import asyncio + >>> from crewai.rag.chromadb.client import ChromaDBClient + >>> from crewai.rag.types import BaseRecord + >>> + >>> async def add_documents(): + ... client = ChromaDBClient() + ... + ... records: list[BaseRecord] = [ + ... { + ... "doc_id": "doc2", + ... "content": "Async operations in Python", + ... "metadata": {"source": "file2", "topic": "async"} + ... } + ... ] + ... await client.aadd_documents(collection_name="my_docs", documents=records) + ... + >>> asyncio.run(add_documents()) + """ + ... + + @abstractmethod + def search( + self, **kwargs: Unpack[BaseCollectionSearchParams] + ) -> list[SearchResult]: + """Search for similar documents using a query. + + Performs a vector similarity search to find the most similar documents + to the provided query. + + Keyword Args: + collection_name: The name of the collection to search in. + query: The text query to search for. The implementation handles + embedding generation internally. + limit: Maximum number of results to return. Defaults to 10. + metadata_filter: Optional metadata filter to apply to the search. The exact + format depends on the backend, but typically supports equality + and range queries on metadata fields. + score_threshold: Optional minimum similarity score threshold. Only + results with scores >= this threshold will be returned. The + score interpretation depends on the distance metric used. + + Returns: + A list of SearchResult dictionaries ordered by similarity score in + descending order. Each result contains: + - id: Document ID + - content: Document text content + - metadata: Document metadata + - score: Similarity score (0-1, higher is better) + + Raises: + ValueError: If collection doesn't exist. + ConnectionError: If unable to connect to the vector database backend. + + Example: + >>> from crewai.rag.chromadb.client import ChromaDBClient + >>> client = ChromaDBClient() + >>> + >>> results = client.search( + ... collection_name="my_docs", + ... query="What is machine learning?", + ... limit=5, + ... metadata_filter={"source": "file1"}, + ... score_threshold=0.7 + ... ) + >>> for result in results: + ... print(f"{result['id']}: {result['score']:.2f}") + """ + ... + + @abstractmethod + async def asearch( + self, **kwargs: Unpack[BaseCollectionSearchParams] + ) -> list[SearchResult]: + """Search for similar documents using a query asynchronously. + + Keyword Args: + collection_name: The name of the collection to search in. + query: The text query to search for. The implementation handles + embedding generation internally. + limit: Maximum number of results to return. Defaults to 10. + metadata_filter: Optional metadata filter to apply to the search. + score_threshold: Optional minimum similarity score threshold. + + Returns: + A list of SearchResult dictionaries ordered by similarity score. + + Raises: + ValueError: If collection doesn't exist. + ConnectionError: If unable to connect to the vector database backend. + + Example: + >>> import asyncio + >>> from crewai.rag.chromadb.client import ChromaDBClient + >>> + >>> async def search_documents(): + ... client = ChromaDBClient() + ... results = await client.asearch( + ... collection_name="my_docs", + ... query="Python programming best practices", + ... limit=5, + ... metadata_filter={"source": "file1"}, + ... score_threshold=0.7 + ... ) + ... for result in results: + ... print(f"{result['id']}: {result['score']:.2f}") + ... + >>> asyncio.run(search_documents()) + """ + ... + + @abstractmethod + def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: + """Delete a collection and all its data. + + This operation is irreversible and will permanently remove all documents, + embeddings, and metadata associated with the collection. + + Keyword Args: + collection_name: The name of the collection to delete. + + Raises: + ValueError: If the collection doesn't exist. + ConnectionError: If unable to connect to the vector database backend. + + Example: + >>> from crewai.rag.chromadb.client import ChromaDBClient + >>> client = ChromaDBClient() + >>> client.delete_collection(collection_name="old_docs") + >>> print("Collection 'old_docs' deleted successfully") + """ + ... + + @abstractmethod + async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: + """Delete a collection and all its data asynchronously. + + Keyword Args: + collection_name: The name of the collection to delete. + + Raises: + ValueError: If the collection doesn't exist. + ConnectionError: If unable to connect to the vector database backend. + + Example: + >>> import asyncio + >>> from crewai.rag.chromadb.client import ChromaDBClient + >>> + >>> async def delete_old_collection(): + ... client = ChromaDBClient() + ... await client.adelete_collection(collection_name="old_docs") + ... print("Collection 'old_docs' deleted successfully") + ... + >>> asyncio.run(delete_old_collection()) + """ + ... + + @abstractmethod + def reset(self) -> None: + """Reset the vector database by deleting all collections and data. + + This method provides a way to completely clear the vector database, + removing all collections and their contents. Use with caution as + this operation is irreversible. + + Raises: + ConnectionError: If unable to connect to the vector database backend. + PermissionError: If the operation is not allowed by the backend. + + Example: + >>> from crewai.rag.chromadb.client import ChromaDBClient + >>> client = ChromaDBClient() + >>> client.reset() + >>> print("Vector database completely reset - all data deleted") + """ + ... + + @abstractmethod + async def areset(self) -> None: + """Reset the vector database by deleting all collections and data asynchronously. + + Raises: + ConnectionError: If unable to connect to the vector database backend. + PermissionError: If the operation is not allowed by the backend. + + Example: + >>> import asyncio + >>> from crewai.rag.chromadb.client import ChromaDBClient + >>> + >>> async def reset_database(): + ... client = ChromaDBClient() + ... await client.areset() + ... print("Vector database completely reset - all data deleted") + ... + >>> asyncio.run(reset_database()) + """ + ... diff --git a/src/crewai/rag/core/base_provider.py b/src/crewai/rag/core/base_provider.py new file mode 100644 index 000000000..0651ce540 --- /dev/null +++ b/src/crewai/rag/core/base_provider.py @@ -0,0 +1,30 @@ +"""Base provider protocol for vector database client creation.""" + +from abc import ABC +from typing import Any, Protocol, runtime_checkable, Union +from pydantic import BaseModel, Field + +from crewai.rag.types import EmbeddingFunction +from crewai.rag.embeddings.types import EmbeddingOptions + + +class BaseProviderOptions(BaseModel, ABC): + """Base configuration for all provider options.""" + + client_type: str = Field(..., description="Type of client to create") + embedding_config: Union[EmbeddingOptions, EmbeddingFunction, None] = Field( + default=None, + description="Embedding configuration - either options for built-in providers or a custom function", + ) + options: Any = Field( + default=None, description="Additional provider-specific options" + ) + + +@runtime_checkable +class BaseProvider(Protocol): + """Protocol for vector database client providers.""" + + def __call__(self, options: BaseProviderOptions) -> Any: + """Create and return a configured client instance.""" + ... diff --git a/src/crewai/rag/embeddings/factory.py b/src/crewai/rag/embeddings/factory.py new file mode 100644 index 000000000..ff3a78c17 --- /dev/null +++ b/src/crewai/rag/embeddings/factory.py @@ -0,0 +1,148 @@ +"""Minimal embedding function factory for CrewAI.""" + +import os + +from chromadb import EmbeddingFunction +from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( + AmazonBedrockEmbeddingFunction, +) +from chromadb.utils.embedding_functions.cohere_embedding_function import ( + CohereEmbeddingFunction, +) +from chromadb.utils.embedding_functions.google_embedding_function import ( + GooglePalmEmbeddingFunction, + GoogleGenerativeAiEmbeddingFunction, + GoogleVertexEmbeddingFunction, +) +from chromadb.utils.embedding_functions.huggingface_embedding_function import ( + HuggingFaceEmbeddingFunction, +) +from chromadb.utils.embedding_functions.instructor_embedding_function import ( + InstructorEmbeddingFunction, +) +from chromadb.utils.embedding_functions.jina_embedding_function import ( + JinaEmbeddingFunction, +) +from chromadb.utils.embedding_functions.ollama_embedding_function import ( + OllamaEmbeddingFunction, +) +from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2 +from chromadb.utils.embedding_functions.open_clip_embedding_function import ( + OpenCLIPEmbeddingFunction, +) +from chromadb.utils.embedding_functions.openai_embedding_function import ( + OpenAIEmbeddingFunction, +) +from chromadb.utils.embedding_functions.roboflow_embedding_function import ( + RoboflowEmbeddingFunction, +) +from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import ( + SentenceTransformerEmbeddingFunction, +) +from chromadb.utils.embedding_functions.text2vec_embedding_function import ( + Text2VecEmbeddingFunction, +) + +from crewai.rag.embeddings.types import EmbeddingOptions + + +def get_embedding_function( + config: EmbeddingOptions | dict | None = None, +) -> EmbeddingFunction: + """Get embedding function - delegates to ChromaDB. + + Args: + config: Optional configuration - either an EmbeddingOptions object or a dict with: + - provider: The embedding provider to use (default: "openai") + - Any other provider-specific parameters + + Returns: + EmbeddingFunction instance ready for use with ChromaDB + + Supported providers: + - openai: OpenAI embeddings (default) + - cohere: Cohere embeddings + - ollama: Ollama local embeddings + - huggingface: HuggingFace embeddings + - sentence-transformer: Local sentence transformers + - instructor: Instructor embeddings for specialized tasks + - google-palm: Google PaLM embeddings + - google-generativeai: Google Generative AI embeddings + - google-vertex: Google Vertex AI embeddings + - amazon-bedrock: AWS Bedrock embeddings + - jina: Jina AI embeddings + - roboflow: Roboflow embeddings for vision tasks + - openclip: OpenCLIP embeddings for multimodal tasks + - text2vec: Text2Vec embeddings + - onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB) + + Examples: + # Use default OpenAI with retry logic + >>> embedder = get_embedding_function() + + # Use Cohere with dict + >>> embedder = get_embedding_function({ + ... "provider": "cohere", + ... "api_key": "your-key", + ... "model_name": "embed-english-v3.0" + ... }) + + # Use with EmbeddingOptions + >>> embedder = get_embedding_function( + ... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2") + ... ) + + # Use local sentence transformers (no API key needed) + >>> embedder = get_embedding_function({ + ... "provider": "sentence-transformer", + ... "model_name": "all-MiniLM-L6-v2" + ... }) + + # Use Ollama for local embeddings + >>> embedder = get_embedding_function({ + ... "provider": "ollama", + ... "model_name": "nomic-embed-text" + ... }) + + # Use ONNX (no API key needed) + >>> embedder = get_embedding_function({ + ... "provider": "onnx" + ... }) + """ + if config is None: + return OpenAIEmbeddingFunction( + api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" + ) + + # Handle EmbeddingOptions object + if isinstance(config, EmbeddingOptions): + config_dict = config.model_dump(exclude_none=True) + else: + config_dict = config.copy() + + provider = config_dict.pop("provider", "openai") + + embedding_functions = { + "openai": OpenAIEmbeddingFunction, + "cohere": CohereEmbeddingFunction, + "ollama": OllamaEmbeddingFunction, + "huggingface": HuggingFaceEmbeddingFunction, + "sentence-transformer": SentenceTransformerEmbeddingFunction, + "instructor": InstructorEmbeddingFunction, + "google-palm": GooglePalmEmbeddingFunction, + "google-generativeai": GoogleGenerativeAiEmbeddingFunction, + "google-vertex": GoogleVertexEmbeddingFunction, + "amazon-bedrock": AmazonBedrockEmbeddingFunction, + "jina": JinaEmbeddingFunction, + "roboflow": RoboflowEmbeddingFunction, + "openclip": OpenCLIPEmbeddingFunction, + "text2vec": Text2VecEmbeddingFunction, + "onnx": ONNXMiniLM_L6_V2, + } + + if provider not in embedding_functions: + raise ValueError( + f"Unsupported provider: {provider}. " + f"Available providers: {list(embedding_functions.keys())}" + ) + return embedding_functions[provider](**config_dict) diff --git a/src/crewai/rag/embeddings/types.py b/src/crewai/rag/embeddings/types.py new file mode 100644 index 000000000..a799bc45a --- /dev/null +++ b/src/crewai/rag/embeddings/types.py @@ -0,0 +1,62 @@ +"""Type definitions for the embeddings module.""" + +from typing import Literal +from pydantic import BaseModel, Field, SecretStr + +from crewai.rag.types import EmbeddingFunction + + +EmbeddingProvider = Literal[ + "openai", + "cohere", + "ollama", + "huggingface", + "sentence-transformer", + "instructor", + "google-palm", + "google-generativeai", + "google-vertex", + "amazon-bedrock", + "jina", + "roboflow", + "openclip", + "text2vec", + "onnx", +] +"""Supported embedding providers. + +These correspond to the embedding functions available in ChromaDB's +embedding_functions module. Each provider has specific requirements +and configuration options. +""" + + +class EmbeddingOptions(BaseModel): + """Configuration options for embedding providers. + + Generic attributes that can be passed to get_embedding_function + to configure various embedding providers. + """ + + provider: EmbeddingProvider = Field( + ..., description="Embedding provider name (e.g., 'openai', 'cohere', 'onnx')" + ) + model_name: str | None = Field( + default=None, description="Model name for the embedding provider" + ) + api_key: SecretStr | None = Field( + default=None, description="API key for the embedding provider" + ) + + +class EmbeddingConfig(BaseModel): + """Configuration wrapper for embedding functions. + + Accepts either a pre-configured EmbeddingFunction or EmbeddingOptions + to create one. This provides flexibility in how embeddings are configured. + + Attributes: + function: Either a callable EmbeddingFunction or EmbeddingOptions to create one + """ + + function: EmbeddingFunction | EmbeddingOptions diff --git a/src/crewai/rag/types.py b/src/crewai/rag/types.py new file mode 100644 index 000000000..0f44422a8 --- /dev/null +++ b/src/crewai/rag/types.py @@ -0,0 +1,50 @@ +"""Type definitions for RAG (Retrieval-Augmented Generation) systems.""" + +from collections.abc import Callable, Mapping +from typing import TypeAlias, TypedDict, Any + +from typing_extensions import Required + + +class BaseRecord(TypedDict, total=False): + """A typed dictionary representing a document record. + + Attributes: + doc_id: Optional unique identifier for the document. If not provided, + a content-based ID will be generated using SHA256 hash. + content: The text content of the document (required) + metadata: Optional metadata associated with the document + """ + + doc_id: str + content: Required[str] + metadata: ( + Mapping[str, str | int | float | bool] + | list[Mapping[str, str | int | float | bool]] + ) + + +DenseVector: TypeAlias = list[float] +IntVector: TypeAlias = list[int] + +EmbeddingFunction: TypeAlias = Callable[..., Any] + + +class SearchResult(TypedDict): + """Standard search result format for vector store queries. + + This provides a consistent interface for search results across different + vector store implementations. Each implementation should convert their + native result format to this standard format. + + Attributes: + id: Unique identifier of the document + content: The text content of the document + metadata: Optional metadata associated with the document + score: Similarity score (higher is better, typically between 0 and 1) + """ + + id: str + content: str + metadata: dict[str, Any] + score: float