mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
feat: centralize embedding types and create base client (#3246)
feat: add RAG system foundation with generic vector store support - Add BaseClient protocol for vector stores - Move BaseRAGStorage to rag/core - Centralize embedding types in embeddings/types.py - Remove unused storage models
This commit is contained in:
1
src/crewai/rag/core/__init__.py
Normal file
1
src/crewai/rag/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core abstract base classes and protocols for RAG systems."""
|
||||
433
src/crewai/rag/core/base_client.py
Normal file
433
src/crewai/rag/core/base_client.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Protocol for vector database client implementations."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated
|
||||
from typing_extensions import Unpack, Required
|
||||
|
||||
|
||||
from crewai.rag.types import (
|
||||
EmbeddingFunction,
|
||||
BaseRecord,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
|
||||
class BaseCollectionParams(TypedDict):
|
||||
"""Base parameters for collection operations.
|
||||
|
||||
Attributes:
|
||||
collection_name: The name of the collection/index to operate on.
|
||||
"""
|
||||
|
||||
collection_name: Required[
|
||||
Annotated[
|
||||
str,
|
||||
"Name of the collection/index. Implementations may have specific constraints (e.g., character limits, allowed characters, case sensitivity).",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class BaseCollectionAddParams(BaseCollectionParams):
|
||||
"""Parameters for adding documents to a collection.
|
||||
|
||||
Extends BaseCollectionParams with document-specific fields.
|
||||
|
||||
Attributes:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dictionaries containing document data.
|
||||
"""
|
||||
|
||||
documents: list[BaseRecord]
|
||||
|
||||
|
||||
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for searching within a collection.
|
||||
|
||||
Extends BaseCollectionParams with search-specific optional fields.
|
||||
All fields except collection_name and query are optional.
|
||||
|
||||
Attributes:
|
||||
query: The text query to search for (required).
|
||||
limit: Maximum number of results to return.
|
||||
metadata_filter: Filter results by metadata fields.
|
||||
score_threshold: Minimum similarity score for results (0-1).
|
||||
"""
|
||||
|
||||
query: Required[str]
|
||||
limit: int
|
||||
metadata_filter: dict[str, Any]
|
||||
score_threshold: float
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseClient(Protocol):
|
||||
"""Protocol for vector store client implementations.
|
||||
|
||||
This protocol defines the interface that all vector store client implementations
|
||||
must follow. It provides a consistent API for storing and retrieving
|
||||
documents with their vector embeddings across different vector database
|
||||
backends (e.g., Qdrant, ChromaDB, Weaviate). Implementing classes should
|
||||
handle connection management, data persistence, and vector similarity
|
||||
search operations specific to their backend.
|
||||
|
||||
Implementation Guidelines:
|
||||
Implementations should accept BaseClientParams in their constructor to allow
|
||||
passing pre-configured client instances:
|
||||
|
||||
class MyVectorClient:
|
||||
def __init__(self, client: Any | None = None, **kwargs):
|
||||
if client:
|
||||
self.client = client
|
||||
else:
|
||||
self.client = self._create_default_client(**kwargs)
|
||||
|
||||
Notes:
|
||||
This protocol replaces the former BaseRAGStorage abstraction,
|
||||
providing a cleaner interface for vector store operations.
|
||||
|
||||
Attributes:
|
||||
embedding_function: Callable that takes a list of text strings
|
||||
and returns a list of embedding vectors. Implementations
|
||||
should always provide a default embedding function.
|
||||
client: The underlying vector database client instance. This could be
|
||||
passed via BaseClientParams during initialization or created internally.
|
||||
"""
|
||||
|
||||
client: Any
|
||||
embedding_function: EmbeddingFunction
|
||||
|
||||
@abstractmethod
|
||||
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Create a new collection/index in the vector database.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to create. Must be unique within
|
||||
the vector database instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection name already exists.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def acreate_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Create a new collection/index in the vector database asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to create. Must be unique within
|
||||
the vector database instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection name already exists.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_or_create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
This method provides a convenient way to ensure a collection exists
|
||||
without having to check for its existence first.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to get or create.
|
||||
|
||||
Returns:
|
||||
A collection object whose type depends on the backend implementation.
|
||||
This could be a collection reference, ID, or client object.
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to create the collection.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[BaseCollectionParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to get or create.
|
||||
|
||||
Returns:
|
||||
A collection object whose type depends on the backend implementation.
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to create the collection.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
This method performs an upsert operation - if a document with the same ID
|
||||
already exists, it will be updated with the new content and metadata.
|
||||
|
||||
Implementations should handle embedding generation internally based on
|
||||
the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
Embeddings will be generated automatically.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
TypeError: If documents are not BaseRecord dict instances.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> from crewai.rag.types import BaseRecord
|
||||
>>> client = ChromaDBClient()
|
||||
>>>
|
||||
>>> records: list[BaseRecord] = [
|
||||
... {
|
||||
... "content": "Machine learning basics",
|
||||
... "metadata": {"source": "file3", "topic": "ML"}
|
||||
... },
|
||||
... {
|
||||
... "doc_id": "custom_id",
|
||||
... "content": "Deep learning fundamentals",
|
||||
... "metadata": {"source": "file4", "topic": "DL"}
|
||||
... }
|
||||
... ]
|
||||
>>> client.add_documents(collection_name="my_docs", documents=records)
|
||||
>>>
|
||||
>>> records_with_id: list[BaseRecord] = [
|
||||
... {
|
||||
... "doc_id": "nlp_001",
|
||||
... "content": "Advanced NLP techniques",
|
||||
... "metadata": {"source": "file5", "topic": "NLP"}
|
||||
... }
|
||||
... ]
|
||||
>>> client.add_documents(collection_name="my_docs", documents=records_with_id)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Implementations should handle embedding generation internally based on
|
||||
the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
Embeddings will be generated automatically.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
TypeError: If documents are not BaseRecord dict instances.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> from crewai.rag.types import BaseRecord
|
||||
>>>
|
||||
>>> async def add_documents():
|
||||
... client = ChromaDBClient()
|
||||
...
|
||||
... records: list[BaseRecord] = [
|
||||
... {
|
||||
... "doc_id": "doc2",
|
||||
... "content": "Async operations in Python",
|
||||
... "metadata": {"source": "file2", "topic": "async"}
|
||||
... }
|
||||
... ]
|
||||
... await client.aadd_documents(collection_name="my_docs", documents=records)
|
||||
...
|
||||
>>> asyncio.run(add_documents())
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Performs a vector similarity search to find the most similar documents
|
||||
to the provided query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to search in.
|
||||
query: The text query to search for. The implementation handles
|
||||
embedding generation internally.
|
||||
limit: Maximum number of results to return. Defaults to 10.
|
||||
metadata_filter: Optional metadata filter to apply to the search. The exact
|
||||
format depends on the backend, but typically supports equality
|
||||
and range queries on metadata fields.
|
||||
score_threshold: Optional minimum similarity score threshold. Only
|
||||
results with scores >= this threshold will be returned. The
|
||||
score interpretation depends on the distance metric used.
|
||||
|
||||
Returns:
|
||||
A list of SearchResult dictionaries ordered by similarity score in
|
||||
descending order. Each result contains:
|
||||
- id: Document ID
|
||||
- content: Document text content
|
||||
- metadata: Document metadata
|
||||
- score: Similarity score (0-1, higher is better)
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> client = ChromaDBClient()
|
||||
>>>
|
||||
>>> results = client.search(
|
||||
... collection_name="my_docs",
|
||||
... query="What is machine learning?",
|
||||
... limit=5,
|
||||
... metadata_filter={"source": "file1"},
|
||||
... score_threshold=0.7
|
||||
... )
|
||||
>>> for result in results:
|
||||
... print(f"{result['id']}: {result['score']:.2f}")
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to search in.
|
||||
query: The text query to search for. The implementation handles
|
||||
embedding generation internally.
|
||||
limit: Maximum number of results to return. Defaults to 10.
|
||||
metadata_filter: Optional metadata filter to apply to the search.
|
||||
score_threshold: Optional minimum similarity score threshold.
|
||||
|
||||
Returns:
|
||||
A list of SearchResult dictionaries ordered by similarity score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>>
|
||||
>>> async def search_documents():
|
||||
... client = ChromaDBClient()
|
||||
... results = await client.asearch(
|
||||
... collection_name="my_docs",
|
||||
... query="Python programming best practices",
|
||||
... limit=5,
|
||||
... metadata_filter={"source": "file1"},
|
||||
... score_threshold=0.7
|
||||
... )
|
||||
... for result in results:
|
||||
... print(f"{result['id']}: {result['score']:.2f}")
|
||||
...
|
||||
>>> asyncio.run(search_documents())
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
This operation is irreversible and will permanently remove all documents,
|
||||
embeddings, and metadata associated with the collection.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.delete_collection(collection_name="old_docs")
|
||||
>>> print("Collection 'old_docs' deleted successfully")
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>>
|
||||
>>> async def delete_old_collection():
|
||||
... client = ChromaDBClient()
|
||||
... await client.adelete_collection(collection_name="old_docs")
|
||||
... print("Collection 'old_docs' deleted successfully")
|
||||
...
|
||||
>>> asyncio.run(delete_old_collection())
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
This method provides a way to completely clear the vector database,
|
||||
removing all collections and their contents. Use with caution as
|
||||
this operation is irreversible.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
PermissionError: If the operation is not allowed by the backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.reset()
|
||||
>>> print("Vector database completely reset - all data deleted")
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
PermissionError: If the operation is not allowed by the backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>>
|
||||
>>> async def reset_database():
|
||||
... client = ChromaDBClient()
|
||||
... await client.areset()
|
||||
... print("Vector database completely reset - all data deleted")
|
||||
...
|
||||
>>> asyncio.run(reset_database())
|
||||
"""
|
||||
...
|
||||
30
src/crewai/rag/core/base_provider.py
Normal file
30
src/crewai/rag/core/base_provider.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Base provider protocol for vector database client creation."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Protocol, runtime_checkable, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.rag.types import EmbeddingFunction
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
|
||||
class BaseProviderOptions(BaseModel, ABC):
|
||||
"""Base configuration for all provider options."""
|
||||
|
||||
client_type: str = Field(..., description="Type of client to create")
|
||||
embedding_config: Union[EmbeddingOptions, EmbeddingFunction, None] = Field(
|
||||
default=None,
|
||||
description="Embedding configuration - either options for built-in providers or a custom function",
|
||||
)
|
||||
options: Any = Field(
|
||||
default=None, description="Additional provider-specific options"
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseProvider(Protocol):
|
||||
"""Protocol for vector database client providers."""
|
||||
|
||||
def __call__(self, options: BaseProviderOptions) -> Any:
|
||||
"""Create and return a configured client instance."""
|
||||
...
|
||||
148
src/crewai/rag/embeddings/factory.py
Normal file
148
src/crewai/rag/embeddings/factory.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Minimal embedding function factory for CrewAI."""
|
||||
|
||||
import os
|
||||
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
config: EmbeddingOptions | dict | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Get embedding function - delegates to ChromaDB.
|
||||
|
||||
Args:
|
||||
config: Optional configuration - either an EmbeddingOptions object or a dict with:
|
||||
- provider: The embedding provider to use (default: "openai")
|
||||
- Any other provider-specific parameters
|
||||
|
||||
Returns:
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
|
||||
Supported providers:
|
||||
- openai: OpenAI embeddings (default)
|
||||
- cohere: Cohere embeddings
|
||||
- ollama: Ollama local embeddings
|
||||
- huggingface: HuggingFace embeddings
|
||||
- sentence-transformer: Local sentence transformers
|
||||
- instructor: Instructor embeddings for specialized tasks
|
||||
- google-palm: Google PaLM embeddings
|
||||
- google-generativeai: Google Generative AI embeddings
|
||||
- google-vertex: Google Vertex AI embeddings
|
||||
- amazon-bedrock: AWS Bedrock embeddings
|
||||
- jina: Jina AI embeddings
|
||||
- roboflow: Roboflow embeddings for vision tasks
|
||||
- openclip: OpenCLIP embeddings for multimodal tasks
|
||||
- text2vec: Text2Vec embeddings
|
||||
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
||||
|
||||
Examples:
|
||||
# Use default OpenAI with retry logic
|
||||
>>> embedder = get_embedding_function()
|
||||
|
||||
# Use Cohere with dict
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "cohere",
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... })
|
||||
|
||||
# Use with EmbeddingOptions
|
||||
>>> embedder = get_embedding_function(
|
||||
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
||||
... )
|
||||
|
||||
# Use local sentence transformers (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "sentence-transformer",
|
||||
... "model_name": "all-MiniLM-L6-v2"
|
||||
... })
|
||||
|
||||
# Use Ollama for local embeddings
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "ollama",
|
||||
... "model_name": "nomic-embed-text"
|
||||
... })
|
||||
|
||||
# Use ONNX (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "onnx"
|
||||
... })
|
||||
"""
|
||||
if config is None:
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
# Handle EmbeddingOptions object
|
||||
if isinstance(config, EmbeddingOptions):
|
||||
config_dict = config.model_dump(exclude_none=True)
|
||||
else:
|
||||
config_dict = config.copy()
|
||||
|
||||
provider = config_dict.pop("provider", "openai")
|
||||
|
||||
embedding_functions = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
|
||||
if provider not in embedding_functions:
|
||||
raise ValueError(
|
||||
f"Unsupported provider: {provider}. "
|
||||
f"Available providers: {list(embedding_functions.keys())}"
|
||||
)
|
||||
return embedding_functions[provider](**config_dict)
|
||||
62
src/crewai/rag/embeddings/types.py
Normal file
62
src/crewai/rag/embeddings/types.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Type definitions for the embeddings module."""
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from crewai.rag.types import EmbeddingFunction
|
||||
|
||||
|
||||
EmbeddingProvider = Literal[
|
||||
"openai",
|
||||
"cohere",
|
||||
"ollama",
|
||||
"huggingface",
|
||||
"sentence-transformer",
|
||||
"instructor",
|
||||
"google-palm",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"amazon-bedrock",
|
||||
"jina",
|
||||
"roboflow",
|
||||
"openclip",
|
||||
"text2vec",
|
||||
"onnx",
|
||||
]
|
||||
"""Supported embedding providers.
|
||||
|
||||
These correspond to the embedding functions available in ChromaDB's
|
||||
embedding_functions module. Each provider has specific requirements
|
||||
and configuration options.
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingOptions(BaseModel):
|
||||
"""Configuration options for embedding providers.
|
||||
|
||||
Generic attributes that can be passed to get_embedding_function
|
||||
to configure various embedding providers.
|
||||
"""
|
||||
|
||||
provider: EmbeddingProvider = Field(
|
||||
..., description="Embedding provider name (e.g., 'openai', 'cohere', 'onnx')"
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None, description="Model name for the embedding provider"
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None, description="API key for the embedding provider"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
"""Configuration wrapper for embedding functions.
|
||||
|
||||
Accepts either a pre-configured EmbeddingFunction or EmbeddingOptions
|
||||
to create one. This provides flexibility in how embeddings are configured.
|
||||
|
||||
Attributes:
|
||||
function: Either a callable EmbeddingFunction or EmbeddingOptions to create one
|
||||
"""
|
||||
|
||||
function: EmbeddingFunction | EmbeddingOptions
|
||||
50
src/crewai/rag/types.py
Normal file
50
src/crewai/rag/types.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TypeAlias, TypedDict, Any
|
||||
|
||||
from typing_extensions import Required
|
||||
|
||||
|
||||
class BaseRecord(TypedDict, total=False):
|
||||
"""A typed dictionary representing a document record.
|
||||
|
||||
Attributes:
|
||||
doc_id: Optional unique identifier for the document. If not provided,
|
||||
a content-based ID will be generated using SHA256 hash.
|
||||
content: The text content of the document (required)
|
||||
metadata: Optional metadata associated with the document
|
||||
"""
|
||||
|
||||
doc_id: str
|
||||
content: Required[str]
|
||||
metadata: (
|
||||
Mapping[str, str | int | float | bool]
|
||||
| list[Mapping[str, str | int | float | bool]]
|
||||
)
|
||||
|
||||
|
||||
DenseVector: TypeAlias = list[float]
|
||||
IntVector: TypeAlias = list[int]
|
||||
|
||||
EmbeddingFunction: TypeAlias = Callable[..., Any]
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
"""Standard search result format for vector store queries.
|
||||
|
||||
This provides a consistent interface for search results across different
|
||||
vector store implementations. Each implementation should convert their
|
||||
native result format to this standard format.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier of the document
|
||||
content: The text content of the document
|
||||
metadata: Optional metadata associated with the document
|
||||
score: Similarity score (higher is better, typically between 0 and 1)
|
||||
"""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
metadata: dict[str, Any]
|
||||
score: float
|
||||
Reference in New Issue
Block a user