feat: upgrade chromadb to v1.1.0, improve types

- update imports and include handling for chromadb v1.1.0  
- fix mypy and typing_compat issues (required, typeddict, voyageai)  
- refine embedderconfig typing and allow base provider instances  
- handle mem0 as special case for external memory storage  
- bump tools and clean up redundant deps
This commit is contained in:
Greyson LaLonde
2025-09-25 20:48:37 -04:00
committed by GitHub
parent ce5ea9be6f
commit 2485ed93d6
35 changed files with 383 additions and 316 deletions

View File

@@ -1,17 +1,10 @@
import shutil
import subprocess
import time
from collections.abc import Callable, Sequence
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -19,12 +12,31 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
)
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.lite_agent import LiteAgent, LiteAgentOutput
from crewai.llm import BaseLLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security import Fingerprint
from crewai.task import Task
from crewai.tools import BaseTool
@@ -38,24 +50,6 @@ from crewai.utilities.agent_utils import (
)
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryRetrievalStartedEvent,
MemoryRetrievalCompletedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -87,36 +81,36 @@ class Agent(BaseAgent):
"""
_times_executed: int = PrivateAttr(default=0)
max_execution_time: Optional[int] = Field(
max_execution_time: int | None = Field(
default=None,
description="Maximum execution time for an agent to execute a task",
)
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
step_callback: Optional[Any] = Field(
step_callback: Any | None = Field(
default=None,
description="Callback to be executed after each step of the agent execution.",
)
use_system_prompt: Optional[bool] = Field(
use_system_prompt: bool | None = Field(
default=True,
description="Use system prompt for the agent.",
)
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
llm: str | InstanceOf[BaseLLM] | Any = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
description="Language model that will run the agent.", default=None
)
system_template: Optional[str] = Field(
system_template: str | None = Field(
default=None, description="System format for the agent."
)
prompt_template: Optional[str] = Field(
prompt_template: str | None = Field(
default=None, description="Prompt format for the agent."
)
response_template: Optional[str] = Field(
response_template: str | None = Field(
default=None, description="Response format for the agent."
)
allow_code_execution: Optional[bool] = Field(
allow_code_execution: bool | None = Field(
default=False, description="Enable code execution for the agent."
)
respect_context_window: bool = Field(
@@ -147,31 +141,31 @@ class Agent(BaseAgent):
default=False,
description="Whether the agent should reflect and create a plan before executing a task.",
)
max_reasoning_attempts: Optional[int] = Field(
max_reasoning_attempts: int | None = Field(
default=None,
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
)
embedder: Optional[Dict[str, Any]] = Field(
embedder: EmbedderConfig | None = Field(
default=None,
description="Embedder configuration for the agent.",
)
agent_knowledge_context: Optional[str] = Field(
agent_knowledge_context: str | None = Field(
default=None,
description="Knowledge context for the agent.",
)
crew_knowledge_context: Optional[str] = Field(
crew_knowledge_context: str | None = Field(
default=None,
description="Knowledge context for the crew.",
)
knowledge_search_query: Optional[str] = Field(
knowledge_search_query: str | None = Field(
default=None,
description="Knowledge search query for the agent dynamically generated by the agent.",
)
from_repository: Optional[str] = Field(
from_repository: str | None = Field(
default=None,
description="The Agent's role to be used from your repository.",
)
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
guardrail: Callable[[Any], tuple[bool, Any]] | str | None = Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
@@ -180,7 +174,7 @@ class Agent(BaseAgent):
)
@model_validator(mode="before")
def validate_from_repository(cls, v):
def validate_from_repository(cls, v): # noqa: N805
if v is not None and (from_repository := v.get("from_repository")):
return load_agent_from_repository(from_repository) | v
return v
@@ -208,7 +202,7 @@ class Agent(BaseAgent):
self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler)
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
try:
if self.embedder is None and crew_embedder:
self.embedder = crew_embedder
@@ -224,7 +218,7 @@ class Agent(BaseAgent):
)
self.knowledge.add_sources()
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
raise ValueError(f"Invalid Knowledge Configuration: {e!s}") from e
def _is_any_available_memory(self) -> bool:
"""Check if any memory is available."""
@@ -244,8 +238,8 @@ class Agent(BaseAgent):
def execute_task(
self,
task: Task,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task with the agent.
@@ -278,11 +272,9 @@ class Agent(BaseAgent):
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
if hasattr(self, "_logger"):
self._logger.log(
"error", f"Error during reasoning process: {str(e)}"
)
self._logger.log("error", f"Error during reasoning process: {e!s}")
else:
print(f"Error during reasoning process: {str(e)}")
print(f"Error during reasoning process: {e!s}")
self._inject_date_to_task(task)
@@ -335,7 +327,7 @@ class Agent(BaseAgent):
agent=self,
task=task,
)
memory = contextual_memory.build_context_for_task(task, context)
memory = contextual_memory.build_context_for_task(task, context) # type: ignore[arg-type]
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
@@ -525,14 +517,14 @@ class Agent(BaseAgent):
try:
return future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
except concurrent.futures.TimeoutError as e:
future.cancel()
raise TimeoutError(
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
)
) from e
except Exception as e:
future.cancel()
raise RuntimeError(f"Task execution failed: {str(e)}")
raise RuntimeError(f"Task execution failed: {e!s}") from e
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
"""Execute a task without a timeout.
@@ -554,14 +546,14 @@ class Agent(BaseAgent):
)["output"]
def create_agent_executor(
self, tools: Optional[List[BaseTool]] = None, task=None
self, tools: list[BaseTool] | None = None, task=None
) -> None:
"""Create an agent executor for the agent.
Returns:
An instance of the CrewAgentExecutor class.
"""
raw_tools: List[BaseTool] = tools or self.tools or []
raw_tools: list[BaseTool] = tools or self.tools or []
parsed_tools = parse_tools(raw_tools)
prompt = Prompts(
@@ -587,7 +579,7 @@ class Agent(BaseAgent):
agent=self,
crew=self.crew,
tools=parsed_tools,
prompt=prompt,
prompt=prompt, # type: ignore[arg-type]
original_tools=raw_tools,
stop_words=stop_words,
max_iter=self.max_iter,
@@ -603,10 +595,9 @@ class Agent(BaseAgent):
callbacks=[TokenCalcHandler(self._token_process)],
)
def get_delegation_tools(self, agents: List[BaseAgent]):
def get_delegation_tools(self, agents: list[BaseAgent]):
agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools()
return tools
return agent_tools.tools()
def get_multimodal_tools(self) -> Sequence[BaseTool]:
from crewai.tools.agent_tools.add_image_tool import AddImageTool
@@ -654,7 +645,7 @@ class Agent(BaseAgent):
)
return task_prompt
def _render_text_description(self, tools: List[Any]) -> str:
def _render_text_description(self, tools: list[Any]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:
@@ -664,15 +655,13 @@ class Agent(BaseAgent):
search: This tool is used for search
calculator: This tool is used for math
"""
description = "\n".join(
return "\n".join(
[
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
for tool in tools
]
)
return description
def _inject_date_to_task(self, task):
"""Inject the current date into the task description if inject_date is enabled."""
if self.inject_date:
@@ -696,13 +685,13 @@ class Agent(BaseAgent):
if not is_valid:
raise ValueError(f"Invalid date format: {self.date_format}")
current_date: str = datetime.now().strftime(self.date_format)
current_date = datetime.now().strftime(self.date_format)
task.description += f"\n\nCurrent Date: {current_date}"
except Exception as e:
if hasattr(self, "_logger"):
self._logger.log("warning", f"Failed to inject date: {str(e)}")
self._logger.log("warning", f"Failed to inject date: {e!s}")
else:
print(f"Warning: Failed to inject date: {str(e)}")
print(f"Warning: Failed to inject date: {e!s}")
def _validate_docker_installation(self) -> None:
"""Check if Docker is installed and running."""
@@ -713,15 +702,15 @@ class Agent(BaseAgent):
try:
subprocess.run(
["docker", "info"],
["/usr/bin/docker", "info"],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError:
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
)
) from e
def __repr__(self):
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
@@ -796,8 +785,8 @@ class Agent(BaseAgent):
def kickoff(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
) -> LiteAgentOutput:
"""
Execute the agent with the given messages using a LiteAgent instance.
@@ -836,8 +825,8 @@ class Agent(BaseAgent):
async def kickoff_async(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages using a LiteAgent instance.

View File

@@ -22,6 +22,7 @@ from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.utilities import I18N, Logger, RPMController
@@ -359,5 +360,5 @@ class BaseAgent(ABC, BaseModel):
self._rpm_controller = rpm_controller
self.create_agent_executor()
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
pass

View File

@@ -59,6 +59,7 @@ from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.process import Process
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.rag.types import SearchResult
from crewai.security import Fingerprint, SecurityConfig
from crewai.task import Task
@@ -168,7 +169,7 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="An Instance of the ExternalMemory to be used by the Crew",
)
embedder: dict | None = Field(
embedder: EmbedderConfig | None = Field(
default=None,
description="Configuration for the embedder to be used for the crew.",
)
@@ -622,7 +623,8 @@ class Crew(FlowTrackable, BaseModel):
training_data=training_data, agent_id=str(agent.id)
)
CrewTrainingHandler(filename).save_trained_data(
agent_id=str(agent.role), trained_data=result.model_dump()
agent_id=str(agent.role),
trained_data=result.model_dump(), # type: ignore[arg-type]
)
crewai_event_bus.emit(
@@ -1057,7 +1059,10 @@ class Crew(FlowTrackable, BaseModel):
def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
self._file_handler.log(
task_name=task.name, task=task.description, agent=role, status="started"
task_name=task.name, # type: ignore[arg-type]
task=task.description,
agent=role,
status="started",
)
def _update_manager_tools(
@@ -1086,7 +1091,7 @@ class Crew(FlowTrackable, BaseModel):
role = task.agent.role if task.agent is not None else "None"
if self.output_log_file:
self._file_handler.log(
task_name=task.name,
task_name=task.name, # type: ignore[arg-type]
task=task.description,
agent=role,
status="completed",

View File

@@ -1,10 +1,10 @@
import os
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.rag.types import SearchResult
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
@@ -16,20 +16,20 @@ class Knowledge(BaseModel):
Args:
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
embedder: EmbedderConfig | None = None
"""
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
embedder: dict[str, Any] | None = None
embedder: EmbedderConfig | None = None
collection_name: str | None = None
def __init__(
self,
collection_name: str,
sources: list[BaseKnowledgeSource],
embedder: dict[str, Any] | None = None,
embedder: EmbedderConfig | None = None,
storage: KnowledgeStorage | None = None,
**data,
):

View File

@@ -24,7 +24,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__(
self,
embedder: ProviderSpec | BaseEmbeddingsProvider | None = None,
embedder: ProviderSpec
| BaseEmbeddingsProvider
| type[BaseEmbeddingsProvider]
| None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
@@ -37,7 +40,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
)
if embedder:
embedding_function = build_embedder(embedder)
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function

View File

@@ -27,7 +27,10 @@ class EntityMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -35,7 +38,11 @@ class EntityMemory(Memory):
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
) from e
config = embedder_config.get("config") if embedder_config else None
config = (
embedder_config.get("config")
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
storage = (

View File

@@ -13,6 +13,7 @@ from crewai.events.types.memory_events import (
from crewai.memory.external.external_memory_item import ExternalMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.interface import Storage
from crewai.rag.embeddings.types import ProviderSpec
if TYPE_CHECKING:
from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -35,7 +36,9 @@ class ExternalMemory(Memory):
}
@staticmethod
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
def create_storage(
crew: Any, embedder_config: dict[str, Any] | ProviderSpec | None
) -> Storage:
if not embedder_config:
raise ValueError("embedder_config is required")
@@ -159,6 +162,6 @@ class ExternalMemory(Memory):
super().set_crew(crew)
if not self.storage:
self.storage = self.create_storage(crew, self.embedder_config)
self.storage = self.create_storage(crew, self.embedder_config) # type: ignore[arg-type]
return self

View File

@@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel
from crewai.rag.embeddings.types import EmbedderConfig
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
@@ -12,7 +14,7 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata.
"""
embedder_config: dict[str, Any] | None = None
embedder_config: EmbedderConfig | dict[str, Any] | None = None
crew: Any | None = None
storage: Any

View File

@@ -29,7 +29,10 @@ class ShortTermMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
memory_provider = embedder_config.get("provider") if embedder_config else None
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
@@ -37,7 +40,11 @@ class ShortTermMemory(Memory):
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
) from e
config = embedder_config.get("config") if embedder_config else None
config = (
embedder_config.get("config")
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
else:
storage = (

View File

@@ -10,7 +10,6 @@ from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Include,
IncludeEnum,
QueryResult,
)
@@ -142,9 +141,12 @@ def _extract_search_params(
score_threshold=kwargs.get("score_threshold"),
where=kwargs.get("where"),
where_document=kwargs.get("where_document"),
include=kwargs.get(
"include",
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
include=cast(
Include,
kwargs.get(
"include",
["metadatas", "documents", "distances"],
),
),
)
@@ -193,7 +195,7 @@ def _convert_chromadb_results_to_search_results(
"""
search_results: list[SearchResult] = []
include_strings = [item.value for item in include] if include else []
include_strings = list(include) if include else []
ids = results["ids"][0] if results.get("ids") else []

View File

@@ -1,5 +1,7 @@
"""Amazon Bedrock embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
@@ -7,15 +9,8 @@ from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
try:
from boto3.session import Session # type: ignore[import-untyped]
except ImportError as exc:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. Install it with: uv add boto3"
) from exc
def create_aws_session() -> Session:
def create_aws_session() -> Any:
"""Create an AWS session for Bedrock.
Returns:
@@ -53,6 +48,6 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
description="Model name to use for embeddings",
validation_alias="BEDROCK_MODEL_NAME",
)
session: Session = Field(
session: Any = Field(
default_factory=create_aws_session, description="AWS session object"
)

View File

@@ -1,6 +1,8 @@
"""Type definitions for AWS embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class BedrockProviderConfig(TypedDict, total=False):
@@ -10,8 +12,8 @@ class BedrockProviderConfig(TypedDict, total=False):
session: Any
class BedrockProviderSpec(TypedDict):
class BedrockProviderSpec(TypedDict, total=False):
"""Bedrock provider specification."""
provider: Literal["amazon-bedrock"]
provider: Required[Literal["amazon-bedrock"]]
config: BedrockProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Cohere embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class CohereProviderConfig(TypedDict, total=False):
@@ -10,8 +12,8 @@ class CohereProviderConfig(TypedDict, total=False):
model_name: Annotated[str, "large"]
class CohereProviderSpec(TypedDict):
class CohereProviderSpec(TypedDict, total=False):
"""Cohere provider specification."""
provider: Literal["cohere"]
provider: Required[Literal["cohere"]]
config: CohereProviderConfig

View File

@@ -1,8 +1,9 @@
"""Type definitions for custom embedding providers."""
from typing import Literal, TypedDict
from typing import Literal
from chromadb.api.types import EmbeddingFunction
from typing_extensions import Required, TypedDict
class CustomProviderConfig(TypedDict, total=False):
@@ -11,8 +12,8 @@ class CustomProviderConfig(TypedDict, total=False):
embedding_callable: type[EmbeddingFunction]
class CustomProviderSpec(TypedDict):
class CustomProviderSpec(TypedDict, total=False):
"""Custom provider specification."""
provider: Literal["custom"]
provider: Required[Literal["custom"]]
config: CustomProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Google embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class GenerativeAiProviderConfig(TypedDict, total=False):
@@ -27,8 +29,8 @@ class VertexAIProviderConfig(TypedDict, total=False):
region: Annotated[str, "us-central1"]
class VertexAIProviderSpec(TypedDict):
class VertexAIProviderSpec(TypedDict, total=False):
"""Vertex AI provider specification."""
provider: Literal["google-vertex"]
provider: Required[Literal["google-vertex"]]
config: VertexAIProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for HuggingFace embedding providers."""
from typing import Literal, TypedDict
from typing import Literal
from typing_extensions import Required, TypedDict
class HuggingFaceProviderConfig(TypedDict, total=False):
@@ -9,8 +11,8 @@ class HuggingFaceProviderConfig(TypedDict, total=False):
url: str
class HuggingFaceProviderSpec(TypedDict):
class HuggingFaceProviderSpec(TypedDict, total=False):
"""HuggingFace provider specification."""
provider: Literal["huggingface"]
provider: Required[Literal["huggingface"]]
config: HuggingFaceProviderConfig

View File

@@ -2,11 +2,6 @@
from typing import cast
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
EmbedTextParamsMetaNames as EmbedParams,
)
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
@@ -34,6 +29,21 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
Returns:
List of embedding vectors.
"""
try:
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai import (
Credentials, # type: ignore[import-not-found, import-untyped]
)
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"ibm-watsonx-ai is required for watson embeddings. "
"Install it with: uv add ibm-watsonx-ai"
) from e
if isinstance(input, str):
input = [input]

View File

@@ -1,6 +1,8 @@
"""Type definitions for IBM Watson embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class WatsonProviderConfig(TypedDict, total=False):
@@ -35,8 +37,8 @@ class WatsonProviderConfig(TypedDict, total=False):
proxies: dict
class WatsonProviderSpec(TypedDict):
class WatsonProviderSpec(TypedDict, total=False):
"""Watson provider specification."""
provider: Literal["watson"]
provider: Required[Literal["watson"]]
config: WatsonProviderConfig

View File

@@ -1,9 +1,7 @@
"""IBM Watson embeddings provider."""
from ibm_watsonx_ai import ( # type: ignore[import-not-found,import-untyped]
APIClient,
Credentials,
)
from typing import Any
from pydantic import Field, model_validator
from typing_extensions import Self
@@ -28,9 +26,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters"
)
credentials: Credentials | None = Field(
default=None, description="Watson credentials"
)
credentials: Any | None = Field(default=None, description="Watson credentials")
project_id: str | None = Field(
default=None,
description="Watson project ID",
@@ -39,7 +35,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
space_id: str | None = Field(
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
)
api_client: APIClient | None = Field(default=None, description="Watson API client")
api_client: Any | None = Field(default=None, description="Watson API client")
verify: bool | str | None = Field(
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
)

View File

@@ -1,6 +1,8 @@
"""Type definitions for Instructor embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class InstructorProviderConfig(TypedDict, total=False):
@@ -11,8 +13,8 @@ class InstructorProviderConfig(TypedDict, total=False):
instruction: str
class InstructorProviderSpec(TypedDict):
class InstructorProviderSpec(TypedDict, total=False):
"""Instructor provider specification."""
provider: Literal["instructor"]
provider: Required[Literal["instructor"]]
config: InstructorProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Jina embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class JinaProviderConfig(TypedDict, total=False):
@@ -10,8 +12,8 @@ class JinaProviderConfig(TypedDict, total=False):
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
class JinaProviderSpec(TypedDict):
class JinaProviderSpec(TypedDict, total=False):
"""Jina provider specification."""
provider: Literal["jina"]
provider: Required[Literal["jina"]]
config: JinaProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Microsoft Azure embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class AzureProviderConfig(TypedDict, total=False):
@@ -17,8 +19,8 @@ class AzureProviderConfig(TypedDict, total=False):
organization_id: str
class AzureProviderSpec(TypedDict):
class AzureProviderSpec(TypedDict, total=False):
"""Azure provider specification."""
provider: Literal["azure"]
provider: Required[Literal["azure"]]
config: AzureProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Ollama embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class OllamaProviderConfig(TypedDict, total=False):
@@ -10,8 +12,8 @@ class OllamaProviderConfig(TypedDict, total=False):
model_name: str
class OllamaProviderSpec(TypedDict):
class OllamaProviderSpec(TypedDict, total=False):
"""Ollama provider specification."""
provider: Literal["ollama"]
provider: Required[Literal["ollama"]]
config: OllamaProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for ONNX embedding providers."""
from typing import Literal, TypedDict
from typing import Literal
from typing_extensions import Required, TypedDict
class ONNXProviderConfig(TypedDict, total=False):
@@ -9,8 +11,8 @@ class ONNXProviderConfig(TypedDict, total=False):
preferred_providers: list[str]
class ONNXProviderSpec(TypedDict):
class ONNXProviderSpec(TypedDict, total=False):
"""ONNX provider specification."""
provider: Literal["onnx"]
provider: Required[Literal["onnx"]]
config: ONNXProviderConfig

View File

@@ -17,8 +17,8 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
default=OpenAIEmbeddingFunction,
description="OpenAI embedding function class",
)
api_key: str = Field(
description="OpenAI API key", validation_alias="OPENAI_API_KEY"
api_key: str | None = Field(
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
)
model_name: str = Field(
default="text-embedding-ada-002",

View File

@@ -1,6 +1,8 @@
"""Type definitions for OpenAI embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class OpenAIProviderConfig(TypedDict, total=False):
@@ -17,8 +19,8 @@ class OpenAIProviderConfig(TypedDict, total=False):
organization_id: str
class OpenAIProviderSpec(TypedDict):
class OpenAIProviderSpec(TypedDict, total=False):
"""OpenAI provider specification."""
provider: Literal["openai"]
provider: Required[Literal["openai"]]
config: OpenAIProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for OpenCLIP embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class OpenCLIPProviderConfig(TypedDict, total=False):
@@ -14,5 +16,5 @@ class OpenCLIPProviderConfig(TypedDict, total=False):
class OpenCLIPProviderSpec(TypedDict):
"""OpenCLIP provider specification."""
provider: Literal["openclip"]
provider: Required[Literal["openclip"]]
config: OpenCLIPProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Roboflow embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class RoboflowProviderConfig(TypedDict, total=False):
@@ -13,5 +15,5 @@ class RoboflowProviderConfig(TypedDict, total=False):
class RoboflowProviderSpec(TypedDict):
"""Roboflow provider specification."""
provider: Literal["roboflow"]
provider: Required[Literal["roboflow"]]
config: RoboflowProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for SentenceTransformer embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class SentenceTransformerProviderConfig(TypedDict, total=False):
@@ -14,5 +16,5 @@ class SentenceTransformerProviderConfig(TypedDict, total=False):
class SentenceTransformerProviderSpec(TypedDict):
"""SentenceTransformer provider specification."""
provider: Literal["sentence-transformer"]
provider: Required[Literal["sentence-transformer"]]
config: SentenceTransformerProviderConfig

View File

@@ -1,6 +1,8 @@
"""Type definitions for Text2Vec embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class Text2VecProviderConfig(TypedDict, total=False):
@@ -12,5 +14,5 @@ class Text2VecProviderConfig(TypedDict, total=False):
class Text2VecProviderSpec(TypedDict):
"""Text2Vec provider specification."""
provider: Literal["text2vec"]
provider: Required[Literal["text2vec"]]
config: Text2VecProviderConfig

View File

@@ -2,7 +2,6 @@
from typing import cast
import voyageai
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
@@ -19,6 +18,14 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
Args:
**kwargs: Configuration parameters for VoyageAI.
"""
try:
import voyageai # type: ignore[import-not-found]
except ImportError as e:
raise ImportError(
"voyageai is required for voyageai embeddings. "
"Install it with: uv add voyageai"
) from e
self._config = kwargs
self._client = voyageai.Client(
api_key=kwargs["api_key"],
@@ -35,6 +42,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]

View File

@@ -1,6 +1,8 @@
"""Type definitions for VoyageAI embedding providers."""
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class VoyageAIProviderConfig(TypedDict, total=False):
@@ -19,5 +21,5 @@ class VoyageAIProviderConfig(TypedDict, total=False):
class VoyageAIProviderSpec(TypedDict):
"""VoyageAI provider specification."""
provider: Literal["voyageai"]
provider: Required[Literal["voyageai"]]
config: VoyageAIProviderConfig

View File

@@ -1,7 +1,8 @@
"""Type definitions for the embeddings module."""
from typing import Literal
from typing import Literal, TypeAlias
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
@@ -66,3 +67,7 @@ AllowedEmbeddingProviders = Literal[
"voyageai",
"watson",
]
EmbedderConfig: TypeAlias = (
ProviderSpec | BaseEmbeddingsProvider | type[BaseEmbeddingsProvider]
)