Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
4489baa149 fix: resolve lint and type-checker issues
- Fix RET504 lint error by removing unnecessary assignment before return
- Add proper type annotations for embedding_functions dictionary
- Import Callable and Any from typing to resolve mypy errors

Co-Authored-By: João <joao@crewai.com>
2025-09-23 15:47:29 +00:00
Devin AI
1442f3e4b6 fix: add Watson embedding support to factory
- Add Watson to EmbeddingProvider type definition
- Implement _create_watson_embedding_function in factory.py
- Add Watson to embedding_functions dictionary
- Add comprehensive tests for Watson embedding functionality
- Ensure proper error handling for missing IBM Watson dependencies

Fixes #3582

Co-Authored-By: João <joao@crewai.com>
2025-09-23 15:41:56 +00:00
105 changed files with 4195 additions and 7364 deletions

View File

@@ -27,7 +27,7 @@ Follow the steps below to get Crewing! 🚣‍♂️
<Step title="Navigate to your new crew project">
<CodeGroup>
```shell Terminal
cd latest_ai_development
cd latest-ai-development
```
</CodeGroup>
</Step>

View File

@@ -27,7 +27,7 @@ mode: "wide"
<Step title="새로운 crew 프로젝트로 이동하기">
<CodeGroup>
```shell Terminal
cd latest_ai_development
cd latest-ai-development
```
</CodeGroup>
</Step>

View File

@@ -27,7 +27,7 @@ Siga os passos abaixo para começar a tripular! 🚣‍♂️
<Step title="Navegue até o novo projeto da sua tripulação">
<CodeGroup>
```shell Terminal
cd latest_ai_development
cd latest-ai-development
```
</CodeGroup>
</Step>

View File

@@ -9,7 +9,7 @@ authors = [
]
dependencies = [
# Core Dependencies
"pydantic>=2.11.9",
"pydantic>=2.4.2",
"openai>=1.13.3",
"litellm==1.74.9",
"instructor>=1.3.3",
@@ -21,12 +21,13 @@ dependencies = [
"opentelemetry-sdk>=1.30.0",
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
# Data Handling
"chromadb~=1.1.0",
"chromadb>=0.5.23",
"tokenizers>=0.20.3",
"onnxruntime==1.22.0",
"openpyxl>=3.1.5",
"pyvis>=0.3.2",
# Authentication and Security
"python-dotenv>=1.1.1",
"python-dotenv>=1.0.0",
"pyjwt>=2.9.0",
# Configuration and Utils
"click>=8.1.7",
@@ -39,7 +40,6 @@ dependencies = [
"blinker>=1.9.0",
"json5>=0.10.0",
"portalocker==2.7.0",
"pydantic-settings>=2.10.1",
]
[project.urls]
@@ -48,9 +48,7 @@ Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools>=0.74.0",
]
tools = ["crewai-tools~=0.73.0"]
embeddings = [
"tiktoken~=0.8.0"
]
@@ -73,30 +71,24 @@ aisuite = [
qdrant = [
"qdrant-client[fastembed]>=1.14.3",
]
aws = [
"boto3>=1.40.38",
]
watson = [
"ibm-watsonx-ai>=1.3.39",
]
voyageai = [
"voyageai>=0.3.5",
]
[dependency-groups]
dev = [
"ruff>=0.13.1",
"mypy>=1.18.2",
[tool.uv]
dev-dependencies = [
"ruff>=0.12.11",
"mypy>=1.17.1",
"pre-commit>=4.3.0",
"bandit>=1.8.6",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-subprocess>=1.5.3",
"pytest-recording>=0.13.4",
"pytest-randomly>=4.0.1",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pytest-split>=0.10.0",
"pillow>=10.2.0",
"cairosvg>=2.7.1",
"pytest>=8.0.0",
"python-dotenv>=1.0.0",
"pytest-asyncio>=0.23.7",
"pytest-subprocess>=1.5.2",
"pytest-recording>=0.13.2",
"pytest-randomly>=3.16.0",
"pytest-timeout>=2.3.1",
"pytest-xdist>=3.6.1",
"pytest-split>=0.9.0",
"types-requests==2.32.*",
"types-pyyaml==6.0.*",
"types-regex==2024.11.6.*",

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "0.201.0"
__version__ = "0.193.2"
_telemetry_submitted = False

View File

@@ -1,10 +1,17 @@
import shutil
import subprocess
import time
from collections.abc import Callable, Sequence
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -12,31 +19,12 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
)
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.lite_agent import LiteAgent, LiteAgentOutput
from crewai.llm import BaseLLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security import Fingerprint
from crewai.task import Task
from crewai.tools import BaseTool
@@ -50,6 +38,24 @@ from crewai.utilities.agent_utils import (
)
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryRetrievalStartedEvent,
MemoryRetrievalCompletedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -81,36 +87,36 @@ class Agent(BaseAgent):
"""
_times_executed: int = PrivateAttr(default=0)
max_execution_time: int | None = Field(
max_execution_time: Optional[int] = Field(
default=None,
description="Maximum execution time for an agent to execute a task",
)
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
step_callback: Any | None = Field(
step_callback: Optional[Any] = Field(
default=None,
description="Callback to be executed after each step of the agent execution.",
)
use_system_prompt: bool | None = Field(
use_system_prompt: Optional[bool] = Field(
default=True,
description="Use system prompt for the agent.",
)
llm: str | InstanceOf[BaseLLM] | Any = Field(
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
system_template: str | None = Field(
system_template: Optional[str] = Field(
default=None, description="System format for the agent."
)
prompt_template: str | None = Field(
prompt_template: Optional[str] = Field(
default=None, description="Prompt format for the agent."
)
response_template: str | None = Field(
response_template: Optional[str] = Field(
default=None, description="Response format for the agent."
)
allow_code_execution: bool | None = Field(
allow_code_execution: Optional[bool] = Field(
default=False, description="Enable code execution for the agent."
)
respect_context_window: bool = Field(
@@ -141,31 +147,31 @@ class Agent(BaseAgent):
default=False,
description="Whether the agent should reflect and create a plan before executing a task.",
)
max_reasoning_attempts: int | None = Field(
max_reasoning_attempts: Optional[int] = Field(
default=None,
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
)
embedder: EmbedderConfig | None = Field(
embedder: Optional[Dict[str, Any]] = Field(
default=None,
description="Embedder configuration for the agent.",
)
agent_knowledge_context: str | None = Field(
agent_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the agent.",
)
crew_knowledge_context: str | None = Field(
crew_knowledge_context: Optional[str] = Field(
default=None,
description="Knowledge context for the crew.",
)
knowledge_search_query: str | None = Field(
knowledge_search_query: Optional[str] = Field(
default=None,
description="Knowledge search query for the agent dynamically generated by the agent.",
)
from_repository: str | None = Field(
from_repository: Optional[str] = Field(
default=None,
description="The Agent's role to be used from your repository.",
)
guardrail: Callable[[Any], tuple[bool, Any]] | str | None = Field(
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
@@ -174,7 +180,7 @@ class Agent(BaseAgent):
)
@model_validator(mode="before")
def validate_from_repository(cls, v): # noqa: N805
def validate_from_repository(cls, v):
if v is not None and (from_repository := v.get("from_repository")):
return load_agent_from_repository(from_repository) | v
return v
@@ -202,7 +208,7 @@ class Agent(BaseAgent):
self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler)
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
try:
if self.embedder is None and crew_embedder:
self.embedder = crew_embedder
@@ -218,7 +224,7 @@ class Agent(BaseAgent):
)
self.knowledge.add_sources()
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid Knowledge Configuration: {e!s}") from e
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
def _is_any_available_memory(self) -> bool:
"""Check if any memory is available."""
@@ -238,8 +244,8 @@ class Agent(BaseAgent):
def execute_task(
self,
task: Task,
context: str | None = None,
tools: list[BaseTool] | None = None,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
) -> str:
"""Execute a task with the agent.
@@ -272,9 +278,11 @@ class Agent(BaseAgent):
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
if hasattr(self, "_logger"):
self._logger.log("error", f"Error during reasoning process: {e!s}")
self._logger.log(
"error", f"Error during reasoning process: {str(e)}"
)
else:
print(f"Error during reasoning process: {e!s}")
print(f"Error during reasoning process: {str(e)}")
self._inject_date_to_task(task)
@@ -327,7 +335,7 @@ class Agent(BaseAgent):
agent=self,
task=task,
)
memory = contextual_memory.build_context_for_task(task, context) # type: ignore[arg-type]
memory = contextual_memory.build_context_for_task(task, context)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
@@ -517,14 +525,14 @@ class Agent(BaseAgent):
try:
return future.result(timeout=timeout)
except concurrent.futures.TimeoutError as e:
except concurrent.futures.TimeoutError:
future.cancel()
raise TimeoutError(
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
) from e
)
except Exception as e:
future.cancel()
raise RuntimeError(f"Task execution failed: {e!s}") from e
raise RuntimeError(f"Task execution failed: {str(e)}")
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
"""Execute a task without a timeout.
@@ -546,14 +554,14 @@ class Agent(BaseAgent):
)["output"]
def create_agent_executor(
self, tools: list[BaseTool] | None = None, task=None
self, tools: Optional[List[BaseTool]] = None, task=None
) -> None:
"""Create an agent executor for the agent.
Returns:
An instance of the CrewAgentExecutor class.
"""
raw_tools: list[BaseTool] = tools or self.tools or []
raw_tools: List[BaseTool] = tools or self.tools or []
parsed_tools = parse_tools(raw_tools)
prompt = Prompts(
@@ -579,7 +587,7 @@ class Agent(BaseAgent):
agent=self,
crew=self.crew,
tools=parsed_tools,
prompt=prompt, # type: ignore[arg-type]
prompt=prompt,
original_tools=raw_tools,
stop_words=stop_words,
max_iter=self.max_iter,
@@ -595,9 +603,10 @@ class Agent(BaseAgent):
callbacks=[TokenCalcHandler(self._token_process)],
)
def get_delegation_tools(self, agents: list[BaseAgent]):
def get_delegation_tools(self, agents: List[BaseAgent]):
agent_tools = AgentTools(agents=agents)
return agent_tools.tools()
tools = agent_tools.tools()
return tools
def get_multimodal_tools(self) -> Sequence[BaseTool]:
from crewai.tools.agent_tools.add_image_tool import AddImageTool
@@ -645,7 +654,7 @@ class Agent(BaseAgent):
)
return task_prompt
def _render_text_description(self, tools: list[Any]) -> str:
def _render_text_description(self, tools: List[Any]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:
@@ -655,13 +664,15 @@ class Agent(BaseAgent):
search: This tool is used for search
calculator: This tool is used for math
"""
return "\n".join(
description = "\n".join(
[
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
for tool in tools
]
)
return description
def _inject_date_to_task(self, task):
"""Inject the current date into the task description if inject_date is enabled."""
if self.inject_date:
@@ -685,13 +696,13 @@ class Agent(BaseAgent):
if not is_valid:
raise ValueError(f"Invalid date format: {self.date_format}")
current_date = datetime.now().strftime(self.date_format)
current_date: str = datetime.now().strftime(self.date_format)
task.description += f"\n\nCurrent Date: {current_date}"
except Exception as e:
if hasattr(self, "_logger"):
self._logger.log("warning", f"Failed to inject date: {e!s}")
self._logger.log("warning", f"Failed to inject date: {str(e)}")
else:
print(f"Warning: Failed to inject date: {e!s}")
print(f"Warning: Failed to inject date: {str(e)}")
def _validate_docker_installation(self) -> None:
"""Check if Docker is installed and running."""
@@ -702,15 +713,15 @@ class Agent(BaseAgent):
try:
subprocess.run(
["/usr/bin/docker", "info"],
["docker", "info"],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError as e:
except subprocess.CalledProcessError:
raise RuntimeError(
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
) from e
)
def __repr__(self):
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
@@ -785,8 +796,8 @@ class Agent(BaseAgent):
def kickoff(
self,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
) -> LiteAgentOutput:
"""
Execute the agent with the given messages using a LiteAgent instance.
@@ -825,8 +836,8 @@ class Agent(BaseAgent):
async def kickoff_async(
self,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages using a LiteAgent instance.

View File

@@ -22,7 +22,6 @@ from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.utilities import I18N, Logger, RPMController
@@ -360,5 +359,5 @@ class BaseAgent(ABC, BaseModel):
self._rpm_controller = rpm_controller
self.create_agent_executor()
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
pass

View File

@@ -166,13 +166,3 @@ class PlusAPI:
json=payload,
timeout=30,
)
def mark_trace_batch_as_failed(
self, trace_batch_id: str, error_message: str
) -> requests.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
json={"status": "failed", "failure_reason": error_message},
timeout=30,
)

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.201.0,<1.0.0"
"crewai[tools]>=0.193.2,<1.0.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.201.0,<1.0.0",
"crewai[tools]>=0.193.2,<1.0.0",
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.201.0"
"crewai[tools]>=0.193.2"
]
[tool.crewai]

View File

@@ -59,7 +59,6 @@ from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.process import Process
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.rag.types import SearchResult
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
@@ -169,7 +168,7 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="An Instance of the ExternalMemory to be used by the Crew",
)
embedder: EmbedderConfig | None = Field(
embedder: dict | None = Field(
default=None,
description="Configuration for the embedder to be used for the crew.",
)
@@ -623,8 +622,7 @@ class Crew(FlowTrackable, BaseModel):
training_data=training_data, agent_id=str(agent.id)
)
CrewTrainingHandler(filename).save_trained_data(
agent_id=str(agent.role),
trained_data=result.model_dump(), # type: ignore[arg-type]
agent_id=str(agent.role), trained_data=result.model_dump()
)
crewai_event_bus.emit(
@@ -1059,10 +1057,7 @@ class Crew(FlowTrackable, BaseModel):
def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
self._file_handler.log(
task_name=task.name, # type: ignore[arg-type]
task=task.description,
agent=role,
status="started",
task_name=task.name, task=task.description, agent=role, status="started"
)
def _update_manager_tools(
@@ -1091,7 +1086,7 @@ class Crew(FlowTrackable, BaseModel):
role = task.agent.role if task.agent is not None else "None"
if self.output_log_file:
self._file_handler.log(
task_name=task.name, # type: ignore[arg-type]
task_name=task.name,
task=task.description,
agent=role,
status="completed",

View File

@@ -200,9 +200,6 @@ class TraceBatchManager:
if self.event_buffer:
events_sent_to_backend_status = self._send_events_to_backend()
if events_sent_to_backend_status == 500:
self.plus_api.mark_trace_batch_as_failed(
self.trace_batch_id, "Error sending events to backend"
)
return None
self._finalize_backend_batch()
@@ -276,13 +273,10 @@ class TraceBatchManager:
logger.error(
f"❌ Failed to finalize trace batch: {response.status_code} - {response.text}"
)
self.plus_api.mark_trace_batch_as_failed(
self.trace_batch_id, response.text
)
except Exception as e:
logger.error(f"❌ Error finalizing trace batch: {e}")
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
# TODO: send error to app marking as failed
def _cleanup_batch_data(self):
"""Clean up batch data after successful finalization to free memory"""

View File

@@ -1,10 +1,10 @@
import os
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.rag.types import SearchResult
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
@@ -16,20 +16,20 @@ class Knowledge(BaseModel):
Args:
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage | None = Field(default=None)
embedder: EmbedderConfig | None = None
embedder: dict[str, Any] | None = None
"""
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
embedder: EmbedderConfig | None = None
embedder: dict[str, Any] | None = None
collection_name: str | None = None
def __init__(
self,
collection_name: str,
sources: list[BaseKnowledgeSource],
embedder: EmbedderConfig | None = None,
embedder: dict[str, Any] | None = None,
storage: KnowledgeStorage | None = None,
**data,
):

View File

@@ -8,9 +8,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import ProviderSpec
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from crewai.utilities.logger import Logger
@@ -24,10 +22,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__(
self,
embedder: ProviderSpec
| BaseEmbeddingsProvider
| type[BaseEmbeddingsProvider]
| None = None,
embedder: dict[str, Any] | None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
@@ -40,7 +35,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
)
if embedder:
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
embedding_function = get_embedding_function(embedder)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function

View File

@@ -27,10 +27,7 @@ class EntityMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -38,11 +35,7 @@ class EntityMemory(Memory):
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
) from e
config = (
embedder_config.get("config")
if embedder_config and isinstance(embedder_config, dict)
else None
)
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
storage = (

View File

@@ -13,7 +13,6 @@ from crewai.events.types.memory_events import (
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.interface import Storage
from crewai.rag.embeddings.types import ProviderSpec
if TYPE_CHECKING:
from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -36,9 +35,7 @@ class ExternalMemory(Memory):
}
@staticmethod
def create_storage(
crew: Any, embedder_config: dict[str, Any] | ProviderSpec | None
) -> Storage:
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
if not embedder_config:
raise ValueError("embedder_config is required")
@@ -162,6 +159,6 @@ class ExternalMemory(Memory):
super().set_crew(crew)
if not self.storage:
self.storage = self.create_storage(crew, self.embedder_config) # type: ignore[arg-type]
self.storage = self.create_storage(crew, self.embedder_config)
return self

View File

@@ -2,8 +2,6 @@ from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel
from crewai.rag.embeddings.types import EmbedderConfig
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
@@ -14,7 +12,7 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata.
"""
embedder_config: EmbedderConfig | dict[str, Any] | None = None
embedder_config: dict[str, Any] | None = None
crew: Any | None = None
storage: Any

View File

@@ -29,10 +29,7 @@ class ShortTermMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
memory_provider = embedder_config.get("provider") if embedder_config else None
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -40,11 +37,7 @@ class ShortTermMemory(Memory):
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
) from e
config = (
embedder_config.get("config")
if embedder_config and isinstance(embedder_config, dict)
else None
)
config = embedder_config.get("config") if embedder_config else None
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
storage = (

View File

@@ -7,9 +7,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import ProviderSpec
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.types import BaseRecord
@@ -27,7 +25,7 @@ class RAGStorage(BaseRAGStorage):
self,
type: str,
allow_reset: bool = True,
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
embedder_config: dict[str, Any] | None = None,
crew: Any = None,
path: str | None = None,
) -> None:
@@ -51,46 +49,12 @@ class RAGStorage(BaseRAGStorage):
)
if self.embedder_config:
embedding_function = build_embedder(self.embedder_config)
try:
_ = embedding_function(["test"])
except Exception as e:
provider = (
self.embedder_config["provider"]
if isinstance(self.embedder_config, dict)
else self.embedder_config.__class__.__name__.replace(
"Provider", ""
).lower()
)
raise ValueError(
f"Failed to initialize embedder. Please check your configuration or connection.\n"
f"Provider: {provider}\n"
f"Error: {e}"
) from e
batch_size = None
if (
isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=cast(int, batch_size),
)
else:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
embedding_function = get_embedding_function(self.embedder_config)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
@@ -131,26 +95,7 @@ class RAGStorage(BaseRAGStorage):
if metadata:
document["metadata"] = metadata
batch_size = None
if (
self.embedder_config
and isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
client.add_documents(
collection_name=collection_name,
documents=[document],
batch_size=cast(int, batch_size),
)
else:
client.add_documents(
collection_name=collection_name, documents=[document]
)
client.add_documents(collection_name=collection_name, documents=[document])
except Exception as e:
logging.error(
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"

View File

@@ -17,7 +17,6 @@ from crewai.rag.chromadb.types import (
ChromaDBCollectionSearchParams,
)
from crewai.rag.chromadb.utils import (
_create_batch_slice,
_extract_search_params,
_is_async_client,
_is_sync_client,
@@ -53,7 +52,6 @@ class ChromaDBClient(BaseClient):
embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
default_batch_size: int = 100,
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
@@ -62,13 +60,11 @@ class ChromaDBClient(BaseClient):
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -295,7 +291,6 @@ class ChromaDBClient(BaseClient):
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
batch_size: Optional batch size for processing documents (default: 100)
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
@@ -310,7 +305,6 @@ class ChromaDBClient(BaseClient):
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
@@ -321,17 +315,13 @@ class ChromaDBClient(BaseClient):
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas,
)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=metadatas,
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
@@ -345,7 +335,6 @@ class ChromaDBClient(BaseClient):
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
batch_size: Optional batch size for processing documents (default: 100)
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
@@ -360,7 +349,6 @@ class ChromaDBClient(BaseClient):
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
@@ -370,17 +358,13 @@ class ChromaDBClient(BaseClient):
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
await collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas,
)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
await collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=metadatas,
)
def search(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]

View File

@@ -41,5 +41,4 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
)

View File

@@ -1,7 +1,6 @@
"""Utility functions for ChromaDB client implementation."""
import hashlib
import json
from collections.abc import Mapping
from typing import Literal, TypeGuard, cast
@@ -10,6 +9,7 @@ from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Include,
IncludeEnum,
QueryResult,
)
@@ -72,15 +72,7 @@ def _prepare_documents_for_chromadb(
if "doc_id" in doc:
ids.append(doc["doc_id"])
else:
content_for_hash = doc["content"]
metadata = doc.get("metadata")
if metadata:
metadata_str = json.dumps(metadata, sort_keys=True)
content_for_hash = f"{content_for_hash}|{metadata_str}"
content_hash = hashlib.blake2b(
content_for_hash.encode(), digest_size=32
).hexdigest()
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
ids.append(content_hash)
texts.append(doc["content"])
@@ -96,32 +88,6 @@ def _prepare_documents_for_chromadb(
return PreparedDocuments(ids, texts, metadatas)
def _create_batch_slice(
prepared: PreparedDocuments, start_index: int, batch_size: int
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
"""Create a batch slice from prepared documents.
Args:
prepared: PreparedDocuments containing ids, texts, and metadatas.
start_index: Starting index for the batch.
batch_size: Size of the batch.
Returns:
Tuple of (batch_ids, batch_texts, batch_metadatas).
"""
batch_end = min(start_index + batch_size, len(prepared.ids))
batch_ids = prepared.ids[start_index:batch_end]
batch_texts = prepared.texts[start_index:batch_end]
batch_metadatas = (
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
)
if batch_metadatas and not any(m for m in batch_metadatas):
batch_metadatas = None
return batch_ids, batch_texts, batch_metadatas
def _extract_search_params(
kwargs: ChromaDBCollectionSearchParams,
) -> ExtractedSearchParams:
@@ -141,12 +107,9 @@ def _extract_search_params(
score_threshold=kwargs.get("score_threshold"),
where=kwargs.get("where"),
where_document=kwargs.get("where_document"),
include=cast(
Include,
kwargs.get(
"include",
["metadatas", "documents", "distances"],
),
include=kwargs.get(
"include",
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
),
)
@@ -195,7 +158,7 @@ def _convert_chromadb_results_to_search_results(
"""
search_results: list[SearchResult] = []
include_strings = list(include) if include else []
include_strings = [item.value for item in include] if include else []
ids = results["ids"][0] if results.get("ids") else []

View File

@@ -16,4 +16,3 @@ class BaseRagConfig:
embedding_function: Any | None = field(default=None)
limit: int = field(default=5)
score_threshold: float = field(default=0.6)
batch_size: int = field(default=100)

View File

@@ -29,7 +29,7 @@ class BaseCollectionParams(TypedDict):
]
class BaseCollectionAddParams(BaseCollectionParams, total=False):
class BaseCollectionAddParams(BaseCollectionParams):
"""Parameters for adding documents to a collection.
Extends BaseCollectionParams with document-specific fields.
@@ -37,11 +37,9 @@ class BaseCollectionAddParams(BaseCollectionParams, total=False):
Attributes:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dictionaries containing document data.
batch_size: Optional batch size for processing documents to avoid token limits.
"""
documents: Required[list[BaseRecord]]
batch_size: int
documents: list[BaseRecord]
class BaseCollectionSearchParams(BaseCollectionParams, total=False):

View File

@@ -1,142 +0,0 @@
"""Base embeddings callable utilities for RAG systems."""
from typing import Protocol, TypeVar, runtime_checkable
import numpy as np
from crewai.rag.core.types import (
Embeddable,
Embedding,
Embeddings,
PyEmbedding,
)
T = TypeVar("T")
D = TypeVar("D", bound=Embeddable, contravariant=True)
def normalize_embeddings(
target: Embedding | list[Embedding] | PyEmbedding | list[PyEmbedding],
) -> Embeddings | None:
"""Normalize various embedding formats to a standard list of numpy arrays.
Args:
target: Input embeddings in various formats (list of floats, list of lists,
numpy array, or list of numpy arrays).
Returns:
Normalized embeddings as a list of numpy arrays, or None if input is None.
Raises:
ValueError: If embeddings are empty or in an unsupported format.
"""
if isinstance(target, np.ndarray):
if target.ndim == 1:
return [target.astype(np.float32)]
if target.ndim == 2:
return [row.astype(np.float32) for row in target]
raise ValueError(f"Unsupported numpy array shape: {target.shape}")
first = target[0]
if isinstance(first, (int, float)) and not isinstance(first, bool):
return [np.array(target, dtype=np.float32)]
if isinstance(first, list):
return [np.array(emb, dtype=np.float32) for emb in target]
if isinstance(first, np.ndarray):
return [emb.astype(np.float32) for emb in target] # type: ignore[union-attr]
raise ValueError(f"Unsupported embeddings format: {type(first)}")
def maybe_cast_one_to_many(target: T | list[T] | None) -> list[T] | None:
"""Cast a single item to a list if needed.
Args:
target: A single item or list of items.
Returns:
A list of items or None if input is None.
"""
if target is None:
return None
return target if isinstance(target, list) else [target]
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
"""Validate embeddings format and content.
Args:
embeddings: List of numpy arrays to validate.
Returns:
Validated embeddings.
Raises:
ValueError: If embeddings format or content is invalid.
"""
if not isinstance(embeddings, list):
raise ValueError(
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
)
if len(embeddings) == 0:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
)
if not all(isinstance(e, np.ndarray) for e in embeddings):
raise ValueError(
"Expected each embedding in the embeddings to be a numpy array"
)
for i, embedding in enumerate(embeddings):
if embedding.ndim == 0:
raise ValueError(
f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}"
)
if embedding.size == 0:
raise ValueError(
f"Expected each embedding to be a 1-dimensional numpy array with at least 1 value. "
f"Got an array with no values at position {i}"
)
if not all(
isinstance(value, (np.integer, float, np.floating))
and not isinstance(value, bool)
for value in embedding
):
raise ValueError(
f"Expected embedding to contain numeric values, got non-numeric values at position {i}"
)
return embeddings
@runtime_checkable
class EmbeddingFunction(Protocol[D]):
"""Protocol for embedding functions.
Embedding functions convert input data (documents or images) into vector embeddings.
"""
def __call__(self, input: D) -> Embeddings:
"""Convert input data to embeddings.
Args:
input: Input data to embed (documents or images).
Returns:
List of numpy arrays representing the embeddings.
"""
...
def __init_subclass__(cls) -> None:
"""Wrap __call__ method to normalize and validate embeddings."""
super().__init_subclass__()
original_call = cls.__call__
def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
result = original_call(self, input)
if result is None:
raise ValueError("Embedding function returned None")
normalized = normalize_embeddings(result)
if normalized is None:
raise ValueError("Normalization returned None for non-None input")
return validate_embeddings(normalized)
cls.__call__ = wrapped_call # type: ignore[method-assign]

View File

@@ -1,23 +0,0 @@
"""Base class for embedding providers."""
from typing import Generic, TypeVar
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
T = TypeVar("T", bound=EmbeddingFunction)
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
"""Abstract base class for embedding providers.
This class provides a common interface for dynamically loading and building
embedding functions from various providers.
"""
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
embedding_callable: type[T] = Field(
..., description="The embedding function class to use"
)

View File

@@ -1,28 +0,0 @@
"""Core type definitions for RAG systems."""
from collections.abc import Sequence
from typing import TypeVar
import numpy as np
from numpy import floating, integer, number
from numpy.typing import NDArray
T = TypeVar("T")
PyEmbedding = Sequence[float] | Sequence[int]
PyEmbeddings = list[PyEmbedding]
Embedding = NDArray[np.int32 | np.float32]
Embeddings = list[Embedding]
Documents = list[str]
Images = list[np.ndarray]
Embeddable = Documents | Images
ScalarType = TypeVar("ScalarType", bound=np.generic)
IntegerType = TypeVar("IntegerType", bound=integer)
FloatingType = TypeVar("FloatingType", bound=floating)
NumberType = TypeVar("NumberType", bound=number)
DType32 = TypeVar("DType32", np.int32, np.float32)
DType64 = TypeVar("DType64", np.int64, np.float64)
DTypeCommon = TypeVar("DTypeCommon", np.int32, np.int64, np.float32, np.float64)

View File

@@ -0,0 +1,245 @@
import os
from typing import Any, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
class EmbeddingConfigurator:
def __init__(self):
self.embedding_functions = {
"openai": self._configure_openai,
"azure": self._configure_azure,
"ollama": self._configure_ollama,
"vertexai": self._configure_vertexai,
"google": self._configure_google,
"cohere": self._configure_cohere,
"voyageai": self._configure_voyageai,
"bedrock": self._configure_bedrock,
"huggingface": self._configure_huggingface,
"watson": self._configure_watson,
"custom": self._configure_custom,
}
def configure_embedder(
self,
embedder_config: dict[str, Any] | None = None,
) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config."""
if embedder_config is None:
return self._create_default_embedding_function()
provider = embedder_config.get("provider")
config = embedder_config.get("config", {})
model_name = config.get("model") if provider != "custom" else None
if provider not in self.embedding_functions:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
)
try:
embedding_function = self.embedding_functions[provider]
except ImportError as e:
missing_package = str(e).split()[-1]
raise ImportError(
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
) from e
return (
embedding_function(config)
if provider == "custom"
else embedding_function(config, model_name)
)
@staticmethod
def _create_default_embedding_function():
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
@staticmethod
def _configure_openai(config, model_name):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
api_base=config.get("api_base", None),
api_type=config.get("api_type", None),
api_version=config.get("api_version", None),
default_headers=config.get("default_headers", None),
dimensions=config.get("dimensions", None),
deployment_id=config.get("deployment_id", None),
organization_id=config.get("organization_id", None),
)
@staticmethod
def _configure_azure(config, model_name):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
default_headers=config.get("default_headers"),
dimensions=config.get("dimensions"),
deployment_id=config.get("deployment_id"),
organization_id=config.get("organization_id"),
)
@staticmethod
def _configure_ollama(config, model_name):
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
return OllamaEmbeddingFunction(
url=config.get("url", "http://localhost:11434/api/embeddings"),
model_name=model_name,
)
@staticmethod
def _configure_vertexai(config, model_name):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
return GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
)
@staticmethod
def _configure_google(config, model_name):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
task_type=config.get("task_type"),
)
@staticmethod
def _configure_cohere(config, model_name):
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
return CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
@staticmethod
def _configure_voyageai(config, model_name):
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
VoyageAIEmbeddingFunction,
)
return VoyageAIEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
@staticmethod
def _configure_bedrock(config, model_name):
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
# Allow custom model_name override with backwards compatibility
kwargs = {"session": config.get("session")}
if model_name is not None:
kwargs["model_name"] = model_name
return AmazonBedrockEmbeddingFunction(**kwargs)
@staticmethod
def _configure_huggingface(config, model_name):
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
return HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
@staticmethod
def _configure_watson(config, model_name):
try:
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=config.get("api_key"), url=config.get("api_url")
),
project_id=config.get("project_id"),
)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print("Error during Watson embedding:", e)
raise e
return WatsonEmbeddingFunction()
@staticmethod
def _configure_custom(config):
custom_embedder = config.get("embedder")
if isinstance(custom_embedder, EmbeddingFunction):
try:
validate_embedding_function(custom_embedder)
return custom_embedder
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {e!s}") from e
elif callable(custom_embedder):
try:
instance = custom_embedder()
if isinstance(instance, EmbeddingFunction):
validate_embedding_function(instance)
return instance
raise ValueError(
"Custom embedder does not create an EmbeddingFunction instance"
)
except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
else:
raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
)

View File

@@ -1,363 +1,204 @@
"""Factory functions for creating embedding providers and functions."""
"""Minimal embedding function factory for CrewAI."""
from __future__ import annotations
import os
from typing import Any, Callable
from typing import TYPE_CHECKING, TypeVar, overload
from chromadb import EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
GooglePalmEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
)
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.utilities.import_utils import import_and_validate_definition
if TYPE_CHECKING:
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
)
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec,
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.huggingface.types import (
HuggingFaceProviderSpec,
)
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonEmbeddingFunction,
)
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
from crewai.rag.embeddings.providers.sentence_transformer.types import (
SentenceTransformerProviderSpec,
)
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
VoyageAIEmbeddingFunction,
)
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
T = TypeVar("T", bound=EmbeddingFunction)
from crewai.rag.embeddings.types import EmbeddingOptions
PROVIDER_PATHS = {
"azure": "crewai.rag.embeddings.providers.microsoft.azure.AzureProvider",
"amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider",
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
"jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider",
"ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider",
"onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider",
"openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider",
"openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider",
"roboflow": "crewai.rag.embeddings.providers.roboflow.roboflow_provider.RoboflowProvider",
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
"watson": "crewai.rag.embeddings.providers.ibm.watson.WatsonProvider",
}
def build_embedder_from_provider(provider: BaseEmbeddingsProvider[T]) -> T:
"""Build an embedding function instance from a provider.
Args:
provider: The embedding provider configuration.
Returns:
An instance of the specified embedding function type.
"""
return provider.embedding_callable(
**provider.model_dump(exclude={"embedding_callable"})
)
@overload
def build_embedder_from_dict(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: BedrockProviderSpec,
) -> AmazonBedrockEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: GenerativeAiProviderSpec,
) -> GoogleGenerativeAiEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: HuggingFaceProviderSpec,
) -> HuggingFaceEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: VertexAIProviderSpec,
) -> GoogleVertexEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: VoyageAIProviderSpec,
) -> VoyageAIEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: SentenceTransformerProviderSpec,
) -> SentenceTransformerEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: InstructorProviderSpec,
) -> InstructorEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: RoboflowProviderSpec,
) -> RoboflowEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: OpenCLIPProviderSpec,
) -> OpenCLIPEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: Text2VecProviderSpec,
) -> Text2VecEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
def build_embedder_from_dict(spec):
"""Build an embedding function instance from a dictionary specification.
Args:
spec: A dictionary with 'provider' and 'config' keys.
Example: {
"provider": "openai",
"config": {
"api_key": "sk-...",
"model_name": "text-embedding-3-small"
}
}
Returns:
An instance of the appropriate embedding function.
Raises:
ValueError: If the provider is not recognized.
"""
provider_name = spec["provider"]
if not provider_name:
raise ValueError("Missing 'provider' key in specification")
if provider_name not in PROVIDER_PATHS:
raise ValueError(
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
)
provider_path = PROVIDER_PATHS[provider_name]
def _create_watson_embedding_function(**config_dict) -> EmbeddingFunction:
"""Create Watson embedding function with proper error handling."""
try:
provider_class = import_and_validate_definition(provider_path)
except (ImportError, AttributeError, ValueError) as e:
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
provider_config = spec.get("config", {})
class WatsonEmbeddingFunction(EmbeddingFunction):
def __init__(self, **kwargs):
self.config = kwargs
if provider_name == "custom" and "embedding_callable" not in provider_config:
raise ValueError("Custom provider requires 'embedding_callable' in config")
def __call__(self, input):
if isinstance(input, str):
input = [input]
provider = provider_class(**provider_config)
return build_embedder_from_provider(provider)
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=self.config.get("model_name") or self.config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=self.config.get("api_key"),
url=self.config.get("api_url") or self.config.get("url")
),
project_id=self.config.get("project_id"),
)
try:
return embedding.embed_documents(input)
except Exception as e:
raise RuntimeError(f"Error during Watson embedding: {e}") from e
return WatsonEmbeddingFunction(**config_dict)
@overload
def build_embedder(spec: BaseEmbeddingsProvider[T]) -> T: ...
@overload
def build_embedder(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder(spec: BedrockProviderSpec) -> AmazonBedrockEmbeddingFunction: ...
@overload
def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
@overload
def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction: ...
@overload
def build_embedder(
spec: GenerativeAiProviderSpec,
) -> GoogleGenerativeAiEmbeddingFunction: ...
@overload
def build_embedder(spec: HuggingFaceProviderSpec) -> HuggingFaceEmbeddingFunction: ...
@overload
def build_embedder(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
@overload
def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
@overload
def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
@overload
def build_embedder(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
@overload
def build_embedder(
spec: SentenceTransformerProviderSpec,
) -> SentenceTransformerEmbeddingFunction: ...
@overload
def build_embedder(spec: InstructorProviderSpec) -> InstructorEmbeddingFunction: ...
@overload
def build_embedder(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
@overload
def build_embedder(spec: RoboflowProviderSpec) -> RoboflowEmbeddingFunction: ...
@overload
def build_embedder(spec: OpenCLIPProviderSpec) -> OpenCLIPEmbeddingFunction: ...
@overload
def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
@overload
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
def build_embedder(spec):
"""Build an embedding function from either a provider spec or a provider instance.
def get_embedding_function(
config: EmbeddingOptions | dict | None = None,
) -> EmbeddingFunction:
"""Get embedding function - delegates to ChromaDB.
Args:
spec: Either a provider specification dictionary or a provider instance.
config: Optional configuration - either an EmbeddingOptions object or a dict with:
- provider: The embedding provider to use (default: "openai")
- Any other provider-specific parameters
Returns:
An embedding function instance. If a typed provider is passed, returns
the specific embedding function type.
EmbeddingFunction instance ready for use with ChromaDB
Supported providers:
- openai: OpenAI embeddings
- cohere: Cohere embeddings
- ollama: Ollama local embeddings
- huggingface: HuggingFace embeddings
- sentence-transformer: Local sentence transformers
- instructor: Instructor embeddings for specialized tasks
- google-palm: Google PaLM embeddings
- google-generativeai: Google Generative AI embeddings
- google-vertex: Google Vertex AI embeddings
- amazon-bedrock: AWS Bedrock embeddings
- jina: Jina AI embeddings
- roboflow: Roboflow embeddings for vision tasks
- openclip: OpenCLIP embeddings for multimodal tasks
- text2vec: Text2Vec embeddings
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
- watson: IBM Watson embeddings
Examples:
# From dictionary specification
embedder = build_embedder({
"provider": "openai",
"config": {"api_key": "sk-..."}
})
# Use default OpenAI embedding
>>> embedder = get_embedding_function()
# From provider instance
provider = OpenAIProvider(api_key="sk-...")
embedder = build_embedder(provider)
# Use Cohere with dict
>>> embedder = get_embedding_function({
... "provider": "cohere",
... "api_key": "your-key",
... "model_name": "embed-english-v3.0"
... })
# Use with EmbeddingOptions
>>> embedder = get_embedding_function(
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
... )
# Use local sentence transformers (no API key needed)
>>> embedder = get_embedding_function({
... "provider": "sentence-transformer",
... "model_name": "all-MiniLM-L6-v2"
... })
# Use Ollama for local embeddings
>>> embedder = get_embedding_function({
... "provider": "ollama",
... "model_name": "nomic-embed-text"
... })
# Use ONNX (no API key needed)
>>> embedder = get_embedding_function({
... "provider": "onnx"
... })
# Use Watson embeddings
>>> embedder = get_embedding_function({
... "provider": "watson",
... "api_key": "your-watson-api-key",
... "api_url": "your-watson-url",
... "project_id": "your-project-id",
... "model_name": "ibm/slate-125m-english-rtrvr"
... })
"""
if isinstance(spec, BaseEmbeddingsProvider):
return build_embedder_from_provider(spec)
return build_embedder_from_dict(spec)
if config is None:
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
# Handle EmbeddingOptions object
if isinstance(config, EmbeddingOptions):
config_dict = config.model_dump(exclude_none=True)
else:
config_dict = config.copy()
# Backward compatibility alias
get_embedding_function = build_embedder
provider = config_dict.pop("provider", "openai")
embedding_functions: dict[str, Callable[..., EmbeddingFunction]] = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
"watson": _create_watson_embedding_function,
}
if provider not in embedding_functions:
raise ValueError(
f"Unsupported provider: {provider}. "
f"Available providers: {list(embedding_functions.keys())}"
)
return embedding_functions[provider](**config_dict)

View File

@@ -1 +0,0 @@
"""Embedding provider implementations."""

View File

@@ -1,13 +0,0 @@
"""AWS embedding providers."""
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
from crewai.rag.embeddings.providers.aws.types import (
BedrockProviderConfig,
BedrockProviderSpec,
)
__all__ = [
"BedrockProvider",
"BedrockProviderConfig",
"BedrockProviderSpec",
]

View File

@@ -1,53 +0,0 @@
"""Amazon Bedrock embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
def create_aws_session() -> Any:
"""Create an AWS session for Bedrock.
Returns:
boto3.Session: AWS session object
Raises:
ImportError: If boto3 is not installed
ValueError: If AWS session creation fails
"""
try:
import boto3 # type: ignore[import]
return boto3.Session()
except ImportError as e:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. "
"Install it with: uv add boto3"
) from e
except Exception as e:
raise ValueError(
f"Failed to create AWS session for amazon-bedrock. "
f"Ensure AWS credentials are configured. Error: {e}"
) from e
class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
"""Amazon Bedrock embeddings provider."""
embedding_callable: type[AmazonBedrockEmbeddingFunction] = Field(
default=AmazonBedrockEmbeddingFunction,
description="Amazon Bedrock embedding function class",
)
model_name: str = Field(
default="amazon.titan-embed-text-v1",
description="Model name to use for embeddings",
validation_alias="BEDROCK_MODEL_NAME",
)
session: Any = Field(
default_factory=create_aws_session, description="AWS session object"
)

View File

@@ -1,19 +0,0 @@
"""Type definitions for AWS embedding providers."""
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class BedrockProviderConfig(TypedDict, total=False):
"""Configuration for Bedrock provider."""
model_name: Annotated[str, "amazon.titan-embed-text-v1"]
session: Any
class BedrockProviderSpec(TypedDict, total=False):
"""Bedrock provider specification."""
provider: Required[Literal["amazon-bedrock"]]
config: BedrockProviderConfig

View File

@@ -1,13 +0,0 @@
"""Cohere embedding providers."""
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
from crewai.rag.embeddings.providers.cohere.types import (
CohereProviderConfig,
CohereProviderSpec,
)
__all__ = [
"CohereProvider",
"CohereProviderConfig",
"CohereProviderSpec",
]

View File

@@ -1,24 +0,0 @@
"""Cohere embeddings provider."""
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
"""Cohere embeddings provider."""
embedding_callable: type[CohereEmbeddingFunction] = Field(
default=CohereEmbeddingFunction, description="Cohere embedding function class"
)
api_key: str = Field(
description="Cohere API key", validation_alias="COHERE_API_KEY"
)
model_name: str = Field(
default="large",
description="Model name to use for embeddings",
validation_alias="COHERE_MODEL_NAME",
)

View File

@@ -1,19 +0,0 @@
"""Type definitions for Cohere embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class CohereProviderConfig(TypedDict, total=False):
"""Configuration for Cohere provider."""
api_key: str
model_name: Annotated[str, "large"]
class CohereProviderSpec(TypedDict, total=False):
"""Cohere provider specification."""
provider: Required[Literal["cohere"]]
config: CohereProviderConfig

View File

@@ -1,13 +0,0 @@
"""Custom embedding providers."""
from crewai.rag.embeddings.providers.custom.custom_provider import CustomProvider
from crewai.rag.embeddings.providers.custom.types import (
CustomProviderConfig,
CustomProviderSpec,
)
__all__ = [
"CustomProvider",
"CustomProviderConfig",
"CustomProviderSpec",
]

View File

@@ -1,19 +0,0 @@
"""Custom embeddings provider for user-defined embedding functions."""
from pydantic import Field
from pydantic_settings import SettingsConfigDict
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.custom.embedding_callable import (
CustomEmbeddingFunction,
)
class CustomProvider(BaseEmbeddingsProvider[CustomEmbeddingFunction]):
"""Custom embeddings provider for user-defined embedding functions."""
embedding_callable: type[CustomEmbeddingFunction] = Field(
..., description="Custom embedding function class"
)
model_config = SettingsConfigDict(extra="allow")

View File

@@ -1,22 +0,0 @@
"""Custom embedding function base implementation."""
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
class CustomEmbeddingFunction(EmbeddingFunction[Documents]):
"""Base class for custom embedding functions.
This provides a concrete implementation that can be subclassed for custom embeddings.
"""
def __call__(self, input: Documents) -> Embeddings:
"""Convert input documents to embeddings.
Args:
input: List of documents to embed.
Returns:
List of numpy arrays representing the embeddings.
"""
raise NotImplementedError("Subclasses must implement __call__ method")

View File

@@ -1,19 +0,0 @@
"""Type definitions for custom embedding providers."""
from typing import Literal
from chromadb.api.types import EmbeddingFunction
from typing_extensions import Required, TypedDict
class CustomProviderConfig(TypedDict, total=False):
"""Configuration for Custom provider."""
embedding_callable: type[EmbeddingFunction]
class CustomProviderSpec(TypedDict, total=False):
"""Custom provider specification."""
provider: Required[Literal["custom"]]
config: CustomProviderConfig

View File

@@ -1,23 +0,0 @@
"""Google embedding providers."""
from crewai.rag.embeddings.providers.google.generative_ai import (
GenerativeAiProvider,
)
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderConfig,
GenerativeAiProviderSpec,
VertexAIProviderConfig,
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.google.vertex import (
VertexAIProvider,
)
__all__ = [
"GenerativeAiProvider",
"GenerativeAiProviderConfig",
"GenerativeAiProviderSpec",
"VertexAIProvider",
"VertexAIProviderConfig",
"VertexAIProviderSpec",
]

View File

@@ -1,30 +0,0 @@
"""Google Generative AI embeddings provider."""
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFunction]):
"""Google Generative AI embeddings provider."""
embedding_callable: type[GoogleGenerativeAiEmbeddingFunction] = Field(
default=GoogleGenerativeAiEmbeddingFunction,
description="Google Generative AI embedding function class",
)
model_name: str = Field(
default="models/embedding-001",
description="Model name to use for embeddings",
validation_alias="GOOGLE_GENERATIVE_AI_MODEL_NAME",
)
api_key: str = Field(
description="Google API key", validation_alias="GOOGLE_API_KEY"
)
task_type: str = Field(
default="RETRIEVAL_DOCUMENT",
description="Task type for embeddings",
validation_alias="GOOGLE_GENERATIVE_AI_TASK_TYPE",
)

View File

@@ -1,36 +0,0 @@
"""Type definitions for Google embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class GenerativeAiProviderConfig(TypedDict, total=False):
"""Configuration for Google Generative AI provider."""
api_key: str
model_name: Annotated[str, "models/embedding-001"]
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
class GenerativeAiProviderSpec(TypedDict):
"""Google Generative AI provider specification."""
provider: Literal["google-generativeai"]
config: GenerativeAiProviderConfig
class VertexAIProviderConfig(TypedDict, total=False):
"""Configuration for Vertex AI provider."""
api_key: str
model_name: Annotated[str, "textembedding-gecko"]
project_id: Annotated[str, "cloud-large-language-models"]
region: Annotated[str, "us-central1"]
class VertexAIProviderSpec(TypedDict, total=False):
"""Vertex AI provider specification."""
provider: Required[Literal["google-vertex"]]
config: VertexAIProviderConfig

View File

@@ -1,35 +0,0 @@
"""Google Vertex AI embeddings provider."""
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
"""Google Vertex AI embeddings provider."""
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
default=GoogleVertexEmbeddingFunction,
description="Vertex AI embedding function class",
)
model_name: str = Field(
default="textembedding-gecko",
description="Model name to use for embeddings",
validation_alias="GOOGLE_VERTEX_MODEL_NAME",
)
api_key: str = Field(
description="Google API key", validation_alias="GOOGLE_CLOUD_API_KEY"
)
project_id: str = Field(
default="cloud-large-language-models",
description="GCP project ID",
validation_alias="GOOGLE_CLOUD_PROJECT",
)
region: str = Field(
default="us-central1",
description="GCP region",
validation_alias="GOOGLE_CLOUD_REGION",
)

View File

@@ -1,15 +0,0 @@
"""HuggingFace embedding providers."""
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
HuggingFaceProvider,
)
from crewai.rag.embeddings.providers.huggingface.types import (
HuggingFaceProviderConfig,
HuggingFaceProviderSpec,
)
__all__ = [
"HuggingFaceProvider",
"HuggingFaceProviderConfig",
"HuggingFaceProviderSpec",
]

View File

@@ -1,20 +0,0 @@
"""HuggingFace embeddings provider."""
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
"""HuggingFace embeddings provider."""
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
default=HuggingFaceEmbeddingServer,
description="HuggingFace embedding function class",
)
url: str = Field(
description="HuggingFace API URL", validation_alias="HUGGINGFACE_URL"
)

View File

@@ -1,18 +0,0 @@
"""Type definitions for HuggingFace embedding providers."""
from typing import Literal
from typing_extensions import Required, TypedDict
class HuggingFaceProviderConfig(TypedDict, total=False):
"""Configuration for HuggingFace provider."""
url: str
class HuggingFaceProviderSpec(TypedDict, total=False):
"""HuggingFace provider specification."""
provider: Required[Literal["huggingface"]]
config: HuggingFaceProviderConfig

View File

@@ -1,15 +0,0 @@
"""IBM embedding providers."""
from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderConfig,
WatsonProviderSpec,
)
from crewai.rag.embeddings.providers.ibm.watson import (
WatsonProvider,
)
__all__ = [
"WatsonProvider",
"WatsonProviderConfig",
"WatsonProviderSpec",
]

View File

@@ -1,159 +0,0 @@
"""IBM Watson embedding function implementation."""
from typing import cast
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderConfig
class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for IBM Watson models."""
def __init__(self, **kwargs: Unpack[WatsonProviderConfig]) -> None:
"""Initialize Watson embedding function.
Args:
**kwargs: Configuration parameters for Watson Embeddings and Credentials.
"""
self._config = kwargs
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
try:
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai import (
Credentials, # type: ignore[import-not-found, import-untyped]
)
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"ibm-watsonx-ai is required for watson embeddings. "
"Install it with: uv add ibm-watsonx-ai"
) from e
if isinstance(input, str):
input = [input]
embeddings_config: dict = {
"model_id": self._config["model_id"],
}
if "params" in self._config and self._config["params"] is not None:
embeddings_config["params"] = self._config["params"]
if "project_id" in self._config and self._config["project_id"] is not None:
embeddings_config["project_id"] = self._config["project_id"]
if "space_id" in self._config and self._config["space_id"] is not None:
embeddings_config["space_id"] = self._config["space_id"]
if "api_client" in self._config and self._config["api_client"] is not None:
embeddings_config["api_client"] = self._config["api_client"]
if "verify" in self._config and self._config["verify"] is not None:
embeddings_config["verify"] = self._config["verify"]
if "persistent_connection" in self._config:
embeddings_config["persistent_connection"] = self._config[
"persistent_connection"
]
if "batch_size" in self._config:
embeddings_config["batch_size"] = self._config["batch_size"]
if "concurrency_limit" in self._config:
embeddings_config["concurrency_limit"] = self._config["concurrency_limit"]
if "max_retries" in self._config and self._config["max_retries"] is not None:
embeddings_config["max_retries"] = self._config["max_retries"]
if "delay_time" in self._config and self._config["delay_time"] is not None:
embeddings_config["delay_time"] = self._config["delay_time"]
if (
"retry_status_codes" in self._config
and self._config["retry_status_codes"] is not None
):
embeddings_config["retry_status_codes"] = self._config["retry_status_codes"]
if "credentials" in self._config and self._config["credentials"] is not None:
embeddings_config["credentials"] = self._config["credentials"]
else:
cred_config: dict = {}
if "url" in self._config and self._config["url"] is not None:
cred_config["url"] = self._config["url"]
if "api_key" in self._config and self._config["api_key"] is not None:
cred_config["api_key"] = self._config["api_key"]
if "name" in self._config and self._config["name"] is not None:
cred_config["name"] = self._config["name"]
if (
"iam_serviceid_crn" in self._config
and self._config["iam_serviceid_crn"] is not None
):
cred_config["iam_serviceid_crn"] = self._config["iam_serviceid_crn"]
if (
"trusted_profile_id" in self._config
and self._config["trusted_profile_id"] is not None
):
cred_config["trusted_profile_id"] = self._config["trusted_profile_id"]
if "token" in self._config and self._config["token"] is not None:
cred_config["token"] = self._config["token"]
if (
"projects_token" in self._config
and self._config["projects_token"] is not None
):
cred_config["projects_token"] = self._config["projects_token"]
if "username" in self._config and self._config["username"] is not None:
cred_config["username"] = self._config["username"]
if "password" in self._config and self._config["password"] is not None:
cred_config["password"] = self._config["password"]
if (
"instance_id" in self._config
and self._config["instance_id"] is not None
):
cred_config["instance_id"] = self._config["instance_id"]
if "version" in self._config and self._config["version"] is not None:
cred_config["version"] = self._config["version"]
if (
"bedrock_url" in self._config
and self._config["bedrock_url"] is not None
):
cred_config["bedrock_url"] = self._config["bedrock_url"]
if (
"platform_url" in self._config
and self._config["platform_url"] is not None
):
cred_config["platform_url"] = self._config["platform_url"]
if "proxies" in self._config and self._config["proxies"] is not None:
cred_config["proxies"] = self._config["proxies"]
if (
"verify" not in embeddings_config
and "verify" in self._config
and self._config["verify"] is not None
):
cred_config["verify"] = self._config["verify"]
if cred_config:
embeddings_config["credentials"] = Credentials(**cred_config)
if "params" not in embeddings_config:
embeddings_config["params"] = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(**embeddings_config)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print(f"Error during Watson embedding: {e}")
raise
@staticmethod
def name() -> str:
"""Return the name identifier for this embedding function."""
return "watson"

View File

@@ -1,44 +0,0 @@
"""Type definitions for IBM Watson embedding providers."""
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class WatsonProviderConfig(TypedDict, total=False):
"""Configuration for Watson provider."""
model_id: str
url: str
params: dict[str, str | dict[str, str]]
credentials: Any
project_id: str
space_id: str
api_client: Any
verify: bool | str
persistent_connection: Annotated[bool, True]
batch_size: Annotated[int, 100]
concurrency_limit: Annotated[int, 10]
max_retries: int
delay_time: float
retry_status_codes: list[int]
api_key: str
name: str
iam_serviceid_crn: str
trusted_profile_id: str
token: str
projects_token: str
username: str
password: str
instance_id: str
version: str
bedrock_url: str
platform_url: str
proxies: dict
class WatsonProviderSpec(TypedDict, total=False):
"""Watson provider specification."""
provider: Required[Literal["watson"]]
config: WatsonProviderConfig

View File

@@ -1,122 +0,0 @@
"""IBM Watson embeddings provider."""
from typing import Any
from pydantic import Field, model_validator
from typing_extensions import Self
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonEmbeddingFunction,
)
class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
"""IBM Watson embeddings provider.
Note: Requires custom implementation as Watson uses a different interface.
"""
embedding_callable: type[WatsonEmbeddingFunction] = Field(
default=WatsonEmbeddingFunction, description="Watson embedding function class"
)
model_id: str = Field(
description="Watson model ID", validation_alias="WATSON_MODEL_ID"
)
params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters"
)
credentials: Any | None = Field(default=None, description="Watson credentials")
project_id: str | None = Field(
default=None,
description="Watson project ID",
validation_alias="WATSON_PROJECT_ID",
)
space_id: str | None = Field(
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
)
api_client: Any | None = Field(default=None, description="Watson API client")
verify: bool | str | None = Field(
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
)
persistent_connection: bool = Field(
default=True,
description="Use persistent connection",
validation_alias="WATSON_PERSISTENT_CONNECTION",
)
batch_size: int = Field(
default=100,
description="Batch size for processing",
validation_alias="WATSON_BATCH_SIZE",
)
concurrency_limit: int = Field(
default=10,
description="Concurrency limit",
validation_alias="WATSON_CONCURRENCY_LIMIT",
)
max_retries: int | None = Field(
default=None,
description="Maximum retries",
validation_alias="WATSON_MAX_RETRIES",
)
delay_time: float | None = Field(
default=None,
description="Delay time between retries",
validation_alias="WATSON_DELAY_TIME",
)
retry_status_codes: list[int] | None = Field(
default=None, description="HTTP status codes to retry on"
)
url: str = Field(description="Watson API URL", validation_alias="WATSON_URL")
api_key: str = Field(
description="Watson API key", validation_alias="WATSON_API_KEY"
)
name: str | None = Field(
default=None, description="Service name", validation_alias="WATSON_NAME"
)
iam_serviceid_crn: str | None = Field(
default=None,
description="IAM service ID CRN",
validation_alias="WATSON_IAM_SERVICEID_CRN",
)
trusted_profile_id: str | None = Field(
default=None,
description="Trusted profile ID",
validation_alias="WATSON_TRUSTED_PROFILE_ID",
)
token: str | None = Field(
default=None, description="Bearer token", validation_alias="WATSON_TOKEN"
)
projects_token: str | None = Field(
default=None,
description="Projects token",
validation_alias="WATSON_PROJECTS_TOKEN",
)
username: str | None = Field(
default=None, description="Username", validation_alias="WATSON_USERNAME"
)
password: str | None = Field(
default=None, description="Password", validation_alias="WATSON_PASSWORD"
)
instance_id: str | None = Field(
default=None,
description="Service instance ID",
validation_alias="WATSON_INSTANCE_ID",
)
version: str | None = Field(
default=None, description="API version", validation_alias="WATSON_VERSION"
)
bedrock_url: str | None = Field(
default=None, description="Bedrock URL", validation_alias="WATSON_BEDROCK_URL"
)
platform_url: str | None = Field(
default=None, description="Platform URL", validation_alias="WATSON_PLATFORM_URL"
)
proxies: dict | None = Field(default=None, description="Proxy configuration")
@model_validator(mode="after")
def validate_space_or_project(self) -> Self:
"""Validate that either space_id or project_id is provided."""
if not self.space_id and not self.project_id:
raise ValueError("One of 'space_id' or 'project_id' must be provided")
return self

View File

@@ -1,15 +0,0 @@
"""Instructor embedding providers."""
from crewai.rag.embeddings.providers.instructor.instructor_provider import (
InstructorProvider,
)
from crewai.rag.embeddings.providers.instructor.types import (
InstructorProviderConfig,
InstructorProviderSpec,
)
__all__ = [
"InstructorProvider",
"InstructorProviderConfig",
"InstructorProviderSpec",
]

View File

@@ -1,32 +0,0 @@
"""Instructor embeddings provider."""
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
"""Instructor embeddings provider."""
embedding_callable: type[InstructorEmbeddingFunction] = Field(
default=InstructorEmbeddingFunction,
description="Instructor embedding function class",
)
model_name: str = Field(
default="hkunlp/instructor-base",
description="Model name to use",
validation_alias="INSTRUCTOR_MODEL_NAME",
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="INSTRUCTOR_DEVICE",
)
instruction: str | None = Field(
default=None,
description="Instruction for embeddings",
validation_alias="INSTRUCTOR_INSTRUCTION",
)

View File

@@ -1,20 +0,0 @@
"""Type definitions for Instructor embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class InstructorProviderConfig(TypedDict, total=False):
"""Configuration for Instructor provider."""
model_name: Annotated[str, "hkunlp/instructor-base"]
device: Annotated[str, "cpu"]
instruction: str
class InstructorProviderSpec(TypedDict, total=False):
"""Instructor provider specification."""
provider: Required[Literal["instructor"]]
config: InstructorProviderConfig

View File

@@ -1,13 +0,0 @@
"""Jina embedding providers."""
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
from crewai.rag.embeddings.providers.jina.types import (
JinaProviderConfig,
JinaProviderSpec,
)
__all__ = [
"JinaProvider",
"JinaProviderConfig",
"JinaProviderSpec",
]

View File

@@ -1,22 +0,0 @@
"""Jina embeddings provider."""
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
"""Jina embeddings provider."""
embedding_callable: type[JinaEmbeddingFunction] = Field(
default=JinaEmbeddingFunction, description="Jina embedding function class"
)
api_key: str = Field(description="Jina API key", validation_alias="JINA_API_KEY")
model_name: str = Field(
default="jina-embeddings-v2-base-en",
description="Model name to use for embeddings",
validation_alias="JINA_MODEL_NAME",
)

View File

@@ -1,19 +0,0 @@
"""Type definitions for Jina embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class JinaProviderConfig(TypedDict, total=False):
"""Configuration for Jina provider."""
api_key: str
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
class JinaProviderSpec(TypedDict, total=False):
"""Jina provider specification."""
provider: Required[Literal["jina"]]
config: JinaProviderConfig

View File

@@ -1,15 +0,0 @@
"""Microsoft embedding providers."""
from crewai.rag.embeddings.providers.microsoft.azure import (
AzureProvider,
)
from crewai.rag.embeddings.providers.microsoft.types import (
AzureProviderConfig,
AzureProviderSpec,
)
__all__ = [
"AzureProvider",
"AzureProviderConfig",
"AzureProviderSpec",
]

View File

@@ -1,58 +0,0 @@
"""Azure OpenAI embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
"""Azure OpenAI embeddings provider."""
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
default=OpenAIEmbeddingFunction,
description="Azure OpenAI embedding function class",
)
api_key: str = Field(description="Azure API key", validation_alias="OPENAI_API_KEY")
api_base: str | None = Field(
default=None,
description="Azure endpoint URL",
validation_alias="OPENAI_API_BASE",
)
api_type: str = Field(
default="azure",
description="API type for Azure",
validation_alias="OPENAI_API_TYPE",
)
api_version: str | None = Field(
default=None,
description="Azure API version",
validation_alias="OPENAI_API_VERSION",
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="OPENAI_MODEL_NAME",
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
)
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="OPENAI_DIMENSIONS",
)
deployment_id: str | None = Field(
default=None,
description="Azure deployment ID",
validation_alias="OPENAI_DEPLOYMENT_ID",
)
organization_id: str | None = Field(
default=None,
description="Organization ID",
validation_alias="OPENAI_ORGANIZATION_ID",
)

View File

@@ -1,26 +0,0 @@
"""Type definitions for Microsoft Azure embedding providers."""
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class AzureProviderConfig(TypedDict, total=False):
"""Configuration for Azure provider."""
api_key: str
api_base: str
api_type: Annotated[str, "azure"]
api_version: str
model_name: Annotated[str, "text-embedding-ada-002"]
default_headers: dict[str, Any]
dimensions: int
deployment_id: str
organization_id: str
class AzureProviderSpec(TypedDict, total=False):
"""Azure provider specification."""
provider: Required[Literal["azure"]]
config: AzureProviderConfig

View File

@@ -1,15 +0,0 @@
"""Ollama embedding providers."""
from crewai.rag.embeddings.providers.ollama.ollama_provider import (
OllamaProvider,
)
from crewai.rag.embeddings.providers.ollama.types import (
OllamaProviderConfig,
OllamaProviderSpec,
)
__all__ = [
"OllamaProvider",
"OllamaProviderConfig",
"OllamaProviderSpec",
]

View File

@@ -1,25 +0,0 @@
"""Ollama embeddings provider."""
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
"""Ollama embeddings provider."""
embedding_callable: type[OllamaEmbeddingFunction] = Field(
default=OllamaEmbeddingFunction, description="Ollama embedding function class"
)
url: str = Field(
default="http://localhost:11434/api/embeddings",
description="Ollama API endpoint URL",
validation_alias="OLLAMA_URL",
)
model_name: str = Field(
description="Model name to use for embeddings",
validation_alias="OLLAMA_MODEL_NAME",
)

View File

@@ -1,19 +0,0 @@
"""Type definitions for Ollama embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class OllamaProviderConfig(TypedDict, total=False):
"""Configuration for Ollama provider."""
url: Annotated[str, "http://localhost:11434/api/embeddings"]
model_name: str
class OllamaProviderSpec(TypedDict, total=False):
"""Ollama provider specification."""
provider: Required[Literal["ollama"]]
config: OllamaProviderConfig

View File

@@ -1,13 +0,0 @@
"""ONNX embedding providers."""
from crewai.rag.embeddings.providers.onnx.onnx_provider import ONNXProvider
from crewai.rag.embeddings.providers.onnx.types import (
ONNXProviderConfig,
ONNXProviderSpec,
)
__all__ = [
"ONNXProvider",
"ONNXProviderConfig",
"ONNXProviderSpec",
]

View File

@@ -1,19 +0,0 @@
"""ONNX embeddings provider."""
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
"""ONNX embeddings provider."""
embedding_callable: type[ONNXMiniLM_L6_V2] = Field(
default=ONNXMiniLM_L6_V2, description="ONNX MiniLM embedding function class"
)
preferred_providers: list[str] | None = Field(
default=None,
description="Preferred ONNX execution providers",
validation_alias="ONNX_PREFERRED_PROVIDERS",
)

View File

@@ -1,18 +0,0 @@
"""Type definitions for ONNX embedding providers."""
from typing import Literal
from typing_extensions import Required, TypedDict
class ONNXProviderConfig(TypedDict, total=False):
"""Configuration for ONNX provider."""
preferred_providers: list[str]
class ONNXProviderSpec(TypedDict, total=False):
"""ONNX provider specification."""
provider: Required[Literal["onnx"]]
config: ONNXProviderConfig

View File

@@ -1,15 +0,0 @@
"""OpenAI embedding providers."""
from crewai.rag.embeddings.providers.openai.openai_provider import (
OpenAIProvider,
)
from crewai.rag.embeddings.providers.openai.types import (
OpenAIProviderConfig,
OpenAIProviderSpec,
)
__all__ = [
"OpenAIProvider",
"OpenAIProviderConfig",
"OpenAIProviderSpec",
]

View File

@@ -1,58 +0,0 @@
"""OpenAI embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
"""OpenAI embeddings provider."""
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
default=OpenAIEmbeddingFunction,
description="OpenAI embedding function class",
)
api_key: str | None = Field(
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="OPENAI_MODEL_NAME",
)
api_base: str | None = Field(
default=None,
description="Base URL for API requests",
validation_alias="OPENAI_API_BASE",
)
api_type: str | None = Field(
default=None,
description="API type (e.g., 'azure')",
validation_alias="OPENAI_API_TYPE",
)
api_version: str | None = Field(
default=None, description="API version", validation_alias="OPENAI_API_VERSION"
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
)
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="OPENAI_DIMENSIONS",
)
deployment_id: str | None = Field(
default=None,
description="Azure deployment ID",
validation_alias="OPENAI_DEPLOYMENT_ID",
)
organization_id: str | None = Field(
default=None,
description="OpenAI organization ID",
validation_alias="OPENAI_ORGANIZATION_ID",
)

View File

@@ -1,26 +0,0 @@
"""Type definitions for OpenAI embedding providers."""
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class OpenAIProviderConfig(TypedDict, total=False):
"""Configuration for OpenAI provider."""
api_key: str
model_name: Annotated[str, "text-embedding-ada-002"]
api_base: str
api_type: str
api_version: str
default_headers: dict[str, Any]
dimensions: int
deployment_id: str
organization_id: str
class OpenAIProviderSpec(TypedDict, total=False):
"""OpenAI provider specification."""
provider: Required[Literal["openai"]]
config: OpenAIProviderConfig

View File

@@ -1,15 +0,0 @@
"""OpenCLIP embedding providers."""
from crewai.rag.embeddings.providers.openclip.openclip_provider import (
OpenCLIPProvider,
)
from crewai.rag.embeddings.providers.openclip.types import (
OpenCLIPProviderConfig,
OpenCLIPProviderSpec,
)
__all__ = [
"OpenCLIPProvider",
"OpenCLIPProviderConfig",
"OpenCLIPProviderSpec",
]

View File

@@ -1,32 +0,0 @@
"""OpenCLIP embeddings provider."""
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
"""OpenCLIP embeddings provider."""
embedding_callable: type[OpenCLIPEmbeddingFunction] = Field(
default=OpenCLIPEmbeddingFunction,
description="OpenCLIP embedding function class",
)
model_name: str = Field(
default="ViT-B-32",
description="Model name to use",
validation_alias="OPENCLIP_MODEL_NAME",
)
checkpoint: str = Field(
default="laion2b_s34b_b79k",
description="Model checkpoint",
validation_alias="OPENCLIP_CHECKPOINT",
)
device: str | None = Field(
default="cpu",
description="Device to run model on",
validation_alias="OPENCLIP_DEVICE",
)

View File

@@ -1,20 +0,0 @@
"""Type definitions for OpenCLIP embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class OpenCLIPProviderConfig(TypedDict, total=False):
"""Configuration for OpenCLIP provider."""
model_name: Annotated[str, "ViT-B-32"]
checkpoint: Annotated[str, "laion2b_s34b_b79k"]
device: Annotated[str, "cpu"]
class OpenCLIPProviderSpec(TypedDict):
"""OpenCLIP provider specification."""
provider: Required[Literal["openclip"]]
config: OpenCLIPProviderConfig

View File

@@ -1,15 +0,0 @@
"""Roboflow embedding providers."""
from crewai.rag.embeddings.providers.roboflow.roboflow_provider import (
RoboflowProvider,
)
from crewai.rag.embeddings.providers.roboflow.types import (
RoboflowProviderConfig,
RoboflowProviderSpec,
)
__all__ = [
"RoboflowProvider",
"RoboflowProviderConfig",
"RoboflowProviderSpec",
]

View File

@@ -1,25 +0,0 @@
"""Roboflow embeddings provider."""
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
"""Roboflow embeddings provider."""
embedding_callable: type[RoboflowEmbeddingFunction] = Field(
default=RoboflowEmbeddingFunction,
description="Roboflow embedding function class",
)
api_key: str = Field(
default="", description="Roboflow API key", validation_alias="ROBOFLOW_API_KEY"
)
api_url: str = Field(
default="https://infer.roboflow.com",
description="Roboflow API URL",
validation_alias="ROBOFLOW_API_URL",
)

View File

@@ -1,19 +0,0 @@
"""Type definitions for Roboflow embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class RoboflowProviderConfig(TypedDict, total=False):
"""Configuration for Roboflow provider."""
api_key: Annotated[str, ""]
api_url: Annotated[str, "https://infer.roboflow.com"]
class RoboflowProviderSpec(TypedDict):
"""Roboflow provider specification."""
provider: Required[Literal["roboflow"]]
config: RoboflowProviderConfig

View File

@@ -1,15 +0,0 @@
"""SentenceTransformer embedding providers."""
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
SentenceTransformerProvider,
)
from crewai.rag.embeddings.providers.sentence_transformer.types import (
SentenceTransformerProviderConfig,
SentenceTransformerProviderSpec,
)
__all__ = [
"SentenceTransformerProvider",
"SentenceTransformerProviderConfig",
"SentenceTransformerProviderSpec",
]

View File

@@ -1,34 +0,0 @@
"""SentenceTransformer embeddings provider."""
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class SentenceTransformerProvider(
BaseEmbeddingsProvider[SentenceTransformerEmbeddingFunction]
):
"""SentenceTransformer embeddings provider."""
embedding_callable: type[SentenceTransformerEmbeddingFunction] = Field(
default=SentenceTransformerEmbeddingFunction,
description="SentenceTransformer embedding function class",
)
model_name: str = Field(
default="all-MiniLM-L6-v2",
description="Model name to use",
validation_alias="SENTENCE_TRANSFORMER_MODEL_NAME",
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="SENTENCE_TRANSFORMER_DEVICE",
)
normalize_embeddings: bool = Field(
default=False,
description="Whether to normalize embeddings",
validation_alias="SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
)

View File

@@ -1,20 +0,0 @@
"""Type definitions for SentenceTransformer embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class SentenceTransformerProviderConfig(TypedDict, total=False):
"""Configuration for SentenceTransformer provider."""
model_name: Annotated[str, "all-MiniLM-L6-v2"]
device: Annotated[str, "cpu"]
normalize_embeddings: Annotated[bool, False]
class SentenceTransformerProviderSpec(TypedDict):
"""SentenceTransformer provider specification."""
provider: Required[Literal["sentence-transformer"]]
config: SentenceTransformerProviderConfig

View File

@@ -1,15 +0,0 @@
"""Text2Vec embedding providers."""
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import (
Text2VecProvider,
)
from crewai.rag.embeddings.providers.text2vec.types import (
Text2VecProviderConfig,
Text2VecProviderSpec,
)
__all__ = [
"Text2VecProvider",
"Text2VecProviderConfig",
"Text2VecProviderSpec",
]

View File

@@ -1,22 +0,0 @@
"""Text2Vec embeddings provider."""
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
"""Text2Vec embeddings provider."""
embedding_callable: type[Text2VecEmbeddingFunction] = Field(
default=Text2VecEmbeddingFunction,
description="Text2Vec embedding function class",
)
model_name: str = Field(
default="shibing624/text2vec-base-chinese",
description="Model name to use",
validation_alias="TEXT2VEC_MODEL_NAME",
)

View File

@@ -1,18 +0,0 @@
"""Type definitions for Text2Vec embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class Text2VecProviderConfig(TypedDict, total=False):
"""Configuration for Text2Vec provider."""
model_name: Annotated[str, "shibing624/text2vec-base-chinese"]
class Text2VecProviderSpec(TypedDict):
"""Text2Vec provider specification."""
provider: Required[Literal["text2vec"]]
config: Text2VecProviderConfig

View File

@@ -1,15 +0,0 @@
"""VoyageAI embedding providers."""
from crewai.rag.embeddings.providers.voyageai.types import (
VoyageAIProviderConfig,
VoyageAIProviderSpec,
)
from crewai.rag.embeddings.providers.voyageai.voyageai_provider import (
VoyageAIProvider,
)
__all__ = [
"VoyageAIProvider",
"VoyageAIProviderConfig",
"VoyageAIProviderSpec",
]

View File

@@ -1,58 +0,0 @@
"""VoyageAI embedding function implementation."""
from typing import cast
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for VoyageAI models."""
def __init__(self, **kwargs: Unpack[VoyageAIProviderConfig]) -> None:
"""Initialize VoyageAI embedding function.
Args:
**kwargs: Configuration parameters for VoyageAI.
"""
try:
import voyageai # type: ignore[import-not-found]
except ImportError as e:
raise ImportError(
"voyageai is required for voyageai embeddings. "
"Install it with: uv add voyageai"
) from e
self._config = kwargs
self._client = voyageai.Client(
api_key=kwargs["api_key"],
max_retries=kwargs.get("max_retries", 0),
timeout=kwargs.get("timeout"),
)
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
result = self._client.embed(
texts=input,
model=self._config.get("model", "voyage-2"),
input_type=self._config.get("input_type"),
truncation=self._config.get("truncation", True),
output_dtype=self._config.get("output_dtype"),
output_dimension=self._config.get("output_dimension"),
)
return cast(Embeddings, result.embeddings)

View File

@@ -1,25 +0,0 @@
"""Type definitions for VoyageAI embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class VoyageAIProviderConfig(TypedDict, total=False):
"""Configuration for VoyageAI provider."""
api_key: str
model: Annotated[str, "voyage-2"]
input_type: str
truncation: Annotated[bool, True]
output_dtype: str
output_dimension: int
max_retries: Annotated[int, 0]
timeout: float
class VoyageAIProviderSpec(TypedDict):
"""VoyageAI provider specification."""
provider: Required[Literal["voyageai"]]
config: VoyageAIProviderConfig

View File

@@ -1,55 +0,0 @@
"""Voyage AI embeddings provider."""
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
VoyageAIEmbeddingFunction,
)
class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
"""Voyage AI embeddings provider."""
embedding_callable: type[VoyageAIEmbeddingFunction] = Field(
default=VoyageAIEmbeddingFunction,
description="Voyage AI embedding function class",
)
model: str = Field(
default="voyage-2",
description="Model to use for embeddings",
validation_alias="VOYAGEAI_MODEL",
)
api_key: str = Field(
description="Voyage AI API key", validation_alias="VOYAGEAI_API_KEY"
)
input_type: str | None = Field(
default=None,
description="Input type for embeddings",
validation_alias="VOYAGEAI_INPUT_TYPE",
)
truncation: bool = Field(
default=True,
description="Whether to truncate inputs",
validation_alias="VOYAGEAI_TRUNCATION",
)
output_dtype: str | None = Field(
default=None,
description="Output data type",
validation_alias="VOYAGEAI_OUTPUT_DTYPE",
)
output_dimension: int | None = Field(
default=None,
description="Output dimension",
validation_alias="VOYAGEAI_OUTPUT_DIMENSION",
)
max_retries: int = Field(
default=0,
description="Maximum retries for API calls",
validation_alias="VOYAGEAI_MAX_RETRIES",
)
timeout: float | None = Field(
default=None,
description="Timeout for API calls",
validation_alias="VOYAGEAI_TIMEOUT",
)

View File

@@ -1,73 +1,63 @@
"""Type definitions for the embeddings module."""
from typing import Literal, TypeAlias
from typing import Literal
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec,
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
from crewai.rag.embeddings.providers.sentence_transformer.types import (
SentenceTransformerProviderSpec,
)
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
from pydantic import BaseModel, Field, SecretStr
ProviderSpec = (
AzureProviderSpec
| BedrockProviderSpec
| CohereProviderSpec
| CustomProviderSpec
| GenerativeAiProviderSpec
| HuggingFaceProviderSpec
| InstructorProviderSpec
| JinaProviderSpec
| OllamaProviderSpec
| ONNXProviderSpec
| OpenAIProviderSpec
| OpenCLIPProviderSpec
| RoboflowProviderSpec
| SentenceTransformerProviderSpec
| Text2VecProviderSpec
| VertexAIProviderSpec
| VoyageAIProviderSpec
| WatsonProviderSpec
)
from crewai.rag.types import EmbeddingFunction
AllowedEmbeddingProviders = Literal[
"azure",
"amazon-bedrock",
EmbeddingProvider = Literal[
"openai",
"cohere",
"custom",
"ollama",
"huggingface",
"sentence-transformer",
"instructor",
"google-palm",
"google-generativeai",
"google-vertex",
"huggingface",
"instructor",
"amazon-bedrock",
"jina",
"ollama",
"onnx",
"openai",
"openclip",
"roboflow",
"sentence-transformer",
"openclip",
"text2vec",
"voyageai",
"onnx",
"watson",
]
"""Supported embedding providers.
EmbedderConfig: TypeAlias = (
ProviderSpec | BaseEmbeddingsProvider | type[BaseEmbeddingsProvider]
)
These correspond to the embedding functions available in ChromaDB's
embedding_functions module. Each provider has specific requirements
and configuration options.
"""
class EmbeddingOptions(BaseModel):
"""Configuration options for embedding providers.
Generic attributes that can be passed to get_embedding_function
to configure various embedding providers.
"""
provider: EmbeddingProvider = Field(
..., description="Embedding provider name (e.g., 'openai', 'cohere', 'onnx')"
)
model_name: str | None = Field(
default=None, description="Model name for the embedding provider"
)
api_key: SecretStr | None = Field(
default=None, description="API key for the embedding provider"
)
class EmbeddingConfig(BaseModel):
"""Configuration wrapper for embedding functions.
Accepts either a pre-configured EmbeddingFunction or EmbeddingOptions
to create one. This provides flexibility in how embeddings are configured.
Attributes:
function: Either a callable EmbeddingFunction or EmbeddingOptions to create one
"""
function: EmbeddingFunction | EmbeddingOptions

View File

@@ -48,7 +48,6 @@ class QdrantClient(BaseClient):
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
default_batch_size: int = 100,
) -> None:
"""Initialize QdrantClient with client and embedding function.
@@ -57,13 +56,11 @@ class QdrantClient(BaseClient):
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
@@ -237,7 +234,6 @@ class QdrantClient(BaseClient):
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data.
batch_size: Optional batch size for processing documents (default: 100)
Raises:
ValueError: If collection doesn't exist or documents list is empty.
@@ -253,7 +249,6 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
@@ -261,20 +256,19 @@ class QdrantClient(BaseClient):
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : min(i + batch_size, len(documents))]
points = []
for doc in batch_docs:
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync add_documents. "
"Use aadd_documents instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
self.client.upsert(collection_name=collection_name, points=points)
points = []
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync add_documents. "
"Use aadd_documents instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
self.client.upsert(collection_name=collection_name, points=points)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
@@ -282,7 +276,6 @@ class QdrantClient(BaseClient):
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data.
batch_size: Optional batch size for processing documents (default: 100)
Raises:
ValueError: If collection doesn't exist or documents list is empty.
@@ -298,7 +291,6 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
@@ -306,19 +298,18 @@ class QdrantClient(BaseClient):
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : min(i + batch_size, len(documents))]
points = []
for doc in batch_docs:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
await self.client.upsert(collection_name=collection_name, points=points)
points = []
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
await self.client.upsert(collection_name=collection_name, points=points)
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]

View File

@@ -22,5 +22,4 @@ def create_client(config: QdrantConfig) -> QdrantClient:
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
)

View File

@@ -1,9 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.types import ProviderSpec
class BaseRAGStorage(ABC):
"""
@@ -16,7 +13,7 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
embedder_config: dict[str, Any] | None = None,
crew: Any = None,
):
self.type = type

View File

@@ -24,7 +24,8 @@ class BaseRecord(TypedDict, total=False):
)
Embeddings: TypeAlias = list[list[float]]
DenseVector: TypeAlias = list[float]
IntVector: TypeAlias = list[int]
EmbeddingFunction: TypeAlias = Callable[..., Any]

View File

@@ -2,94 +2,29 @@
import importlib
from types import ModuleType
from typing import Annotated, Any, TypeAlias
from pydantic import AfterValidator, TypeAdapter
from typing_extensions import deprecated
@deprecated(
"Not needed when using `crewai.utilities.import_utils.import_and_validate_definition`"
)
class OptionalDependencyError(ImportError):
"""Exception raised when an optional dependency is not installed."""
@deprecated(
"Use `crewai.utilities.import_utils.import_and_validate_definition` instead."
)
def require(name: str, *, purpose: str, attr: str | None = None) -> ModuleType | Any:
"""Import a module, optionally returning a specific attribute.
def require(name: str, *, purpose: str) -> ModuleType:
"""Import a module, raising a helpful error if it's not installed.
Args:
name: The module name to import.
purpose: Description of what requires this dependency.
attr: Optional attribute name to get from the module.
Returns:
The imported module or the specified attribute.
The imported module.
Raises:
OptionalDependencyError: If the module is not installed.
AttributeError: If the specified attribute doesn't exist.
"""
try:
module = importlib.import_module(name)
if attr is not None:
return getattr(module, attr)
return module
return importlib.import_module(name)
except ImportError as exc:
package_name = name.split(".")[0]
raise OptionalDependencyError(
f"{purpose} requires the optional dependency '{name}'.\n"
f"Install it with: uv add {package_name}"
f"Install it with: uv add {name}"
) from exc
except AttributeError as exc:
raise AttributeError(f"Module '{name}' has no attribute '{attr}'") from exc
def validate_import_path(v: str) -> Any:
"""Import and return the class/function from the import path.
Args:
v: Import path string in the format 'module.path.ClassName'.
Returns:
The imported class or function.
Raises:
ValueError: If the import path is malformed or the module cannot be imported.
"""
module_path, _, attr = v.rpartition(".")
if not module_path or not attr:
raise ValueError(f"import_path '{v}' must be of the form 'module.ClassName'")
try:
mod = importlib.import_module(module_path)
except ImportError as exc:
parts = module_path.split(".")
if not parts:
raise ValueError(f"Malformed import path: '{v}'") from exc
package = parts[0]
raise ValueError(
f"Package '{package}' could not be imported. Install it with: uv add {package}"
) from exc
if not hasattr(mod, attr):
raise ValueError(f"Attribute '{attr}' not found in module '{module_path}'")
return getattr(mod, attr)
ImportedDefinition: TypeAlias = Annotated[Any, AfterValidator(validate_import_path)]
adapter = TypeAdapter(ImportedDefinition)
def import_and_validate_definition(v: str) -> Any:
"""Pydantic-compatible function to import a class/function from a string path.
Args:
v: Import path string in the format 'module.path.ClassName'.
Returns:
The imported class or function
"""
return adapter.validate_python(v)

View File

@@ -1,298 +0,0 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
personal goal is: Test goal\nTo give my best complete final answer to the task
respond using the exact following format:\n\nThought: I now can give a great
answer\nFinal Answer: Your final answer must be the great and the most complete
as possible, it must be outcome described.\n\nI MUST use these formats, my job
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
This is VERY important to you, use the tools available and give your best Final
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
["\nObservation:"]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '825'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.93.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.93.0
x-stainless-raw-response:
- 'true'
x-stainless-read-timeout:
- '600.0'
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.9
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jFJNj9MwEL3nV4x8blC/0nRzA8QKuCBxQFrBKnLtSeLF8Vi2s11Y9b8j
O90m5UPiEinz5j2/NzPPGQBTklXARMeD6K3O33588/nm/d27T5sv6wdp1OGu/9l1T9tbWzQlW0QG
HR5QhBfWK0G91RgUmREWDnnAqLoqi/1uv19vVgnoSaKOtNaGfEt5r4zK18v1Nl+W+Wp/ZnekBHpW
wdcMAOA5faNPI/GJVbBcvFR69J63yKpLEwBzpGOFce+VD9wEtphAQSagSdY/gKEjCG6gVY8IHNpo
G7jxR3QA38ytMlzD6/RfQYdaExzJaTkXdNgMnsdQZtB6BnBjKPA4lBTl/oycLuY1tdbRwf9GZY0y
yne1Q+7JRKM+kGUJPWUA92lIw1VuZh31NtSBvmN6blWUox6bdjNDN2cwUOB6Vi/Po73WqyUGrrSf
jZkJLjqUE3XaCR+kohmQzVL/6eZv2mNyZdr/kZ8AIdAGlLV1KJW4Tjy1OYyn+6+2y5STYebRPSqB
dVDo4iYkNnzQ40Ex/8MH7OtGmRaddWq8qsbWxW7Jmx0WxQ3LTtkvAAAA//8DAIkIBqtjAwAA
headers:
CF-RAY:
- 983f8c061b6ec487-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Wed, 24 Sep 2025 04:30:32 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=JDjpnzx5y8PJaJDQcCeX6MeBt8BOGuL79pd.ca5mqvE-1758688232-1.0.1.1-5VN5hj5LzEZFfkotBaZ_dbUITo_YB7RLsFOlQc.0MdSZOsz7WhNkH.s7H700L12Yi8nHGW44BgIwCF3uWx1w4PRBqrb1IVH3FkeV.QwCTaA;
path=/; expires=Wed, 24-Sep-25 05:00:32 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=b5n8BZZDRtHA4TrxQ1RDeEdtQBzhstjP6u21LYM8L94-1758688232142-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '535'
openai-project:
- proj_xitITlrFeen7zjNSzML82h9x
openai-version:
- '2020-10-01'
x-envoy-upstream-service-time:
- '562'
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-project-tokens:
- '150000000'
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-project-tokens:
- '149999827'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999830'
x-ratelimit-reset-project-tokens:
- 0s
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_af61ab9d53bf400baf30c5bc5a7e2102
status:
code: 200
message: OK
- request:
body: null
headers:
Connection:
- close
Host:
- api.scarf.sh
User-Agent:
- CrewAI-Python/0.193.2
method: GET
uri: https://api.scarf.sh/v2/packages/CrewAI/crewai/docs/00f2dad1-8334-4a39-934e-003b2e1146db
response:
body:
string: ''
headers:
Connection:
- close
Date:
- Wed, 24 Sep 2025 04:47:59 GMT
Strict-Transport-Security:
- max-age=15724800; includeSubDomains
Transfer-Encoding:
- chunked
x-scarf-request-id:
- 4158376f-cb1c-46fe-a14c-dee366b955e2
status:
code: 401
message: Unauthorized
- request:
body: '{"trace_id": "06e1250e-6d88-4c64-abe5-deabde573ae1", "execution_type":
"crew", "user_identifier": null, "execution_context": {"crew_fingerprint": null,
"crew_name": "crew", "flow_name": null, "crewai_version": "0.193.2", "privacy_level":
"standard"}, "execution_metadata": {"expected_duration_estimate": 300, "agent_count":
0, "task_count": 0, "flow_method_count": 0, "execution_started_at": "2025-09-24T04:50:23.219835+00:00"}}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate
Connection:
- keep-alive
Content-Length:
- '428'
Content-Type:
- application/json
User-Agent:
- CrewAI-CLI/0.193.2
X-Crewai-Organization-Id:
- d3a3d10c-35db-423f-a7a4-c026030ba64d
X-Crewai-Version:
- 0.193.2
method: POST
uri: http://localhost:3000/crewai_plus/api/v1/tracing/batches
response:
body:
string: '{"error":"bad_credentials","message":"Bad credentials"}'
headers:
Content-Length:
- '55'
cache-control:
- no-cache
content-security-policy:
- 'default-src ''self'' *.crewai.com crewai.com; script-src ''self'' ''unsafe-inline''
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts https://www.gstatic.com
https://run.pstmn.io https://share.descript.com/; style-src ''self'' ''unsafe-inline''
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts; img-src ''self''
data: *.crewai.com crewai.com https://zeus.tools.crewai.com https://dashboard.tools.crewai.com
https://cdn.jsdelivr.net; font-src ''self'' data: *.crewai.com crewai.com;
connect-src ''self'' *.crewai.com crewai.com https://zeus.tools.crewai.com
https://connect.useparagon.com/ https://zeus.useparagon.com/* https://*.useparagon.com/*
https://run.pstmn.io https://connect.tools.crewai.com/ ws://localhost:3036
wss://localhost:3036; frame-src ''self'' *.crewai.com crewai.com https://connect.useparagon.com/
https://zeus.tools.crewai.com https://zeus.useparagon.com/* https://connect.tools.crewai.com/
https://www.youtube.com https://share.descript.com'
content-type:
- application/json; charset=utf-8
permissions-policy:
- camera=(), microphone=(self), geolocation=()
referrer-policy:
- strict-origin-when-cross-origin
server-timing:
- cache_read.active_support;dur=0.37, sql.active_record;dur=30.81, cache_generate.active_support;dur=29.14,
cache_write.active_support;dur=0.14, cache_read_multi.active_support;dur=0.19,
start_processing.action_controller;dur=0.00, process_action.action_controller;dur=2.74
vary:
- Accept
x-content-type-options:
- nosniff
x-frame-options:
- SAMEORIGIN
x-permitted-cross-domain-policies:
- none
x-request-id:
- 2420790e-9669-4235-851c-468185b6ef40
x-runtime:
- '0.102516'
x-xss-protection:
- 1; mode=block
status:
code: 401
message: Unauthorized
- request:
body: '{"status": "failed", "failure_reason": "Error sending events to backend"}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate
Connection:
- keep-alive
Content-Length:
- '73'
Content-Type:
- application/json
User-Agent:
- CrewAI-CLI/0.193.2
X-Crewai-Organization-Id:
- d3a3d10c-35db-423f-a7a4-c026030ba64d
X-Crewai-Version:
- 0.193.2
method: PATCH
uri: http://localhost:3000/crewai_plus/api/v1/tracing/batches/None
response:
body:
string: '{"error":"bad_credentials","message":"Bad credentials"}'
headers:
Content-Length:
- '55'
cache-control:
- no-cache
content-security-policy:
- 'default-src ''self'' *.crewai.com crewai.com; script-src ''self'' ''unsafe-inline''
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts https://www.gstatic.com
https://run.pstmn.io https://share.descript.com/; style-src ''self'' ''unsafe-inline''
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts; img-src ''self''
data: *.crewai.com crewai.com https://zeus.tools.crewai.com https://dashboard.tools.crewai.com
https://cdn.jsdelivr.net; font-src ''self'' data: *.crewai.com crewai.com;
connect-src ''self'' *.crewai.com crewai.com https://zeus.tools.crewai.com
https://connect.useparagon.com/ https://zeus.useparagon.com/* https://*.useparagon.com/*
https://run.pstmn.io https://connect.tools.crewai.com/ ws://localhost:3036
wss://localhost:3036; frame-src ''self'' *.crewai.com crewai.com https://connect.useparagon.com/
https://zeus.tools.crewai.com https://zeus.useparagon.com/* https://connect.tools.crewai.com/
https://www.youtube.com https://share.descript.com'
content-type:
- application/json; charset=utf-8
permissions-policy:
- camera=(), microphone=(self), geolocation=()
referrer-policy:
- strict-origin-when-cross-origin
server-timing:
- cache_read.active_support;dur=0.06, sql.active_record;dur=3.86, cache_generate.active_support;dur=4.28,
cache_write.active_support;dur=0.15, cache_read_multi.active_support;dur=0.12,
start_processing.action_controller;dur=0.00, process_action.action_controller;dur=1.70
vary:
- Accept
x-content-type-options:
- nosniff
x-frame-options:
- SAMEORIGIN
x-permitted-cross-domain-policies:
- none
x-request-id:
- 1750d141-c48f-47f1-b8b4-130195437d22
x-runtime:
- '0.043849'
x-xss-protection:
- 1; mode=block
status:
code: 401
message: Unauthorized
version: 1

View File

@@ -11,7 +11,7 @@ from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.create_client")
@patch("crewai.knowledge.storage.knowledge_storage.build_embedder")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_knowledge_storage_uses_rag_client(
mock_get_embedding: MagicMock,
mock_create_client: MagicMock,
@@ -122,7 +122,7 @@ def test_search_error_handling(mock_get_client: MagicMock) -> None:
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
@patch("crewai.knowledge.storage.knowledge_storage.build_embedder")
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
def test_embedding_configuration_flow(
mock_get_embedding: MagicMock, mock_get_client: MagicMock
) -> None:

View File

@@ -34,30 +34,6 @@ def client(mock_chromadb_client) -> ChromaDBClient:
return client
@pytest.fixture
def client_with_batch_size(mock_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with custom batch size for testing."""
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_chromadb_client,
embedding_function=mock_embedding,
default_batch_size=2,
)
return client
@pytest.fixture
def async_client_with_batch_size(mock_async_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with async client and custom batch size for testing."""
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_async_chromadb_client,
embedding_function=mock_embedding,
default_batch_size=2,
)
return client
@pytest.fixture
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with async client for testing."""
@@ -636,139 +612,3 @@ class TestChromaDBClient:
await async_client.areset()
mock_async_chromadb_client.reset.assert_called_once_with()
def test_add_documents_with_batch_size(
self, client_with_batch_size, mock_chromadb_client
) -> None:
"""Test add_documents with batch size splits documents into batches."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
{"doc_id": "id4", "content": "Document 4", "metadata": {"source": "test4"}},
{"doc_id": "id5", "content": "Document 5", "metadata": {"source": "test5"}},
]
client_with_batch_size.add_documents(
collection_name="test_collection", documents=documents
)
assert mock_collection.upsert.call_count == 3
first_call = mock_collection.upsert.call_args_list[0]
assert first_call.kwargs["ids"] == ["id1", "id2"]
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
assert first_call.kwargs["metadatas"] == [
{"source": "test1"},
{"source": "test2"},
]
second_call = mock_collection.upsert.call_args_list[1]
assert second_call.kwargs["ids"] == ["id3", "id4"]
assert second_call.kwargs["documents"] == ["Document 3", "Document 4"]
assert second_call.kwargs["metadatas"] == [
{"source": "test3"},
{"source": "test4"},
]
third_call = mock_collection.upsert.call_args_list[2]
assert third_call.kwargs["ids"] == ["id5"]
assert third_call.kwargs["documents"] == ["Document 5"]
assert third_call.kwargs["metadatas"] == [{"source": "test5"}]
def test_add_documents_with_explicit_batch_size(
self, client, mock_chromadb_client
) -> None:
"""Test add_documents with explicitly provided batch size."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1"},
{"doc_id": "id2", "content": "Document 2"},
{"doc_id": "id3", "content": "Document 3"},
]
client.add_documents(
collection_name="test_collection", documents=documents, batch_size=1
)
assert mock_collection.upsert.call_count == 3
for i, call in enumerate(mock_collection.upsert.call_args_list):
assert len(call.kwargs["ids"]) == 1
assert call.kwargs["ids"] == [f"id{i + 1}"]
@pytest.mark.asyncio
async def test_aadd_documents_with_batch_size(
self, async_client_with_batch_size, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with batch size splits documents into batches."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
]
await async_client_with_batch_size.aadd_documents(
collection_name="test_collection", documents=documents
)
assert mock_collection.upsert.call_count == 2
first_call = mock_collection.upsert.call_args_list[0]
assert first_call.kwargs["ids"] == ["id1", "id2"]
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
second_call = mock_collection.upsert.call_args_list[1]
assert second_call.kwargs["ids"] == ["id3"]
assert second_call.kwargs["documents"] == ["Document 3"]
@pytest.mark.asyncio
async def test_aadd_documents_with_explicit_batch_size(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with explicitly provided batch size."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1"},
{"doc_id": "id2", "content": "Document 2"},
{"doc_id": "id3", "content": "Document 3"},
{"doc_id": "id4", "content": "Document 4"},
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents, batch_size=3
)
assert mock_collection.upsert.call_count == 2
first_call = mock_collection.upsert.call_args_list[0]
assert len(first_call.kwargs["ids"]) == 3
second_call = mock_collection.upsert.call_args_list[1]
assert len(second_call.kwargs["ids"]) == 1
def test_client_default_batch_size_initialization(self) -> None:
"""Test that client initializes with correct default batch size."""
mock_client = Mock()
mock_embedding = Mock()
client = ChromaDBClient(client=mock_client, embedding_function=mock_embedding)
assert client.default_batch_size == 100
custom_client = ChromaDBClient(
client=mock_client, embedding_function=mock_embedding, default_batch_size=50
)
assert custom_client.default_batch_size == 50

View File

@@ -1,15 +1,11 @@
"""Tests for ChromaDB utility functions."""
from crewai.rag.chromadb.types import PreparedDocuments
from crewai.rag.chromadb.utils import (
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
_create_batch_slice,
_is_ipv4_pattern,
_prepare_documents_for_chromadb,
_sanitize_collection_name,
)
from crewai.rag.types import BaseRecord
class TestChromaDBUtils:
@@ -97,206 +93,3 @@ class TestChromaDBUtils:
assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert sanitized[0].isalnum()
assert sanitized[-1].isalnum()
class TestPrepareDocumentsForChromaDB:
"""Test suite for _prepare_documents_for_chromadb function."""
def test_prepare_documents_with_doc_ids(self) -> None:
"""Test preparing documents that already have doc_ids."""
documents: list[BaseRecord] = [
{
"doc_id": "id1",
"content": "First document",
"metadata": {"source": "test1"},
},
{
"doc_id": "id2",
"content": "Second document",
"metadata": {"source": "test2"},
},
]
result = _prepare_documents_for_chromadb(documents)
assert result.ids == ["id1", "id2"]
assert result.texts == ["First document", "Second document"]
assert result.metadatas == [{"source": "test1"}, {"source": "test2"}]
def test_prepare_documents_generate_ids(self) -> None:
"""Test preparing documents without doc_ids (should generate hashes)."""
documents: list[BaseRecord] = [
{"content": "Test content", "metadata": {"key": "value"}},
{"content": "Another test"},
]
result = _prepare_documents_for_chromadb(documents)
assert len(result.ids) == 2
assert all(len(doc_id) == 64 for doc_id in result.ids)
assert result.texts == ["Test content", "Another test"]
assert result.metadatas == [{"key": "value"}, {}]
def test_prepare_documents_with_list_metadata(self) -> None:
"""Test preparing documents with list metadata (should take first item)."""
documents: list[BaseRecord] = [
{"content": "Test", "metadata": [{"first": "item"}, {"second": "item"}]},
{"content": "Test2", "metadata": []},
]
result = _prepare_documents_for_chromadb(documents)
assert result.metadatas == [{"first": "item"}, {}]
def test_prepare_documents_no_metadata(self) -> None:
"""Test preparing documents without metadata."""
documents: list[BaseRecord] = [
{"content": "Document 1"},
{"content": "Document 2", "metadata": None},
]
result = _prepare_documents_for_chromadb(documents)
assert result.metadatas == [{}, {}]
def test_prepare_documents_hash_consistency(self) -> None:
"""Test that identical content produces identical hashes."""
documents1: list[BaseRecord] = [
{"content": "Same content", "metadata": {"key": "value"}}
]
documents2: list[BaseRecord] = [
{"content": "Same content", "metadata": {"key": "value"}}
]
result1 = _prepare_documents_for_chromadb(documents1)
result2 = _prepare_documents_for_chromadb(documents2)
assert result1.ids == result2.ids
class TestCreateBatchSlice:
"""Test suite for _create_batch_slice function."""
def test_create_batch_slice_normal(self) -> None:
"""Test creating a normal batch slice."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3", "id4", "id5"],
texts=["doc1", "doc2", "doc3", "doc4", "doc5"],
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}, {"e": 5}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=1, batch_size=3
)
assert batch_ids == ["id2", "id3", "id4"]
assert batch_texts == ["doc2", "doc3", "doc4"]
assert batch_metadatas == [{"b": 2}, {"c": 3}, {"d": 4}]
def test_create_batch_slice_at_end(self) -> None:
"""Test creating a batch slice that goes beyond the end."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"],
texts=["doc1", "doc2", "doc3"],
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=2, batch_size=5
)
assert batch_ids == ["id3"]
assert batch_texts == ["doc3"]
assert batch_metadatas == [{"c": 3}]
def test_create_batch_slice_empty_batch(self) -> None:
"""Test creating a batch slice starting beyond the data."""
prepared = PreparedDocuments(
ids=["id1", "id2"], texts=["doc1", "doc2"], metadatas=[{"a": 1}, {"b": 2}]
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=5, batch_size=3
)
assert batch_ids == []
assert batch_texts == []
assert batch_metadatas == []
def test_create_batch_slice_no_metadatas(self) -> None:
"""Test creating a batch slice with no metadatas."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"], texts=["doc1", "doc2", "doc3"], metadatas=[]
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=0, batch_size=2
)
assert batch_ids == ["id1", "id2"]
assert batch_texts == ["doc1", "doc2"]
assert batch_metadatas is None
def test_create_batch_slice_all_empty_metadatas(self) -> None:
"""Test creating a batch slice where all metadatas are empty."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"],
texts=["doc1", "doc2", "doc3"],
metadatas=[{}, {}, {}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=0, batch_size=3
)
assert batch_ids == ["id1", "id2", "id3"]
assert batch_texts == ["doc1", "doc2", "doc3"]
assert batch_metadatas is None
def test_create_batch_slice_some_empty_metadatas(self) -> None:
"""Test creating a batch slice where some metadatas are empty."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"],
texts=["doc1", "doc2", "doc3"],
metadatas=[{"a": 1}, {}, {"c": 3}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=0, batch_size=3
)
assert batch_ids == ["id1", "id2", "id3"]
assert batch_texts == ["doc1", "doc2", "doc3"]
assert batch_metadatas == [{"a": 1}, {}, {"c": 3}]
def test_create_batch_slice_zero_start_index(self) -> None:
"""Test creating a batch slice starting from index 0."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3", "id4"],
texts=["doc1", "doc2", "doc3", "doc4"],
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=0, batch_size=2
)
assert batch_ids == ["id1", "id2"]
assert batch_texts == ["doc1", "doc2"]
assert batch_metadatas == [{"a": 1}, {"b": 2}]
def test_create_batch_slice_single_item(self) -> None:
"""Test creating a batch slice with batch size 1."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"],
texts=["doc1", "doc2", "doc3"],
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=1, batch_size=1
)
assert batch_ids == ["id2"]
assert batch_texts == ["doc2"]
assert batch_metadatas == [{"b": 2}]

View File

@@ -1,244 +0,0 @@
"""Tests for embedding function factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import build_embedder
class TestEmbeddingFactory:
"""Test embedding factory functions."""
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_openai(self, mock_import):
"""Test building OpenAI embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "openai",
"config": {
"api_key": "test-key",
"model_name": "text-embedding-3-small",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-key"
assert call_kwargs["model_name"] == "text-embedding-3-small"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_azure(self, mock_import):
"""Test building Azure embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "azure",
"config": {
"api_key": "test-azure-key",
"api_base": "https://test.openai.azure.com/",
"api_type": "azure",
"api_version": "2023-05-15",
"model_name": "text-embedding-3-small",
"deployment_id": "test-deployment",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-azure-key"
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
assert call_kwargs["api_type"] == "azure"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_ollama(self, mock_import):
"""Test building Ollama embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "ollama",
"config": {
"model_name": "nomic-embed-text",
"url": "http://localhost:11434",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider"
)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_cohere(self, mock_import):
"""Test building Cohere embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "cohere",
"config": {
"api_key": "cohere-key",
"model_name": "embed-english-v3.0",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider"
)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_voyageai(self, mock_import):
"""Test building VoyageAI embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "voyageai",
"config": {
"api_key": "voyage-key",
"model": "voyage-2",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider"
)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_watson(self, mock_import):
"""Test building Watson embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "watson",
"config": {
"model_id": "ibm/slate-125m-english-rtrvr",
"api_key": "watson-key",
"url": "https://us-south.ml.cloud.ibm.com",
"project_id": "test-project",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.ibm.watson.WatsonProvider"
)
def test_build_embedder_unknown_provider(self):
"""Test error handling for unknown provider."""
config = {"provider": "unknown-provider", "config": {}}
with pytest.raises(ValueError, match="Unknown provider: unknown-provider"):
build_embedder(config)
def test_build_embedder_missing_provider(self):
"""Test error handling for missing provider key."""
config = {"config": {"api_key": "test-key"}}
with pytest.raises(KeyError):
build_embedder(config)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_import_error(self, mock_import):
"""Test error handling when provider import fails."""
mock_import.side_effect = ImportError("Module not found")
config = {"provider": "openai", "config": {"api_key": "test-key"}}
with pytest.raises(ImportError, match="Failed to import provider openai"):
build_embedder(config)
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_custom_provider(self, mock_import):
"""Test building custom embedder."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_callable = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable = mock_embedding_callable
config = {
"provider": "custom",
"config": {"embedding_callable": mock_embedding_callable},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["embedding_callable"] == mock_embedding_callable
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
@patch("crewai.rag.embeddings.factory.build_embedder_from_provider")
def test_build_embedder_with_provider_instance(
self, mock_build_from_provider, mock_import
):
"""Test building embedder from provider instance."""
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
mock_provider = MagicMock(spec=BaseEmbeddingsProvider)
mock_embedding_function = MagicMock()
mock_build_from_provider.return_value = mock_embedding_function
result = build_embedder(mock_provider)
mock_build_from_provider.assert_called_once_with(mock_provider)
assert result == mock_embedding_function
mock_import.assert_not_called()

View File

@@ -1,122 +0,0 @@
"""Test Azure embedder configuration with factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import build_embedder
class TestAzureEmbedderFactory:
"""Test Azure embedder configuration with factory function."""
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_azure_with_nested_config(self, mock_import):
"""Test Azure configuration with nested config key."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
embedder_config = {
"provider": "azure",
"config": {
"api_key": "test-azure-key",
"api_base": "https://test.openai.azure.com/",
"api_type": "azure",
"api_version": "2023-05-15",
"model_name": "text-embedding-3-small",
"deployment_id": "test-deployment",
},
}
result = build_embedder(embedder_config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-azure-key"
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
assert call_kwargs["api_type"] == "azure"
assert call_kwargs["api_version"] == "2023-05-15"
assert call_kwargs["model_name"] == "text-embedding-3-small"
assert call_kwargs["deployment_id"] == "test-deployment"
assert result == mock_embedding_function
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_regular_openai_with_nested_config(self, mock_import):
"""Test regular OpenAI configuration with nested config."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
embedder_config = {
"provider": "openai",
"config": {"api_key": "test-openai-key", "model": "text-embedding-3-large"},
}
result = build_embedder(embedder_config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-openai-key"
assert call_kwargs["model"] == "text-embedding-3-large"
assert result == mock_embedding_function
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_azure_provider_with_minimal_config(self, mock_import):
"""Test Azure provider with minimal required configuration."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
embedder_config = {
"provider": "azure",
"config": {
"api_key": "test-key",
"api_base": "https://test.openai.azure.com/",
},
}
build_embedder(embedder_config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-key"
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_azure_import_error(self, mock_import):
"""Test handling of import errors for Azure provider."""
mock_import.side_effect = ImportError("Failed to import Azure provider")
embedder_config = {
"provider": "azure",
"config": {"api_key": "test-key"},
}
with pytest.raises(ImportError) as exc_info:
build_embedder(embedder_config)
assert "Failed to import provider azure" in str(exc_info.value)

View File

@@ -0,0 +1,315 @@
"""Enhanced tests for embedding function factory."""
from unittest.mock import MagicMock, patch
import pytest
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
get_embedding_function,
)
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
def test_get_embedding_function_default() -> None:
"""Test default embedding function when no config provided."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
with patch(
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
):
result = get_embedding_function()
mock_openai.assert_called_once_with(
api_key="test-api-key", model_name="text-embedding-3-small"
)
assert result == mock_instance
def test_get_embedding_function_with_embedding_options() -> None:
"""Test embedding function creation with EmbeddingOptions object."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(
provider="openai", api_key="test-key", model="text-embedding-3-large"
)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
# OpenAI uses model_name parameter, not model
assert result == mock_instance
def test_get_embedding_function_sentence_transformer() -> None:
"""Test sentence transformer embedding function."""
with patch(
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
) as mock_st:
mock_instance = MagicMock()
mock_st.return_value = mock_instance
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
result = get_embedding_function(config)
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
assert result == mock_instance
def test_get_embedding_function_ollama() -> None:
"""Test Ollama embedding function."""
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
mock_instance = MagicMock()
mock_ollama.return_value = mock_instance
config = {
"provider": "ollama",
"model_name": "nomic-embed-text",
"url": "http://localhost:11434",
}
result = get_embedding_function(config)
mock_ollama.assert_called_once_with(
model_name="nomic-embed-text", url="http://localhost:11434"
)
assert result == mock_instance
def test_get_embedding_function_cohere() -> None:
"""Test Cohere embedding function."""
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
mock_instance = MagicMock()
mock_cohere.return_value = mock_instance
config = {
"provider": "cohere",
"api_key": "cohere-key",
"model_name": "embed-english-v3.0",
}
result = get_embedding_function(config)
mock_cohere.assert_called_once_with(
api_key="cohere-key", model_name="embed-english-v3.0"
)
assert result == mock_instance
def test_get_embedding_function_huggingface() -> None:
"""Test HuggingFace embedding function."""
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
mock_instance = MagicMock()
mock_hf.return_value = mock_instance
config = {
"provider": "huggingface",
"api_key": "hf-token",
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
}
result = get_embedding_function(config)
mock_hf.assert_called_once_with(
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
)
assert result == mock_instance
def test_get_embedding_function_onnx() -> None:
"""Test ONNX embedding function."""
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
mock_instance = MagicMock()
mock_onnx.return_value = mock_instance
config = {"provider": "onnx"}
result = get_embedding_function(config)
mock_onnx.assert_called_once()
assert result == mock_instance
def test_get_embedding_function_google_palm() -> None:
"""Test Google PaLM embedding function."""
with patch(
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
) as mock_palm:
mock_instance = MagicMock()
mock_palm.return_value = mock_instance
config = {"provider": "google-palm", "api_key": "palm-key"}
result = get_embedding_function(config)
mock_palm.assert_called_once_with(api_key="palm-key")
assert result == mock_instance
def test_get_embedding_function_amazon_bedrock() -> None:
"""Test Amazon Bedrock embedding function."""
with patch(
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
) as mock_bedrock:
mock_instance = MagicMock()
mock_bedrock.return_value = mock_instance
config = {
"provider": "amazon-bedrock",
"region_name": "us-west-2",
"model_name": "amazon.titan-embed-text-v1",
}
result = get_embedding_function(config)
mock_bedrock.assert_called_once_with(
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
)
assert result == mock_instance
def test_get_embedding_function_jina() -> None:
"""Test Jina embedding function."""
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
mock_instance = MagicMock()
mock_jina.return_value = mock_instance
config = {
"provider": "jina",
"api_key": "jina-key",
"model_name": "jina-embeddings-v2-base-en",
}
result = get_embedding_function(config)
mock_jina.assert_called_once_with(
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
)
assert result == mock_instance
def test_get_embedding_function_unsupported_provider() -> None:
"""Test handling of unsupported provider."""
config = {"provider": "unsupported-provider"}
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
get_embedding_function(config)
def test_get_embedding_function_config_modification() -> None:
"""Test that original config dict is not modified."""
original_config = {
"provider": "openai",
"api_key": "test-key",
"model": "text-embedding-3-small",
}
config_copy = original_config.copy()
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
get_embedding_function(config_copy)
assert config_copy == original_config
def test_get_embedding_function_exclude_none_values() -> None:
"""Test that None values are excluded from embedding function calls."""
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
mock_instance = MagicMock()
mock_openai.return_value = mock_instance
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
result = get_embedding_function(options)
call_kwargs = mock_openai.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "test-key"
assert "model" not in call_kwargs
assert result == mock_instance
def test_get_embedding_function_instructor() -> None:
"""Test Instructor embedding function."""
with patch(
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
) as mock_instructor:
mock_instance = MagicMock()
mock_instructor.return_value = mock_instance
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
result = get_embedding_function(config)
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
assert result == mock_instance
def test_get_embedding_function_watson() -> None:
"""Test Watson embedding function."""
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
mock_instance = MagicMock()
mock_watson.return_value = mock_instance
config = {
"provider": "watson",
"api_key": "watson-api-key",
"api_url": "https://watson-url.com",
"project_id": "watson-project-id",
"model_name": "ibm/slate-125m-english-rtrvr",
}
result = get_embedding_function(config)
mock_watson.assert_called_once_with(
api_key="watson-api-key",
api_url="https://watson-url.com",
project_id="watson-project-id",
model_name="ibm/slate-125m-english-rtrvr",
)
assert result == mock_instance
def test_get_embedding_function_watson_missing_dependencies() -> None:
"""Test Watson embedding function with missing dependencies."""
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
mock_watson.side_effect = ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
)
config = {
"provider": "watson",
"api_key": "watson-api-key",
"api_url": "https://watson-url.com",
"project_id": "watson-project-id",
"model_name": "ibm/slate-125m-english-rtrvr",
}
with pytest.raises(ImportError, match="IBM Watson dependencies are not installed"):
get_embedding_function(config)
def test_get_embedding_function_watson_with_embedding_options() -> None:
"""Test Watson embedding function with EmbeddingOptions object."""
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
mock_instance = MagicMock()
mock_watson.return_value = mock_instance
options = EmbeddingOptions(
provider="watson",
api_key="watson-key",
model_name="ibm/slate-125m-english-rtrvr"
)
result = get_embedding_function(options)
call_kwargs = mock_watson.call_args.kwargs
assert "api_key" in call_kwargs
assert call_kwargs["api_key"].get_secret_value() == "watson-key"
assert call_kwargs["model_name"] == "ibm/slate-125m-english-rtrvr"
assert result == mock_instance

View File

@@ -1,77 +0,0 @@
"""Tests for Watson embedding function name method."""
import pytest
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonEmbeddingFunction,
)
class TestWatsonEmbeddingName:
"""Test Watson embedding function name method."""
def test_watson_embedding_function_has_name_method(self):
"""Test that WatsonEmbeddingFunction has a name method."""
assert hasattr(WatsonEmbeddingFunction, 'name')
assert callable(WatsonEmbeddingFunction.name)
def test_watson_embedding_function_name_returns_watson(self):
"""Test that the name method returns 'watson'."""
assert WatsonEmbeddingFunction.name() == "watson"
def test_watson_embedding_function_name_is_static(self):
"""Test that the name method can be called without instantiation."""
name = WatsonEmbeddingFunction.name()
assert name == "watson"
assert isinstance(name, str)
def test_watson_embedding_function_name_with_chromadb_validation(self):
"""Test that the name method works in ChromaDB validation scenario."""
config = {
"model_id": "test-model",
"api_key": "test-key",
"url": "https://test.com"
}
watson_func = WatsonEmbeddingFunction(**config)
try:
name = watson_func.name()
assert name == "watson"
except AttributeError as e:
pytest.fail(f"ChromaDB validation failed with AttributeError: {e}")
def test_watson_embedding_function_name_method_signature(self):
"""Test that the name method has the correct signature."""
import inspect
name_method = WatsonEmbeddingFunction.name
assert isinstance(inspect.getattr_static(WatsonEmbeddingFunction, 'name'), staticmethod)
sig = inspect.signature(name_method)
if sig.return_annotation != inspect.Signature.empty:
assert sig.return_annotation is str
def test_watson_embedding_function_reproduces_original_issue(self):
"""Test that reproduces the original issue scenario from #3597."""
config = {
"model_id": "ibm/slate-125m-english-rtrvr",
"api_key": "test-key",
"url": "https://us-south.ml.cloud.ibm.com",
"project_id": "test-project"
}
watson_func = WatsonEmbeddingFunction(**config)
name = watson_func.name()
assert name == "watson"
assert isinstance(name, str)
class_name = WatsonEmbeddingFunction.name()
assert class_name == "watson"
assert class_name == name

Some files were not shown because too many files have changed in this diff Show More