Merge branch 'release/v1.0.0' into lorenze/native-anthropic-test

This commit is contained in:
Lorenze Jay
2025-10-15 15:57:23 -07:00
committed by GitHub
8 changed files with 875 additions and 35 deletions

View File

@@ -15,7 +15,6 @@ dependencies = [
"crewai==1.0.0b1", "crewai==1.0.0b1",
"lancedb>=0.5.4", "lancedb>=0.5.4",
"tiktoken>=0.8.0", "tiktoken>=0.8.0",
"stagehand>=0.4.1",
"beautifulsoup4>=4.13.4", "beautifulsoup4>=4.13.4",
"pypdf>=5.9.0", "pypdf>=5.9.0",
"python-docx>=1.2.0", "python-docx>=1.2.0",

View File

@@ -3,13 +3,16 @@
import hashlib import hashlib
from pathlib import Path from pathlib import Path
from typing import Any, TypeAlias, TypedDict from typing import Any, TypeAlias, TypedDict
import uuid
from crewai.rag.config.types import RagConfigType from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import get_rag_client from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client from crewai.rag.factory import create_client
from crewai.rag.qdrant.config import QdrantConfig
from crewai.rag.types import BaseRecord, SearchResult from crewai.rag.types import BaseRecord, SearchResult
from pydantic import PrivateAttr from pydantic import PrivateAttr
from qdrant_client.models import VectorParams
from typing_extensions import Unpack from typing_extensions import Unpack
from crewai_tools.rag.data_types import DataType from crewai_tools.rag.data_types import DataType
@@ -52,7 +55,11 @@ class CrewAIRagAdapter(Adapter):
self._client = create_client(self.config) self._client = create_client(self.config)
else: else:
self._client = get_rag_client() self._client = get_rag_client()
self._client.get_or_create_collection(collection_name=self.collection_name) collection_params: dict[str, Any] = {"collection_name": self.collection_name}
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
if isinstance(self.config.vectors_config, VectorParams):
collection_params["vectors_config"] = self.config.vectors_config
self._client.get_or_create_collection(**collection_params)
def query( def query(
self, self,
@@ -76,6 +83,8 @@ class CrewAIRagAdapter(Adapter):
if similarity_threshold is not None if similarity_threshold is not None
else self.similarity_threshold else self.similarity_threshold
) )
if self._client is None:
raise ValueError("Client is not initialized")
results: list[SearchResult] = self._client.search( results: list[SearchResult] = self._client.search(
collection_name=self.collection_name, collection_name=self.collection_name,
@@ -201,9 +210,10 @@ class CrewAIRagAdapter(Adapter):
if isinstance(arg, dict): if isinstance(arg, dict):
file_metadata.update(arg.get("metadata", {})) file_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256( chunk_hash = hashlib.sha256(
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode() f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
).hexdigest() ).hexdigest()
chunk_id = str(uuid.UUID(chunk_hash[:32]))
documents.append( documents.append(
{ {
@@ -251,9 +261,10 @@ class CrewAIRagAdapter(Adapter):
if isinstance(arg, dict): if isinstance(arg, dict):
chunk_metadata.update(arg.get("metadata", {})) chunk_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256( chunk_hash = hashlib.sha256(
f"{loader_result.doc_id}_{i}_{chunk}".encode() f"{loader_result.doc_id}_{i}_{chunk}".encode()
).hexdigest() ).hexdigest()
chunk_id = str(uuid.UUID(chunk_hash[:32]))
documents.append( documents.append(
{ {
@@ -264,6 +275,8 @@ class CrewAIRagAdapter(Adapter):
) )
if documents: if documents:
if self._client is None:
raise ValueError("Client is not initialized")
self._client.add_documents( self._client.add_documents(
collection_name=self.collection_name, documents=documents collection_name=self.collection_name, documents=documents
) )

View File

@@ -4,12 +4,12 @@ from typing import Any
from uuid import uuid4 from uuid import uuid4
import chromadb import chromadb
import litellm
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.rag.base_loader import BaseLoader from crewai_tools.rag.base_loader import BaseLoader
from crewai_tools.rag.chunkers.base_chunker import BaseChunker from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.data_types import DataType from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.embedding_service import EmbeddingService
from crewai_tools.rag.misc import compute_sha256 from crewai_tools.rag.misc import compute_sha256
from crewai_tools.rag.source_content import SourceContent from crewai_tools.rag.source_content import SourceContent
from crewai_tools.tools.rag.rag_tool import Adapter from crewai_tools.tools.rag.rag_tool import Adapter
@@ -18,31 +18,6 @@ from crewai_tools.tools.rag.rag_tool import Adapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
self.model = model
self.kwargs = kwargs
def embed_text(self, text: str) -> list[float]:
try:
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
return response.data[0]["embedding"]
except Exception as e:
logger.error(f"Error generating embedding: {e}")
raise
def embed_batch(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []
try:
response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
return [data["embedding"] for data in response.data]
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise
class Document(BaseModel): class Document(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4())) id: str = Field(default_factory=lambda: str(uuid4()))
content: str content: str
@@ -54,6 +29,7 @@ class Document(BaseModel):
class RAG(Adapter): class RAG(Adapter):
collection_name: str = "crewai_knowledge_base" collection_name: str = "crewai_knowledge_base"
persist_directory: str | None = None persist_directory: str | None = None
embedding_provider: str = "openai"
embedding_model: str = "text-embedding-3-large" embedding_model: str = "text-embedding-3-large"
summarize: bool = False summarize: bool = False
top_k: int = 5 top_k: int = 5
@@ -79,7 +55,9 @@ class RAG(Adapter):
) )
self._embedding_service = EmbeddingService( self._embedding_service = EmbeddingService(
model=self.embedding_model, **self.embedding_config provider=self.embedding_provider,
model=self.embedding_model,
**self.embedding_config,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}") logger.error(f"Failed to initialize ChromaDB: {e}")
@@ -181,7 +159,7 @@ class RAG(Adapter):
except Exception as e: except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}") logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: dict[str, Any] | None = None) -> str: def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
try: try:
question_embedding = self._embedding_service.embed_text(question) question_embedding = self._embedding_service.embed_text(question)

View File

@@ -0,0 +1,508 @@
"""
Enhanced embedding service that leverages CrewAI's existing embedding providers.
This replaces the litellm-based EmbeddingService with a more flexible architecture.
"""
import logging
import os
from typing import Any
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class EmbeddingConfig(BaseModel):
"""Configuration for embedding providers."""
provider: str = Field(description="Embedding provider name")
model: str = Field(description="Model name to use")
api_key: str | None = Field(default=None, description="API key for the provider")
timeout: float | None = Field(
default=30.0, description="Request timeout in seconds"
)
max_retries: int = Field(default=3, description="Maximum number of retries")
batch_size: int = Field(
default=100, description="Batch size for processing multiple texts"
)
extra_config: dict[str, Any] = Field(
default_factory=dict, description="Additional provider-specific configuration"
)
class EmbeddingService:
"""
Enhanced embedding service that uses CrewAI's existing embedding providers.
Supports multiple providers:
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
- cohere: Cohere embeddings (embed-english-v3.0, embed-multilingual-v3.0, etc.)
- google-generativeai: Google Gemini embeddings (models/embedding-001, etc.)
- google-vertex: Google Vertex embeddings (models/embedding-001, etc.)
- huggingface: Hugging Face embeddings (sentence-transformers/all-MiniLM-L6-v2, etc.)
- jina: Jina embeddings (jina-embeddings-v2-base-en, etc.)
- ollama: Ollama embeddings (nomic-embed-text, etc.)
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
- roboflow: Roboflow embeddings (roboflow-embeddings-v2-base-en, etc.)
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
- watsonx: Watson X embeddings (ibm/slate-125m-english-rtrvr, etc.)
- custom: Custom embeddings (embedding_callable, etc.)
- sentence-transformer: Sentence Transformers embeddings (all-MiniLM-L6-v2, etc.)
- text2vec: Text2Vec embeddings (text2vec-base-en, etc.)
- openclip: OpenClip embeddings (openclip-large-v2, etc.)
- instructor: Instructor embeddings (hkunlp/instructor-large, etc.)
- onnx: ONNX embeddings (onnx-large-v2, etc.)
"""
def __init__(
self,
provider: str = "openai",
model: str = "text-embedding-3-small",
api_key: str | None = None,
**kwargs: Any,
):
"""
Initialize the embedding service.
Args:
provider: The embedding provider to use
model: The model name
api_key: API key (if not provided, will look for environment variables)
**kwargs: Additional configuration options
"""
self.config = EmbeddingConfig(
provider=provider,
model=model,
api_key=api_key or self._get_default_api_key(provider),
**kwargs,
)
self._embedding_function = None
self._initialize_embedding_function()
def _get_default_api_key(self, provider: str) -> str | None:
"""Get default API key from environment variables."""
env_key_map = {
"azure": "AZURE_OPENAI_API_KEY",
"amazon-bedrock": "AWS_ACCESS_KEY_ID", # or AWS_PROFILE
"cohere": "COHERE_API_KEY",
"google-generativeai": "GOOGLE_API_KEY",
"google-vertex": "GOOGLE_APPLICATION_CREDENTIALS",
"huggingface": "HUGGINGFACE_API_KEY",
"jina": "JINA_API_KEY",
"ollama": None, # Ollama typically runs locally without API key
"openai": "OPENAI_API_KEY",
"roboflow": "ROBOFLOW_API_KEY",
"voyageai": "VOYAGE_API_KEY",
"watsonx": "WATSONX_API_KEY",
}
env_key = env_key_map.get(provider)
if env_key:
return os.getenv(env_key)
return None
def _initialize_embedding_function(self):
"""Initialize the embedding function using CrewAI's factory."""
try:
from crewai.rag.embeddings.factory import build_embedder
# Build the configuration for CrewAI's factory
config = self._build_provider_config()
# Create the embedding function
self._embedding_function = build_embedder(config)
logger.info(
f"Initialized {self.config.provider} embedding service with model "
f"{self.config.model}"
)
except ImportError as e:
raise ImportError(
f"CrewAI embedding providers not available. "
f"Make sure crewai is installed: {e}"
) from e
except Exception as e:
logger.error(f"Failed to initialize embedding function: {e}")
raise RuntimeError(
f"Failed to initialize {self.config.provider} embedding service: {e}"
) from e
def _build_provider_config(self) -> dict[str, Any]:
"""Build configuration dictionary for CrewAI's embedding factory."""
base_config = {"provider": self.config.provider, "config": {}}
# Provider-specific configuration mapping
if self.config.provider == "openai":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "azure":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "voyageai":
base_config["config"] = {
"api_key": self.config.api_key,
"model": self.config.model,
"max_retries": self.config.max_retries,
"timeout": self.config.timeout,
**self.config.extra_config,
}
elif self.config.provider == "cohere":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider in ["google-generativeai", "google-vertex"]:
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "amazon-bedrock":
base_config["config"] = {
"aws_access_key_id": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "huggingface":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "jina":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "ollama":
base_config["config"] = {
"model": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "sentence-transformer":
base_config["config"] = {
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "instructor":
base_config["config"] = {
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "onnx":
base_config["config"] = {
**self.config.extra_config,
}
elif self.config.provider == "roboflow":
base_config["config"] = {
"api_key": self.config.api_key,
**self.config.extra_config,
}
elif self.config.provider == "openclip":
base_config["config"] = {
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "text2vec":
base_config["config"] = {
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "watsonx":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "custom":
# Custom provider requires embedding_callable in extra_config
base_config["config"] = {
**self.config.extra_config,
}
else:
# Generic configuration for any unlisted providers
base_config["config"] = {
"api_key": self.config.api_key,
"model": self.config.model,
**self.config.extra_config,
}
return base_config
def embed_text(self, text: str) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
Returns:
List of floats representing the embedding
Raises:
RuntimeError: If embedding generation fails
"""
if not text or not text.strip():
logger.warning("Empty text provided for embedding")
return []
try:
# Use ChromaDB's embedding function interface
embeddings = self._embedding_function([text]) # type: ignore
return embeddings[0] if embeddings else []
except Exception as e:
logger.error(f"Error generating embedding for text: {e}")
raise RuntimeError(f"Failed to generate embedding: {e}") from e
def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for multiple texts.
Args:
texts: List of texts to embed
Returns:
List of embedding vectors
Raises:
RuntimeError: If embedding generation fails
"""
if not texts:
return []
# Filter out empty texts
valid_texts = [text for text in texts if text and text.strip()]
if not valid_texts:
logger.warning("No valid texts provided for batch embedding")
return []
try:
# Process in batches to avoid API limits
all_embeddings = []
for i in range(0, len(valid_texts), self.config.batch_size):
batch = valid_texts[i : i + self.config.batch_size]
batch_embeddings = self._embedding_function(batch) # type: ignore
all_embeddings.extend(batch_embeddings)
return all_embeddings
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise RuntimeError(f"Failed to generate batch embeddings: {e}") from e
def get_embedding_dimension(self) -> int | None:
"""
Get the dimension of embeddings produced by this service.
Returns:
Embedding dimension or None if unknown
"""
# Try to get dimension by generating a test embedding
try:
test_embedding = self.embed_text("test")
return len(test_embedding) if test_embedding else None
except Exception:
logger.warning("Could not determine embedding dimension")
return None
def validate_connection(self) -> bool:
"""
Validate that the embedding service is working correctly.
Returns:
True if the service is working, False otherwise
"""
try:
test_embedding = self.embed_text("test connection")
return len(test_embedding) > 0
except Exception as e:
logger.error(f"Connection validation failed: {e}")
return False
def get_service_info(self) -> dict[str, Any]:
"""
Get information about the current embedding service.
Returns:
Dictionary with service information
"""
return {
"provider": self.config.provider,
"model": self.config.model,
"embedding_dimension": self.get_embedding_dimension(),
"batch_size": self.config.batch_size,
"is_connected": self.validate_connection(),
}
@classmethod
def list_supported_providers(cls) -> list[str]:
"""
List all supported embedding providers.
Returns:
List of supported provider names
"""
return [
"azure",
"amazon-bedrock",
"cohere",
"custom",
"google-generativeai",
"google-vertex",
"huggingface",
"instructor",
"jina",
"ollama",
"onnx",
"openai",
"openclip",
"roboflow",
"sentence-transformer",
"text2vec",
"voyageai",
"watsonx",
]
@classmethod
def create_openai_service(
cls,
model: str = "text-embedding-3-small",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create an OpenAI embedding service."""
return cls(provider="openai", model=model, api_key=api_key, **kwargs)
@classmethod
def create_voyage_service(
cls, model: str = "voyage-2", api_key: str | None = None, **kwargs: Any
) -> "EmbeddingService":
"""Create a Voyage AI embedding service."""
return cls(provider="voyageai", model=model, api_key=api_key, **kwargs)
@classmethod
def create_cohere_service(
cls,
model: str = "embed-english-v3.0",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Cohere embedding service."""
return cls(provider="cohere", model=model, api_key=api_key, **kwargs)
@classmethod
def create_gemini_service(
cls,
model: str = "models/embedding-001",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Google Gemini embedding service."""
return cls(
provider="google-generativeai", model=model, api_key=api_key, **kwargs
)
@classmethod
def create_azure_service(
cls,
model: str = "text-embedding-ada-002",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create an Azure OpenAI embedding service."""
return cls(provider="azure", model=model, api_key=api_key, **kwargs)
@classmethod
def create_bedrock_service(
cls,
model: str = "amazon.titan-embed-text-v1",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create an Amazon Bedrock embedding service."""
return cls(provider="amazon-bedrock", model=model, api_key=api_key, **kwargs)
@classmethod
def create_huggingface_service(
cls,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Hugging Face embedding service."""
return cls(provider="huggingface", model=model, api_key=api_key, **kwargs)
@classmethod
def create_sentence_transformer_service(
cls,
model: str = "all-MiniLM-L6-v2",
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Sentence Transformers embedding service (local)."""
return cls(provider="sentence-transformer", model=model, **kwargs)
@classmethod
def create_ollama_service(
cls,
model: str = "nomic-embed-text",
**kwargs: Any,
) -> "EmbeddingService":
"""Create an Ollama embedding service (local)."""
return cls(provider="ollama", model=model, **kwargs)
@classmethod
def create_jina_service(
cls,
model: str = "jina-embeddings-v2-base-en",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Jina AI embedding service."""
return cls(provider="jina", model=model, api_key=api_key, **kwargs)
@classmethod
def create_instructor_service(
cls,
model: str = "hkunlp/instructor-large",
**kwargs: Any,
) -> "EmbeddingService":
"""Create an Instructor embedding service."""
return cls(provider="instructor", model=model, **kwargs)
@classmethod
def create_watsonx_service(
cls,
model: str = "ibm/slate-125m-english-rtrvr",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Watson X embedding service."""
return cls(provider="watsonx", model=model, api_key=api_key, **kwargs)
@classmethod
def create_custom_service(
cls,
embedding_callable: Any,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a custom embedding service with your own embedding function."""
return cls(
provider="custom",
model="custom",
extra_config={"embedding_callable": embedding_callable},
**kwargs,
)

View File

@@ -0,0 +1,342 @@
"""
Tests for the enhanced embedding service.
"""
import os
import pytest
from unittest.mock import Mock, patch
from crewai_tools.rag.embedding_service import EmbeddingService, EmbeddingConfig
class TestEmbeddingConfig:
"""Test the EmbeddingConfig model."""
def test_default_config(self):
"""Test default configuration values."""
config = EmbeddingConfig(provider="openai", model="text-embedding-3-small")
assert config.provider == "openai"
assert config.model == "text-embedding-3-small"
assert config.api_key is None
assert config.timeout == 30.0
assert config.max_retries == 3
assert config.batch_size == 100
assert config.extra_config == {}
def test_custom_config(self):
"""Test custom configuration values."""
config = EmbeddingConfig(
provider="voyageai",
model="voyage-2",
api_key="test-key",
timeout=60.0,
max_retries=5,
batch_size=50,
extra_config={"input_type": "document"}
)
assert config.provider == "voyageai"
assert config.model == "voyage-2"
assert config.api_key == "test-key"
assert config.timeout == 60.0
assert config.max_retries == 5
assert config.batch_size == 50
assert config.extra_config == {"input_type": "document"}
class TestEmbeddingService:
"""Test the EmbeddingService class."""
def test_list_supported_providers(self):
"""Test listing supported providers."""
providers = EmbeddingService.list_supported_providers()
expected_providers = [
"openai", "azure", "voyageai", "cohere", "google-generativeai",
"amazon-bedrock", "huggingface", "jina", "ollama", "sentence-transformer",
"instructor", "watsonx", "custom"
]
assert isinstance(providers, list)
assert len(providers) >= 15 # Should have at least 15 providers
assert all(provider in providers for provider in expected_providers)
def test_get_default_api_key(self):
"""Test getting default API keys from environment."""
service = EmbeddingService.__new__(EmbeddingService) # Create without __init__
# Test with environment variable set
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-openai-key"}):
api_key = service._get_default_api_key("openai")
assert api_key == "test-openai-key"
# Test with no environment variable
with patch.dict(os.environ, {}, clear=True):
api_key = service._get_default_api_key("openai")
assert api_key is None
# Test unknown provider
api_key = service._get_default_api_key("unknown-provider")
assert api_key is None
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_initialization_success(self, mock_build_embedder):
"""Test successful initialization."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(
provider="openai",
model="text-embedding-3-small",
api_key="test-key"
)
assert service.config.provider == "openai"
assert service.config.model == "text-embedding-3-small"
assert service.config.api_key == "test-key"
assert service._embedding_function == mock_embedding_function
# Verify build_embedder was called with correct config
mock_build_embedder.assert_called_once()
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "openai"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model_name"] == "text-embedding-3-small"
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_initialization_import_error(self, mock_build_embedder):
"""Test initialization with import error."""
mock_build_embedder.side_effect = ImportError("CrewAI not installed")
with pytest.raises(ImportError, match="CrewAI embedding providers not available"):
EmbeddingService(provider="openai", model="test-model", api_key="test-key")
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_text_success(self, mock_build_embedder):
"""Test successful text embedding."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
result = service.embed_text("test text")
assert result == [0.1, 0.2, 0.3]
mock_embedding_function.assert_called_once_with(["test text"])
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_text_empty_input(self, mock_build_embedder):
"""Test embedding empty text."""
mock_embedding_function = Mock()
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
result = service.embed_text("")
assert result == []
result = service.embed_text(" ")
assert result == []
# Embedding function should not be called for empty text
mock_embedding_function.assert_not_called()
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_batch_success(self, mock_build_embedder):
"""Test successful batch embedding."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
texts = ["text1", "text2", "text3"]
result = service.embed_batch(texts)
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
mock_embedding_function.assert_called_once_with(texts)
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_batch_empty_input(self, mock_build_embedder):
"""Test batch embedding with empty input."""
mock_embedding_function = Mock()
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
# Empty list
result = service.embed_batch([])
assert result == []
# List with empty strings
result = service.embed_batch(["", " ", ""])
assert result == []
# Embedding function should not be called for empty input
mock_embedding_function.assert_not_called()
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_validate_connection(self, mock_build_embedder):
"""Test connection validation."""
# Mock successful embedding
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
assert service.validate_connection() is True
# Mock failed embedding
mock_embedding_function.side_effect = Exception("Connection failed")
assert service.validate_connection() is False
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_get_service_info(self, mock_build_embedder):
"""Test getting service information."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
info = service.get_service_info()
assert info["provider"] == "openai"
assert info["model"] == "test-model"
assert info["embedding_dimension"] == 3
assert info["batch_size"] == 100
assert info["is_connected"] is True
def test_create_openai_service(self):
"""Test OpenAI service creation."""
with patch('crewai.rag.embeddings.factory.build_embedder'):
service = EmbeddingService.create_openai_service(
model="text-embedding-3-large",
api_key="test-key"
)
assert service.config.provider == "openai"
assert service.config.model == "text-embedding-3-large"
assert service.config.api_key == "test-key"
def test_create_voyage_service(self):
"""Test Voyage AI service creation."""
with patch('crewai.rag.embeddings.factory.build_embedder'):
service = EmbeddingService.create_voyage_service(
model="voyage-large-2",
api_key="test-key"
)
assert service.config.provider == "voyageai"
assert service.config.model == "voyage-large-2"
assert service.config.api_key == "test-key"
def test_create_cohere_service(self):
"""Test Cohere service creation."""
with patch('crewai.rag.embeddings.factory.build_embedder'):
service = EmbeddingService.create_cohere_service(
model="embed-multilingual-v3.0",
api_key="test-key"
)
assert service.config.provider == "cohere"
assert service.config.model == "embed-multilingual-v3.0"
assert service.config.api_key == "test-key"
def test_create_gemini_service(self):
"""Test Gemini service creation."""
with patch('crewai.rag.embeddings.factory.build_embedder'):
service = EmbeddingService.create_gemini_service(
model="models/embedding-001",
api_key="test-key"
)
assert service.config.provider == "google-generativeai"
assert service.config.model == "models/embedding-001"
assert service.config.api_key == "test-key"
class TestProviderConfigurations:
"""Test provider-specific configurations."""
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_openai_config(self, mock_build_embedder):
"""Test OpenAI configuration mapping."""
mock_build_embedder.return_value = Mock()
service = EmbeddingService(
provider="openai",
model="text-embedding-3-small",
api_key="test-key",
extra_config={"dimensions": 1024}
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "openai"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model_name"] == "text-embedding-3-small"
assert call_args["config"]["dimensions"] == 1024
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_voyageai_config(self, mock_build_embedder):
"""Test Voyage AI configuration mapping."""
mock_build_embedder.return_value = Mock()
service = EmbeddingService(
provider="voyageai",
model="voyage-2",
api_key="test-key",
timeout=60.0,
max_retries=5,
extra_config={"input_type": "document"}
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "voyageai"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model"] == "voyage-2"
assert call_args["config"]["timeout"] == 60.0
assert call_args["config"]["max_retries"] == 5
assert call_args["config"]["input_type"] == "document"
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_cohere_config(self, mock_build_embedder):
"""Test Cohere configuration mapping."""
mock_build_embedder.return_value = Mock()
service = EmbeddingService(
provider="cohere",
model="embed-english-v3.0",
api_key="test-key"
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "cohere"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model_name"] == "embed-english-v3.0"
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_gemini_config(self, mock_build_embedder):
"""Test Gemini configuration mapping."""
mock_build_embedder.return_value = Mock()
service = EmbeddingService(
provider="google-generativeai",
model="models/embedding-001",
api_key="test-key"
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "google-generativeai"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model_name"] == "models/embedding-001"

View File

@@ -4,6 +4,7 @@ from dataclasses import field
from typing import Literal, cast from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass from pydantic.dataclasses import dataclass as pyd_dataclass
from qdrant_client.models import VectorParams
from crewai.rag.config.base import BaseRagConfig from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
@@ -53,3 +54,4 @@ class QdrantConfig(BaseRagConfig):
embedding_function: QdrantEmbeddingFunctionWrapper = field( embedding_function: QdrantEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function default_factory=_default_embedding_function
) )
vectors_config: VectorParams | None = field(default=None)

View File

@@ -4,8 +4,8 @@ import asyncio
from typing import TypeGuard from typing import TypeGuard
from uuid import uuid4 from uuid import uuid4
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client import ( from qdrant_client import (
AsyncQdrantClient, # type: ignore[import-not-found]
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found] QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
) )
from qdrant_client.models import ( # type: ignore[import-not-found] from qdrant_client.models import ( # type: ignore[import-not-found]

2
uv.lock generated
View File

@@ -1124,7 +1124,6 @@ dependencies = [
{ name = "python-docx" }, { name = "python-docx" },
{ name = "pytube" }, { name = "pytube" },
{ name = "requests" }, { name = "requests" },
{ name = "stagehand" },
{ name = "tiktoken" }, { name = "tiktoken" },
{ name = "youtube-transcript-api" }, { name = "youtube-transcript-api" },
] ]
@@ -1295,7 +1294,6 @@ requires-dist = [
{ name = "spider-client", marker = "extra == 'spider-client'", specifier = ">=0.1.25" }, { name = "spider-client", marker = "extra == 'spider-client'", specifier = ">=0.1.25" },
{ name = "sqlalchemy", marker = "extra == 'singlestore'", specifier = ">=2.0.40" }, { name = "sqlalchemy", marker = "extra == 'singlestore'", specifier = ">=2.0.40" },
{ name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0.35" }, { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0.35" },
{ name = "stagehand", specifier = ">=0.4.1" },
{ name = "stagehand", marker = "extra == 'stagehand'", specifier = ">=0.4.1" }, { name = "stagehand", marker = "extra == 'stagehand'", specifier = ">=0.4.1" },
{ name = "tavily-python", marker = "extra == 'tavily-python'", specifier = ">=0.5.4" }, { name = "tavily-python", marker = "extra == 'tavily-python'", specifier = ">=0.5.4" },
{ name = "tiktoken", specifier = ">=0.8.0" }, { name = "tiktoken", specifier = ">=0.8.0" },