mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 23:28:30 +00:00
Compare commits
2 Commits
devin/1758
...
devin/1758
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4489baa149 | ||
|
|
1442f3e4b6 |
@@ -27,7 +27,7 @@ Follow the steps below to get Crewing! 🚣♂️
|
||||
<Step title="Navigate to your new crew project">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest_ai_development
|
||||
cd latest-ai-development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -27,7 +27,7 @@ mode: "wide"
|
||||
<Step title="새로운 crew 프로젝트로 이동하기">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest_ai_development
|
||||
cd latest-ai-development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -27,7 +27,7 @@ Siga os passos abaixo para começar a tripular! 🚣♂️
|
||||
<Step title="Navegue até o novo projeto da sua tripulação">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest_ai_development
|
||||
cd latest-ai-development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -9,7 +9,7 @@ authors = [
|
||||
]
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.11.9",
|
||||
"pydantic>=2.4.2",
|
||||
"openai>=1.13.3",
|
||||
"litellm==1.74.9",
|
||||
"instructor>=1.3.3",
|
||||
@@ -21,12 +21,13 @@ dependencies = [
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
||||
# Data Handling
|
||||
"chromadb~=1.1.0",
|
||||
"chromadb>=0.5.23",
|
||||
"tokenizers>=0.20.3",
|
||||
"onnxruntime==1.22.0",
|
||||
"openpyxl>=3.1.5",
|
||||
"pyvis>=0.3.2",
|
||||
# Authentication and Security
|
||||
"python-dotenv>=1.1.1",
|
||||
"python-dotenv>=1.0.0",
|
||||
"pyjwt>=2.9.0",
|
||||
# Configuration and Utils
|
||||
"click>=8.1.7",
|
||||
@@ -39,7 +40,6 @@ dependencies = [
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"portalocker==2.7.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -48,9 +48,7 @@ Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools>=0.74.0",
|
||||
]
|
||||
tools = ["crewai-tools~=0.73.0"]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
@@ -73,30 +71,24 @@ aisuite = [
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
]
|
||||
aws = [
|
||||
"boto3>=1.40.38",
|
||||
]
|
||||
watson = [
|
||||
"ibm-watsonx-ai>=1.3.39",
|
||||
]
|
||||
voyageai = [
|
||||
"voyageai>=0.3.5",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.13.1",
|
||||
"mypy>=1.18.2",
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"ruff>=0.12.11",
|
||||
"mypy>=1.17.1",
|
||||
"pre-commit>=4.3.0",
|
||||
"bandit>=1.8.6",
|
||||
"pytest>=8.4.2",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-subprocess>=1.5.3",
|
||||
"pytest-recording>=0.13.4",
|
||||
"pytest-randomly>=4.0.1",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pytest-split>=0.10.0",
|
||||
"pillow>=10.2.0",
|
||||
"cairosvg>=2.7.1",
|
||||
"pytest>=8.0.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"pytest-asyncio>=0.23.7",
|
||||
"pytest-subprocess>=1.5.2",
|
||||
"pytest-recording>=0.13.2",
|
||||
"pytest-randomly>=3.16.0",
|
||||
"pytest-timeout>=2.3.1",
|
||||
"pytest-xdist>=3.6.1",
|
||||
"pytest-split>=0.9.0",
|
||||
"types-requests==2.32.*",
|
||||
"types-pyyaml==6.0.*",
|
||||
"types-regex==2024.11.6.*",
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "0.201.0"
|
||||
__version__ = "0.193.2"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
@@ -12,31 +19,12 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security import Fingerprint
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
@@ -50,6 +38,24 @@ from crewai.utilities.agent_utils import (
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -81,36 +87,36 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
max_execution_time: int | None = Field(
|
||||
max_execution_time: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
)
|
||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
step_callback: Any | None = Field(
|
||||
step_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
use_system_prompt: bool | None = Field(
|
||||
use_system_prompt: Optional[bool] = Field(
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: str | None = Field(
|
||||
system_template: Optional[str] = Field(
|
||||
default=None, description="System format for the agent."
|
||||
)
|
||||
prompt_template: str | None = Field(
|
||||
prompt_template: Optional[str] = Field(
|
||||
default=None, description="Prompt format for the agent."
|
||||
)
|
||||
response_template: str | None = Field(
|
||||
response_template: Optional[str] = Field(
|
||||
default=None, description="Response format for the agent."
|
||||
)
|
||||
allow_code_execution: bool | None = Field(
|
||||
allow_code_execution: Optional[bool] = Field(
|
||||
default=False, description="Enable code execution for the agent."
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
@@ -141,31 +147,31 @@ class Agent(BaseAgent):
|
||||
default=False,
|
||||
description="Whether the agent should reflect and create a plan before executing a task.",
|
||||
)
|
||||
max_reasoning_attempts: int | None = Field(
|
||||
max_reasoning_attempts: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||
)
|
||||
embedder: EmbedderConfig | None = Field(
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
agent_knowledge_context: str | None = Field(
|
||||
agent_knowledge_context: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the agent.",
|
||||
)
|
||||
crew_knowledge_context: str | None = Field(
|
||||
crew_knowledge_context: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the crew.",
|
||||
)
|
||||
knowledge_search_query: str | None = Field(
|
||||
knowledge_search_query: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||
)
|
||||
from_repository: str | None = Field(
|
||||
from_repository: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
guardrail: Callable[[Any], tuple[bool, Any]] | str | None = Field(
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
@@ -174,7 +180,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v): # noqa: N805
|
||||
def validate_from_repository(cls, v):
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
@@ -202,7 +208,7 @@ class Agent(BaseAgent):
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
@@ -218,7 +224,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
self.knowledge.add_sources()
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {e!s}") from e
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
||||
|
||||
def _is_any_available_memory(self) -> bool:
|
||||
"""Check if any memory is available."""
|
||||
@@ -238,8 +244,8 @@ class Agent(BaseAgent):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
@@ -272,9 +278,11 @@ class Agent(BaseAgent):
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
if hasattr(self, "_logger"):
|
||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
self._logger.log(
|
||||
"error", f"Error during reasoning process: {str(e)}"
|
||||
)
|
||||
else:
|
||||
print(f"Error during reasoning process: {e!s}")
|
||||
print(f"Error during reasoning process: {str(e)}")
|
||||
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
@@ -327,7 +335,7 @@ class Agent(BaseAgent):
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context) # type: ignore[arg-type]
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
@@ -517,14 +525,14 @@ class Agent(BaseAgent):
|
||||
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError as e:
|
||||
except concurrent.futures.TimeoutError:
|
||||
future.cancel()
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
||||
) from e
|
||||
)
|
||||
except Exception as e:
|
||||
future.cancel()
|
||||
raise RuntimeError(f"Task execution failed: {e!s}") from e
|
||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
||||
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||
"""Execute a task without a timeout.
|
||||
@@ -546,14 +554,14 @@ class Agent(BaseAgent):
|
||||
)["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: list[BaseTool] | None = None, task=None
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
) -> None:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the CrewAgentExecutor class.
|
||||
"""
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
prompt = Prompts(
|
||||
@@ -579,7 +587,7 @@ class Agent(BaseAgent):
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
tools=parsed_tools,
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
prompt=prompt,
|
||||
original_tools=raw_tools,
|
||||
stop_words=stop_words,
|
||||
max_iter=self.max_iter,
|
||||
@@ -595,9 +603,10 @@ class Agent(BaseAgent):
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]):
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
@@ -645,7 +654,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
@@ -655,13 +664,15 @@ class Agent(BaseAgent):
|
||||
search: This tool is used for search
|
||||
calculator: This tool is used for math
|
||||
"""
|
||||
return "\n".join(
|
||||
description = "\n".join(
|
||||
[
|
||||
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
|
||||
return description
|
||||
|
||||
def _inject_date_to_task(self, task):
|
||||
"""Inject the current date into the task description if inject_date is enabled."""
|
||||
if self.inject_date:
|
||||
@@ -685,13 +696,13 @@ class Agent(BaseAgent):
|
||||
if not is_valid:
|
||||
raise ValueError(f"Invalid date format: {self.date_format}")
|
||||
|
||||
current_date = datetime.now().strftime(self.date_format)
|
||||
current_date: str = datetime.now().strftime(self.date_format)
|
||||
task.description += f"\n\nCurrent Date: {current_date}"
|
||||
except Exception as e:
|
||||
if hasattr(self, "_logger"):
|
||||
self._logger.log("warning", f"Failed to inject date: {e!s}")
|
||||
self._logger.log("warning", f"Failed to inject date: {str(e)}")
|
||||
else:
|
||||
print(f"Warning: Failed to inject date: {e!s}")
|
||||
print(f"Warning: Failed to inject date: {str(e)}")
|
||||
|
||||
def _validate_docker_installation(self) -> None:
|
||||
"""Check if Docker is installed and running."""
|
||||
@@ -702,15 +713,15 @@ class Agent(BaseAgent):
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["/usr/bin/docker", "info"],
|
||||
["docker", "info"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
except subprocess.CalledProcessError:
|
||||
raise RuntimeError(
|
||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
) from e
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
@@ -785,8 +796,8 @@ class Agent(BaseAgent):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
@@ -825,8 +836,8 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
|
||||
@@ -22,7 +22,6 @@ from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
@@ -360,5 +359,5 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
|
||||
pass
|
||||
|
||||
@@ -166,13 +166,3 @@ class PlusAPI:
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def mark_trace_batch_as_failed(
|
||||
self, trace_batch_id: str, error_message: str
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
|
||||
json={"status": "failed", "failure_reason": error_message},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.201.0,<1.0.0"
|
||||
"crewai[tools]>=0.193.2,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.201.0,<1.0.0",
|
||||
"crewai[tools]>=0.193.2,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.201.0"
|
||||
"crewai[tools]>=0.193.2"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -59,7 +59,6 @@ from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.task import Task
|
||||
@@ -169,7 +168,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="An Instance of the ExternalMemory to be used by the Crew",
|
||||
)
|
||||
embedder: EmbedderConfig | None = Field(
|
||||
embedder: dict | None = Field(
|
||||
default=None,
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
@@ -623,8 +622,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role),
|
||||
trained_data=result.model_dump(), # type: ignore[arg-type]
|
||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -1059,10 +1057,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name, # type: ignore[arg-type]
|
||||
task=task.description,
|
||||
agent=role,
|
||||
status="started",
|
||||
task_name=task.name, task=task.description, agent=role, status="started"
|
||||
)
|
||||
|
||||
def _update_manager_tools(
|
||||
@@ -1091,7 +1086,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
role = task.agent.role if task.agent is not None else "None"
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name, # type: ignore[arg-type]
|
||||
task_name=task.name,
|
||||
task=task.description,
|
||||
agent=role,
|
||||
status="completed",
|
||||
|
||||
@@ -200,9 +200,6 @@ class TraceBatchManager:
|
||||
if self.event_buffer:
|
||||
events_sent_to_backend_status = self._send_events_to_backend()
|
||||
if events_sent_to_backend_status == 500:
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, "Error sending events to backend"
|
||||
)
|
||||
return None
|
||||
self._finalize_backend_batch()
|
||||
|
||||
@@ -276,13 +273,10 @@ class TraceBatchManager:
|
||||
logger.error(
|
||||
f"❌ Failed to finalize trace batch: {response.status_code} - {response.text}"
|
||||
)
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, response.text
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error finalizing trace batch: {e}")
|
||||
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
|
||||
# TODO: send error to app marking as failed
|
||||
|
||||
def _cleanup_batch_data(self):
|
||||
"""Clean up batch data after successful finalization to free memory"""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
@@ -16,20 +16,20 @@ class Knowledge(BaseModel):
|
||||
Args:
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: EmbedderConfig | None = None
|
||||
embedder: dict[str, Any] | None = None
|
||||
"""
|
||||
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: EmbedderConfig | None = None
|
||||
embedder: dict[str, Any] | None = None
|
||||
collection_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: EmbedderConfig | None = None,
|
||||
embedder: dict[str, Any] | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
|
||||
@@ -8,9 +8,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.utilities.logger import Logger
|
||||
@@ -24,10 +22,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec
|
||||
| BaseEmbeddingsProvider
|
||||
| type[BaseEmbeddingsProvider]
|
||||
| None = None,
|
||||
embedder: dict[str, Any] | None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
@@ -40,7 +35,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
|
||||
if embedder:
|
||||
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
|
||||
embedding_function = get_embedding_function(embedder)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
|
||||
@@ -27,10 +27,7 @@ class EntityMemory(Memory):
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
@@ -38,11 +35,7 @@ class EntityMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
) from e
|
||||
config = (
|
||||
embedder_config.get("config")
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
|
||||
@@ -13,7 +13,6 @@ from crewai.events.types.memory_events import (
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
@@ -36,9 +35,7 @@ class ExternalMemory(Memory):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(
|
||||
crew: Any, embedder_config: dict[str, Any] | ProviderSpec | None
|
||||
) -> Storage:
|
||||
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
|
||||
@@ -162,6 +159,6 @@ class ExternalMemory(Memory):
|
||||
super().set_crew(crew)
|
||||
|
||||
if not self.storage:
|
||||
self.storage = self.create_storage(crew, self.embedder_config) # type: ignore[arg-type]
|
||||
self.storage = self.create_storage(crew, self.embedder_config)
|
||||
|
||||
return self
|
||||
|
||||
@@ -2,8 +2,6 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
@@ -14,7 +12,7 @@ class Memory(BaseModel):
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
|
||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||
embedder_config: dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
|
||||
@@ -29,10 +29,7 @@ class ShortTermMemory(Memory):
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
@@ -40,11 +37,7 @@ class ShortTermMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
) from e
|
||||
config = (
|
||||
embedder_config.get("config")
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
|
||||
@@ -7,9 +7,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.types import BaseRecord
|
||||
@@ -27,7 +25,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
@@ -51,46 +49,12 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = build_embedder(self.embedder_config)
|
||||
|
||||
try:
|
||||
_ = embedding_function(["test"])
|
||||
except Exception as e:
|
||||
provider = (
|
||||
self.embedder_config["provider"]
|
||||
if isinstance(self.embedder_config, dict)
|
||||
else self.embedder_config.__class__.__name__.replace(
|
||||
"Provider", ""
|
||||
).lower()
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||
f"Provider: {provider}\n"
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
@@ -131,26 +95,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
client.add_documents(collection_name=collection_name, documents=[document])
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
|
||||
@@ -17,7 +17,6 @@ from crewai.rag.chromadb.types import (
|
||||
ChromaDBCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.chromadb.utils import (
|
||||
_create_batch_slice,
|
||||
_extract_search_params,
|
||||
_is_async_client,
|
||||
_is_sync_client,
|
||||
@@ -53,7 +52,6 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
@@ -62,13 +60,11 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -295,7 +291,6 @@ class ChromaDBClient(BaseClient):
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
@@ -310,7 +305,6 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -321,17 +315,13 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -345,7 +335,6 @@ class ChromaDBClient(BaseClient):
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
@@ -360,7 +349,6 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -370,17 +358,13 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
await collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
|
||||
@@ -41,5 +41,4 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Utility functions for ChromaDB client implementation."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
@@ -10,6 +9,7 @@ from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
|
||||
@@ -72,15 +72,7 @@ def _prepare_documents_for_chromadb(
|
||||
if "doc_id" in doc:
|
||||
ids.append(doc["doc_id"])
|
||||
else:
|
||||
content_for_hash = doc["content"]
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
metadata_str = json.dumps(metadata, sort_keys=True)
|
||||
content_for_hash = f"{content_for_hash}|{metadata_str}"
|
||||
|
||||
content_hash = hashlib.blake2b(
|
||||
content_for_hash.encode(), digest_size=32
|
||||
).hexdigest()
|
||||
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
|
||||
ids.append(content_hash)
|
||||
|
||||
texts.append(doc["content"])
|
||||
@@ -96,32 +88,6 @@ def _prepare_documents_for_chromadb(
|
||||
return PreparedDocuments(ids, texts, metadatas)
|
||||
|
||||
|
||||
def _create_batch_slice(
|
||||
prepared: PreparedDocuments, start_index: int, batch_size: int
|
||||
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
|
||||
"""Create a batch slice from prepared documents.
|
||||
|
||||
Args:
|
||||
prepared: PreparedDocuments containing ids, texts, and metadatas.
|
||||
start_index: Starting index for the batch.
|
||||
batch_size: Size of the batch.
|
||||
|
||||
Returns:
|
||||
Tuple of (batch_ids, batch_texts, batch_metadatas).
|
||||
"""
|
||||
batch_end = min(start_index + batch_size, len(prepared.ids))
|
||||
batch_ids = prepared.ids[start_index:batch_end]
|
||||
batch_texts = prepared.texts[start_index:batch_end]
|
||||
batch_metadatas = (
|
||||
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
|
||||
)
|
||||
|
||||
if batch_metadatas and not any(m for m in batch_metadatas):
|
||||
batch_metadatas = None
|
||||
|
||||
return batch_ids, batch_texts, batch_metadatas
|
||||
|
||||
|
||||
def _extract_search_params(
|
||||
kwargs: ChromaDBCollectionSearchParams,
|
||||
) -> ExtractedSearchParams:
|
||||
@@ -141,12 +107,9 @@ def _extract_search_params(
|
||||
score_threshold=kwargs.get("score_threshold"),
|
||||
where=kwargs.get("where"),
|
||||
where_document=kwargs.get("where_document"),
|
||||
include=cast(
|
||||
Include,
|
||||
kwargs.get(
|
||||
"include",
|
||||
["metadatas", "documents", "distances"],
|
||||
),
|
||||
include=kwargs.get(
|
||||
"include",
|
||||
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
|
||||
),
|
||||
)
|
||||
|
||||
@@ -195,7 +158,7 @@ def _convert_chromadb_results_to_search_results(
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = list(include) if include else []
|
||||
include_strings = [item.value for item in include] if include else []
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
|
||||
@@ -16,4 +16,3 @@ class BaseRagConfig:
|
||||
embedding_function: Any | None = field(default=None)
|
||||
limit: int = field(default=5)
|
||||
score_threshold: float = field(default=0.6)
|
||||
batch_size: int = field(default=100)
|
||||
|
||||
@@ -29,7 +29,7 @@ class BaseCollectionParams(TypedDict):
|
||||
]
|
||||
|
||||
|
||||
class BaseCollectionAddParams(BaseCollectionParams, total=False):
|
||||
class BaseCollectionAddParams(BaseCollectionParams):
|
||||
"""Parameters for adding documents to a collection.
|
||||
|
||||
Extends BaseCollectionParams with document-specific fields.
|
||||
@@ -37,11 +37,9 @@ class BaseCollectionAddParams(BaseCollectionParams, total=False):
|
||||
Attributes:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dictionaries containing document data.
|
||||
batch_size: Optional batch size for processing documents to avoid token limits.
|
||||
"""
|
||||
|
||||
documents: Required[list[BaseRecord]]
|
||||
batch_size: int
|
||||
documents: list[BaseRecord]
|
||||
|
||||
|
||||
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
"""Base embeddings callable utilities for RAG systems."""
|
||||
|
||||
from typing import Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from crewai.rag.core.types import (
|
||||
Embeddable,
|
||||
Embedding,
|
||||
Embeddings,
|
||||
PyEmbedding,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
D = TypeVar("D", bound=Embeddable, contravariant=True)
|
||||
|
||||
|
||||
def normalize_embeddings(
|
||||
target: Embedding | list[Embedding] | PyEmbedding | list[PyEmbedding],
|
||||
) -> Embeddings | None:
|
||||
"""Normalize various embedding formats to a standard list of numpy arrays.
|
||||
|
||||
Args:
|
||||
target: Input embeddings in various formats (list of floats, list of lists,
|
||||
numpy array, or list of numpy arrays).
|
||||
|
||||
Returns:
|
||||
Normalized embeddings as a list of numpy arrays, or None if input is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings are empty or in an unsupported format.
|
||||
"""
|
||||
if isinstance(target, np.ndarray):
|
||||
if target.ndim == 1:
|
||||
return [target.astype(np.float32)]
|
||||
if target.ndim == 2:
|
||||
return [row.astype(np.float32) for row in target]
|
||||
raise ValueError(f"Unsupported numpy array shape: {target.shape}")
|
||||
|
||||
first = target[0]
|
||||
if isinstance(first, (int, float)) and not isinstance(first, bool):
|
||||
return [np.array(target, dtype=np.float32)]
|
||||
if isinstance(first, list):
|
||||
return [np.array(emb, dtype=np.float32) for emb in target]
|
||||
if isinstance(first, np.ndarray):
|
||||
return [emb.astype(np.float32) for emb in target] # type: ignore[union-attr]
|
||||
|
||||
raise ValueError(f"Unsupported embeddings format: {type(first)}")
|
||||
|
||||
|
||||
def maybe_cast_one_to_many(target: T | list[T] | None) -> list[T] | None:
|
||||
"""Cast a single item to a list if needed.
|
||||
|
||||
Args:
|
||||
target: A single item or list of items.
|
||||
|
||||
Returns:
|
||||
A list of items or None if input is None.
|
||||
"""
|
||||
if target is None:
|
||||
return None
|
||||
return target if isinstance(target, list) else [target]
|
||||
|
||||
|
||||
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
|
||||
"""Validate embeddings format and content.
|
||||
|
||||
Args:
|
||||
embeddings: List of numpy arrays to validate.
|
||||
|
||||
Returns:
|
||||
Validated embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings format or content is invalid.
|
||||
"""
|
||||
if not isinstance(embeddings, list):
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
|
||||
)
|
||||
if len(embeddings) == 0:
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
|
||||
)
|
||||
if not all(isinstance(e, np.ndarray) for e in embeddings):
|
||||
raise ValueError(
|
||||
"Expected each embedding in the embeddings to be a numpy array"
|
||||
)
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if embedding.ndim == 0:
|
||||
raise ValueError(
|
||||
f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}"
|
||||
)
|
||||
if embedding.size == 0:
|
||||
raise ValueError(
|
||||
f"Expected each embedding to be a 1-dimensional numpy array with at least 1 value. "
|
||||
f"Got an array with no values at position {i}"
|
||||
)
|
||||
if not all(
|
||||
isinstance(value, (np.integer, float, np.floating))
|
||||
and not isinstance(value, bool)
|
||||
for value in embedding
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected embedding to contain numeric values, got non-numeric values at position {i}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EmbeddingFunction(Protocol[D]):
|
||||
"""Protocol for embedding functions.
|
||||
|
||||
Embedding functions convert input data (documents or images) into vector embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: D) -> Embeddings:
|
||||
"""Convert input data to embeddings.
|
||||
|
||||
Args:
|
||||
input: Input data to embed (documents or images).
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
...
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
"""Wrap __call__ method to normalize and validate embeddings."""
|
||||
super().__init_subclass__()
|
||||
original_call = cls.__call__
|
||||
|
||||
def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
|
||||
result = original_call(self, input)
|
||||
if result is None:
|
||||
raise ValueError("Embedding function returned None")
|
||||
normalized = normalize_embeddings(result)
|
||||
if normalized is None:
|
||||
raise ValueError("Normalization returned None for non-None input")
|
||||
return validate_embeddings(normalized)
|
||||
|
||||
cls.__call__ = wrapped_call # type: ignore[method-assign]
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
|
||||
|
||||
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
|
||||
"""Abstract base class for embedding providers.
|
||||
|
||||
This class provides a common interface for dynamically loading and building
|
||||
embedding functions from various providers.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
|
||||
embedding_callable: type[T] = Field(
|
||||
..., description="The embedding function class to use"
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Core type definitions for RAG systems."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy import floating, integer, number
|
||||
from numpy.typing import NDArray
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
PyEmbedding = Sequence[float] | Sequence[int]
|
||||
PyEmbeddings = list[PyEmbedding]
|
||||
Embedding = NDArray[np.int32 | np.float32]
|
||||
Embeddings = list[Embedding]
|
||||
|
||||
Documents = list[str]
|
||||
Images = list[np.ndarray]
|
||||
Embeddable = Documents | Images
|
||||
|
||||
ScalarType = TypeVar("ScalarType", bound=np.generic)
|
||||
IntegerType = TypeVar("IntegerType", bound=integer)
|
||||
FloatingType = TypeVar("FloatingType", bound=floating)
|
||||
NumberType = TypeVar("NumberType", bound=number)
|
||||
|
||||
DType32 = TypeVar("DType32", np.int32, np.float32)
|
||||
DType64 = TypeVar("DType64", np.int64, np.float64)
|
||||
DTypeCommon = TypeVar("DTypeCommon", np.int32, np.int64, np.float32, np.float64)
|
||||
245
src/crewai/rag/embeddings/configurator.py
Normal file
245
src/crewai/rag/embeddings/configurator.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import os
|
||||
from typing import Any, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
|
||||
|
||||
class EmbeddingConfigurator:
|
||||
def __init__(self):
|
||||
self.embedding_functions = {
|
||||
"openai": self._configure_openai,
|
||||
"azure": self._configure_azure,
|
||||
"ollama": self._configure_ollama,
|
||||
"vertexai": self._configure_vertexai,
|
||||
"google": self._configure_google,
|
||||
"cohere": self._configure_cohere,
|
||||
"voyageai": self._configure_voyageai,
|
||||
"bedrock": self._configure_bedrock,
|
||||
"huggingface": self._configure_huggingface,
|
||||
"watson": self._configure_watson,
|
||||
"custom": self._configure_custom,
|
||||
}
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
if embedder_config is None:
|
||||
return self._create_default_embedding_function()
|
||||
|
||||
provider = embedder_config.get("provider")
|
||||
config = embedder_config.get("config", {})
|
||||
model_name = config.get("model") if provider != "custom" else None
|
||||
|
||||
if provider not in self.embedding_functions:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||
)
|
||||
|
||||
try:
|
||||
embedding_function = self.embedding_functions[provider]
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
) from e
|
||||
|
||||
return (
|
||||
embedding_function(config)
|
||||
if provider == "custom"
|
||||
else embedding_function(config, model_name)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_openai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model_name,
|
||||
api_base=config.get("api_base", None),
|
||||
api_type=config.get("api_type", None),
|
||||
api_version=config.get("api_version", None),
|
||||
default_headers=config.get("default_headers", None),
|
||||
dimensions=config.get("dimensions", None),
|
||||
deployment_id=config.get("deployment_id", None),
|
||||
organization_id=config.get("organization_id", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_azure(config, model_name):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key"),
|
||||
api_base=config.get("api_base"),
|
||||
api_type=config.get("api_type", "azure"),
|
||||
api_version=config.get("api_version"),
|
||||
model_name=model_name,
|
||||
default_headers=config.get("default_headers"),
|
||||
dimensions=config.get("dimensions"),
|
||||
deployment_id=config.get("deployment_id"),
|
||||
organization_id=config.get("organization_id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_ollama(config, model_name):
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OllamaEmbeddingFunction(
|
||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_vertexai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_google(config, model_name):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
|
||||
return GoogleGenerativeAiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
task_type=config.get("task_type"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_cohere(config, model_name):
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
|
||||
return CohereEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_voyageai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return VoyageAIEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_bedrock(config, model_name):
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
|
||||
# Allow custom model_name override with backwards compatibility
|
||||
kwargs = {"session": config.get("session")}
|
||||
if model_name is not None:
|
||||
kwargs["model_name"] = model_name
|
||||
return AmazonBedrockEmbeddingFunction(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _configure_huggingface(config, model_name):
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
|
||||
return HuggingFaceEmbeddingServer(
|
||||
url=config.get("api_url"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_watson(config, model_name):
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
) from e
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embed_params = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(
|
||||
model_id=config.get("model"),
|
||||
params=embed_params,
|
||||
credentials=Credentials(
|
||||
api_key=config.get("api_key"), url=config.get("api_url")
|
||||
),
|
||||
project_id=config.get("project_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print("Error during Watson embedding:", e)
|
||||
raise e
|
||||
|
||||
return WatsonEmbeddingFunction()
|
||||
|
||||
@staticmethod
|
||||
def _configure_custom(config):
|
||||
custom_embedder = config.get("embedder")
|
||||
if isinstance(custom_embedder, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(custom_embedder)
|
||||
return custom_embedder
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {e!s}") from e
|
||||
elif callable(custom_embedder):
|
||||
try:
|
||||
instance = custom_embedder()
|
||||
if isinstance(instance, EmbeddingFunction):
|
||||
validate_embedding_function(instance)
|
||||
return instance
|
||||
raise ValueError(
|
||||
"Custom embedder does not create an EmbeddingFunction instance"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
|
||||
else:
|
||||
raise ValueError(
|
||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||
)
|
||||
@@ -1,363 +1,204 @@
|
||||
"""Factory functions for creating embedding providers and functions."""
|
||||
"""Minimal embedding function factory for CrewAI."""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import Any, Callable
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar, overload
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.utilities.import_utils import import_and_validate_definition
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
|
||||
PROVIDER_PATHS = {
|
||||
"azure": "crewai.rag.embeddings.providers.microsoft.azure.AzureProvider",
|
||||
"amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider",
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
"jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider",
|
||||
"ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider",
|
||||
"onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider",
|
||||
"openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider",
|
||||
"openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider",
|
||||
"roboflow": "crewai.rag.embeddings.providers.roboflow.roboflow_provider.RoboflowProvider",
|
||||
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
|
||||
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
|
||||
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
|
||||
"watson": "crewai.rag.embeddings.providers.ibm.watson.WatsonProvider",
|
||||
}
|
||||
|
||||
|
||||
def build_embedder_from_provider(provider: BaseEmbeddingsProvider[T]) -> T:
|
||||
"""Build an embedding function instance from a provider.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider configuration.
|
||||
|
||||
Returns:
|
||||
An instance of the specified embedding function type.
|
||||
"""
|
||||
return provider.embedding_callable(
|
||||
**provider.model_dump(exclude={"embedding_callable"})
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: BedrockProviderSpec,
|
||||
) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: HuggingFaceProviderSpec,
|
||||
) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VoyageAIProviderSpec,
|
||||
) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: InstructorProviderSpec,
|
||||
) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: RoboflowProviderSpec,
|
||||
) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: OpenCLIPProviderSpec,
|
||||
) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: Text2VecProviderSpec,
|
||||
) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder_from_dict(spec):
|
||||
"""Build an embedding function instance from a dictionary specification.
|
||||
|
||||
Args:
|
||||
spec: A dictionary with 'provider' and 'config' keys.
|
||||
Example: {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "sk-...",
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate embedding function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not recognized.
|
||||
"""
|
||||
provider_name = spec["provider"]
|
||||
if not provider_name:
|
||||
raise ValueError("Missing 'provider' key in specification")
|
||||
|
||||
if provider_name not in PROVIDER_PATHS:
|
||||
raise ValueError(
|
||||
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
|
||||
)
|
||||
|
||||
provider_path = PROVIDER_PATHS[provider_name]
|
||||
def _create_watson_embedding_function(**config_dict) -> EmbeddingFunction:
|
||||
"""Create Watson embedding function with proper error handling."""
|
||||
try:
|
||||
provider_class = import_and_validate_definition(provider_path)
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
) from e
|
||||
|
||||
provider_config = spec.get("config", {})
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||
def __init__(self, **kwargs):
|
||||
self.config = kwargs
|
||||
|
||||
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
||||
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
||||
def __call__(self, input):
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
provider = provider_class(**provider_config)
|
||||
return build_embedder_from_provider(provider)
|
||||
embed_params = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(
|
||||
model_id=self.config.get("model_name") or self.config.get("model"),
|
||||
params=embed_params,
|
||||
credentials=Credentials(
|
||||
api_key=self.config.get("api_key"),
|
||||
url=self.config.get("api_url") or self.config.get("url")
|
||||
),
|
||||
project_id=self.config.get("project_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
return embedding.embed_documents(input)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during Watson embedding: {e}") from e
|
||||
|
||||
return WatsonEmbeddingFunction(**config_dict)
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BaseEmbeddingsProvider[T]) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BedrockProviderSpec) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: HuggingFaceProviderSpec) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: InstructorProviderSpec) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: RoboflowProviderSpec) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenCLIPProviderSpec) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder(spec):
|
||||
"""Build an embedding function from either a provider spec or a provider instance.
|
||||
def get_embedding_function(
|
||||
config: EmbeddingOptions | dict | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Get embedding function - delegates to ChromaDB.
|
||||
|
||||
Args:
|
||||
spec: Either a provider specification dictionary or a provider instance.
|
||||
config: Optional configuration - either an EmbeddingOptions object or a dict with:
|
||||
- provider: The embedding provider to use (default: "openai")
|
||||
- Any other provider-specific parameters
|
||||
|
||||
Returns:
|
||||
An embedding function instance. If a typed provider is passed, returns
|
||||
the specific embedding function type.
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
|
||||
Supported providers:
|
||||
- openai: OpenAI embeddings
|
||||
- cohere: Cohere embeddings
|
||||
- ollama: Ollama local embeddings
|
||||
- huggingface: HuggingFace embeddings
|
||||
- sentence-transformer: Local sentence transformers
|
||||
- instructor: Instructor embeddings for specialized tasks
|
||||
- google-palm: Google PaLM embeddings
|
||||
- google-generativeai: Google Generative AI embeddings
|
||||
- google-vertex: Google Vertex AI embeddings
|
||||
- amazon-bedrock: AWS Bedrock embeddings
|
||||
- jina: Jina AI embeddings
|
||||
- roboflow: Roboflow embeddings for vision tasks
|
||||
- openclip: OpenCLIP embeddings for multimodal tasks
|
||||
- text2vec: Text2Vec embeddings
|
||||
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
||||
- watson: IBM Watson embeddings
|
||||
|
||||
Examples:
|
||||
# From dictionary specification
|
||||
embedder = build_embedder({
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "sk-..."}
|
||||
})
|
||||
# Use default OpenAI embedding
|
||||
>>> embedder = get_embedding_function()
|
||||
|
||||
# From provider instance
|
||||
provider = OpenAIProvider(api_key="sk-...")
|
||||
embedder = build_embedder(provider)
|
||||
# Use Cohere with dict
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "cohere",
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... })
|
||||
|
||||
# Use with EmbeddingOptions
|
||||
>>> embedder = get_embedding_function(
|
||||
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
||||
... )
|
||||
|
||||
# Use local sentence transformers (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "sentence-transformer",
|
||||
... "model_name": "all-MiniLM-L6-v2"
|
||||
... })
|
||||
|
||||
# Use Ollama for local embeddings
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "ollama",
|
||||
... "model_name": "nomic-embed-text"
|
||||
... })
|
||||
|
||||
# Use ONNX (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "onnx"
|
||||
... })
|
||||
|
||||
# Use Watson embeddings
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "watson",
|
||||
... "api_key": "your-watson-api-key",
|
||||
... "api_url": "your-watson-url",
|
||||
... "project_id": "your-project-id",
|
||||
... "model_name": "ibm/slate-125m-english-rtrvr"
|
||||
... })
|
||||
"""
|
||||
if isinstance(spec, BaseEmbeddingsProvider):
|
||||
return build_embedder_from_provider(spec)
|
||||
return build_embedder_from_dict(spec)
|
||||
if config is None:
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
# Handle EmbeddingOptions object
|
||||
if isinstance(config, EmbeddingOptions):
|
||||
config_dict = config.model_dump(exclude_none=True)
|
||||
else:
|
||||
config_dict = config.copy()
|
||||
|
||||
# Backward compatibility alias
|
||||
get_embedding_function = build_embedder
|
||||
provider = config_dict.pop("provider", "openai")
|
||||
|
||||
embedding_functions: dict[str, Callable[..., EmbeddingFunction]] = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
"watson": _create_watson_embedding_function,
|
||||
}
|
||||
|
||||
if provider not in embedding_functions:
|
||||
raise ValueError(
|
||||
f"Unsupported provider: {provider}. "
|
||||
f"Available providers: {list(embedding_functions.keys())}"
|
||||
)
|
||||
return embedding_functions[provider](**config_dict)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Embedding provider implementations."""
|
||||
@@ -1,13 +0,0 @@
|
||||
"""AWS embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import (
|
||||
BedrockProviderConfig,
|
||||
BedrockProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BedrockProvider",
|
||||
"BedrockProviderConfig",
|
||||
"BedrockProviderSpec",
|
||||
]
|
||||
@@ -1,53 +0,0 @@
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
def create_aws_session() -> Any:
|
||||
"""Create an AWS session for Bedrock.
|
||||
|
||||
Returns:
|
||||
boto3.Session: AWS session object
|
||||
|
||||
Raises:
|
||||
ImportError: If boto3 is not installed
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. "
|
||||
"Install it with: uv add boto3"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to create AWS session for amazon-bedrock. "
|
||||
f"Ensure AWS credentials are configured. Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
embedding_callable: type[AmazonBedrockEmbeddingFunction] = Field(
|
||||
default=AmazonBedrockEmbeddingFunction,
|
||||
description="Amazon Bedrock embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="BEDROCK_MODEL_NAME",
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for AWS embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class BedrockProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Bedrock provider."""
|
||||
|
||||
model_name: Annotated[str, "amazon.titan-embed-text-v1"]
|
||||
session: Any
|
||||
|
||||
|
||||
class BedrockProviderSpec(TypedDict, total=False):
|
||||
"""Bedrock provider specification."""
|
||||
|
||||
provider: Required[Literal["amazon-bedrock"]]
|
||||
config: BedrockProviderConfig
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Cohere embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
|
||||
from crewai.rag.embeddings.providers.cohere.types import (
|
||||
CohereProviderConfig,
|
||||
CohereProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CohereProvider",
|
||||
"CohereProviderConfig",
|
||||
"CohereProviderSpec",
|
||||
]
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
embedding_callable: type[CohereEmbeddingFunction] = Field(
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="COHERE_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="COHERE_MODEL_NAME",
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for Cohere embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CohereProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Cohere provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "large"]
|
||||
|
||||
|
||||
class CohereProviderSpec(TypedDict, total=False):
|
||||
"""Cohere provider specification."""
|
||||
|
||||
provider: Required[Literal["cohere"]]
|
||||
config: CohereProviderConfig
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Custom embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.custom.custom_provider import CustomProvider
|
||||
from crewai.rag.embeddings.providers.custom.types import (
|
||||
CustomProviderConfig,
|
||||
CustomProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CustomProvider",
|
||||
"CustomProviderConfig",
|
||||
"CustomProviderSpec",
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.custom.embedding_callable import (
|
||||
CustomEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class CustomProvider(BaseEmbeddingsProvider[CustomEmbeddingFunction]):
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
embedding_callable: type[CustomEmbeddingFunction] = Field(
|
||||
..., description="Custom embedding function class"
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow")
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Custom embedding function base implementation."""
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
|
||||
|
||||
class CustomEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Base class for custom embedding functions.
|
||||
|
||||
This provides a concrete implementation that can be subclassed for custom embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Convert input documents to embeddings.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement __call__ method")
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for custom embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.api.types import EmbeddingFunction
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CustomProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Custom provider."""
|
||||
|
||||
embedding_callable: type[EmbeddingFunction]
|
||||
|
||||
|
||||
class CustomProviderSpec(TypedDict, total=False):
|
||||
"""Custom provider specification."""
|
||||
|
||||
provider: Required[Literal["custom"]]
|
||||
config: CustomProviderConfig
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Google embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||
GenerativeAiProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderConfig,
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderConfig,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.vertex import (
|
||||
VertexAIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GenerativeAiProvider",
|
||||
"GenerativeAiProviderConfig",
|
||||
"GenerativeAiProviderSpec",
|
||||
"VertexAIProvider",
|
||||
"VertexAIProviderConfig",
|
||||
"VertexAIProviderSpec",
|
||||
]
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFunction]):
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleGenerativeAiEmbeddingFunction] = Field(
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="GOOGLE_API_KEY"
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Type definitions for Google embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
class GenerativeAiProviderSpec(TypedDict):
|
||||
"""Google Generative AI provider specification."""
|
||||
|
||||
provider: Literal["google-generativeai"]
|
||||
config: GenerativeAiProviderConfig
|
||||
|
||||
|
||||
class VertexAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Vertex AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "textembedding-gecko"]
|
||||
project_id: Annotated[str, "cloud-large-language-models"]
|
||||
region: Annotated[str, "us-central1"]
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
"""Vertex AI provider specification."""
|
||||
|
||||
provider: Required[Literal["google-vertex"]]
|
||||
config: VertexAIProviderConfig
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||
default=GoogleVertexEmbeddingFunction,
|
||||
description="Vertex AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="GOOGLE_VERTEX_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="GOOGLE_CLOUD_API_KEY"
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="GOOGLE_CLOUD_PROJECT",
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="GOOGLE_CLOUD_REGION",
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
"""HuggingFace embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
|
||||
HuggingFaceProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderConfig,
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceProvider",
|
||||
"HuggingFaceProviderConfig",
|
||||
"HuggingFaceProviderSpec",
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="HUGGINGFACE_URL"
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider."""
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||
"""HuggingFace provider specification."""
|
||||
|
||||
provider: Required[Literal["huggingface"]]
|
||||
config: HuggingFaceProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""IBM embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderConfig,
|
||||
WatsonProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.watson import (
|
||||
WatsonProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"WatsonProvider",
|
||||
"WatsonProviderConfig",
|
||||
"WatsonProviderSpec",
|
||||
]
|
||||
@@ -1,159 +0,0 @@
|
||||
"""IBM Watson embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderConfig
|
||||
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for IBM Watson models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[WatsonProviderConfig]) -> None:
|
||||
"""Initialize Watson embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for Watson Embeddings and Credentials.
|
||||
"""
|
||||
self._config = kwargs
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai import (
|
||||
Credentials, # type: ignore[import-not-found, import-untyped]
|
||||
)
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ibm-watsonx-ai is required for watson embeddings. "
|
||||
"Install it with: uv add ibm-watsonx-ai"
|
||||
) from e
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embeddings_config: dict = {
|
||||
"model_id": self._config["model_id"],
|
||||
}
|
||||
if "params" in self._config and self._config["params"] is not None:
|
||||
embeddings_config["params"] = self._config["params"]
|
||||
if "project_id" in self._config and self._config["project_id"] is not None:
|
||||
embeddings_config["project_id"] = self._config["project_id"]
|
||||
if "space_id" in self._config and self._config["space_id"] is not None:
|
||||
embeddings_config["space_id"] = self._config["space_id"]
|
||||
if "api_client" in self._config and self._config["api_client"] is not None:
|
||||
embeddings_config["api_client"] = self._config["api_client"]
|
||||
if "verify" in self._config and self._config["verify"] is not None:
|
||||
embeddings_config["verify"] = self._config["verify"]
|
||||
if "persistent_connection" in self._config:
|
||||
embeddings_config["persistent_connection"] = self._config[
|
||||
"persistent_connection"
|
||||
]
|
||||
if "batch_size" in self._config:
|
||||
embeddings_config["batch_size"] = self._config["batch_size"]
|
||||
if "concurrency_limit" in self._config:
|
||||
embeddings_config["concurrency_limit"] = self._config["concurrency_limit"]
|
||||
if "max_retries" in self._config and self._config["max_retries"] is not None:
|
||||
embeddings_config["max_retries"] = self._config["max_retries"]
|
||||
if "delay_time" in self._config and self._config["delay_time"] is not None:
|
||||
embeddings_config["delay_time"] = self._config["delay_time"]
|
||||
if (
|
||||
"retry_status_codes" in self._config
|
||||
and self._config["retry_status_codes"] is not None
|
||||
):
|
||||
embeddings_config["retry_status_codes"] = self._config["retry_status_codes"]
|
||||
|
||||
if "credentials" in self._config and self._config["credentials"] is not None:
|
||||
embeddings_config["credentials"] = self._config["credentials"]
|
||||
else:
|
||||
cred_config: dict = {}
|
||||
if "url" in self._config and self._config["url"] is not None:
|
||||
cred_config["url"] = self._config["url"]
|
||||
if "api_key" in self._config and self._config["api_key"] is not None:
|
||||
cred_config["api_key"] = self._config["api_key"]
|
||||
if "name" in self._config and self._config["name"] is not None:
|
||||
cred_config["name"] = self._config["name"]
|
||||
if (
|
||||
"iam_serviceid_crn" in self._config
|
||||
and self._config["iam_serviceid_crn"] is not None
|
||||
):
|
||||
cred_config["iam_serviceid_crn"] = self._config["iam_serviceid_crn"]
|
||||
if (
|
||||
"trusted_profile_id" in self._config
|
||||
and self._config["trusted_profile_id"] is not None
|
||||
):
|
||||
cred_config["trusted_profile_id"] = self._config["trusted_profile_id"]
|
||||
if "token" in self._config and self._config["token"] is not None:
|
||||
cred_config["token"] = self._config["token"]
|
||||
if (
|
||||
"projects_token" in self._config
|
||||
and self._config["projects_token"] is not None
|
||||
):
|
||||
cred_config["projects_token"] = self._config["projects_token"]
|
||||
if "username" in self._config and self._config["username"] is not None:
|
||||
cred_config["username"] = self._config["username"]
|
||||
if "password" in self._config and self._config["password"] is not None:
|
||||
cred_config["password"] = self._config["password"]
|
||||
if (
|
||||
"instance_id" in self._config
|
||||
and self._config["instance_id"] is not None
|
||||
):
|
||||
cred_config["instance_id"] = self._config["instance_id"]
|
||||
if "version" in self._config and self._config["version"] is not None:
|
||||
cred_config["version"] = self._config["version"]
|
||||
if (
|
||||
"bedrock_url" in self._config
|
||||
and self._config["bedrock_url"] is not None
|
||||
):
|
||||
cred_config["bedrock_url"] = self._config["bedrock_url"]
|
||||
if (
|
||||
"platform_url" in self._config
|
||||
and self._config["platform_url"] is not None
|
||||
):
|
||||
cred_config["platform_url"] = self._config["platform_url"]
|
||||
if "proxies" in self._config and self._config["proxies"] is not None:
|
||||
cred_config["proxies"] = self._config["proxies"]
|
||||
if (
|
||||
"verify" not in embeddings_config
|
||||
and "verify" in self._config
|
||||
and self._config["verify"] is not None
|
||||
):
|
||||
cred_config["verify"] = self._config["verify"]
|
||||
|
||||
if cred_config:
|
||||
embeddings_config["credentials"] = Credentials(**cred_config)
|
||||
|
||||
if "params" not in embeddings_config:
|
||||
embeddings_config["params"] = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(**embeddings_config)
|
||||
|
||||
try:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print(f"Error during Watson embedding: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name identifier for this embedding function."""
|
||||
return "watson"
|
||||
@@ -1,44 +0,0 @@
|
||||
"""Type definitions for IBM Watson embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class WatsonProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Watson provider."""
|
||||
|
||||
model_id: str
|
||||
url: str
|
||||
params: dict[str, str | dict[str, str]]
|
||||
credentials: Any
|
||||
project_id: str
|
||||
space_id: str
|
||||
api_client: Any
|
||||
verify: bool | str
|
||||
persistent_connection: Annotated[bool, True]
|
||||
batch_size: Annotated[int, 100]
|
||||
concurrency_limit: Annotated[int, 10]
|
||||
max_retries: int
|
||||
delay_time: float
|
||||
retry_status_codes: list[int]
|
||||
api_key: str
|
||||
name: str
|
||||
iam_serviceid_crn: str
|
||||
trusted_profile_id: str
|
||||
token: str
|
||||
projects_token: str
|
||||
username: str
|
||||
password: str
|
||||
instance_id: str
|
||||
version: str
|
||||
bedrock_url: str
|
||||
platform_url: str
|
||||
proxies: dict
|
||||
|
||||
|
||||
class WatsonProviderSpec(TypedDict, total=False):
|
||||
"""Watson provider specification."""
|
||||
|
||||
provider: Required[Literal["watson"]]
|
||||
config: WatsonProviderConfig
|
||||
@@ -1,122 +0,0 @@
|
||||
"""IBM Watson embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
"""IBM Watson embeddings provider.
|
||||
|
||||
Note: Requires custom implementation as Watson uses a different interface.
|
||||
"""
|
||||
|
||||
embedding_callable: type[WatsonEmbeddingFunction] = Field(
|
||||
default=WatsonEmbeddingFunction, description="Watson embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="Watson model ID", validation_alias="WATSON_MODEL_ID"
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
)
|
||||
credentials: Any | None = Field(default=None, description="Watson credentials")
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="Watson project ID",
|
||||
validation_alias="WATSON_PROJECT_ID",
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
|
||||
)
|
||||
api_client: Any | None = Field(default=None, description="Watson API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias="WATSON_PERSISTENT_CONNECTION",
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias="WATSON_BATCH_SIZE",
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias="WATSON_CONCURRENCY_LIMIT",
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias="WATSON_MAX_RETRIES",
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias="WATSON_DELAY_TIME",
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(description="Watson API URL", validation_alias="WATSON_URL")
|
||||
api_key: str = Field(
|
||||
description="Watson API key", validation_alias="WATSON_API_KEY"
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None, description="Service name", validation_alias="WATSON_NAME"
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias="WATSON_IAM_SERVICEID_CRN",
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias="WATSON_TRUSTED_PROFILE_ID",
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None, description="Bearer token", validation_alias="WATSON_TOKEN"
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias="WATSON_PROJECTS_TOKEN",
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None, description="Username", validation_alias="WATSON_USERNAME"
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None, description="Password", validation_alias="WATSON_PASSWORD"
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias="WATSON_INSTANCE_ID",
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None, description="API version", validation_alias="WATSON_VERSION"
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None, description="Bedrock URL", validation_alias="WATSON_BEDROCK_URL"
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None, description="Platform URL", validation_alias="WATSON_PLATFORM_URL"
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_space_or_project(self) -> Self:
|
||||
"""Validate that either space_id or project_id is provided."""
|
||||
if not self.space_id and not self.project_id:
|
||||
raise ValueError("One of 'space_id' or 'project_id' must be provided")
|
||||
return self
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Instructor embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.instructor.instructor_provider import (
|
||||
InstructorProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import (
|
||||
InstructorProviderConfig,
|
||||
InstructorProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"InstructorProvider",
|
||||
"InstructorProviderConfig",
|
||||
"InstructorProviderSpec",
|
||||
]
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Instructor embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
"""Instructor embeddings provider."""
|
||||
|
||||
embedding_callable: type[InstructorEmbeddingFunction] = Field(
|
||||
default=InstructorEmbeddingFunction,
|
||||
description="Instructor embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias="INSTRUCTOR_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="INSTRUCTOR_DEVICE",
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias="INSTRUCTOR_INSTRUCTION",
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Type definitions for Instructor embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class InstructorProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Instructor provider."""
|
||||
|
||||
model_name: Annotated[str, "hkunlp/instructor-base"]
|
||||
device: Annotated[str, "cpu"]
|
||||
instruction: str
|
||||
|
||||
|
||||
class InstructorProviderSpec(TypedDict, total=False):
|
||||
"""Instructor provider specification."""
|
||||
|
||||
provider: Required[Literal["instructor"]]
|
||||
config: InstructorProviderConfig
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Jina embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
|
||||
from crewai.rag.embeddings.providers.jina.types import (
|
||||
JinaProviderConfig,
|
||||
JinaProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JinaProvider",
|
||||
"JinaProviderConfig",
|
||||
"JinaProviderSpec",
|
||||
]
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Jina embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
"""Jina embeddings provider."""
|
||||
|
||||
embedding_callable: type[JinaEmbeddingFunction] = Field(
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(description="Jina API key", validation_alias="JINA_API_KEY")
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="JINA_MODEL_NAME",
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for Jina embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class JinaProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Jina provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
|
||||
|
||||
|
||||
class JinaProviderSpec(TypedDict, total=False):
|
||||
"""Jina provider specification."""
|
||||
|
||||
provider: Required[Literal["jina"]]
|
||||
config: JinaProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Microsoft embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.microsoft.azure import (
|
||||
AzureProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.microsoft.types import (
|
||||
AzureProviderConfig,
|
||||
AzureProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureProvider",
|
||||
"AzureProviderConfig",
|
||||
"AzureProviderSpec",
|
||||
]
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Azure OpenAI embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
"""Azure OpenAI embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(description="Azure API key", validation_alias="OPENAI_API_KEY")
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias="OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias="OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version",
|
||||
validation_alias="OPENAI_API_VERSION",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="OPENAI_MODEL_NAME",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
)
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias="OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Type definitions for Microsoft Azure embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class AzureProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Azure provider."""
|
||||
|
||||
api_key: str
|
||||
api_base: str
|
||||
api_type: Annotated[str, "azure"]
|
||||
api_version: str
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class AzureProviderSpec(TypedDict, total=False):
|
||||
"""Azure provider specification."""
|
||||
|
||||
provider: Required[Literal["azure"]]
|
||||
config: AzureProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Ollama embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ollama.ollama_provider import (
|
||||
OllamaProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ollama.types import (
|
||||
OllamaProviderConfig,
|
||||
OllamaProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OllamaProvider",
|
||||
"OllamaProviderConfig",
|
||||
"OllamaProviderSpec",
|
||||
]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Ollama embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
"""Ollama embeddings provider."""
|
||||
|
||||
embedding_callable: type[OllamaEmbeddingFunction] = Field(
|
||||
default=OllamaEmbeddingFunction, description="Ollama embedding function class"
|
||||
)
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias="OLLAMA_URL",
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="OLLAMA_MODEL_NAME",
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for Ollama embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OllamaProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Ollama provider."""
|
||||
|
||||
url: Annotated[str, "http://localhost:11434/api/embeddings"]
|
||||
model_name: str
|
||||
|
||||
|
||||
class OllamaProviderSpec(TypedDict, total=False):
|
||||
"""Ollama provider specification."""
|
||||
|
||||
provider: Required[Literal["ollama"]]
|
||||
config: OllamaProviderConfig
|
||||
@@ -1,13 +0,0 @@
|
||||
"""ONNX embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.onnx.onnx_provider import ONNXProvider
|
||||
from crewai.rag.embeddings.providers.onnx.types import (
|
||||
ONNXProviderConfig,
|
||||
ONNXProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ONNXProvider",
|
||||
"ONNXProviderConfig",
|
||||
"ONNXProviderSpec",
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
embedding_callable: type[ONNXMiniLM_L6_V2] = Field(
|
||||
default=ONNXMiniLM_L6_V2, description="ONNX MiniLM embedding function class"
|
||||
)
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias="ONNX_PREFERRED_PROVIDERS",
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Type definitions for ONNX embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class ONNXProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for ONNX provider."""
|
||||
|
||||
preferred_providers: list[str]
|
||||
|
||||
|
||||
class ONNXProviderSpec(TypedDict, total=False):
|
||||
"""ONNX provider specification."""
|
||||
|
||||
provider: Required[Literal["onnx"]]
|
||||
config: ONNXProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""OpenAI embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.openai.openai_provider import (
|
||||
OpenAIProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.openai.types import (
|
||||
OpenAIProviderConfig,
|
||||
OpenAIProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenAIProvider",
|
||||
"OpenAIProviderConfig",
|
||||
"OpenAIProviderSpec",
|
||||
]
|
||||
@@ -1,58 +0,0 @@
|
||||
"""OpenAI embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
"""OpenAI embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="OpenAI embedding function class",
|
||||
)
|
||||
api_key: str | None = Field(
|
||||
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="OPENAI_MODEL_NAME",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias="OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias="OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None, description="API version", validation_alias="OPENAI_API_VERSION"
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
)
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias="OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Type definitions for OpenAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for OpenAI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
api_base: str
|
||||
api_type: str
|
||||
api_version: str
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class OpenAIProviderSpec(TypedDict, total=False):
|
||||
"""OpenAI provider specification."""
|
||||
|
||||
provider: Required[Literal["openai"]]
|
||||
config: OpenAIProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""OpenCLIP embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.openclip.openclip_provider import (
|
||||
OpenCLIPProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.openclip.types import (
|
||||
OpenCLIPProviderConfig,
|
||||
OpenCLIPProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenCLIPProvider",
|
||||
"OpenCLIPProviderConfig",
|
||||
"OpenCLIPProviderSpec",
|
||||
]
|
||||
@@ -1,32 +0,0 @@
|
||||
"""OpenCLIP embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
"""OpenCLIP embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenCLIPEmbeddingFunction] = Field(
|
||||
default=OpenCLIPEmbeddingFunction,
|
||||
description="OpenCLIP embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias="OPENCLIP_MODEL_NAME",
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias="OPENCLIP_CHECKPOINT",
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias="OPENCLIP_DEVICE",
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Type definitions for OpenCLIP embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenCLIPProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for OpenCLIP provider."""
|
||||
|
||||
model_name: Annotated[str, "ViT-B-32"]
|
||||
checkpoint: Annotated[str, "laion2b_s34b_b79k"]
|
||||
device: Annotated[str, "cpu"]
|
||||
|
||||
|
||||
class OpenCLIPProviderSpec(TypedDict):
|
||||
"""OpenCLIP provider specification."""
|
||||
|
||||
provider: Required[Literal["openclip"]]
|
||||
config: OpenCLIPProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Roboflow embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.roboflow.roboflow_provider import (
|
||||
RoboflowProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.roboflow.types import (
|
||||
RoboflowProviderConfig,
|
||||
RoboflowProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RoboflowProvider",
|
||||
"RoboflowProviderConfig",
|
||||
"RoboflowProviderSpec",
|
||||
]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Roboflow embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
"""Roboflow embeddings provider."""
|
||||
|
||||
embedding_callable: type[RoboflowEmbeddingFunction] = Field(
|
||||
default=RoboflowEmbeddingFunction,
|
||||
description="Roboflow embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
default="", description="Roboflow API key", validation_alias="ROBOFLOW_API_KEY"
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias="ROBOFLOW_API_URL",
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Type definitions for Roboflow embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class RoboflowProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Roboflow provider."""
|
||||
|
||||
api_key: Annotated[str, ""]
|
||||
api_url: Annotated[str, "https://infer.roboflow.com"]
|
||||
|
||||
|
||||
class RoboflowProviderSpec(TypedDict):
|
||||
"""Roboflow provider specification."""
|
||||
|
||||
provider: Required[Literal["roboflow"]]
|
||||
config: RoboflowProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""SentenceTransformer embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
|
||||
SentenceTransformerProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderConfig,
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SentenceTransformerProvider",
|
||||
"SentenceTransformerProviderConfig",
|
||||
"SentenceTransformerProviderSpec",
|
||||
]
|
||||
@@ -1,34 +0,0 @@
|
||||
"""SentenceTransformer embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class SentenceTransformerProvider(
|
||||
BaseEmbeddingsProvider[SentenceTransformerEmbeddingFunction]
|
||||
):
|
||||
"""SentenceTransformer embeddings provider."""
|
||||
|
||||
embedding_callable: type[SentenceTransformerEmbeddingFunction] = Field(
|
||||
default=SentenceTransformerEmbeddingFunction,
|
||||
description="SentenceTransformer embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias="SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="SENTENCE_TRANSFORMER_DEVICE",
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias="SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Type definitions for SentenceTransformer embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class SentenceTransformerProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for SentenceTransformer provider."""
|
||||
|
||||
model_name: Annotated[str, "all-MiniLM-L6-v2"]
|
||||
device: Annotated[str, "cpu"]
|
||||
normalize_embeddings: Annotated[bool, False]
|
||||
|
||||
|
||||
class SentenceTransformerProviderSpec(TypedDict):
|
||||
"""SentenceTransformer provider specification."""
|
||||
|
||||
provider: Required[Literal["sentence-transformer"]]
|
||||
config: SentenceTransformerProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Text2Vec embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import (
|
||||
Text2VecProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import (
|
||||
Text2VecProviderConfig,
|
||||
Text2VecProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Text2VecProvider",
|
||||
"Text2VecProviderConfig",
|
||||
"Text2VecProviderSpec",
|
||||
]
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Text2Vec embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
"""Text2Vec embeddings provider."""
|
||||
|
||||
embedding_callable: type[Text2VecEmbeddingFunction] = Field(
|
||||
default=Text2VecEmbeddingFunction,
|
||||
description="Text2Vec embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias="TEXT2VEC_MODEL_NAME",
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Type definitions for Text2Vec embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class Text2VecProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Text2Vec provider."""
|
||||
|
||||
model_name: Annotated[str, "shibing624/text2vec-base-chinese"]
|
||||
|
||||
|
||||
class Text2VecProviderSpec(TypedDict):
|
||||
"""Text2Vec provider specification."""
|
||||
|
||||
provider: Required[Literal["text2vec"]]
|
||||
config: Text2VecProviderConfig
|
||||
@@ -1,15 +0,0 @@
|
||||
"""VoyageAI embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.voyageai.types import (
|
||||
VoyageAIProviderConfig,
|
||||
VoyageAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.voyageai_provider import (
|
||||
VoyageAIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"VoyageAIProvider",
|
||||
"VoyageAIProviderConfig",
|
||||
"VoyageAIProviderSpec",
|
||||
]
|
||||
@@ -1,58 +0,0 @@
|
||||
"""VoyageAI embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
|
||||
|
||||
|
||||
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for VoyageAI models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VoyageAIProviderConfig]) -> None:
|
||||
"""Initialize VoyageAI embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for VoyageAI.
|
||||
"""
|
||||
try:
|
||||
import voyageai # type: ignore[import-not-found]
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"voyageai is required for voyageai embeddings. "
|
||||
"Install it with: uv add voyageai"
|
||||
) from e
|
||||
self._config = kwargs
|
||||
self._client = voyageai.Client(
|
||||
api_key=kwargs["api_key"],
|
||||
max_retries=kwargs.get("max_retries", 0),
|
||||
timeout=kwargs.get("timeout"),
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
result = self._client.embed(
|
||||
texts=input,
|
||||
model=self._config.get("model", "voyage-2"),
|
||||
input_type=self._config.get("input_type"),
|
||||
truncation=self._config.get("truncation", True),
|
||||
output_dtype=self._config.get("output_dtype"),
|
||||
output_dimension=self._config.get("output_dimension"),
|
||||
)
|
||||
|
||||
return cast(Embeddings, result.embeddings)
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Type definitions for VoyageAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class VoyageAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for VoyageAI provider."""
|
||||
|
||||
api_key: str
|
||||
model: Annotated[str, "voyage-2"]
|
||||
input_type: str
|
||||
truncation: Annotated[bool, True]
|
||||
output_dtype: str
|
||||
output_dimension: int
|
||||
max_retries: Annotated[int, 0]
|
||||
timeout: float
|
||||
|
||||
|
||||
class VoyageAIProviderSpec(TypedDict):
|
||||
"""VoyageAI provider specification."""
|
||||
|
||||
provider: Required[Literal["voyageai"]]
|
||||
config: VoyageAIProviderConfig
|
||||
@@ -1,55 +0,0 @@
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[VoyageAIEmbeddingFunction] = Field(
|
||||
default=VoyageAIEmbeddingFunction,
|
||||
description="Voyage AI embedding function class",
|
||||
)
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias="VOYAGEAI_MODEL",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key", validation_alias="VOYAGEAI_API_KEY"
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias="VOYAGEAI_INPUT_TYPE",
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias="VOYAGEAI_TRUNCATION",
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias="VOYAGEAI_OUTPUT_DTYPE",
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias="VOYAGEAI_OUTPUT_DIMENSION",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias="VOYAGEAI_MAX_RETRIES",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias="VOYAGEAI_TIMEOUT",
|
||||
)
|
||||
@@ -1,73 +1,63 @@
|
||||
"""Type definitions for the embeddings module."""
|
||||
|
||||
from typing import Literal, TypeAlias
|
||||
from typing import Literal
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
ProviderSpec = (
|
||||
AzureProviderSpec
|
||||
| BedrockProviderSpec
|
||||
| CohereProviderSpec
|
||||
| CustomProviderSpec
|
||||
| GenerativeAiProviderSpec
|
||||
| HuggingFaceProviderSpec
|
||||
| InstructorProviderSpec
|
||||
| JinaProviderSpec
|
||||
| OllamaProviderSpec
|
||||
| ONNXProviderSpec
|
||||
| OpenAIProviderSpec
|
||||
| OpenCLIPProviderSpec
|
||||
| RoboflowProviderSpec
|
||||
| SentenceTransformerProviderSpec
|
||||
| Text2VecProviderSpec
|
||||
| VertexAIProviderSpec
|
||||
| VoyageAIProviderSpec
|
||||
| WatsonProviderSpec
|
||||
)
|
||||
from crewai.rag.types import EmbeddingFunction
|
||||
|
||||
AllowedEmbeddingProviders = Literal[
|
||||
"azure",
|
||||
"amazon-bedrock",
|
||||
EmbeddingProvider = Literal[
|
||||
"openai",
|
||||
"cohere",
|
||||
"custom",
|
||||
"ollama",
|
||||
"huggingface",
|
||||
"sentence-transformer",
|
||||
"instructor",
|
||||
"google-palm",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"huggingface",
|
||||
"instructor",
|
||||
"amazon-bedrock",
|
||||
"jina",
|
||||
"ollama",
|
||||
"onnx",
|
||||
"openai",
|
||||
"openclip",
|
||||
"roboflow",
|
||||
"sentence-transformer",
|
||||
"openclip",
|
||||
"text2vec",
|
||||
"voyageai",
|
||||
"onnx",
|
||||
"watson",
|
||||
]
|
||||
"""Supported embedding providers.
|
||||
|
||||
EmbedderConfig: TypeAlias = (
|
||||
ProviderSpec | BaseEmbeddingsProvider | type[BaseEmbeddingsProvider]
|
||||
)
|
||||
These correspond to the embedding functions available in ChromaDB's
|
||||
embedding_functions module. Each provider has specific requirements
|
||||
and configuration options.
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingOptions(BaseModel):
|
||||
"""Configuration options for embedding providers.
|
||||
|
||||
Generic attributes that can be passed to get_embedding_function
|
||||
to configure various embedding providers.
|
||||
"""
|
||||
|
||||
provider: EmbeddingProvider = Field(
|
||||
..., description="Embedding provider name (e.g., 'openai', 'cohere', 'onnx')"
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None, description="Model name for the embedding provider"
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None, description="API key for the embedding provider"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
"""Configuration wrapper for embedding functions.
|
||||
|
||||
Accepts either a pre-configured EmbeddingFunction or EmbeddingOptions
|
||||
to create one. This provides flexibility in how embeddings are configured.
|
||||
|
||||
Attributes:
|
||||
function: Either a callable EmbeddingFunction or EmbeddingOptions to create one
|
||||
"""
|
||||
|
||||
function: EmbeddingFunction | EmbeddingOptions
|
||||
|
||||
@@ -48,7 +48,6 @@ class QdrantClient(BaseClient):
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with client and embedding function.
|
||||
|
||||
@@ -57,13 +56,11 @@ class QdrantClient(BaseClient):
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
@@ -237,7 +234,6 @@ class QdrantClient(BaseClient):
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
@@ -253,7 +249,6 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -261,20 +256,19 @@ class QdrantClient(BaseClient):
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : min(i + batch_size, len(documents))]
|
||||
points = []
|
||||
for doc in batch_docs:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
self.client.upsert(collection_name=collection_name, points=points)
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -282,7 +276,6 @@ class QdrantClient(BaseClient):
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
@@ -298,7 +291,6 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -306,19 +298,18 @@ class QdrantClient(BaseClient):
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : min(i + batch_size, len(documents))]
|
||||
points = []
|
||||
for doc in batch_docs:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
await self.client.upsert(collection_name=collection_name, points=points)
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
await self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
|
||||
@@ -22,5 +22,4 @@ def create_client(config: QdrantConfig) -> QdrantClient:
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
@@ -16,7 +13,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
|
||||
@@ -24,7 +24,8 @@ class BaseRecord(TypedDict, total=False):
|
||||
)
|
||||
|
||||
|
||||
Embeddings: TypeAlias = list[list[float]]
|
||||
DenseVector: TypeAlias = list[float]
|
||||
IntVector: TypeAlias = list[int]
|
||||
|
||||
EmbeddingFunction: TypeAlias = Callable[..., Any]
|
||||
|
||||
|
||||
@@ -2,94 +2,29 @@
|
||||
|
||||
import importlib
|
||||
from types import ModuleType
|
||||
from typing import Annotated, Any, TypeAlias
|
||||
|
||||
from pydantic import AfterValidator, TypeAdapter
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Not needed when using `crewai.utilities.import_utils.import_and_validate_definition`"
|
||||
)
|
||||
class OptionalDependencyError(ImportError):
|
||||
"""Exception raised when an optional dependency is not installed."""
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Use `crewai.utilities.import_utils.import_and_validate_definition` instead."
|
||||
)
|
||||
def require(name: str, *, purpose: str, attr: str | None = None) -> ModuleType | Any:
|
||||
"""Import a module, optionally returning a specific attribute.
|
||||
def require(name: str, *, purpose: str) -> ModuleType:
|
||||
"""Import a module, raising a helpful error if it's not installed.
|
||||
|
||||
Args:
|
||||
name: The module name to import.
|
||||
purpose: Description of what requires this dependency.
|
||||
attr: Optional attribute name to get from the module.
|
||||
|
||||
Returns:
|
||||
The imported module or the specified attribute.
|
||||
The imported module.
|
||||
|
||||
Raises:
|
||||
OptionalDependencyError: If the module is not installed.
|
||||
AttributeError: If the specified attribute doesn't exist.
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
if attr is not None:
|
||||
return getattr(module, attr)
|
||||
return module
|
||||
return importlib.import_module(name)
|
||||
except ImportError as exc:
|
||||
package_name = name.split(".")[0]
|
||||
raise OptionalDependencyError(
|
||||
f"{purpose} requires the optional dependency '{name}'.\n"
|
||||
f"Install it with: uv add {package_name}"
|
||||
f"Install it with: uv add {name}"
|
||||
) from exc
|
||||
except AttributeError as exc:
|
||||
raise AttributeError(f"Module '{name}' has no attribute '{attr}'") from exc
|
||||
|
||||
|
||||
def validate_import_path(v: str) -> Any:
|
||||
"""Import and return the class/function from the import path.
|
||||
|
||||
Args:
|
||||
v: Import path string in the format 'module.path.ClassName'.
|
||||
|
||||
Returns:
|
||||
The imported class or function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the import path is malformed or the module cannot be imported.
|
||||
"""
|
||||
module_path, _, attr = v.rpartition(".")
|
||||
if not module_path or not attr:
|
||||
raise ValueError(f"import_path '{v}' must be of the form 'module.ClassName'")
|
||||
|
||||
try:
|
||||
mod = importlib.import_module(module_path)
|
||||
except ImportError as exc:
|
||||
parts = module_path.split(".")
|
||||
if not parts:
|
||||
raise ValueError(f"Malformed import path: '{v}'") from exc
|
||||
package = parts[0]
|
||||
raise ValueError(
|
||||
f"Package '{package}' could not be imported. Install it with: uv add {package}"
|
||||
) from exc
|
||||
|
||||
if not hasattr(mod, attr):
|
||||
raise ValueError(f"Attribute '{attr}' not found in module '{module_path}'")
|
||||
return getattr(mod, attr)
|
||||
|
||||
|
||||
ImportedDefinition: TypeAlias = Annotated[Any, AfterValidator(validate_import_path)]
|
||||
adapter = TypeAdapter(ImportedDefinition)
|
||||
|
||||
|
||||
def import_and_validate_definition(v: str) -> Any:
|
||||
"""Pydantic-compatible function to import a class/function from a string path.
|
||||
|
||||
Args:
|
||||
v: Import path string in the format 'module.path.ClassName'.
|
||||
Returns:
|
||||
The imported class or function
|
||||
"""
|
||||
return adapter.validate_python(v)
|
||||
|
||||
@@ -1,298 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Test Agent. Test backstory\nYour
|
||||
personal goal is: Test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Say hello to
|
||||
the world\n\nThis is the expected criteria for your final answer: hello world\nyou
|
||||
MUST return the actual complete content as the final answer, not a summary.\n\nBegin!
|
||||
This is VERY important to you, use the tools available and give your best Final
|
||||
Answer, your job depends on it!\n\nThought:"}], "model": "gpt-4o-mini", "stop":
|
||||
["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '825'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.93.0
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.93.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- '600.0'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.9
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFJNj9MwEL3nV4x8blC/0nRzA8QKuCBxQFrBKnLtSeLF8Vi2s11Y9b8j
|
||||
O90m5UPiEinz5j2/NzPPGQBTklXARMeD6K3O33588/nm/d27T5sv6wdp1OGu/9l1T9tbWzQlW0QG
|
||||
HR5QhBfWK0G91RgUmREWDnnAqLoqi/1uv19vVgnoSaKOtNaGfEt5r4zK18v1Nl+W+Wp/ZnekBHpW
|
||||
wdcMAOA5faNPI/GJVbBcvFR69J63yKpLEwBzpGOFce+VD9wEtphAQSagSdY/gKEjCG6gVY8IHNpo
|
||||
G7jxR3QA38ytMlzD6/RfQYdaExzJaTkXdNgMnsdQZtB6BnBjKPA4lBTl/oycLuY1tdbRwf9GZY0y
|
||||
yne1Q+7JRKM+kGUJPWUA92lIw1VuZh31NtSBvmN6blWUox6bdjNDN2cwUOB6Vi/Po73WqyUGrrSf
|
||||
jZkJLjqUE3XaCR+kohmQzVL/6eZv2mNyZdr/kZ8AIdAGlLV1KJW4Tjy1OYyn+6+2y5STYebRPSqB
|
||||
dVDo4iYkNnzQ40Ex/8MH7OtGmRaddWq8qsbWxW7Jmx0WxQ3LTtkvAAAA//8DAIkIBqtjAwAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 983f8c061b6ec487-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Wed, 24 Sep 2025 04:30:32 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=JDjpnzx5y8PJaJDQcCeX6MeBt8BOGuL79pd.ca5mqvE-1758688232-1.0.1.1-5VN5hj5LzEZFfkotBaZ_dbUITo_YB7RLsFOlQc.0MdSZOsz7WhNkH.s7H700L12Yi8nHGW44BgIwCF3uWx1w4PRBqrb1IVH3FkeV.QwCTaA;
|
||||
path=/; expires=Wed, 24-Sep-25 05:00:32 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=b5n8BZZDRtHA4TrxQ1RDeEdtQBzhstjP6u21LYM8L94-1758688232142-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '535'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '562'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999827'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999830'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_af61ab9d53bf400baf30c5bc5a7e2102
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: null
|
||||
headers:
|
||||
Connection:
|
||||
- close
|
||||
Host:
|
||||
- api.scarf.sh
|
||||
User-Agent:
|
||||
- CrewAI-Python/0.193.2
|
||||
method: GET
|
||||
uri: https://api.scarf.sh/v2/packages/CrewAI/crewai/docs/00f2dad1-8334-4a39-934e-003b2e1146db
|
||||
response:
|
||||
body:
|
||||
string: ''
|
||||
headers:
|
||||
Connection:
|
||||
- close
|
||||
Date:
|
||||
- Wed, 24 Sep 2025 04:47:59 GMT
|
||||
Strict-Transport-Security:
|
||||
- max-age=15724800; includeSubDomains
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
x-scarf-request-id:
|
||||
- 4158376f-cb1c-46fe-a14c-dee366b955e2
|
||||
status:
|
||||
code: 401
|
||||
message: Unauthorized
|
||||
- request:
|
||||
body: '{"trace_id": "06e1250e-6d88-4c64-abe5-deabde573ae1", "execution_type":
|
||||
"crew", "user_identifier": null, "execution_context": {"crew_fingerprint": null,
|
||||
"crew_name": "crew", "flow_name": null, "crewai_version": "0.193.2", "privacy_level":
|
||||
"standard"}, "execution_metadata": {"expected_duration_estimate": 300, "agent_count":
|
||||
0, "task_count": 0, "flow_method_count": 0, "execution_started_at": "2025-09-24T04:50:23.219835+00:00"}}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '428'
|
||||
Content-Type:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- CrewAI-CLI/0.193.2
|
||||
X-Crewai-Organization-Id:
|
||||
- d3a3d10c-35db-423f-a7a4-c026030ba64d
|
||||
X-Crewai-Version:
|
||||
- 0.193.2
|
||||
method: POST
|
||||
uri: http://localhost:3000/crewai_plus/api/v1/tracing/batches
|
||||
response:
|
||||
body:
|
||||
string: '{"error":"bad_credentials","message":"Bad credentials"}'
|
||||
headers:
|
||||
Content-Length:
|
||||
- '55'
|
||||
cache-control:
|
||||
- no-cache
|
||||
content-security-policy:
|
||||
- 'default-src ''self'' *.crewai.com crewai.com; script-src ''self'' ''unsafe-inline''
|
||||
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts https://www.gstatic.com
|
||||
https://run.pstmn.io https://share.descript.com/; style-src ''self'' ''unsafe-inline''
|
||||
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts; img-src ''self''
|
||||
data: *.crewai.com crewai.com https://zeus.tools.crewai.com https://dashboard.tools.crewai.com
|
||||
https://cdn.jsdelivr.net; font-src ''self'' data: *.crewai.com crewai.com;
|
||||
connect-src ''self'' *.crewai.com crewai.com https://zeus.tools.crewai.com
|
||||
https://connect.useparagon.com/ https://zeus.useparagon.com/* https://*.useparagon.com/*
|
||||
https://run.pstmn.io https://connect.tools.crewai.com/ ws://localhost:3036
|
||||
wss://localhost:3036; frame-src ''self'' *.crewai.com crewai.com https://connect.useparagon.com/
|
||||
https://zeus.tools.crewai.com https://zeus.useparagon.com/* https://connect.tools.crewai.com/
|
||||
https://www.youtube.com https://share.descript.com'
|
||||
content-type:
|
||||
- application/json; charset=utf-8
|
||||
permissions-policy:
|
||||
- camera=(), microphone=(self), geolocation=()
|
||||
referrer-policy:
|
||||
- strict-origin-when-cross-origin
|
||||
server-timing:
|
||||
- cache_read.active_support;dur=0.37, sql.active_record;dur=30.81, cache_generate.active_support;dur=29.14,
|
||||
cache_write.active_support;dur=0.14, cache_read_multi.active_support;dur=0.19,
|
||||
start_processing.action_controller;dur=0.00, process_action.action_controller;dur=2.74
|
||||
vary:
|
||||
- Accept
|
||||
x-content-type-options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- SAMEORIGIN
|
||||
x-permitted-cross-domain-policies:
|
||||
- none
|
||||
x-request-id:
|
||||
- 2420790e-9669-4235-851c-468185b6ef40
|
||||
x-runtime:
|
||||
- '0.102516'
|
||||
x-xss-protection:
|
||||
- 1; mode=block
|
||||
status:
|
||||
code: 401
|
||||
message: Unauthorized
|
||||
- request:
|
||||
body: '{"status": "failed", "failure_reason": "Error sending events to backend"}'
|
||||
headers:
|
||||
Accept:
|
||||
- '*/*'
|
||||
Accept-Encoding:
|
||||
- gzip, deflate
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '73'
|
||||
Content-Type:
|
||||
- application/json
|
||||
User-Agent:
|
||||
- CrewAI-CLI/0.193.2
|
||||
X-Crewai-Organization-Id:
|
||||
- d3a3d10c-35db-423f-a7a4-c026030ba64d
|
||||
X-Crewai-Version:
|
||||
- 0.193.2
|
||||
method: PATCH
|
||||
uri: http://localhost:3000/crewai_plus/api/v1/tracing/batches/None
|
||||
response:
|
||||
body:
|
||||
string: '{"error":"bad_credentials","message":"Bad credentials"}'
|
||||
headers:
|
||||
Content-Length:
|
||||
- '55'
|
||||
cache-control:
|
||||
- no-cache
|
||||
content-security-policy:
|
||||
- 'default-src ''self'' *.crewai.com crewai.com; script-src ''self'' ''unsafe-inline''
|
||||
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts https://www.gstatic.com
|
||||
https://run.pstmn.io https://share.descript.com/; style-src ''self'' ''unsafe-inline''
|
||||
*.crewai.com crewai.com https://cdn.jsdelivr.net/npm/apexcharts; img-src ''self''
|
||||
data: *.crewai.com crewai.com https://zeus.tools.crewai.com https://dashboard.tools.crewai.com
|
||||
https://cdn.jsdelivr.net; font-src ''self'' data: *.crewai.com crewai.com;
|
||||
connect-src ''self'' *.crewai.com crewai.com https://zeus.tools.crewai.com
|
||||
https://connect.useparagon.com/ https://zeus.useparagon.com/* https://*.useparagon.com/*
|
||||
https://run.pstmn.io https://connect.tools.crewai.com/ ws://localhost:3036
|
||||
wss://localhost:3036; frame-src ''self'' *.crewai.com crewai.com https://connect.useparagon.com/
|
||||
https://zeus.tools.crewai.com https://zeus.useparagon.com/* https://connect.tools.crewai.com/
|
||||
https://www.youtube.com https://share.descript.com'
|
||||
content-type:
|
||||
- application/json; charset=utf-8
|
||||
permissions-policy:
|
||||
- camera=(), microphone=(self), geolocation=()
|
||||
referrer-policy:
|
||||
- strict-origin-when-cross-origin
|
||||
server-timing:
|
||||
- cache_read.active_support;dur=0.06, sql.active_record;dur=3.86, cache_generate.active_support;dur=4.28,
|
||||
cache_write.active_support;dur=0.15, cache_read_multi.active_support;dur=0.12,
|
||||
start_processing.action_controller;dur=0.00, process_action.action_controller;dur=1.70
|
||||
vary:
|
||||
- Accept
|
||||
x-content-type-options:
|
||||
- nosniff
|
||||
x-frame-options:
|
||||
- SAMEORIGIN
|
||||
x-permitted-cross-domain-policies:
|
||||
- none
|
||||
x-request-id:
|
||||
- 1750d141-c48f-47f1-b8b4-130195437d22
|
||||
x-runtime:
|
||||
- '0.043849'
|
||||
x-xss-protection:
|
||||
- 1; mode=block
|
||||
status:
|
||||
code: 401
|
||||
message: Unauthorized
|
||||
version: 1
|
||||
@@ -11,7 +11,7 @@ from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.create_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.build_embedder")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
|
||||
def test_knowledge_storage_uses_rag_client(
|
||||
mock_get_embedding: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
@@ -122,7 +122,7 @@ def test_search_error_handling(mock_get_client: MagicMock) -> None:
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.build_embedder")
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_embedding_function")
|
||||
def test_embedding_configuration_flow(
|
||||
mock_get_embedding: MagicMock, mock_get_client: MagicMock
|
||||
) -> None:
|
||||
|
||||
@@ -34,30 +34,6 @@ def client(mock_chromadb_client) -> ChromaDBClient:
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_with_batch_size(mock_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with custom batch size for testing."""
|
||||
mock_embedding = Mock()
|
||||
client = ChromaDBClient(
|
||||
client=mock_chromadb_client,
|
||||
embedding_function=mock_embedding,
|
||||
default_batch_size=2,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client_with_batch_size(mock_async_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with async client and custom batch size for testing."""
|
||||
mock_embedding = Mock()
|
||||
client = ChromaDBClient(
|
||||
client=mock_async_chromadb_client,
|
||||
embedding_function=mock_embedding,
|
||||
default_batch_size=2,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with async client for testing."""
|
||||
@@ -636,139 +612,3 @@ class TestChromaDBClient:
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_chromadb_client.reset.assert_called_once_with()
|
||||
|
||||
def test_add_documents_with_batch_size(
|
||||
self, client_with_batch_size, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents with batch size splits documents into batches."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
|
||||
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
|
||||
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
|
||||
{"doc_id": "id4", "content": "Document 4", "metadata": {"source": "test4"}},
|
||||
{"doc_id": "id5", "content": "Document 5", "metadata": {"source": "test5"}},
|
||||
]
|
||||
|
||||
client_with_batch_size.add_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 3
|
||||
|
||||
first_call = mock_collection.upsert.call_args_list[0]
|
||||
assert first_call.kwargs["ids"] == ["id1", "id2"]
|
||||
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
|
||||
assert first_call.kwargs["metadatas"] == [
|
||||
{"source": "test1"},
|
||||
{"source": "test2"},
|
||||
]
|
||||
|
||||
second_call = mock_collection.upsert.call_args_list[1]
|
||||
assert second_call.kwargs["ids"] == ["id3", "id4"]
|
||||
assert second_call.kwargs["documents"] == ["Document 3", "Document 4"]
|
||||
assert second_call.kwargs["metadatas"] == [
|
||||
{"source": "test3"},
|
||||
{"source": "test4"},
|
||||
]
|
||||
|
||||
third_call = mock_collection.upsert.call_args_list[2]
|
||||
assert third_call.kwargs["ids"] == ["id5"]
|
||||
assert third_call.kwargs["documents"] == ["Document 5"]
|
||||
assert third_call.kwargs["metadatas"] == [{"source": "test5"}]
|
||||
|
||||
def test_add_documents_with_explicit_batch_size(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents with explicitly provided batch size."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1"},
|
||||
{"doc_id": "id2", "content": "Document 2"},
|
||||
{"doc_id": "id3", "content": "Document 3"},
|
||||
]
|
||||
|
||||
client.add_documents(
|
||||
collection_name="test_collection", documents=documents, batch_size=1
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 3
|
||||
for i, call in enumerate(mock_collection.upsert.call_args_list):
|
||||
assert len(call.kwargs["ids"]) == 1
|
||||
assert call.kwargs["ids"] == [f"id{i + 1}"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_batch_size(
|
||||
self, async_client_with_batch_size, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with batch size splits documents into batches."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
|
||||
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
|
||||
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
|
||||
]
|
||||
|
||||
await async_client_with_batch_size.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 2
|
||||
|
||||
first_call = mock_collection.upsert.call_args_list[0]
|
||||
assert first_call.kwargs["ids"] == ["id1", "id2"]
|
||||
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
|
||||
|
||||
second_call = mock_collection.upsert.call_args_list[1]
|
||||
assert second_call.kwargs["ids"] == ["id3"]
|
||||
assert second_call.kwargs["documents"] == ["Document 3"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_explicit_batch_size(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with explicitly provided batch size."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1"},
|
||||
{"doc_id": "id2", "content": "Document 2"},
|
||||
{"doc_id": "id3", "content": "Document 3"},
|
||||
{"doc_id": "id4", "content": "Document 4"},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents, batch_size=3
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 2
|
||||
|
||||
first_call = mock_collection.upsert.call_args_list[0]
|
||||
assert len(first_call.kwargs["ids"]) == 3
|
||||
|
||||
second_call = mock_collection.upsert.call_args_list[1]
|
||||
assert len(second_call.kwargs["ids"]) == 1
|
||||
|
||||
def test_client_default_batch_size_initialization(self) -> None:
|
||||
"""Test that client initializes with correct default batch size."""
|
||||
mock_client = Mock()
|
||||
mock_embedding = Mock()
|
||||
|
||||
client = ChromaDBClient(client=mock_client, embedding_function=mock_embedding)
|
||||
assert client.default_batch_size == 100
|
||||
|
||||
custom_client = ChromaDBClient(
|
||||
client=mock_client, embedding_function=mock_embedding, default_batch_size=50
|
||||
)
|
||||
assert custom_client.default_batch_size == 50
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
"""Tests for ChromaDB utility functions."""
|
||||
|
||||
from crewai.rag.chromadb.types import PreparedDocuments
|
||||
from crewai.rag.chromadb.utils import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
_create_batch_slice,
|
||||
_is_ipv4_pattern,
|
||||
_prepare_documents_for_chromadb,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
|
||||
class TestChromaDBUtils:
|
||||
@@ -97,206 +93,3 @@ class TestChromaDBUtils:
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
|
||||
class TestPrepareDocumentsForChromaDB:
|
||||
"""Test suite for _prepare_documents_for_chromadb function."""
|
||||
|
||||
def test_prepare_documents_with_doc_ids(self) -> None:
|
||||
"""Test preparing documents that already have doc_ids."""
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "id1",
|
||||
"content": "First document",
|
||||
"metadata": {"source": "test1"},
|
||||
},
|
||||
{
|
||||
"doc_id": "id2",
|
||||
"content": "Second document",
|
||||
"metadata": {"source": "test2"},
|
||||
},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert result.ids == ["id1", "id2"]
|
||||
assert result.texts == ["First document", "Second document"]
|
||||
assert result.metadatas == [{"source": "test1"}, {"source": "test2"}]
|
||||
|
||||
def test_prepare_documents_generate_ids(self) -> None:
|
||||
"""Test preparing documents without doc_ids (should generate hashes)."""
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Test content", "metadata": {"key": "value"}},
|
||||
{"content": "Another test"},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert len(result.ids) == 2
|
||||
assert all(len(doc_id) == 64 for doc_id in result.ids)
|
||||
assert result.texts == ["Test content", "Another test"]
|
||||
assert result.metadatas == [{"key": "value"}, {}]
|
||||
|
||||
def test_prepare_documents_with_list_metadata(self) -> None:
|
||||
"""Test preparing documents with list metadata (should take first item)."""
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Test", "metadata": [{"first": "item"}, {"second": "item"}]},
|
||||
{"content": "Test2", "metadata": []},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert result.metadatas == [{"first": "item"}, {}]
|
||||
|
||||
def test_prepare_documents_no_metadata(self) -> None:
|
||||
"""Test preparing documents without metadata."""
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
{"content": "Document 2", "metadata": None},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert result.metadatas == [{}, {}]
|
||||
|
||||
def test_prepare_documents_hash_consistency(self) -> None:
|
||||
"""Test that identical content produces identical hashes."""
|
||||
documents1: list[BaseRecord] = [
|
||||
{"content": "Same content", "metadata": {"key": "value"}}
|
||||
]
|
||||
documents2: list[BaseRecord] = [
|
||||
{"content": "Same content", "metadata": {"key": "value"}}
|
||||
]
|
||||
|
||||
result1 = _prepare_documents_for_chromadb(documents1)
|
||||
result2 = _prepare_documents_for_chromadb(documents2)
|
||||
|
||||
assert result1.ids == result2.ids
|
||||
|
||||
|
||||
class TestCreateBatchSlice:
|
||||
"""Test suite for _create_batch_slice function."""
|
||||
|
||||
def test_create_batch_slice_normal(self) -> None:
|
||||
"""Test creating a normal batch slice."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3", "id4", "id5"],
|
||||
texts=["doc1", "doc2", "doc3", "doc4", "doc5"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}, {"e": 5}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=1, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == ["id2", "id3", "id4"]
|
||||
assert batch_texts == ["doc2", "doc3", "doc4"]
|
||||
assert batch_metadatas == [{"b": 2}, {"c": 3}, {"d": 4}]
|
||||
|
||||
def test_create_batch_slice_at_end(self) -> None:
|
||||
"""Test creating a batch slice that goes beyond the end."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=2, batch_size=5
|
||||
)
|
||||
|
||||
assert batch_ids == ["id3"]
|
||||
assert batch_texts == ["doc3"]
|
||||
assert batch_metadatas == [{"c": 3}]
|
||||
|
||||
def test_create_batch_slice_empty_batch(self) -> None:
|
||||
"""Test creating a batch slice starting beyond the data."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2"], texts=["doc1", "doc2"], metadatas=[{"a": 1}, {"b": 2}]
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=5, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == []
|
||||
assert batch_texts == []
|
||||
assert batch_metadatas == []
|
||||
|
||||
def test_create_batch_slice_no_metadatas(self) -> None:
|
||||
"""Test creating a batch slice with no metadatas."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"], texts=["doc1", "doc2", "doc3"], metadatas=[]
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=2
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2"]
|
||||
assert batch_texts == ["doc1", "doc2"]
|
||||
assert batch_metadatas is None
|
||||
|
||||
def test_create_batch_slice_all_empty_metadatas(self) -> None:
|
||||
"""Test creating a batch slice where all metadatas are empty."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{}, {}, {}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2", "id3"]
|
||||
assert batch_texts == ["doc1", "doc2", "doc3"]
|
||||
assert batch_metadatas is None
|
||||
|
||||
def test_create_batch_slice_some_empty_metadatas(self) -> None:
|
||||
"""Test creating a batch slice where some metadatas are empty."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{"a": 1}, {}, {"c": 3}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2", "id3"]
|
||||
assert batch_texts == ["doc1", "doc2", "doc3"]
|
||||
assert batch_metadatas == [{"a": 1}, {}, {"c": 3}]
|
||||
|
||||
def test_create_batch_slice_zero_start_index(self) -> None:
|
||||
"""Test creating a batch slice starting from index 0."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3", "id4"],
|
||||
texts=["doc1", "doc2", "doc3", "doc4"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=2
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2"]
|
||||
assert batch_texts == ["doc1", "doc2"]
|
||||
assert batch_metadatas == [{"a": 1}, {"b": 2}]
|
||||
|
||||
def test_create_batch_slice_single_item(self) -> None:
|
||||
"""Test creating a batch slice with batch size 1."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=1, batch_size=1
|
||||
)
|
||||
|
||||
assert batch_ids == ["id2"]
|
||||
assert batch_texts == ["doc2"]
|
||||
assert batch_metadatas == [{"b": 2}]
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
"""Tests for embedding function factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
|
||||
class TestEmbeddingFactory:
|
||||
"""Test embedding factory functions."""
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_openai(self, mock_import):
|
||||
"""Test building OpenAI embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model_name": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-key"
|
||||
assert call_kwargs["model_name"] == "text-embedding-3-small"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_azure(self, mock_import):
|
||||
"""Test building Azure embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_key": "test-azure-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_type": "azure",
|
||||
"api_version": "2023-05-15",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-azure-key"
|
||||
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
|
||||
assert call_kwargs["api_type"] == "azure"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_ollama(self, mock_import):
|
||||
"""Test building Ollama embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider"
|
||||
)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_cohere(self, mock_import):
|
||||
"""Test building Cohere embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"api_key": "cohere-key",
|
||||
"model_name": "embed-english-v3.0",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider"
|
||||
)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_voyageai(self, mock_import):
|
||||
"""Test building VoyageAI embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "voyageai",
|
||||
"config": {
|
||||
"api_key": "voyage-key",
|
||||
"model": "voyage-2",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider"
|
||||
)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_watson(self, mock_import):
|
||||
"""Test building Watson embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "watson",
|
||||
"config": {
|
||||
"model_id": "ibm/slate-125m-english-rtrvr",
|
||||
"api_key": "watson-key",
|
||||
"url": "https://us-south.ml.cloud.ibm.com",
|
||||
"project_id": "test-project",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.ibm.watson.WatsonProvider"
|
||||
)
|
||||
|
||||
def test_build_embedder_unknown_provider(self):
|
||||
"""Test error handling for unknown provider."""
|
||||
config = {"provider": "unknown-provider", "config": {}}
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown provider: unknown-provider"):
|
||||
build_embedder(config)
|
||||
|
||||
def test_build_embedder_missing_provider(self):
|
||||
"""Test error handling for missing provider key."""
|
||||
config = {"config": {"api_key": "test-key"}}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
build_embedder(config)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_import_error(self, mock_import):
|
||||
"""Test error handling when provider import fails."""
|
||||
mock_import.side_effect = ImportError("Module not found")
|
||||
|
||||
config = {"provider": "openai", "config": {"api_key": "test-key"}}
|
||||
|
||||
with pytest.raises(ImportError, match="Failed to import provider openai"):
|
||||
build_embedder(config)
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_custom_provider(self, mock_import):
|
||||
"""Test building custom embedder."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_callable = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable = mock_embedding_callable
|
||||
|
||||
config = {
|
||||
"provider": "custom",
|
||||
"config": {"embedding_callable": mock_embedding_callable},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["embedding_callable"] == mock_embedding_callable
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
@patch("crewai.rag.embeddings.factory.build_embedder_from_provider")
|
||||
def test_build_embedder_with_provider_instance(
|
||||
self, mock_build_from_provider, mock_import
|
||||
):
|
||||
"""Test building embedder from provider instance."""
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
mock_provider = MagicMock(spec=BaseEmbeddingsProvider)
|
||||
mock_embedding_function = MagicMock()
|
||||
mock_build_from_provider.return_value = mock_embedding_function
|
||||
|
||||
result = build_embedder(mock_provider)
|
||||
|
||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||
assert result == mock_embedding_function
|
||||
mock_import.assert_not_called()
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Test Azure embedder configuration with factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
|
||||
class TestAzureEmbedderFactory:
|
||||
"""Test Azure embedder configuration with factory function."""
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_azure_with_nested_config(self, mock_import):
|
||||
"""Test Azure configuration with nested config key."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
embedder_config = {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_key": "test-azure-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_type": "azure",
|
||||
"api_version": "2023-05-15",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
|
||||
result = build_embedder(embedder_config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-azure-key"
|
||||
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
|
||||
assert call_kwargs["api_type"] == "azure"
|
||||
assert call_kwargs["api_version"] == "2023-05-15"
|
||||
assert call_kwargs["model_name"] == "text-embedding-3-small"
|
||||
assert call_kwargs["deployment_id"] == "test-deployment"
|
||||
|
||||
assert result == mock_embedding_function
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_regular_openai_with_nested_config(self, mock_import):
|
||||
"""Test regular OpenAI configuration with nested config."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
embedder_config = {
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "test-openai-key", "model": "text-embedding-3-large"},
|
||||
}
|
||||
|
||||
result = build_embedder(embedder_config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-openai-key"
|
||||
assert call_kwargs["model"] == "text-embedding-3-large"
|
||||
|
||||
assert result == mock_embedding_function
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_azure_provider_with_minimal_config(self, mock_import):
|
||||
"""Test Azure provider with minimal required configuration."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
embedder_config = {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(embedder_config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.microsoft.azure.AzureProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-key"
|
||||
assert call_kwargs["api_base"] == "https://test.openai.azure.com/"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_azure_import_error(self, mock_import):
|
||||
"""Test handling of import errors for Azure provider."""
|
||||
mock_import.side_effect = ImportError("Failed to import Azure provider")
|
||||
|
||||
embedder_config = {
|
||||
"provider": "azure",
|
||||
"config": {"api_key": "test-key"},
|
||||
}
|
||||
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
build_embedder(embedder_config)
|
||||
|
||||
assert "Failed to import provider azure" in str(exc_info.value)
|
||||
315
tests/rag/embeddings/test_factory_enhanced.py
Normal file
315
tests/rag/embeddings/test_factory_enhanced.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Enhanced tests for embedding function factory."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.factory import ( # type: ignore[import-untyped]
|
||||
get_embedding_function,
|
||||
)
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions # type: ignore[import-untyped]
|
||||
|
||||
|
||||
def test_get_embedding_function_default() -> None:
|
||||
"""Test default embedding function when no config provided."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.os.getenv", return_value="test-api-key"
|
||||
):
|
||||
result = get_embedding_function()
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-api-key", model_name="text-embedding-3-small"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_with_embedding_options() -> None:
|
||||
"""Test embedding function creation with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="openai", api_key="test-key", model="text-embedding-3-large"
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
# OpenAI uses model_name parameter, not model
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_sentence_transformer() -> None:
|
||||
"""Test sentence transformer embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.SentenceTransformerEmbeddingFunction"
|
||||
) as mock_st:
|
||||
mock_instance = MagicMock()
|
||||
mock_st.return_value = mock_instance
|
||||
|
||||
config = {"provider": "sentence-transformer", "model_name": "all-MiniLM-L6-v2"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_st.assert_called_once_with(model_name="all-MiniLM-L6-v2")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_ollama() -> None:
|
||||
"""Test Ollama embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.OllamaEmbeddingFunction") as mock_ollama:
|
||||
mock_instance = MagicMock()
|
||||
mock_ollama.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_ollama.assert_called_once_with(
|
||||
model_name="nomic-embed-text", url="http://localhost:11434"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_cohere() -> None:
|
||||
"""Test Cohere embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.CohereEmbeddingFunction") as mock_cohere:
|
||||
mock_instance = MagicMock()
|
||||
mock_cohere.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "cohere",
|
||||
"api_key": "cohere-key",
|
||||
"model_name": "embed-english-v3.0",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_cohere.assert_called_once_with(
|
||||
api_key="cohere-key", model_name="embed-english-v3.0"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_huggingface() -> None:
|
||||
"""Test HuggingFace embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.HuggingFaceEmbeddingFunction") as mock_hf:
|
||||
mock_instance = MagicMock()
|
||||
mock_hf.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "huggingface",
|
||||
"api_key": "hf-token",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_hf.assert_called_once_with(
|
||||
api_key="hf-token", model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_onnx() -> None:
|
||||
"""Test ONNX embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.ONNXMiniLM_L6_V2") as mock_onnx:
|
||||
mock_instance = MagicMock()
|
||||
mock_onnx.return_value = mock_instance
|
||||
|
||||
config = {"provider": "onnx"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_onnx.assert_called_once()
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_google_palm() -> None:
|
||||
"""Test Google PaLM embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.GooglePalmEmbeddingFunction"
|
||||
) as mock_palm:
|
||||
mock_instance = MagicMock()
|
||||
mock_palm.return_value = mock_instance
|
||||
|
||||
config = {"provider": "google-palm", "api_key": "palm-key"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_palm.assert_called_once_with(api_key="palm-key")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_amazon_bedrock() -> None:
|
||||
"""Test Amazon Bedrock embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.AmazonBedrockEmbeddingFunction"
|
||||
) as mock_bedrock:
|
||||
mock_instance = MagicMock()
|
||||
mock_bedrock.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "amazon-bedrock",
|
||||
"region_name": "us-west-2",
|
||||
"model_name": "amazon.titan-embed-text-v1",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_bedrock.assert_called_once_with(
|
||||
region_name="us-west-2", model_name="amazon.titan-embed-text-v1"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_jina() -> None:
|
||||
"""Test Jina embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory.JinaEmbeddingFunction") as mock_jina:
|
||||
mock_instance = MagicMock()
|
||||
mock_jina.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "jina",
|
||||
"api_key": "jina-key",
|
||||
"model_name": "jina-embeddings-v2-base-en",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_jina.assert_called_once_with(
|
||||
api_key="jina-key", model_name="jina-embeddings-v2-base-en"
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_unsupported_provider() -> None:
|
||||
"""Test handling of unsupported provider."""
|
||||
config = {"provider": "unsupported-provider"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: unsupported-provider"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_config_modification() -> None:
|
||||
"""Test that original config dict is not modified."""
|
||||
original_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
}
|
||||
config_copy = original_config.copy()
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction"):
|
||||
get_embedding_function(config_copy)
|
||||
|
||||
assert config_copy == original_config
|
||||
|
||||
|
||||
def test_get_embedding_function_exclude_none_values() -> None:
|
||||
"""Test that None values are excluded from embedding function calls."""
|
||||
with patch("crewai.rag.embeddings.factory.OpenAIEmbeddingFunction") as mock_openai:
|
||||
mock_instance = MagicMock()
|
||||
mock_openai.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(provider="openai", api_key="test-key", model=None)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "test-key"
|
||||
assert "model" not in call_kwargs
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_instructor() -> None:
|
||||
"""Test Instructor embedding function."""
|
||||
with patch(
|
||||
"crewai.rag.embeddings.factory.InstructorEmbeddingFunction"
|
||||
) as mock_instructor:
|
||||
mock_instance = MagicMock()
|
||||
mock_instructor.return_value = mock_instance
|
||||
|
||||
config = {"provider": "instructor", "model_name": "hkunlp/instructor-large"}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_instructor.assert_called_once_with(model_name="hkunlp/instructor-large")
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_watson() -> None:
|
||||
"""Test Watson embedding function."""
|
||||
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||
mock_instance = MagicMock()
|
||||
mock_watson.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "watson",
|
||||
"api_key": "watson-api-key",
|
||||
"api_url": "https://watson-url.com",
|
||||
"project_id": "watson-project-id",
|
||||
"model_name": "ibm/slate-125m-english-rtrvr",
|
||||
}
|
||||
|
||||
result = get_embedding_function(config)
|
||||
|
||||
mock_watson.assert_called_once_with(
|
||||
api_key="watson-api-key",
|
||||
api_url="https://watson-url.com",
|
||||
project_id="watson-project-id",
|
||||
model_name="ibm/slate-125m-english-rtrvr",
|
||||
)
|
||||
assert result == mock_instance
|
||||
|
||||
|
||||
def test_get_embedding_function_watson_missing_dependencies() -> None:
|
||||
"""Test Watson embedding function with missing dependencies."""
|
||||
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||
mock_watson.side_effect = ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
)
|
||||
|
||||
config = {
|
||||
"provider": "watson",
|
||||
"api_key": "watson-api-key",
|
||||
"api_url": "https://watson-url.com",
|
||||
"project_id": "watson-project-id",
|
||||
"model_name": "ibm/slate-125m-english-rtrvr",
|
||||
}
|
||||
|
||||
with pytest.raises(ImportError, match="IBM Watson dependencies are not installed"):
|
||||
get_embedding_function(config)
|
||||
|
||||
|
||||
def test_get_embedding_function_watson_with_embedding_options() -> None:
|
||||
"""Test Watson embedding function with EmbeddingOptions object."""
|
||||
with patch("crewai.rag.embeddings.factory._create_watson_embedding_function") as mock_watson:
|
||||
mock_instance = MagicMock()
|
||||
mock_watson.return_value = mock_instance
|
||||
|
||||
options = EmbeddingOptions(
|
||||
provider="watson",
|
||||
api_key="watson-key",
|
||||
model_name="ibm/slate-125m-english-rtrvr"
|
||||
)
|
||||
|
||||
result = get_embedding_function(options)
|
||||
|
||||
call_kwargs = mock_watson.call_args.kwargs
|
||||
assert "api_key" in call_kwargs
|
||||
assert call_kwargs["api_key"].get_secret_value() == "watson-key"
|
||||
assert call_kwargs["model_name"] == "ibm/slate-125m-english-rtrvr"
|
||||
assert result == mock_instance
|
||||
@@ -1,77 +0,0 @@
|
||||
"""Tests for Watson embedding function name method."""
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class TestWatsonEmbeddingName:
|
||||
"""Test Watson embedding function name method."""
|
||||
|
||||
def test_watson_embedding_function_has_name_method(self):
|
||||
"""Test that WatsonEmbeddingFunction has a name method."""
|
||||
assert hasattr(WatsonEmbeddingFunction, 'name')
|
||||
assert callable(WatsonEmbeddingFunction.name)
|
||||
|
||||
def test_watson_embedding_function_name_returns_watson(self):
|
||||
"""Test that the name method returns 'watson'."""
|
||||
assert WatsonEmbeddingFunction.name() == "watson"
|
||||
|
||||
def test_watson_embedding_function_name_is_static(self):
|
||||
"""Test that the name method can be called without instantiation."""
|
||||
name = WatsonEmbeddingFunction.name()
|
||||
assert name == "watson"
|
||||
assert isinstance(name, str)
|
||||
|
||||
def test_watson_embedding_function_name_with_chromadb_validation(self):
|
||||
"""Test that the name method works in ChromaDB validation scenario."""
|
||||
config = {
|
||||
"model_id": "test-model",
|
||||
"api_key": "test-key",
|
||||
"url": "https://test.com"
|
||||
}
|
||||
|
||||
watson_func = WatsonEmbeddingFunction(**config)
|
||||
|
||||
try:
|
||||
name = watson_func.name()
|
||||
assert name == "watson"
|
||||
except AttributeError as e:
|
||||
pytest.fail(f"ChromaDB validation failed with AttributeError: {e}")
|
||||
|
||||
def test_watson_embedding_function_name_method_signature(self):
|
||||
"""Test that the name method has the correct signature."""
|
||||
import inspect
|
||||
|
||||
name_method = WatsonEmbeddingFunction.name
|
||||
|
||||
assert isinstance(inspect.getattr_static(WatsonEmbeddingFunction, 'name'), staticmethod)
|
||||
|
||||
sig = inspect.signature(name_method)
|
||||
if sig.return_annotation != inspect.Signature.empty:
|
||||
assert sig.return_annotation is str
|
||||
|
||||
def test_watson_embedding_function_reproduces_original_issue(self):
|
||||
"""Test that reproduces the original issue scenario from #3597."""
|
||||
|
||||
|
||||
config = {
|
||||
"model_id": "ibm/slate-125m-english-rtrvr",
|
||||
"api_key": "test-key",
|
||||
"url": "https://us-south.ml.cloud.ibm.com",
|
||||
"project_id": "test-project"
|
||||
}
|
||||
|
||||
watson_func = WatsonEmbeddingFunction(**config)
|
||||
|
||||
name = watson_func.name()
|
||||
|
||||
assert name == "watson"
|
||||
assert isinstance(name, str)
|
||||
|
||||
class_name = WatsonEmbeddingFunction.name()
|
||||
assert class_name == "watson"
|
||||
assert class_name == name
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user