mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Merge branch 'release/v1.0.0' into lorenze/native-anthropic-test
This commit is contained in:
@@ -15,7 +15,6 @@ dependencies = [
|
|||||||
"crewai==1.0.0b1",
|
"crewai==1.0.0b1",
|
||||||
"lancedb>=0.5.4",
|
"lancedb>=0.5.4",
|
||||||
"tiktoken>=0.8.0",
|
"tiktoken>=0.8.0",
|
||||||
"stagehand>=0.4.1",
|
|
||||||
"beautifulsoup4>=4.13.4",
|
"beautifulsoup4>=4.13.4",
|
||||||
"pypdf>=5.9.0",
|
"pypdf>=5.9.0",
|
||||||
"python-docx>=1.2.0",
|
"python-docx>=1.2.0",
|
||||||
|
|||||||
@@ -3,13 +3,16 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, TypeAlias, TypedDict
|
from typing import Any, TypeAlias, TypedDict
|
||||||
|
import uuid
|
||||||
|
|
||||||
from crewai.rag.config.types import RagConfigType
|
from crewai.rag.config.types import RagConfigType
|
||||||
from crewai.rag.config.utils import get_rag_client
|
from crewai.rag.config.utils import get_rag_client
|
||||||
from crewai.rag.core.base_client import BaseClient
|
from crewai.rag.core.base_client import BaseClient
|
||||||
from crewai.rag.factory import create_client
|
from crewai.rag.factory import create_client
|
||||||
|
from crewai.rag.qdrant.config import QdrantConfig
|
||||||
from crewai.rag.types import BaseRecord, SearchResult
|
from crewai.rag.types import BaseRecord, SearchResult
|
||||||
from pydantic import PrivateAttr
|
from pydantic import PrivateAttr
|
||||||
|
from qdrant_client.models import VectorParams
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from crewai_tools.rag.data_types import DataType
|
from crewai_tools.rag.data_types import DataType
|
||||||
@@ -52,7 +55,11 @@ class CrewAIRagAdapter(Adapter):
|
|||||||
self._client = create_client(self.config)
|
self._client = create_client(self.config)
|
||||||
else:
|
else:
|
||||||
self._client = get_rag_client()
|
self._client = get_rag_client()
|
||||||
self._client.get_or_create_collection(collection_name=self.collection_name)
|
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
||||||
|
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||||
|
if isinstance(self.config.vectors_config, VectorParams):
|
||||||
|
collection_params["vectors_config"] = self.config.vectors_config
|
||||||
|
self._client.get_or_create_collection(**collection_params)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
@@ -76,6 +83,8 @@ class CrewAIRagAdapter(Adapter):
|
|||||||
if similarity_threshold is not None
|
if similarity_threshold is not None
|
||||||
else self.similarity_threshold
|
else self.similarity_threshold
|
||||||
)
|
)
|
||||||
|
if self._client is None:
|
||||||
|
raise ValueError("Client is not initialized")
|
||||||
|
|
||||||
results: list[SearchResult] = self._client.search(
|
results: list[SearchResult] = self._client.search(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
@@ -201,9 +210,10 @@ class CrewAIRagAdapter(Adapter):
|
|||||||
if isinstance(arg, dict):
|
if isinstance(arg, dict):
|
||||||
file_metadata.update(arg.get("metadata", {}))
|
file_metadata.update(arg.get("metadata", {}))
|
||||||
|
|
||||||
chunk_id = hashlib.sha256(
|
chunk_hash = hashlib.sha256(
|
||||||
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
|
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
chunk_id = str(uuid.UUID(chunk_hash[:32]))
|
||||||
|
|
||||||
documents.append(
|
documents.append(
|
||||||
{
|
{
|
||||||
@@ -251,9 +261,10 @@ class CrewAIRagAdapter(Adapter):
|
|||||||
if isinstance(arg, dict):
|
if isinstance(arg, dict):
|
||||||
chunk_metadata.update(arg.get("metadata", {}))
|
chunk_metadata.update(arg.get("metadata", {}))
|
||||||
|
|
||||||
chunk_id = hashlib.sha256(
|
chunk_hash = hashlib.sha256(
|
||||||
f"{loader_result.doc_id}_{i}_{chunk}".encode()
|
f"{loader_result.doc_id}_{i}_{chunk}".encode()
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
chunk_id = str(uuid.UUID(chunk_hash[:32]))
|
||||||
|
|
||||||
documents.append(
|
documents.append(
|
||||||
{
|
{
|
||||||
@@ -264,6 +275,8 @@ class CrewAIRagAdapter(Adapter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if documents:
|
if documents:
|
||||||
|
if self._client is None:
|
||||||
|
raise ValueError("Client is not initialized")
|
||||||
self._client.add_documents(
|
self._client.add_documents(
|
||||||
collection_name=self.collection_name, documents=documents
|
collection_name=self.collection_name, documents=documents
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ from typing import Any
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import litellm
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from crewai_tools.rag.base_loader import BaseLoader
|
from crewai_tools.rag.base_loader import BaseLoader
|
||||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||||
from crewai_tools.rag.data_types import DataType
|
from crewai_tools.rag.data_types import DataType
|
||||||
|
from crewai_tools.rag.embedding_service import EmbeddingService
|
||||||
from crewai_tools.rag.misc import compute_sha256
|
from crewai_tools.rag.misc import compute_sha256
|
||||||
from crewai_tools.rag.source_content import SourceContent
|
from crewai_tools.rag.source_content import SourceContent
|
||||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||||
@@ -18,31 +18,6 @@ from crewai_tools.tools.rag.rag_tool import Adapter
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingService:
|
|
||||||
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
|
|
||||||
self.model = model
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def embed_text(self, text: str) -> list[float]:
|
|
||||||
try:
|
|
||||||
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
|
|
||||||
return response.data[0]["embedding"]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating embedding: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
|
|
||||||
return [data["embedding"] for data in response.data]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating batch embeddings: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
class Document(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
content: str
|
content: str
|
||||||
@@ -54,6 +29,7 @@ class Document(BaseModel):
|
|||||||
class RAG(Adapter):
|
class RAG(Adapter):
|
||||||
collection_name: str = "crewai_knowledge_base"
|
collection_name: str = "crewai_knowledge_base"
|
||||||
persist_directory: str | None = None
|
persist_directory: str | None = None
|
||||||
|
embedding_provider: str = "openai"
|
||||||
embedding_model: str = "text-embedding-3-large"
|
embedding_model: str = "text-embedding-3-large"
|
||||||
summarize: bool = False
|
summarize: bool = False
|
||||||
top_k: int = 5
|
top_k: int = 5
|
||||||
@@ -79,7 +55,9 @@ class RAG(Adapter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._embedding_service = EmbeddingService(
|
self._embedding_service = EmbeddingService(
|
||||||
model=self.embedding_model, **self.embedding_config
|
provider=self.embedding_provider,
|
||||||
|
model=self.embedding_model,
|
||||||
|
**self.embedding_config,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize ChromaDB: {e}")
|
logger.error(f"Failed to initialize ChromaDB: {e}")
|
||||||
@@ -181,7 +159,7 @@ class RAG(Adapter):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||||
|
|
||||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str:
|
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
|
||||||
try:
|
try:
|
||||||
question_embedding = self._embedding_service.embed_text(question)
|
question_embedding = self._embedding_service.embed_text(question)
|
||||||
|
|
||||||
|
|||||||
508
lib/crewai-tools/src/crewai_tools/rag/embedding_service.py
Normal file
508
lib/crewai-tools/src/crewai_tools/rag/embedding_service.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
"""
|
||||||
|
Enhanced embedding service that leverages CrewAI's existing embedding providers.
|
||||||
|
This replaces the litellm-based EmbeddingService with a more flexible architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingConfig(BaseModel):
|
||||||
|
"""Configuration for embedding providers."""
|
||||||
|
|
||||||
|
provider: str = Field(description="Embedding provider name")
|
||||||
|
model: str = Field(description="Model name to use")
|
||||||
|
api_key: str | None = Field(default=None, description="API key for the provider")
|
||||||
|
timeout: float | None = Field(
|
||||||
|
default=30.0, description="Request timeout in seconds"
|
||||||
|
)
|
||||||
|
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||||
|
batch_size: int = Field(
|
||||||
|
default=100, description="Batch size for processing multiple texts"
|
||||||
|
)
|
||||||
|
extra_config: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Additional provider-specific configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingService:
|
||||||
|
"""
|
||||||
|
Enhanced embedding service that uses CrewAI's existing embedding providers.
|
||||||
|
|
||||||
|
Supports multiple providers:
|
||||||
|
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
|
||||||
|
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
|
||||||
|
- cohere: Cohere embeddings (embed-english-v3.0, embed-multilingual-v3.0, etc.)
|
||||||
|
- google-generativeai: Google Gemini embeddings (models/embedding-001, etc.)
|
||||||
|
- google-vertex: Google Vertex embeddings (models/embedding-001, etc.)
|
||||||
|
- huggingface: Hugging Face embeddings (sentence-transformers/all-MiniLM-L6-v2, etc.)
|
||||||
|
- jina: Jina embeddings (jina-embeddings-v2-base-en, etc.)
|
||||||
|
- ollama: Ollama embeddings (nomic-embed-text, etc.)
|
||||||
|
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
|
||||||
|
- roboflow: Roboflow embeddings (roboflow-embeddings-v2-base-en, etc.)
|
||||||
|
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
|
||||||
|
- watsonx: Watson X embeddings (ibm/slate-125m-english-rtrvr, etc.)
|
||||||
|
- custom: Custom embeddings (embedding_callable, etc.)
|
||||||
|
- sentence-transformer: Sentence Transformers embeddings (all-MiniLM-L6-v2, etc.)
|
||||||
|
- text2vec: Text2Vec embeddings (text2vec-base-en, etc.)
|
||||||
|
- openclip: OpenClip embeddings (openclip-large-v2, etc.)
|
||||||
|
- instructor: Instructor embeddings (hkunlp/instructor-large, etc.)
|
||||||
|
- onnx: ONNX embeddings (onnx-large-v2, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: str = "openai",
|
||||||
|
model: str = "text-embedding-3-small",
|
||||||
|
api_key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the embedding service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The embedding provider to use
|
||||||
|
model: The model name
|
||||||
|
api_key: API key (if not provided, will look for environment variables)
|
||||||
|
**kwargs: Additional configuration options
|
||||||
|
"""
|
||||||
|
self.config = EmbeddingConfig(
|
||||||
|
provider=provider,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key or self._get_default_api_key(provider),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._embedding_function = None
|
||||||
|
self._initialize_embedding_function()
|
||||||
|
|
||||||
|
def _get_default_api_key(self, provider: str) -> str | None:
|
||||||
|
"""Get default API key from environment variables."""
|
||||||
|
env_key_map = {
|
||||||
|
"azure": "AZURE_OPENAI_API_KEY",
|
||||||
|
"amazon-bedrock": "AWS_ACCESS_KEY_ID", # or AWS_PROFILE
|
||||||
|
"cohere": "COHERE_API_KEY",
|
||||||
|
"google-generativeai": "GOOGLE_API_KEY",
|
||||||
|
"google-vertex": "GOOGLE_APPLICATION_CREDENTIALS",
|
||||||
|
"huggingface": "HUGGINGFACE_API_KEY",
|
||||||
|
"jina": "JINA_API_KEY",
|
||||||
|
"ollama": None, # Ollama typically runs locally without API key
|
||||||
|
"openai": "OPENAI_API_KEY",
|
||||||
|
"roboflow": "ROBOFLOW_API_KEY",
|
||||||
|
"voyageai": "VOYAGE_API_KEY",
|
||||||
|
"watsonx": "WATSONX_API_KEY",
|
||||||
|
}
|
||||||
|
|
||||||
|
env_key = env_key_map.get(provider)
|
||||||
|
if env_key:
|
||||||
|
return os.getenv(env_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _initialize_embedding_function(self):
|
||||||
|
"""Initialize the embedding function using CrewAI's factory."""
|
||||||
|
try:
|
||||||
|
from crewai.rag.embeddings.factory import build_embedder
|
||||||
|
|
||||||
|
# Build the configuration for CrewAI's factory
|
||||||
|
config = self._build_provider_config()
|
||||||
|
|
||||||
|
# Create the embedding function
|
||||||
|
self._embedding_function = build_embedder(config)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initialized {self.config.provider} embedding service with model "
|
||||||
|
f"{self.config.model}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"CrewAI embedding providers not available. "
|
||||||
|
f"Make sure crewai is installed: {e}"
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize embedding function: {e}")
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to initialize {self.config.provider} embedding service: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _build_provider_config(self) -> dict[str, Any]:
|
||||||
|
"""Build configuration dictionary for CrewAI's embedding factory."""
|
||||||
|
base_config = {"provider": self.config.provider, "config": {}}
|
||||||
|
|
||||||
|
# Provider-specific configuration mapping
|
||||||
|
if self.config.provider == "openai":
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "azure":
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "voyageai":
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
"model": self.config.model,
|
||||||
|
"max_retries": self.config.max_retries,
|
||||||
|
"timeout": self.config.timeout,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "cohere":
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider in ["google-generativeai", "google-vertex"]:
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "amazon-bedrock":
|
||||||
|
base_config["config"] = {
|
||||||
|
"aws_access_key_id": self.config.api_key,
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "huggingface":
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "jina":
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "ollama":
|
||||||
|
base_config["config"] = {
|
||||||
|
"model": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "sentence-transformer":
|
||||||
|
base_config["config"] = {
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "instructor":
|
||||||
|
base_config["config"] = {
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "onnx":
|
||||||
|
base_config["config"] = {
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "roboflow":
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "openclip":
|
||||||
|
base_config["config"] = {
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "text2vec":
|
||||||
|
base_config["config"] = {
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "watsonx":
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
"model_name": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
elif self.config.provider == "custom":
|
||||||
|
# Custom provider requires embedding_callable in extra_config
|
||||||
|
base_config["config"] = {
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Generic configuration for any unlisted providers
|
||||||
|
base_config["config"] = {
|
||||||
|
"api_key": self.config.api_key,
|
||||||
|
"model": self.config.model,
|
||||||
|
**self.config.extra_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
return base_config
|
||||||
|
|
||||||
|
def embed_text(self, text: str) -> list[float]:
|
||||||
|
"""
|
||||||
|
Generate embedding for a single text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of floats representing the embedding
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If embedding generation fails
|
||||||
|
"""
|
||||||
|
if not text or not text.strip():
|
||||||
|
logger.warning("Empty text provided for embedding")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use ChromaDB's embedding function interface
|
||||||
|
embeddings = self._embedding_function([text]) # type: ignore
|
||||||
|
return embeddings[0] if embeddings else []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating embedding for text: {e}")
|
||||||
|
raise RuntimeError(f"Failed to generate embedding: {e}") from e
|
||||||
|
|
||||||
|
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for multiple texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If embedding generation fails
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Filter out empty texts
|
||||||
|
valid_texts = [text for text in texts if text and text.strip()]
|
||||||
|
if not valid_texts:
|
||||||
|
logger.warning("No valid texts provided for batch embedding")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process in batches to avoid API limits
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
for i in range(0, len(valid_texts), self.config.batch_size):
|
||||||
|
batch = valid_texts[i : i + self.config.batch_size]
|
||||||
|
batch_embeddings = self._embedding_function(batch) # type: ignore
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating batch embeddings: {e}")
|
||||||
|
raise RuntimeError(f"Failed to generate batch embeddings: {e}") from e
|
||||||
|
|
||||||
|
def get_embedding_dimension(self) -> int | None:
|
||||||
|
"""
|
||||||
|
Get the dimension of embeddings produced by this service.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding dimension or None if unknown
|
||||||
|
"""
|
||||||
|
# Try to get dimension by generating a test embedding
|
||||||
|
try:
|
||||||
|
test_embedding = self.embed_text("test")
|
||||||
|
return len(test_embedding) if test_embedding else None
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Could not determine embedding dimension")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_connection(self) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that the embedding service is working correctly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the service is working, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
test_embedding = self.embed_text("test connection")
|
||||||
|
return len(test_embedding) > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Connection validation failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_service_info(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get information about the current embedding service.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with service information
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"provider": self.config.provider,
|
||||||
|
"model": self.config.model,
|
||||||
|
"embedding_dimension": self.get_embedding_dimension(),
|
||||||
|
"batch_size": self.config.batch_size,
|
||||||
|
"is_connected": self.validate_connection(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_supported_providers(cls) -> list[str]:
|
||||||
|
"""
|
||||||
|
List all supported embedding providers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of supported provider names
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
"azure",
|
||||||
|
"amazon-bedrock",
|
||||||
|
"cohere",
|
||||||
|
"custom",
|
||||||
|
"google-generativeai",
|
||||||
|
"google-vertex",
|
||||||
|
"huggingface",
|
||||||
|
"instructor",
|
||||||
|
"jina",
|
||||||
|
"ollama",
|
||||||
|
"onnx",
|
||||||
|
"openai",
|
||||||
|
"openclip",
|
||||||
|
"roboflow",
|
||||||
|
"sentence-transformer",
|
||||||
|
"text2vec",
|
||||||
|
"voyageai",
|
||||||
|
"watsonx",
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_openai_service(
|
||||||
|
cls,
|
||||||
|
model: str = "text-embedding-3-small",
|
||||||
|
api_key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create an OpenAI embedding service."""
|
||||||
|
return cls(provider="openai", model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_voyage_service(
|
||||||
|
cls, model: str = "voyage-2", api_key: str | None = None, **kwargs: Any
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create a Voyage AI embedding service."""
|
||||||
|
return cls(provider="voyageai", model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_cohere_service(
|
||||||
|
cls,
|
||||||
|
model: str = "embed-english-v3.0",
|
||||||
|
api_key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create a Cohere embedding service."""
|
||||||
|
return cls(provider="cohere", model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_gemini_service(
|
||||||
|
cls,
|
||||||
|
model: str = "models/embedding-001",
|
||||||
|
api_key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create a Google Gemini embedding service."""
|
||||||
|
return cls(
|
||||||
|
provider="google-generativeai", model=model, api_key=api_key, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_azure_service(
|
||||||
|
cls,
|
||||||
|
model: str = "text-embedding-ada-002",
|
||||||
|
api_key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create an Azure OpenAI embedding service."""
|
||||||
|
return cls(provider="azure", model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_bedrock_service(
|
||||||
|
cls,
|
||||||
|
model: str = "amazon.titan-embed-text-v1",
|
||||||
|
api_key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create an Amazon Bedrock embedding service."""
|
||||||
|
return cls(provider="amazon-bedrock", model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_huggingface_service(
|
||||||
|
cls,
|
||||||
|
model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
api_key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create a Hugging Face embedding service."""
|
||||||
|
return cls(provider="huggingface", model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_sentence_transformer_service(
|
||||||
|
cls,
|
||||||
|
model: str = "all-MiniLM-L6-v2",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create a Sentence Transformers embedding service (local)."""
|
||||||
|
return cls(provider="sentence-transformer", model=model, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_ollama_service(
|
||||||
|
cls,
|
||||||
|
model: str = "nomic-embed-text",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create an Ollama embedding service (local)."""
|
||||||
|
return cls(provider="ollama", model=model, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_jina_service(
|
||||||
|
cls,
|
||||||
|
model: str = "jina-embeddings-v2-base-en",
|
||||||
|
api_key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create a Jina AI embedding service."""
|
||||||
|
return cls(provider="jina", model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_instructor_service(
|
||||||
|
cls,
|
||||||
|
model: str = "hkunlp/instructor-large",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create an Instructor embedding service."""
|
||||||
|
return cls(provider="instructor", model=model, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_watsonx_service(
|
||||||
|
cls,
|
||||||
|
model: str = "ibm/slate-125m-english-rtrvr",
|
||||||
|
api_key: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create a Watson X embedding service."""
|
||||||
|
return cls(provider="watsonx", model=model, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_custom_service(
|
||||||
|
cls,
|
||||||
|
embedding_callable: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "EmbeddingService":
|
||||||
|
"""Create a custom embedding service with your own embedding function."""
|
||||||
|
return cls(
|
||||||
|
provider="custom",
|
||||||
|
model="custom",
|
||||||
|
extra_config={"embedding_callable": embedding_callable},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
342
lib/crewai-tools/tests/rag/test_embedding_service.py
Normal file
342
lib/crewai-tools/tests/rag/test_embedding_service.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
Tests for the enhanced embedding service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from crewai_tools.rag.embedding_service import EmbeddingService, EmbeddingConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingConfig:
|
||||||
|
"""Test the EmbeddingConfig model."""
|
||||||
|
|
||||||
|
def test_default_config(self):
|
||||||
|
"""Test default configuration values."""
|
||||||
|
config = EmbeddingConfig(provider="openai", model="text-embedding-3-small")
|
||||||
|
|
||||||
|
assert config.provider == "openai"
|
||||||
|
assert config.model == "text-embedding-3-small"
|
||||||
|
assert config.api_key is None
|
||||||
|
assert config.timeout == 30.0
|
||||||
|
assert config.max_retries == 3
|
||||||
|
assert config.batch_size == 100
|
||||||
|
assert config.extra_config == {}
|
||||||
|
|
||||||
|
def test_custom_config(self):
|
||||||
|
"""Test custom configuration values."""
|
||||||
|
config = EmbeddingConfig(
|
||||||
|
provider="voyageai",
|
||||||
|
model="voyage-2",
|
||||||
|
api_key="test-key",
|
||||||
|
timeout=60.0,
|
||||||
|
max_retries=5,
|
||||||
|
batch_size=50,
|
||||||
|
extra_config={"input_type": "document"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.provider == "voyageai"
|
||||||
|
assert config.model == "voyage-2"
|
||||||
|
assert config.api_key == "test-key"
|
||||||
|
assert config.timeout == 60.0
|
||||||
|
assert config.max_retries == 5
|
||||||
|
assert config.batch_size == 50
|
||||||
|
assert config.extra_config == {"input_type": "document"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingService:
|
||||||
|
"""Test the EmbeddingService class."""
|
||||||
|
|
||||||
|
def test_list_supported_providers(self):
|
||||||
|
"""Test listing supported providers."""
|
||||||
|
providers = EmbeddingService.list_supported_providers()
|
||||||
|
expected_providers = [
|
||||||
|
"openai", "azure", "voyageai", "cohere", "google-generativeai",
|
||||||
|
"amazon-bedrock", "huggingface", "jina", "ollama", "sentence-transformer",
|
||||||
|
"instructor", "watsonx", "custom"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert isinstance(providers, list)
|
||||||
|
assert len(providers) >= 15 # Should have at least 15 providers
|
||||||
|
assert all(provider in providers for provider in expected_providers)
|
||||||
|
|
||||||
|
def test_get_default_api_key(self):
|
||||||
|
"""Test getting default API keys from environment."""
|
||||||
|
service = EmbeddingService.__new__(EmbeddingService) # Create without __init__
|
||||||
|
|
||||||
|
# Test with environment variable set
|
||||||
|
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-openai-key"}):
|
||||||
|
api_key = service._get_default_api_key("openai")
|
||||||
|
assert api_key == "test-openai-key"
|
||||||
|
|
||||||
|
# Test with no environment variable
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
api_key = service._get_default_api_key("openai")
|
||||||
|
assert api_key is None
|
||||||
|
|
||||||
|
# Test unknown provider
|
||||||
|
api_key = service._get_default_api_key("unknown-provider")
|
||||||
|
assert api_key is None
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_initialization_success(self, mock_build_embedder):
|
||||||
|
"""Test successful initialization."""
|
||||||
|
# Mock the embedding function
|
||||||
|
mock_embedding_function = Mock()
|
||||||
|
mock_build_embedder.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
service = EmbeddingService(
|
||||||
|
provider="openai",
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert service.config.provider == "openai"
|
||||||
|
assert service.config.model == "text-embedding-3-small"
|
||||||
|
assert service.config.api_key == "test-key"
|
||||||
|
assert service._embedding_function == mock_embedding_function
|
||||||
|
|
||||||
|
# Verify build_embedder was called with correct config
|
||||||
|
mock_build_embedder.assert_called_once()
|
||||||
|
call_args = mock_build_embedder.call_args[0][0]
|
||||||
|
assert call_args["provider"] == "openai"
|
||||||
|
assert call_args["config"]["api_key"] == "test-key"
|
||||||
|
assert call_args["config"]["model_name"] == "text-embedding-3-small"
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_initialization_import_error(self, mock_build_embedder):
|
||||||
|
"""Test initialization with import error."""
|
||||||
|
mock_build_embedder.side_effect = ImportError("CrewAI not installed")
|
||||||
|
|
||||||
|
with pytest.raises(ImportError, match="CrewAI embedding providers not available"):
|
||||||
|
EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_embed_text_success(self, mock_build_embedder):
|
||||||
|
"""Test successful text embedding."""
|
||||||
|
# Mock the embedding function
|
||||||
|
mock_embedding_function = Mock()
|
||||||
|
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
mock_build_embedder.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||||
|
|
||||||
|
result = service.embed_text("test text")
|
||||||
|
|
||||||
|
assert result == [0.1, 0.2, 0.3]
|
||||||
|
mock_embedding_function.assert_called_once_with(["test text"])
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_embed_text_empty_input(self, mock_build_embedder):
|
||||||
|
"""Test embedding empty text."""
|
||||||
|
mock_embedding_function = Mock()
|
||||||
|
mock_build_embedder.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||||
|
|
||||||
|
result = service.embed_text("")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
result = service.embed_text(" ")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
# Embedding function should not be called for empty text
|
||||||
|
mock_embedding_function.assert_not_called()
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_embed_batch_success(self, mock_build_embedder):
|
||||||
|
"""Test successful batch embedding."""
|
||||||
|
# Mock the embedding function
|
||||||
|
mock_embedding_function = Mock()
|
||||||
|
mock_embedding_function.return_value = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||||
|
mock_build_embedder.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||||
|
|
||||||
|
texts = ["text1", "text2", "text3"]
|
||||||
|
result = service.embed_batch(texts)
|
||||||
|
|
||||||
|
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||||
|
mock_embedding_function.assert_called_once_with(texts)
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_embed_batch_empty_input(self, mock_build_embedder):
|
||||||
|
"""Test batch embedding with empty input."""
|
||||||
|
mock_embedding_function = Mock()
|
||||||
|
mock_build_embedder.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||||
|
|
||||||
|
# Empty list
|
||||||
|
result = service.embed_batch([])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
# List with empty strings
|
||||||
|
result = service.embed_batch(["", " ", ""])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
# Embedding function should not be called for empty input
|
||||||
|
mock_embedding_function.assert_not_called()
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_validate_connection(self, mock_build_embedder):
|
||||||
|
"""Test connection validation."""
|
||||||
|
# Mock successful embedding
|
||||||
|
mock_embedding_function = Mock()
|
||||||
|
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
mock_build_embedder.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||||
|
|
||||||
|
assert service.validate_connection() is True
|
||||||
|
|
||||||
|
# Mock failed embedding
|
||||||
|
mock_embedding_function.side_effect = Exception("Connection failed")
|
||||||
|
assert service.validate_connection() is False
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_get_service_info(self, mock_build_embedder):
|
||||||
|
"""Test getting service information."""
|
||||||
|
# Mock the embedding function
|
||||||
|
mock_embedding_function = Mock()
|
||||||
|
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
mock_build_embedder.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||||
|
|
||||||
|
info = service.get_service_info()
|
||||||
|
|
||||||
|
assert info["provider"] == "openai"
|
||||||
|
assert info["model"] == "test-model"
|
||||||
|
assert info["embedding_dimension"] == 3
|
||||||
|
assert info["batch_size"] == 100
|
||||||
|
assert info["is_connected"] is True
|
||||||
|
|
||||||
|
def test_create_openai_service(self):
|
||||||
|
"""Test OpenAI service creation."""
|
||||||
|
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||||
|
service = EmbeddingService.create_openai_service(
|
||||||
|
model="text-embedding-3-large",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert service.config.provider == "openai"
|
||||||
|
assert service.config.model == "text-embedding-3-large"
|
||||||
|
assert service.config.api_key == "test-key"
|
||||||
|
|
||||||
|
def test_create_voyage_service(self):
|
||||||
|
"""Test Voyage AI service creation."""
|
||||||
|
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||||
|
service = EmbeddingService.create_voyage_service(
|
||||||
|
model="voyage-large-2",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert service.config.provider == "voyageai"
|
||||||
|
assert service.config.model == "voyage-large-2"
|
||||||
|
assert service.config.api_key == "test-key"
|
||||||
|
|
||||||
|
def test_create_cohere_service(self):
|
||||||
|
"""Test Cohere service creation."""
|
||||||
|
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||||
|
service = EmbeddingService.create_cohere_service(
|
||||||
|
model="embed-multilingual-v3.0",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert service.config.provider == "cohere"
|
||||||
|
assert service.config.model == "embed-multilingual-v3.0"
|
||||||
|
assert service.config.api_key == "test-key"
|
||||||
|
|
||||||
|
def test_create_gemini_service(self):
|
||||||
|
"""Test Gemini service creation."""
|
||||||
|
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||||
|
service = EmbeddingService.create_gemini_service(
|
||||||
|
model="models/embedding-001",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert service.config.provider == "google-generativeai"
|
||||||
|
assert service.config.model == "models/embedding-001"
|
||||||
|
assert service.config.api_key == "test-key"
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderConfigurations:
|
||||||
|
"""Test provider-specific configurations."""
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_openai_config(self, mock_build_embedder):
|
||||||
|
"""Test OpenAI configuration mapping."""
|
||||||
|
mock_build_embedder.return_value = Mock()
|
||||||
|
|
||||||
|
service = EmbeddingService(
|
||||||
|
provider="openai",
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
api_key="test-key",
|
||||||
|
extra_config={"dimensions": 1024}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the configuration passed to build_embedder
|
||||||
|
call_args = mock_build_embedder.call_args[0][0]
|
||||||
|
assert call_args["provider"] == "openai"
|
||||||
|
assert call_args["config"]["api_key"] == "test-key"
|
||||||
|
assert call_args["config"]["model_name"] == "text-embedding-3-small"
|
||||||
|
assert call_args["config"]["dimensions"] == 1024
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_voyageai_config(self, mock_build_embedder):
|
||||||
|
"""Test Voyage AI configuration mapping."""
|
||||||
|
mock_build_embedder.return_value = Mock()
|
||||||
|
|
||||||
|
service = EmbeddingService(
|
||||||
|
provider="voyageai",
|
||||||
|
model="voyage-2",
|
||||||
|
api_key="test-key",
|
||||||
|
timeout=60.0,
|
||||||
|
max_retries=5,
|
||||||
|
extra_config={"input_type": "document"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the configuration passed to build_embedder
|
||||||
|
call_args = mock_build_embedder.call_args[0][0]
|
||||||
|
assert call_args["provider"] == "voyageai"
|
||||||
|
assert call_args["config"]["api_key"] == "test-key"
|
||||||
|
assert call_args["config"]["model"] == "voyage-2"
|
||||||
|
assert call_args["config"]["timeout"] == 60.0
|
||||||
|
assert call_args["config"]["max_retries"] == 5
|
||||||
|
assert call_args["config"]["input_type"] == "document"
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_cohere_config(self, mock_build_embedder):
|
||||||
|
"""Test Cohere configuration mapping."""
|
||||||
|
mock_build_embedder.return_value = Mock()
|
||||||
|
|
||||||
|
service = EmbeddingService(
|
||||||
|
provider="cohere",
|
||||||
|
model="embed-english-v3.0",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the configuration passed to build_embedder
|
||||||
|
call_args = mock_build_embedder.call_args[0][0]
|
||||||
|
assert call_args["provider"] == "cohere"
|
||||||
|
assert call_args["config"]["api_key"] == "test-key"
|
||||||
|
assert call_args["config"]["model_name"] == "embed-english-v3.0"
|
||||||
|
|
||||||
|
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||||
|
def test_gemini_config(self, mock_build_embedder):
|
||||||
|
"""Test Gemini configuration mapping."""
|
||||||
|
mock_build_embedder.return_value = Mock()
|
||||||
|
|
||||||
|
service = EmbeddingService(
|
||||||
|
provider="google-generativeai",
|
||||||
|
model="models/embedding-001",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the configuration passed to build_embedder
|
||||||
|
call_args = mock_build_embedder.call_args[0][0]
|
||||||
|
assert call_args["provider"] == "google-generativeai"
|
||||||
|
assert call_args["config"]["api_key"] == "test-key"
|
||||||
|
assert call_args["config"]["model_name"] == "models/embedding-001"
|
||||||
@@ -4,6 +4,7 @@ from dataclasses import field
|
|||||||
from typing import Literal, cast
|
from typing import Literal, cast
|
||||||
|
|
||||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
from qdrant_client.models import VectorParams
|
||||||
|
|
||||||
from crewai.rag.config.base import BaseRagConfig
|
from crewai.rag.config.base import BaseRagConfig
|
||||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||||
@@ -53,3 +54,4 @@ class QdrantConfig(BaseRagConfig):
|
|||||||
embedding_function: QdrantEmbeddingFunctionWrapper = field(
|
embedding_function: QdrantEmbeddingFunctionWrapper = field(
|
||||||
default_factory=_default_embedding_function
|
default_factory=_default_embedding_function
|
||||||
)
|
)
|
||||||
|
vectors_config: VectorParams | None = field(default=None)
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import asyncio
|
|||||||
from typing import TypeGuard
|
from typing import TypeGuard
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
|
||||||
from qdrant_client import (
|
from qdrant_client import (
|
||||||
|
AsyncQdrantClient, # type: ignore[import-not-found]
|
||||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||||
)
|
)
|
||||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||||
|
|||||||
2
uv.lock
generated
2
uv.lock
generated
@@ -1124,7 +1124,6 @@ dependencies = [
|
|||||||
{ name = "python-docx" },
|
{ name = "python-docx" },
|
||||||
{ name = "pytube" },
|
{ name = "pytube" },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
{ name = "stagehand" },
|
|
||||||
{ name = "tiktoken" },
|
{ name = "tiktoken" },
|
||||||
{ name = "youtube-transcript-api" },
|
{ name = "youtube-transcript-api" },
|
||||||
]
|
]
|
||||||
@@ -1295,7 +1294,6 @@ requires-dist = [
|
|||||||
{ name = "spider-client", marker = "extra == 'spider-client'", specifier = ">=0.1.25" },
|
{ name = "spider-client", marker = "extra == 'spider-client'", specifier = ">=0.1.25" },
|
||||||
{ name = "sqlalchemy", marker = "extra == 'singlestore'", specifier = ">=2.0.40" },
|
{ name = "sqlalchemy", marker = "extra == 'singlestore'", specifier = ">=2.0.40" },
|
||||||
{ name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0.35" },
|
{ name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0.35" },
|
||||||
{ name = "stagehand", specifier = ">=0.4.1" },
|
|
||||||
{ name = "stagehand", marker = "extra == 'stagehand'", specifier = ">=0.4.1" },
|
{ name = "stagehand", marker = "extra == 'stagehand'", specifier = ">=0.4.1" },
|
||||||
{ name = "tavily-python", marker = "extra == 'tavily-python'", specifier = ">=0.5.4" },
|
{ name = "tavily-python", marker = "extra == 'tavily-python'", specifier = ">=0.5.4" },
|
||||||
{ name = "tiktoken", specifier = ">=0.8.0" },
|
{ name = "tiktoken", specifier = ">=0.8.0" },
|
||||||
|
|||||||
Reference in New Issue
Block a user