mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-05 06:08:29 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7351e4b0ef | ||
|
|
d9b68ddd85 | ||
|
|
2d5ad7a187 |
@@ -12,10 +12,9 @@ dependencies = [
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.0.0b1",
|
||||
"crewai==1.0.0b2",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"stagehand>=0.4.1",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"pypdf>=5.9.0",
|
||||
"python-docx>=1.2.0",
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0b1"
|
||||
__version__ = "1.0.0b2"
|
||||
|
||||
@@ -3,13 +3,16 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from pydantic import PrivateAttr
|
||||
from qdrant_client.models import VectorParams
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
@@ -52,7 +55,11 @@ class CrewAIRagAdapter(Adapter):
|
||||
self._client = create_client(self.config)
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
self._client.get_or_create_collection(collection_name=self.collection_name)
|
||||
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
||||
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||
if isinstance(self.config.vectors_config, VectorParams):
|
||||
collection_params["vectors_config"] = self.config.vectors_config
|
||||
self._client.get_or_create_collection(**collection_params)
|
||||
|
||||
def query(
|
||||
self,
|
||||
@@ -76,6 +83,8 @@ class CrewAIRagAdapter(Adapter):
|
||||
if similarity_threshold is not None
|
||||
else self.similarity_threshold
|
||||
)
|
||||
if self._client is None:
|
||||
raise ValueError("Client is not initialized")
|
||||
|
||||
results: list[SearchResult] = self._client.search(
|
||||
collection_name=self.collection_name,
|
||||
@@ -201,9 +210,10 @@ class CrewAIRagAdapter(Adapter):
|
||||
if isinstance(arg, dict):
|
||||
file_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(
|
||||
chunk_hash = hashlib.sha256(
|
||||
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
|
||||
).hexdigest()
|
||||
chunk_id = str(uuid.UUID(chunk_hash[:32]))
|
||||
|
||||
documents.append(
|
||||
{
|
||||
@@ -251,9 +261,10 @@ class CrewAIRagAdapter(Adapter):
|
||||
if isinstance(arg, dict):
|
||||
chunk_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(
|
||||
chunk_hash = hashlib.sha256(
|
||||
f"{loader_result.doc_id}_{i}_{chunk}".encode()
|
||||
).hexdigest()
|
||||
chunk_id = str(uuid.UUID(chunk_hash[:32]))
|
||||
|
||||
documents.append(
|
||||
{
|
||||
@@ -264,6 +275,8 @@ class CrewAIRagAdapter(Adapter):
|
||||
)
|
||||
|
||||
if documents:
|
||||
if self._client is None:
|
||||
raise ValueError("Client is not initialized")
|
||||
self._client.add_documents(
|
||||
collection_name=self.collection_name, documents=documents
|
||||
)
|
||||
|
||||
@@ -4,12 +4,12 @@ from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
import litellm
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.embedding_service import EmbeddingService
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
@@ -18,31 +18,6 @@ from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
|
||||
self.model = model
|
||||
self.kwargs = kwargs
|
||||
|
||||
def embed_text(self, text: str) -> list[float]:
|
||||
try:
|
||||
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
|
||||
return response.data[0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding: {e}")
|
||||
raise
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
|
||||
return [data["embedding"] for data in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating batch embeddings: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
content: str
|
||||
@@ -54,6 +29,7 @@ class Document(BaseModel):
|
||||
class RAG(Adapter):
|
||||
collection_name: str = "crewai_knowledge_base"
|
||||
persist_directory: str | None = None
|
||||
embedding_provider: str = "openai"
|
||||
embedding_model: str = "text-embedding-3-large"
|
||||
summarize: bool = False
|
||||
top_k: int = 5
|
||||
@@ -79,7 +55,9 @@ class RAG(Adapter):
|
||||
)
|
||||
|
||||
self._embedding_service = EmbeddingService(
|
||||
model=self.embedding_model, **self.embedding_config
|
||||
provider=self.embedding_provider,
|
||||
model=self.embedding_model,
|
||||
**self.embedding_config,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChromaDB: {e}")
|
||||
@@ -181,7 +159,7 @@ class RAG(Adapter):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str:
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
|
||||
try:
|
||||
question_embedding = self._embedding_service.embed_text(question)
|
||||
|
||||
|
||||
508
lib/crewai-tools/src/crewai_tools/rag/embedding_service.py
Normal file
508
lib/crewai-tools/src/crewai_tools/rag/embedding_service.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""
|
||||
Enhanced embedding service that leverages CrewAI's existing embedding providers.
|
||||
This replaces the litellm-based EmbeddingService with a more flexible architecture.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
"""Configuration for embedding providers."""
|
||||
|
||||
provider: str = Field(description="Embedding provider name")
|
||||
model: str = Field(description="Model name to use")
|
||||
api_key: str | None = Field(default=None, description="API key for the provider")
|
||||
timeout: float | None = Field(
|
||||
default=30.0, description="Request timeout in seconds"
|
||||
)
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
batch_size: int = Field(
|
||||
default=100, description="Batch size for processing multiple texts"
|
||||
)
|
||||
extra_config: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional provider-specific configuration"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""
|
||||
Enhanced embedding service that uses CrewAI's existing embedding providers.
|
||||
|
||||
Supports multiple providers:
|
||||
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
|
||||
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
|
||||
- cohere: Cohere embeddings (embed-english-v3.0, embed-multilingual-v3.0, etc.)
|
||||
- google-generativeai: Google Gemini embeddings (models/embedding-001, etc.)
|
||||
- google-vertex: Google Vertex embeddings (models/embedding-001, etc.)
|
||||
- huggingface: Hugging Face embeddings (sentence-transformers/all-MiniLM-L6-v2, etc.)
|
||||
- jina: Jina embeddings (jina-embeddings-v2-base-en, etc.)
|
||||
- ollama: Ollama embeddings (nomic-embed-text, etc.)
|
||||
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
|
||||
- roboflow: Roboflow embeddings (roboflow-embeddings-v2-base-en, etc.)
|
||||
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
|
||||
- watsonx: Watson X embeddings (ibm/slate-125m-english-rtrvr, etc.)
|
||||
- custom: Custom embeddings (embedding_callable, etc.)
|
||||
- sentence-transformer: Sentence Transformers embeddings (all-MiniLM-L6-v2, etc.)
|
||||
- text2vec: Text2Vec embeddings (text2vec-base-en, etc.)
|
||||
- openclip: OpenClip embeddings (openclip-large-v2, etc.)
|
||||
- instructor: Instructor embeddings (hkunlp/instructor-large, etc.)
|
||||
- onnx: ONNX embeddings (onnx-large-v2, etc.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
model: str = "text-embedding-3-small",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Initialize the embedding service.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider to use
|
||||
model: The model name
|
||||
api_key: API key (if not provided, will look for environment variables)
|
||||
**kwargs: Additional configuration options
|
||||
"""
|
||||
self.config = EmbeddingConfig(
|
||||
provider=provider,
|
||||
model=model,
|
||||
api_key=api_key or self._get_default_api_key(provider),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._embedding_function = None
|
||||
self._initialize_embedding_function()
|
||||
|
||||
def _get_default_api_key(self, provider: str) -> str | None:
|
||||
"""Get default API key from environment variables."""
|
||||
env_key_map = {
|
||||
"azure": "AZURE_OPENAI_API_KEY",
|
||||
"amazon-bedrock": "AWS_ACCESS_KEY_ID", # or AWS_PROFILE
|
||||
"cohere": "COHERE_API_KEY",
|
||||
"google-generativeai": "GOOGLE_API_KEY",
|
||||
"google-vertex": "GOOGLE_APPLICATION_CREDENTIALS",
|
||||
"huggingface": "HUGGINGFACE_API_KEY",
|
||||
"jina": "JINA_API_KEY",
|
||||
"ollama": None, # Ollama typically runs locally without API key
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"roboflow": "ROBOFLOW_API_KEY",
|
||||
"voyageai": "VOYAGE_API_KEY",
|
||||
"watsonx": "WATSONX_API_KEY",
|
||||
}
|
||||
|
||||
env_key = env_key_map.get(provider)
|
||||
if env_key:
|
||||
return os.getenv(env_key)
|
||||
return None
|
||||
|
||||
def _initialize_embedding_function(self):
|
||||
"""Initialize the embedding function using CrewAI's factory."""
|
||||
try:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
# Build the configuration for CrewAI's factory
|
||||
config = self._build_provider_config()
|
||||
|
||||
# Create the embedding function
|
||||
self._embedding_function = build_embedder(config)
|
||||
|
||||
logger.info(
|
||||
f"Initialized {self.config.provider} embedding service with model "
|
||||
f"{self.config.model}"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"CrewAI embedding providers not available. "
|
||||
f"Make sure crewai is installed: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize embedding function: {e}")
|
||||
raise RuntimeError(
|
||||
f"Failed to initialize {self.config.provider} embedding service: {e}"
|
||||
) from e
|
||||
|
||||
def _build_provider_config(self) -> dict[str, Any]:
|
||||
"""Build configuration dictionary for CrewAI's embedding factory."""
|
||||
base_config = {"provider": self.config.provider, "config": {}}
|
||||
|
||||
# Provider-specific configuration mapping
|
||||
if self.config.provider == "openai":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "azure":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "voyageai":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model": self.config.model,
|
||||
"max_retries": self.config.max_retries,
|
||||
"timeout": self.config.timeout,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "cohere":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider in ["google-generativeai", "google-vertex"]:
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "amazon-bedrock":
|
||||
base_config["config"] = {
|
||||
"aws_access_key_id": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "huggingface":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "jina":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "ollama":
|
||||
base_config["config"] = {
|
||||
"model": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "sentence-transformer":
|
||||
base_config["config"] = {
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "instructor":
|
||||
base_config["config"] = {
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "onnx":
|
||||
base_config["config"] = {
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "roboflow":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "openclip":
|
||||
base_config["config"] = {
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "text2vec":
|
||||
base_config["config"] = {
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "watsonx":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "custom":
|
||||
# Custom provider requires embedding_callable in extra_config
|
||||
base_config["config"] = {
|
||||
**self.config.extra_config,
|
||||
}
|
||||
else:
|
||||
# Generic configuration for any unlisted providers
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
|
||||
return base_config
|
||||
|
||||
def embed_text(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
List of floats representing the embedding
|
||||
|
||||
Raises:
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
logger.warning("Empty text provided for embedding")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Use ChromaDB's embedding function interface
|
||||
embeddings = self._embedding_function([text]) # type: ignore
|
||||
return embeddings[0] if embeddings else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding for text: {e}")
|
||||
raise RuntimeError(f"Failed to generate embedding: {e}") from e
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
|
||||
Raises:
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Filter out empty texts
|
||||
valid_texts = [text for text in texts if text and text.strip()]
|
||||
if not valid_texts:
|
||||
logger.warning("No valid texts provided for batch embedding")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Process in batches to avoid API limits
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(valid_texts), self.config.batch_size):
|
||||
batch = valid_texts[i : i + self.config.batch_size]
|
||||
batch_embeddings = self._embedding_function(batch) # type: ignore
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating batch embeddings: {e}")
|
||||
raise RuntimeError(f"Failed to generate batch embeddings: {e}") from e
|
||||
|
||||
def get_embedding_dimension(self) -> int | None:
|
||||
"""
|
||||
Get the dimension of embeddings produced by this service.
|
||||
|
||||
Returns:
|
||||
Embedding dimension or None if unknown
|
||||
"""
|
||||
# Try to get dimension by generating a test embedding
|
||||
try:
|
||||
test_embedding = self.embed_text("test")
|
||||
return len(test_embedding) if test_embedding else None
|
||||
except Exception:
|
||||
logger.warning("Could not determine embedding dimension")
|
||||
return None
|
||||
|
||||
def validate_connection(self) -> bool:
|
||||
"""
|
||||
Validate that the embedding service is working correctly.
|
||||
|
||||
Returns:
|
||||
True if the service is working, False otherwise
|
||||
"""
|
||||
try:
|
||||
test_embedding = self.embed_text("test connection")
|
||||
return len(test_embedding) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Connection validation failed: {e}")
|
||||
return False
|
||||
|
||||
def get_service_info(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about the current embedding service.
|
||||
|
||||
Returns:
|
||||
Dictionary with service information
|
||||
"""
|
||||
return {
|
||||
"provider": self.config.provider,
|
||||
"model": self.config.model,
|
||||
"embedding_dimension": self.get_embedding_dimension(),
|
||||
"batch_size": self.config.batch_size,
|
||||
"is_connected": self.validate_connection(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_supported_providers(cls) -> list[str]:
|
||||
"""
|
||||
List all supported embedding providers.
|
||||
|
||||
Returns:
|
||||
List of supported provider names
|
||||
"""
|
||||
return [
|
||||
"azure",
|
||||
"amazon-bedrock",
|
||||
"cohere",
|
||||
"custom",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"huggingface",
|
||||
"instructor",
|
||||
"jina",
|
||||
"ollama",
|
||||
"onnx",
|
||||
"openai",
|
||||
"openclip",
|
||||
"roboflow",
|
||||
"sentence-transformer",
|
||||
"text2vec",
|
||||
"voyageai",
|
||||
"watsonx",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def create_openai_service(
|
||||
cls,
|
||||
model: str = "text-embedding-3-small",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an OpenAI embedding service."""
|
||||
return cls(provider="openai", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_voyage_service(
|
||||
cls, model: str = "voyage-2", api_key: str | None = None, **kwargs: Any
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Voyage AI embedding service."""
|
||||
return cls(provider="voyageai", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_cohere_service(
|
||||
cls,
|
||||
model: str = "embed-english-v3.0",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Cohere embedding service."""
|
||||
return cls(provider="cohere", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_gemini_service(
|
||||
cls,
|
||||
model: str = "models/embedding-001",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Google Gemini embedding service."""
|
||||
return cls(
|
||||
provider="google-generativeai", model=model, api_key=api_key, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_azure_service(
|
||||
cls,
|
||||
model: str = "text-embedding-ada-002",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Azure OpenAI embedding service."""
|
||||
return cls(provider="azure", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_bedrock_service(
|
||||
cls,
|
||||
model: str = "amazon.titan-embed-text-v1",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Amazon Bedrock embedding service."""
|
||||
return cls(provider="amazon-bedrock", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_huggingface_service(
|
||||
cls,
|
||||
model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Hugging Face embedding service."""
|
||||
return cls(provider="huggingface", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_sentence_transformer_service(
|
||||
cls,
|
||||
model: str = "all-MiniLM-L6-v2",
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Sentence Transformers embedding service (local)."""
|
||||
return cls(provider="sentence-transformer", model=model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_ollama_service(
|
||||
cls,
|
||||
model: str = "nomic-embed-text",
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Ollama embedding service (local)."""
|
||||
return cls(provider="ollama", model=model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_jina_service(
|
||||
cls,
|
||||
model: str = "jina-embeddings-v2-base-en",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Jina AI embedding service."""
|
||||
return cls(provider="jina", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_instructor_service(
|
||||
cls,
|
||||
model: str = "hkunlp/instructor-large",
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Instructor embedding service."""
|
||||
return cls(provider="instructor", model=model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_watsonx_service(
|
||||
cls,
|
||||
model: str = "ibm/slate-125m-english-rtrvr",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Watson X embedding service."""
|
||||
return cls(provider="watsonx", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_custom_service(
|
||||
cls,
|
||||
embedding_callable: Any,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a custom embedding service with your own embedding function."""
|
||||
return cls(
|
||||
provider="custom",
|
||||
model="custom",
|
||||
extra_config={"embedding_callable": embedding_callable},
|
||||
**kwargs,
|
||||
)
|
||||
342
lib/crewai-tools/tests/rag/test_embedding_service.py
Normal file
342
lib/crewai-tools/tests/rag/test_embedding_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Tests for the enhanced embedding service.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from crewai_tools.rag.embedding_service import EmbeddingService, EmbeddingConfig
|
||||
|
||||
|
||||
class TestEmbeddingConfig:
|
||||
"""Test the EmbeddingConfig model."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = EmbeddingConfig(provider="openai", model="text-embedding-3-small")
|
||||
|
||||
assert config.provider == "openai"
|
||||
assert config.model == "text-embedding-3-small"
|
||||
assert config.api_key is None
|
||||
assert config.timeout == 30.0
|
||||
assert config.max_retries == 3
|
||||
assert config.batch_size == 100
|
||||
assert config.extra_config == {}
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration values."""
|
||||
config = EmbeddingConfig(
|
||||
provider="voyageai",
|
||||
model="voyage-2",
|
||||
api_key="test-key",
|
||||
timeout=60.0,
|
||||
max_retries=5,
|
||||
batch_size=50,
|
||||
extra_config={"input_type": "document"}
|
||||
)
|
||||
|
||||
assert config.provider == "voyageai"
|
||||
assert config.model == "voyage-2"
|
||||
assert config.api_key == "test-key"
|
||||
assert config.timeout == 60.0
|
||||
assert config.max_retries == 5
|
||||
assert config.batch_size == 50
|
||||
assert config.extra_config == {"input_type": "document"}
|
||||
|
||||
|
||||
class TestEmbeddingService:
|
||||
"""Test the EmbeddingService class."""
|
||||
|
||||
def test_list_supported_providers(self):
|
||||
"""Test listing supported providers."""
|
||||
providers = EmbeddingService.list_supported_providers()
|
||||
expected_providers = [
|
||||
"openai", "azure", "voyageai", "cohere", "google-generativeai",
|
||||
"amazon-bedrock", "huggingface", "jina", "ollama", "sentence-transformer",
|
||||
"instructor", "watsonx", "custom"
|
||||
]
|
||||
|
||||
assert isinstance(providers, list)
|
||||
assert len(providers) >= 15 # Should have at least 15 providers
|
||||
assert all(provider in providers for provider in expected_providers)
|
||||
|
||||
def test_get_default_api_key(self):
|
||||
"""Test getting default API keys from environment."""
|
||||
service = EmbeddingService.__new__(EmbeddingService) # Create without __init__
|
||||
|
||||
# Test with environment variable set
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-openai-key"}):
|
||||
api_key = service._get_default_api_key("openai")
|
||||
assert api_key == "test-openai-key"
|
||||
|
||||
# Test with no environment variable
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
api_key = service._get_default_api_key("openai")
|
||||
assert api_key is None
|
||||
|
||||
# Test unknown provider
|
||||
api_key = service._get_default_api_key("unknown-provider")
|
||||
assert api_key is None
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_initialization_success(self, mock_build_embedder):
|
||||
"""Test successful initialization."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "openai"
|
||||
assert service.config.model == "text-embedding-3-small"
|
||||
assert service.config.api_key == "test-key"
|
||||
assert service._embedding_function == mock_embedding_function
|
||||
|
||||
# Verify build_embedder was called with correct config
|
||||
mock_build_embedder.assert_called_once()
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "openai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model_name"] == "text-embedding-3-small"
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_initialization_import_error(self, mock_build_embedder):
|
||||
"""Test initialization with import error."""
|
||||
mock_build_embedder.side_effect = ImportError("CrewAI not installed")
|
||||
|
||||
with pytest.raises(ImportError, match="CrewAI embedding providers not available"):
|
||||
EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_text_success(self, mock_build_embedder):
|
||||
"""Test successful text embedding."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
result = service.embed_text("test text")
|
||||
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
mock_embedding_function.assert_called_once_with(["test text"])
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_text_empty_input(self, mock_build_embedder):
|
||||
"""Test embedding empty text."""
|
||||
mock_embedding_function = Mock()
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
result = service.embed_text("")
|
||||
assert result == []
|
||||
|
||||
result = service.embed_text(" ")
|
||||
assert result == []
|
||||
|
||||
# Embedding function should not be called for empty text
|
||||
mock_embedding_function.assert_not_called()
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_batch_success(self, mock_build_embedder):
|
||||
"""Test successful batch embedding."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
texts = ["text1", "text2", "text3"]
|
||||
result = service.embed_batch(texts)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||
mock_embedding_function.assert_called_once_with(texts)
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_batch_empty_input(self, mock_build_embedder):
|
||||
"""Test batch embedding with empty input."""
|
||||
mock_embedding_function = Mock()
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
# Empty list
|
||||
result = service.embed_batch([])
|
||||
assert result == []
|
||||
|
||||
# List with empty strings
|
||||
result = service.embed_batch(["", " ", ""])
|
||||
assert result == []
|
||||
|
||||
# Embedding function should not be called for empty input
|
||||
mock_embedding_function.assert_not_called()
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_validate_connection(self, mock_build_embedder):
|
||||
"""Test connection validation."""
|
||||
# Mock successful embedding
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
assert service.validate_connection() is True
|
||||
|
||||
# Mock failed embedding
|
||||
mock_embedding_function.side_effect = Exception("Connection failed")
|
||||
assert service.validate_connection() is False
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_get_service_info(self, mock_build_embedder):
|
||||
"""Test getting service information."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
info = service.get_service_info()
|
||||
|
||||
assert info["provider"] == "openai"
|
||||
assert info["model"] == "test-model"
|
||||
assert info["embedding_dimension"] == 3
|
||||
assert info["batch_size"] == 100
|
||||
assert info["is_connected"] is True
|
||||
|
||||
def test_create_openai_service(self):
|
||||
"""Test OpenAI service creation."""
|
||||
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||
service = EmbeddingService.create_openai_service(
|
||||
model="text-embedding-3-large",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "openai"
|
||||
assert service.config.model == "text-embedding-3-large"
|
||||
assert service.config.api_key == "test-key"
|
||||
|
||||
def test_create_voyage_service(self):
|
||||
"""Test Voyage AI service creation."""
|
||||
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||
service = EmbeddingService.create_voyage_service(
|
||||
model="voyage-large-2",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "voyageai"
|
||||
assert service.config.model == "voyage-large-2"
|
||||
assert service.config.api_key == "test-key"
|
||||
|
||||
def test_create_cohere_service(self):
|
||||
"""Test Cohere service creation."""
|
||||
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||
service = EmbeddingService.create_cohere_service(
|
||||
model="embed-multilingual-v3.0",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "cohere"
|
||||
assert service.config.model == "embed-multilingual-v3.0"
|
||||
assert service.config.api_key == "test-key"
|
||||
|
||||
def test_create_gemini_service(self):
|
||||
"""Test Gemini service creation."""
|
||||
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||
service = EmbeddingService.create_gemini_service(
|
||||
model="models/embedding-001",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "google-generativeai"
|
||||
assert service.config.model == "models/embedding-001"
|
||||
assert service.config.api_key == "test-key"
|
||||
|
||||
|
||||
class TestProviderConfigurations:
|
||||
"""Test provider-specific configurations."""
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_openai_config(self, mock_build_embedder):
|
||||
"""Test OpenAI configuration mapping."""
|
||||
mock_build_embedder.return_value = Mock()
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key="test-key",
|
||||
extra_config={"dimensions": 1024}
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "openai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model_name"] == "text-embedding-3-small"
|
||||
assert call_args["config"]["dimensions"] == 1024
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_voyageai_config(self, mock_build_embedder):
|
||||
"""Test Voyage AI configuration mapping."""
|
||||
mock_build_embedder.return_value = Mock()
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="voyageai",
|
||||
model="voyage-2",
|
||||
api_key="test-key",
|
||||
timeout=60.0,
|
||||
max_retries=5,
|
||||
extra_config={"input_type": "document"}
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "voyageai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model"] == "voyage-2"
|
||||
assert call_args["config"]["timeout"] == 60.0
|
||||
assert call_args["config"]["max_retries"] == 5
|
||||
assert call_args["config"]["input_type"] == "document"
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_cohere_config(self, mock_build_embedder):
|
||||
"""Test Cohere configuration mapping."""
|
||||
mock_build_embedder.return_value = Mock()
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="cohere",
|
||||
model="embed-english-v3.0",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "cohere"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model_name"] == "embed-english-v3.0"
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_gemini_config(self, mock_build_embedder):
|
||||
"""Test Gemini configuration mapping."""
|
||||
mock_build_embedder.return_value = Mock()
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="google-generativeai",
|
||||
model="models/embedding-001",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "google-generativeai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model_name"] == "models/embedding-001"
|
||||
@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.0.0b1",
|
||||
"crewai-tools==1.0.0b2",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.0.0b1"
|
||||
__version__ = "1.0.0b2"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
@@ -53,3 +54,4 @@ class QdrantConfig(BaseRagConfig):
|
||||
embedding_function: QdrantEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
vectors_config: VectorParams | None = field(default=None)
|
||||
|
||||
@@ -4,8 +4,8 @@ import asyncio
|
||||
from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||
from qdrant_client import (
|
||||
AsyncQdrantClient, # type: ignore[import-not-found]
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.0.0b1"
|
||||
__version__ = "1.0.0b2"
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -1124,7 +1124,6 @@ dependencies = [
|
||||
{ name = "python-docx" },
|
||||
{ name = "pytube" },
|
||||
{ name = "requests" },
|
||||
{ name = "stagehand" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "youtube-transcript-api" },
|
||||
]
|
||||
@@ -1295,7 +1294,6 @@ requires-dist = [
|
||||
{ name = "spider-client", marker = "extra == 'spider-client'", specifier = ">=0.1.25" },
|
||||
{ name = "sqlalchemy", marker = "extra == 'singlestore'", specifier = ">=2.0.40" },
|
||||
{ name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0.35" },
|
||||
{ name = "stagehand", specifier = ">=0.4.1" },
|
||||
{ name = "stagehand", marker = "extra == 'stagehand'", specifier = ">=0.4.1" },
|
||||
{ name = "tavily-python", marker = "extra == 'tavily-python'", specifier = ">=0.5.4" },
|
||||
{ name = "tiktoken", specifier = ">=0.8.0" },
|
||||
|
||||
Reference in New Issue
Block a user