Compare commits

...

3 Commits

Author SHA1 Message Date
Lorenze Jay
7351e4b0ef feat: bump versions to 1.0.0b2 (#3713) 2025-10-16 08:27:57 -07:00
Lorenze Jay
d9b68ddd85 moved stagehand as optional dep (#3712) 2025-10-15 15:50:59 -07:00
Lorenze Jay
2d5ad7a187 Lorenze/tools drop litellm (#3710)
* completely drop litellm and correctly pass config for qdrant

* feat: add support for additional embedding models in EmbeddingService

- Expanded the list of supported embedding models to include Google Vertex, Hugging Face, Jina, Ollama, OpenAI, Roboflow, Watson X, custom embeddings, Sentence Transformers, Text2Vec, OpenClip, and Instructor.
- This enhancement improves the versatility of the EmbeddingService by allowing integration with a wider range of embedding providers.

* fix: update collection parameter handling in CrewAIRagAdapter

- Changed the condition for setting vectors_config in the CrewAIRagAdapter to check for QdrantConfig instance instead of using hasattr. This improves type safety and ensures proper configuration handling for Qdrant integration.
2025-10-15 15:34:44 -07:00
12 changed files with 880 additions and 40 deletions

View File

@@ -12,10 +12,9 @@ dependencies = [
"pytube>=15.0.0",
"requests>=2.32.5",
"docker>=7.1.0",
"crewai==1.0.0b1",
"crewai==1.0.0b2",
"lancedb>=0.5.4",
"tiktoken>=0.8.0",
"stagehand>=0.4.1",
"beautifulsoup4>=4.13.4",
"pypdf>=5.9.0",
"python-docx>=1.2.0",

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.0.0b1"
__version__ = "1.0.0b2"

View File

@@ -3,13 +3,16 @@
import hashlib
from pathlib import Path
from typing import Any, TypeAlias, TypedDict
import uuid
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai.rag.qdrant.config import QdrantConfig
from crewai.rag.types import BaseRecord, SearchResult
from pydantic import PrivateAttr
from qdrant_client.models import VectorParams
from typing_extensions import Unpack
from crewai_tools.rag.data_types import DataType
@@ -52,7 +55,11 @@ class CrewAIRagAdapter(Adapter):
self._client = create_client(self.config)
else:
self._client = get_rag_client()
self._client.get_or_create_collection(collection_name=self.collection_name)
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
if isinstance(self.config.vectors_config, VectorParams):
collection_params["vectors_config"] = self.config.vectors_config
self._client.get_or_create_collection(**collection_params)
def query(
self,
@@ -76,6 +83,8 @@ class CrewAIRagAdapter(Adapter):
if similarity_threshold is not None
else self.similarity_threshold
)
if self._client is None:
raise ValueError("Client is not initialized")
results: list[SearchResult] = self._client.search(
collection_name=self.collection_name,
@@ -201,9 +210,10 @@ class CrewAIRagAdapter(Adapter):
if isinstance(arg, dict):
file_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256(
chunk_hash = hashlib.sha256(
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
).hexdigest()
chunk_id = str(uuid.UUID(chunk_hash[:32]))
documents.append(
{
@@ -251,9 +261,10 @@ class CrewAIRagAdapter(Adapter):
if isinstance(arg, dict):
chunk_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256(
chunk_hash = hashlib.sha256(
f"{loader_result.doc_id}_{i}_{chunk}".encode()
).hexdigest()
chunk_id = str(uuid.UUID(chunk_hash[:32]))
documents.append(
{
@@ -264,6 +275,8 @@ class CrewAIRagAdapter(Adapter):
)
if documents:
if self._client is None:
raise ValueError("Client is not initialized")
self._client.add_documents(
collection_name=self.collection_name, documents=documents
)

View File

@@ -4,12 +4,12 @@ from typing import Any
from uuid import uuid4
import chromadb
import litellm
from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.rag.base_loader import BaseLoader
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.embedding_service import EmbeddingService
from crewai_tools.rag.misc import compute_sha256
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.tools.rag.rag_tool import Adapter
@@ -18,31 +18,6 @@ from crewai_tools.tools.rag.rag_tool import Adapter
logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
self.model = model
self.kwargs = kwargs
def embed_text(self, text: str) -> list[float]:
try:
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
return response.data[0]["embedding"]
except Exception as e:
logger.error(f"Error generating embedding: {e}")
raise
def embed_batch(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []
try:
response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
return [data["embedding"] for data in response.data]
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise
class Document(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
content: str
@@ -54,6 +29,7 @@ class Document(BaseModel):
class RAG(Adapter):
collection_name: str = "crewai_knowledge_base"
persist_directory: str | None = None
embedding_provider: str = "openai"
embedding_model: str = "text-embedding-3-large"
summarize: bool = False
top_k: int = 5
@@ -79,7 +55,9 @@ class RAG(Adapter):
)
self._embedding_service = EmbeddingService(
model=self.embedding_model, **self.embedding_config
provider=self.embedding_provider,
model=self.embedding_model,
**self.embedding_config,
)
except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}")
@@ -181,7 +159,7 @@ class RAG(Adapter):
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: dict[str, Any] | None = None) -> str:
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
try:
question_embedding = self._embedding_service.embed_text(question)

View File

@@ -0,0 +1,508 @@
"""
Enhanced embedding service that leverages CrewAI's existing embedding providers.
This replaces the litellm-based EmbeddingService with a more flexible architecture.
"""
import logging
import os
from typing import Any
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class EmbeddingConfig(BaseModel):
"""Configuration for embedding providers."""
provider: str = Field(description="Embedding provider name")
model: str = Field(description="Model name to use")
api_key: str | None = Field(default=None, description="API key for the provider")
timeout: float | None = Field(
default=30.0, description="Request timeout in seconds"
)
max_retries: int = Field(default=3, description="Maximum number of retries")
batch_size: int = Field(
default=100, description="Batch size for processing multiple texts"
)
extra_config: dict[str, Any] = Field(
default_factory=dict, description="Additional provider-specific configuration"
)
class EmbeddingService:
"""
Enhanced embedding service that uses CrewAI's existing embedding providers.
Supports multiple providers:
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
- cohere: Cohere embeddings (embed-english-v3.0, embed-multilingual-v3.0, etc.)
- google-generativeai: Google Gemini embeddings (models/embedding-001, etc.)
- google-vertex: Google Vertex embeddings (models/embedding-001, etc.)
- huggingface: Hugging Face embeddings (sentence-transformers/all-MiniLM-L6-v2, etc.)
- jina: Jina embeddings (jina-embeddings-v2-base-en, etc.)
- ollama: Ollama embeddings (nomic-embed-text, etc.)
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
- roboflow: Roboflow embeddings (roboflow-embeddings-v2-base-en, etc.)
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
- watsonx: Watson X embeddings (ibm/slate-125m-english-rtrvr, etc.)
- custom: Custom embeddings (embedding_callable, etc.)
- sentence-transformer: Sentence Transformers embeddings (all-MiniLM-L6-v2, etc.)
- text2vec: Text2Vec embeddings (text2vec-base-en, etc.)
- openclip: OpenClip embeddings (openclip-large-v2, etc.)
- instructor: Instructor embeddings (hkunlp/instructor-large, etc.)
- onnx: ONNX embeddings (onnx-large-v2, etc.)
"""
def __init__(
self,
provider: str = "openai",
model: str = "text-embedding-3-small",
api_key: str | None = None,
**kwargs: Any,
):
"""
Initialize the embedding service.
Args:
provider: The embedding provider to use
model: The model name
api_key: API key (if not provided, will look for environment variables)
**kwargs: Additional configuration options
"""
self.config = EmbeddingConfig(
provider=provider,
model=model,
api_key=api_key or self._get_default_api_key(provider),
**kwargs,
)
self._embedding_function = None
self._initialize_embedding_function()
def _get_default_api_key(self, provider: str) -> str | None:
"""Get default API key from environment variables."""
env_key_map = {
"azure": "AZURE_OPENAI_API_KEY",
"amazon-bedrock": "AWS_ACCESS_KEY_ID", # or AWS_PROFILE
"cohere": "COHERE_API_KEY",
"google-generativeai": "GOOGLE_API_KEY",
"google-vertex": "GOOGLE_APPLICATION_CREDENTIALS",
"huggingface": "HUGGINGFACE_API_KEY",
"jina": "JINA_API_KEY",
"ollama": None, # Ollama typically runs locally without API key
"openai": "OPENAI_API_KEY",
"roboflow": "ROBOFLOW_API_KEY",
"voyageai": "VOYAGE_API_KEY",
"watsonx": "WATSONX_API_KEY",
}
env_key = env_key_map.get(provider)
if env_key:
return os.getenv(env_key)
return None
def _initialize_embedding_function(self):
"""Initialize the embedding function using CrewAI's factory."""
try:
from crewai.rag.embeddings.factory import build_embedder
# Build the configuration for CrewAI's factory
config = self._build_provider_config()
# Create the embedding function
self._embedding_function = build_embedder(config)
logger.info(
f"Initialized {self.config.provider} embedding service with model "
f"{self.config.model}"
)
except ImportError as e:
raise ImportError(
f"CrewAI embedding providers not available. "
f"Make sure crewai is installed: {e}"
) from e
except Exception as e:
logger.error(f"Failed to initialize embedding function: {e}")
raise RuntimeError(
f"Failed to initialize {self.config.provider} embedding service: {e}"
) from e
def _build_provider_config(self) -> dict[str, Any]:
"""Build configuration dictionary for CrewAI's embedding factory."""
base_config = {"provider": self.config.provider, "config": {}}
# Provider-specific configuration mapping
if self.config.provider == "openai":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "azure":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "voyageai":
base_config["config"] = {
"api_key": self.config.api_key,
"model": self.config.model,
"max_retries": self.config.max_retries,
"timeout": self.config.timeout,
**self.config.extra_config,
}
elif self.config.provider == "cohere":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider in ["google-generativeai", "google-vertex"]:
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "amazon-bedrock":
base_config["config"] = {
"aws_access_key_id": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "huggingface":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "jina":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "ollama":
base_config["config"] = {
"model": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "sentence-transformer":
base_config["config"] = {
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "instructor":
base_config["config"] = {
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "onnx":
base_config["config"] = {
**self.config.extra_config,
}
elif self.config.provider == "roboflow":
base_config["config"] = {
"api_key": self.config.api_key,
**self.config.extra_config,
}
elif self.config.provider == "openclip":
base_config["config"] = {
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "text2vec":
base_config["config"] = {
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "watsonx":
base_config["config"] = {
"api_key": self.config.api_key,
"model_name": self.config.model,
**self.config.extra_config,
}
elif self.config.provider == "custom":
# Custom provider requires embedding_callable in extra_config
base_config["config"] = {
**self.config.extra_config,
}
else:
# Generic configuration for any unlisted providers
base_config["config"] = {
"api_key": self.config.api_key,
"model": self.config.model,
**self.config.extra_config,
}
return base_config
def embed_text(self, text: str) -> list[float]:
"""
Generate embedding for a single text.
Args:
text: Text to embed
Returns:
List of floats representing the embedding
Raises:
RuntimeError: If embedding generation fails
"""
if not text or not text.strip():
logger.warning("Empty text provided for embedding")
return []
try:
# Use ChromaDB's embedding function interface
embeddings = self._embedding_function([text]) # type: ignore
return embeddings[0] if embeddings else []
except Exception as e:
logger.error(f"Error generating embedding for text: {e}")
raise RuntimeError(f"Failed to generate embedding: {e}") from e
def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for multiple texts.
Args:
texts: List of texts to embed
Returns:
List of embedding vectors
Raises:
RuntimeError: If embedding generation fails
"""
if not texts:
return []
# Filter out empty texts
valid_texts = [text for text in texts if text and text.strip()]
if not valid_texts:
logger.warning("No valid texts provided for batch embedding")
return []
try:
# Process in batches to avoid API limits
all_embeddings = []
for i in range(0, len(valid_texts), self.config.batch_size):
batch = valid_texts[i : i + self.config.batch_size]
batch_embeddings = self._embedding_function(batch) # type: ignore
all_embeddings.extend(batch_embeddings)
return all_embeddings
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise RuntimeError(f"Failed to generate batch embeddings: {e}") from e
def get_embedding_dimension(self) -> int | None:
"""
Get the dimension of embeddings produced by this service.
Returns:
Embedding dimension or None if unknown
"""
# Try to get dimension by generating a test embedding
try:
test_embedding = self.embed_text("test")
return len(test_embedding) if test_embedding else None
except Exception:
logger.warning("Could not determine embedding dimension")
return None
def validate_connection(self) -> bool:
"""
Validate that the embedding service is working correctly.
Returns:
True if the service is working, False otherwise
"""
try:
test_embedding = self.embed_text("test connection")
return len(test_embedding) > 0
except Exception as e:
logger.error(f"Connection validation failed: {e}")
return False
def get_service_info(self) -> dict[str, Any]:
"""
Get information about the current embedding service.
Returns:
Dictionary with service information
"""
return {
"provider": self.config.provider,
"model": self.config.model,
"embedding_dimension": self.get_embedding_dimension(),
"batch_size": self.config.batch_size,
"is_connected": self.validate_connection(),
}
@classmethod
def list_supported_providers(cls) -> list[str]:
"""
List all supported embedding providers.
Returns:
List of supported provider names
"""
return [
"azure",
"amazon-bedrock",
"cohere",
"custom",
"google-generativeai",
"google-vertex",
"huggingface",
"instructor",
"jina",
"ollama",
"onnx",
"openai",
"openclip",
"roboflow",
"sentence-transformer",
"text2vec",
"voyageai",
"watsonx",
]
@classmethod
def create_openai_service(
cls,
model: str = "text-embedding-3-small",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create an OpenAI embedding service."""
return cls(provider="openai", model=model, api_key=api_key, **kwargs)
@classmethod
def create_voyage_service(
cls, model: str = "voyage-2", api_key: str | None = None, **kwargs: Any
) -> "EmbeddingService":
"""Create a Voyage AI embedding service."""
return cls(provider="voyageai", model=model, api_key=api_key, **kwargs)
@classmethod
def create_cohere_service(
cls,
model: str = "embed-english-v3.0",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Cohere embedding service."""
return cls(provider="cohere", model=model, api_key=api_key, **kwargs)
@classmethod
def create_gemini_service(
cls,
model: str = "models/embedding-001",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Google Gemini embedding service."""
return cls(
provider="google-generativeai", model=model, api_key=api_key, **kwargs
)
@classmethod
def create_azure_service(
cls,
model: str = "text-embedding-ada-002",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create an Azure OpenAI embedding service."""
return cls(provider="azure", model=model, api_key=api_key, **kwargs)
@classmethod
def create_bedrock_service(
cls,
model: str = "amazon.titan-embed-text-v1",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create an Amazon Bedrock embedding service."""
return cls(provider="amazon-bedrock", model=model, api_key=api_key, **kwargs)
@classmethod
def create_huggingface_service(
cls,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Hugging Face embedding service."""
return cls(provider="huggingface", model=model, api_key=api_key, **kwargs)
@classmethod
def create_sentence_transformer_service(
cls,
model: str = "all-MiniLM-L6-v2",
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Sentence Transformers embedding service (local)."""
return cls(provider="sentence-transformer", model=model, **kwargs)
@classmethod
def create_ollama_service(
cls,
model: str = "nomic-embed-text",
**kwargs: Any,
) -> "EmbeddingService":
"""Create an Ollama embedding service (local)."""
return cls(provider="ollama", model=model, **kwargs)
@classmethod
def create_jina_service(
cls,
model: str = "jina-embeddings-v2-base-en",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Jina AI embedding service."""
return cls(provider="jina", model=model, api_key=api_key, **kwargs)
@classmethod
def create_instructor_service(
cls,
model: str = "hkunlp/instructor-large",
**kwargs: Any,
) -> "EmbeddingService":
"""Create an Instructor embedding service."""
return cls(provider="instructor", model=model, **kwargs)
@classmethod
def create_watsonx_service(
cls,
model: str = "ibm/slate-125m-english-rtrvr",
api_key: str | None = None,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a Watson X embedding service."""
return cls(provider="watsonx", model=model, api_key=api_key, **kwargs)
@classmethod
def create_custom_service(
cls,
embedding_callable: Any,
**kwargs: Any,
) -> "EmbeddingService":
"""Create a custom embedding service with your own embedding function."""
return cls(
provider="custom",
model="custom",
extra_config={"embedding_callable": embedding_callable},
**kwargs,
)

View File

@@ -0,0 +1,342 @@
"""
Tests for the enhanced embedding service.
"""
import os
import pytest
from unittest.mock import Mock, patch
from crewai_tools.rag.embedding_service import EmbeddingService, EmbeddingConfig
class TestEmbeddingConfig:
"""Test the EmbeddingConfig model."""
def test_default_config(self):
"""Test default configuration values."""
config = EmbeddingConfig(provider="openai", model="text-embedding-3-small")
assert config.provider == "openai"
assert config.model == "text-embedding-3-small"
assert config.api_key is None
assert config.timeout == 30.0
assert config.max_retries == 3
assert config.batch_size == 100
assert config.extra_config == {}
def test_custom_config(self):
"""Test custom configuration values."""
config = EmbeddingConfig(
provider="voyageai",
model="voyage-2",
api_key="test-key",
timeout=60.0,
max_retries=5,
batch_size=50,
extra_config={"input_type": "document"}
)
assert config.provider == "voyageai"
assert config.model == "voyage-2"
assert config.api_key == "test-key"
assert config.timeout == 60.0
assert config.max_retries == 5
assert config.batch_size == 50
assert config.extra_config == {"input_type": "document"}
class TestEmbeddingService:
"""Test the EmbeddingService class."""
def test_list_supported_providers(self):
"""Test listing supported providers."""
providers = EmbeddingService.list_supported_providers()
expected_providers = [
"openai", "azure", "voyageai", "cohere", "google-generativeai",
"amazon-bedrock", "huggingface", "jina", "ollama", "sentence-transformer",
"instructor", "watsonx", "custom"
]
assert isinstance(providers, list)
assert len(providers) >= 15 # Should have at least 15 providers
assert all(provider in providers for provider in expected_providers)
def test_get_default_api_key(self):
"""Test getting default API keys from environment."""
service = EmbeddingService.__new__(EmbeddingService) # Create without __init__
# Test with environment variable set
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-openai-key"}):
api_key = service._get_default_api_key("openai")
assert api_key == "test-openai-key"
# Test with no environment variable
with patch.dict(os.environ, {}, clear=True):
api_key = service._get_default_api_key("openai")
assert api_key is None
# Test unknown provider
api_key = service._get_default_api_key("unknown-provider")
assert api_key is None
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_initialization_success(self, mock_build_embedder):
"""Test successful initialization."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(
provider="openai",
model="text-embedding-3-small",
api_key="test-key"
)
assert service.config.provider == "openai"
assert service.config.model == "text-embedding-3-small"
assert service.config.api_key == "test-key"
assert service._embedding_function == mock_embedding_function
# Verify build_embedder was called with correct config
mock_build_embedder.assert_called_once()
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "openai"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model_name"] == "text-embedding-3-small"
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_initialization_import_error(self, mock_build_embedder):
"""Test initialization with import error."""
mock_build_embedder.side_effect = ImportError("CrewAI not installed")
with pytest.raises(ImportError, match="CrewAI embedding providers not available"):
EmbeddingService(provider="openai", model="test-model", api_key="test-key")
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_text_success(self, mock_build_embedder):
"""Test successful text embedding."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
result = service.embed_text("test text")
assert result == [0.1, 0.2, 0.3]
mock_embedding_function.assert_called_once_with(["test text"])
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_text_empty_input(self, mock_build_embedder):
"""Test embedding empty text."""
mock_embedding_function = Mock()
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
result = service.embed_text("")
assert result == []
result = service.embed_text(" ")
assert result == []
# Embedding function should not be called for empty text
mock_embedding_function.assert_not_called()
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_batch_success(self, mock_build_embedder):
"""Test successful batch embedding."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
texts = ["text1", "text2", "text3"]
result = service.embed_batch(texts)
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
mock_embedding_function.assert_called_once_with(texts)
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_embed_batch_empty_input(self, mock_build_embedder):
"""Test batch embedding with empty input."""
mock_embedding_function = Mock()
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
# Empty list
result = service.embed_batch([])
assert result == []
# List with empty strings
result = service.embed_batch(["", " ", ""])
assert result == []
# Embedding function should not be called for empty input
mock_embedding_function.assert_not_called()
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_validate_connection(self, mock_build_embedder):
"""Test connection validation."""
# Mock successful embedding
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
assert service.validate_connection() is True
# Mock failed embedding
mock_embedding_function.side_effect = Exception("Connection failed")
assert service.validate_connection() is False
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_get_service_info(self, mock_build_embedder):
"""Test getting service information."""
# Mock the embedding function
mock_embedding_function = Mock()
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
mock_build_embedder.return_value = mock_embedding_function
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
info = service.get_service_info()
assert info["provider"] == "openai"
assert info["model"] == "test-model"
assert info["embedding_dimension"] == 3
assert info["batch_size"] == 100
assert info["is_connected"] is True
def test_create_openai_service(self):
"""Test OpenAI service creation."""
with patch('crewai.rag.embeddings.factory.build_embedder'):
service = EmbeddingService.create_openai_service(
model="text-embedding-3-large",
api_key="test-key"
)
assert service.config.provider == "openai"
assert service.config.model == "text-embedding-3-large"
assert service.config.api_key == "test-key"
def test_create_voyage_service(self):
"""Test Voyage AI service creation."""
with patch('crewai.rag.embeddings.factory.build_embedder'):
service = EmbeddingService.create_voyage_service(
model="voyage-large-2",
api_key="test-key"
)
assert service.config.provider == "voyageai"
assert service.config.model == "voyage-large-2"
assert service.config.api_key == "test-key"
def test_create_cohere_service(self):
"""Test Cohere service creation."""
with patch('crewai.rag.embeddings.factory.build_embedder'):
service = EmbeddingService.create_cohere_service(
model="embed-multilingual-v3.0",
api_key="test-key"
)
assert service.config.provider == "cohere"
assert service.config.model == "embed-multilingual-v3.0"
assert service.config.api_key == "test-key"
def test_create_gemini_service(self):
"""Test Gemini service creation."""
with patch('crewai.rag.embeddings.factory.build_embedder'):
service = EmbeddingService.create_gemini_service(
model="models/embedding-001",
api_key="test-key"
)
assert service.config.provider == "google-generativeai"
assert service.config.model == "models/embedding-001"
assert service.config.api_key == "test-key"
class TestProviderConfigurations:
"""Test provider-specific configurations."""
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_openai_config(self, mock_build_embedder):
"""Test OpenAI configuration mapping."""
mock_build_embedder.return_value = Mock()
service = EmbeddingService(
provider="openai",
model="text-embedding-3-small",
api_key="test-key",
extra_config={"dimensions": 1024}
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "openai"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model_name"] == "text-embedding-3-small"
assert call_args["config"]["dimensions"] == 1024
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_voyageai_config(self, mock_build_embedder):
"""Test Voyage AI configuration mapping."""
mock_build_embedder.return_value = Mock()
service = EmbeddingService(
provider="voyageai",
model="voyage-2",
api_key="test-key",
timeout=60.0,
max_retries=5,
extra_config={"input_type": "document"}
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "voyageai"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model"] == "voyage-2"
assert call_args["config"]["timeout"] == 60.0
assert call_args["config"]["max_retries"] == 5
assert call_args["config"]["input_type"] == "document"
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_cohere_config(self, mock_build_embedder):
"""Test Cohere configuration mapping."""
mock_build_embedder.return_value = Mock()
service = EmbeddingService(
provider="cohere",
model="embed-english-v3.0",
api_key="test-key"
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "cohere"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model_name"] == "embed-english-v3.0"
@patch('crewai.rag.embeddings.factory.build_embedder')
def test_gemini_config(self, mock_build_embedder):
"""Test Gemini configuration mapping."""
mock_build_embedder.return_value = Mock()
service = EmbeddingService(
provider="google-generativeai",
model="models/embedding-001",
api_key="test-key"
)
# Check the configuration passed to build_embedder
call_args = mock_build_embedder.call_args[0][0]
assert call_args["provider"] == "google-generativeai"
assert call_args["config"]["api_key"] == "test-key"
assert call_args["config"]["model_name"] == "models/embedding-001"

View File

@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.0.0b1",
"crewai-tools==1.0.0b2",
]
embeddings = [
"tiktoken~=0.8.0"

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.0.0b1"
__version__ = "1.0.0b2"
_telemetry_submitted = False

View File

@@ -4,6 +4,7 @@ from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from qdrant_client.models import VectorParams
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
@@ -53,3 +54,4 @@ class QdrantConfig(BaseRagConfig):
embedding_function: QdrantEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)
vectors_config: VectorParams | None = field(default=None)

View File

@@ -4,8 +4,8 @@ import asyncio
from typing import TypeGuard
from uuid import uuid4
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client import (
AsyncQdrantClient, # type: ignore[import-not-found]
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
)
from qdrant_client.models import ( # type: ignore[import-not-found]

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools."""
__version__ = "1.0.0b1"
__version__ = "1.0.0b2"

2
uv.lock generated
View File

@@ -1124,7 +1124,6 @@ dependencies = [
{ name = "python-docx" },
{ name = "pytube" },
{ name = "requests" },
{ name = "stagehand" },
{ name = "tiktoken" },
{ name = "youtube-transcript-api" },
]
@@ -1295,7 +1294,6 @@ requires-dist = [
{ name = "spider-client", marker = "extra == 'spider-client'", specifier = ">=0.1.25" },
{ name = "sqlalchemy", marker = "extra == 'singlestore'", specifier = ">=2.0.40" },
{ name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0.35" },
{ name = "stagehand", specifier = ">=0.4.1" },
{ name = "stagehand", marker = "extra == 'stagehand'", specifier = ">=0.4.1" },
{ name = "tavily-python", marker = "extra == 'tavily-python'", specifier = ">=0.5.4" },
{ name = "tiktoken", specifier = ">=0.8.0" },