mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-14 18:48:29 +00:00
feat: upgrade chromadb to v1.1.0, improve types
- update imports and include handling for chromadb v1.1.0 - fix mypy and typing_compat issues (required, typeddict, voyageai) - refine embedderconfig typing and allow base provider instances - handle mem0 as special case for external memory storage - bump tools and clean up redundant deps
This commit is contained in:
@@ -1,17 +1,10 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
@@ -19,12 +12,31 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security import Fingerprint
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
@@ -38,24 +50,6 @@ from crewai.utilities.agent_utils import (
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -87,36 +81,36 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
max_execution_time: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
)
|
||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
step_callback: Optional[Any] = Field(
|
||||
step_callback: Any | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
use_system_prompt: Optional[bool] = Field(
|
||||
use_system_prompt: bool | None = Field(
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
system_template: str | None = Field(
|
||||
default=None, description="System format for the agent."
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
prompt_template: str | None = Field(
|
||||
default=None, description="Prompt format for the agent."
|
||||
)
|
||||
response_template: Optional[str] = Field(
|
||||
response_template: str | None = Field(
|
||||
default=None, description="Response format for the agent."
|
||||
)
|
||||
allow_code_execution: Optional[bool] = Field(
|
||||
allow_code_execution: bool | None = Field(
|
||||
default=False, description="Enable code execution for the agent."
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
@@ -147,31 +141,31 @@ class Agent(BaseAgent):
|
||||
default=False,
|
||||
description="Whether the agent should reflect and create a plan before executing a task.",
|
||||
)
|
||||
max_reasoning_attempts: Optional[int] = Field(
|
||||
max_reasoning_attempts: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||
)
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
embedder: EmbedderConfig | None = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
agent_knowledge_context: Optional[str] = Field(
|
||||
agent_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the agent.",
|
||||
)
|
||||
crew_knowledge_context: Optional[str] = Field(
|
||||
crew_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the crew.",
|
||||
)
|
||||
knowledge_search_query: Optional[str] = Field(
|
||||
knowledge_search_query: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||
)
|
||||
from_repository: Optional[str] = Field(
|
||||
from_repository: str | None = Field(
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
guardrail: Callable[[Any], tuple[bool, Any]] | str | None = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
@@ -180,7 +174,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v):
|
||||
def validate_from_repository(cls, v): # noqa: N805
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
@@ -208,7 +202,7 @@ class Agent(BaseAgent):
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
@@ -224,7 +218,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
self.knowledge.add_sources()
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {e!s}") from e
|
||||
|
||||
def _is_any_available_memory(self) -> bool:
|
||||
"""Check if any memory is available."""
|
||||
@@ -244,8 +238,8 @@ class Agent(BaseAgent):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
@@ -278,11 +272,9 @@ class Agent(BaseAgent):
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
if hasattr(self, "_logger"):
|
||||
self._logger.log(
|
||||
"error", f"Error during reasoning process: {str(e)}"
|
||||
)
|
||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
else:
|
||||
print(f"Error during reasoning process: {str(e)}")
|
||||
print(f"Error during reasoning process: {e!s}")
|
||||
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
@@ -335,7 +327,7 @@ class Agent(BaseAgent):
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
memory = contextual_memory.build_context_for_task(task, context) # type: ignore[arg-type]
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
@@ -525,14 +517,14 @@ class Agent(BaseAgent):
|
||||
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
except concurrent.futures.TimeoutError as e:
|
||||
future.cancel()
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
||||
)
|
||||
) from e
|
||||
except Exception as e:
|
||||
future.cancel()
|
||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
||||
raise RuntimeError(f"Task execution failed: {e!s}") from e
|
||||
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||
"""Execute a task without a timeout.
|
||||
@@ -554,14 +546,14 @@ class Agent(BaseAgent):
|
||||
)["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
self, tools: list[BaseTool] | None = None, task=None
|
||||
) -> None:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the CrewAgentExecutor class.
|
||||
"""
|
||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
prompt = Prompts(
|
||||
@@ -587,7 +579,7 @@ class Agent(BaseAgent):
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
tools=parsed_tools,
|
||||
prompt=prompt,
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
original_tools=raw_tools,
|
||||
stop_words=stop_words,
|
||||
max_iter=self.max_iter,
|
||||
@@ -603,10 +595,9 @@ class Agent(BaseAgent):
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]):
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
return agent_tools.tools()
|
||||
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
@@ -654,7 +645,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
@@ -664,15 +655,13 @@ class Agent(BaseAgent):
|
||||
search: This tool is used for search
|
||||
calculator: This tool is used for math
|
||||
"""
|
||||
description = "\n".join(
|
||||
return "\n".join(
|
||||
[
|
||||
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
|
||||
return description
|
||||
|
||||
def _inject_date_to_task(self, task):
|
||||
"""Inject the current date into the task description if inject_date is enabled."""
|
||||
if self.inject_date:
|
||||
@@ -696,13 +685,13 @@ class Agent(BaseAgent):
|
||||
if not is_valid:
|
||||
raise ValueError(f"Invalid date format: {self.date_format}")
|
||||
|
||||
current_date: str = datetime.now().strftime(self.date_format)
|
||||
current_date = datetime.now().strftime(self.date_format)
|
||||
task.description += f"\n\nCurrent Date: {current_date}"
|
||||
except Exception as e:
|
||||
if hasattr(self, "_logger"):
|
||||
self._logger.log("warning", f"Failed to inject date: {str(e)}")
|
||||
self._logger.log("warning", f"Failed to inject date: {e!s}")
|
||||
else:
|
||||
print(f"Warning: Failed to inject date: {str(e)}")
|
||||
print(f"Warning: Failed to inject date: {e!s}")
|
||||
|
||||
def _validate_docker_installation(self) -> None:
|
||||
"""Check if Docker is installed and running."""
|
||||
@@ -713,15 +702,15 @@ class Agent(BaseAgent):
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "info"],
|
||||
["/usr/bin/docker", "info"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
)
|
||||
) from e
|
||||
|
||||
def __repr__(self):
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
@@ -796,8 +785,8 @@ class Agent(BaseAgent):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
@@ -836,8 +825,8 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
|
||||
@@ -22,6 +22,7 @@ from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
@@ -359,5 +360,5 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
|
||||
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
|
||||
pass
|
||||
|
||||
@@ -59,6 +59,7 @@ from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.task import Task
|
||||
@@ -168,7 +169,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="An Instance of the ExternalMemory to be used by the Crew",
|
||||
)
|
||||
embedder: dict | None = Field(
|
||||
embedder: EmbedderConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
@@ -622,7 +623,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||
agent_id=str(agent.role),
|
||||
trained_data=result.model_dump(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -1057,7 +1059,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name, task=task.description, agent=role, status="started"
|
||||
task_name=task.name, # type: ignore[arg-type]
|
||||
task=task.description,
|
||||
agent=role,
|
||||
status="started",
|
||||
)
|
||||
|
||||
def _update_manager_tools(
|
||||
@@ -1086,7 +1091,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
role = task.agent.role if task.agent is not None else "None"
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name,
|
||||
task_name=task.name, # type: ignore[arg-type]
|
||||
task=task.description,
|
||||
agent=role,
|
||||
status="completed",
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
@@ -16,20 +16,20 @@ class Knowledge(BaseModel):
|
||||
Args:
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
embedder: EmbedderConfig | None = None
|
||||
"""
|
||||
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
embedder: EmbedderConfig | None = None
|
||||
collection_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: dict[str, Any] | None = None,
|
||||
embedder: EmbedderConfig | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
|
||||
@@ -24,7 +24,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
embedder: ProviderSpec
|
||||
| BaseEmbeddingsProvider
|
||||
| type[BaseEmbeddingsProvider]
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
@@ -37,7 +40,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
|
||||
if embedder:
|
||||
embedding_function = build_embedder(embedder)
|
||||
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
|
||||
@@ -27,7 +27,10 @@ class EntityMemory(Memory):
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
@@ -35,7 +38,11 @@ class EntityMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
config = (
|
||||
embedder_config.get("config")
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
|
||||
@@ -13,6 +13,7 @@ from crewai.events.types.memory_events import (
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
@@ -35,7 +36,9 @@ class ExternalMemory(Memory):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||
def create_storage(
|
||||
crew: Any, embedder_config: dict[str, Any] | ProviderSpec | None
|
||||
) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
|
||||
@@ -159,6 +162,6 @@ class ExternalMemory(Memory):
|
||||
super().set_crew(crew)
|
||||
|
||||
if not self.storage:
|
||||
self.storage = self.create_storage(crew, self.embedder_config)
|
||||
self.storage = self.create_storage(crew, self.embedder_config) # type: ignore[arg-type]
|
||||
|
||||
return self
|
||||
|
||||
@@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
@@ -12,7 +14,7 @@ class Memory(BaseModel):
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
|
||||
embedder_config: dict[str, Any] | None = None
|
||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
|
||||
@@ -29,7 +29,10 @@ class ShortTermMemory(Memory):
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
@@ -37,7 +40,11 @@ class ShortTermMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
config = (
|
||||
embedder_config.get("config")
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
|
||||
@@ -10,7 +10,6 @@ from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
|
||||
@@ -142,9 +141,12 @@ def _extract_search_params(
|
||||
score_threshold=kwargs.get("score_threshold"),
|
||||
where=kwargs.get("where"),
|
||||
where_document=kwargs.get("where_document"),
|
||||
include=kwargs.get(
|
||||
"include",
|
||||
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
|
||||
include=cast(
|
||||
Include,
|
||||
kwargs.get(
|
||||
"include",
|
||||
["metadatas", "documents", "distances"],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -193,7 +195,7 @@ def _convert_chromadb_results_to_search_results(
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = [item.value for item in include] if include else []
|
||||
include_strings = list(include) if include else []
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
@@ -7,15 +9,8 @@ from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
try:
|
||||
from boto3.session import Session # type: ignore[import-untyped]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. Install it with: uv add boto3"
|
||||
) from exc
|
||||
|
||||
|
||||
def create_aws_session() -> Session:
|
||||
def create_aws_session() -> Any:
|
||||
"""Create an AWS session for Bedrock.
|
||||
|
||||
Returns:
|
||||
@@ -53,6 +48,6 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="BEDROCK_MODEL_NAME",
|
||||
)
|
||||
session: Session = Field(
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for AWS embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class BedrockProviderConfig(TypedDict, total=False):
|
||||
@@ -10,8 +12,8 @@ class BedrockProviderConfig(TypedDict, total=False):
|
||||
session: Any
|
||||
|
||||
|
||||
class BedrockProviderSpec(TypedDict):
|
||||
class BedrockProviderSpec(TypedDict, total=False):
|
||||
"""Bedrock provider specification."""
|
||||
|
||||
provider: Literal["amazon-bedrock"]
|
||||
provider: Required[Literal["amazon-bedrock"]]
|
||||
config: BedrockProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Cohere embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CohereProviderConfig(TypedDict, total=False):
|
||||
@@ -10,8 +12,8 @@ class CohereProviderConfig(TypedDict, total=False):
|
||||
model_name: Annotated[str, "large"]
|
||||
|
||||
|
||||
class CohereProviderSpec(TypedDict):
|
||||
class CohereProviderSpec(TypedDict, total=False):
|
||||
"""Cohere provider specification."""
|
||||
|
||||
provider: Literal["cohere"]
|
||||
provider: Required[Literal["cohere"]]
|
||||
config: CohereProviderConfig
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Type definitions for custom embedding providers."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.api.types import EmbeddingFunction
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CustomProviderConfig(TypedDict, total=False):
|
||||
@@ -11,8 +12,8 @@ class CustomProviderConfig(TypedDict, total=False):
|
||||
embedding_callable: type[EmbeddingFunction]
|
||||
|
||||
|
||||
class CustomProviderSpec(TypedDict):
|
||||
class CustomProviderSpec(TypedDict, total=False):
|
||||
"""Custom provider specification."""
|
||||
|
||||
provider: Literal["custom"]
|
||||
provider: Required[Literal["custom"]]
|
||||
config: CustomProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Google embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
@@ -27,8 +29,8 @@ class VertexAIProviderConfig(TypedDict, total=False):
|
||||
region: Annotated[str, "us-central1"]
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict):
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
"""Vertex AI provider specification."""
|
||||
|
||||
provider: Literal["google-vertex"]
|
||||
provider: Required[Literal["google-vertex"]]
|
||||
config: VertexAIProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
@@ -9,8 +11,8 @@ class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
url: str
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict):
|
||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||
"""HuggingFace provider specification."""
|
||||
|
||||
provider: Literal["huggingface"]
|
||||
provider: Required[Literal["huggingface"]]
|
||||
config: HuggingFaceProviderConfig
|
||||
|
||||
@@ -2,11 +2,6 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
@@ -34,6 +29,21 @@ class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai import (
|
||||
Credentials, # type: ignore[import-not-found, import-untyped]
|
||||
)
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ibm-watsonx-ai is required for watson embeddings. "
|
||||
"Install it with: uv add ibm-watsonx-ai"
|
||||
) from e
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for IBM Watson embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class WatsonProviderConfig(TypedDict, total=False):
|
||||
@@ -35,8 +37,8 @@ class WatsonProviderConfig(TypedDict, total=False):
|
||||
proxies: dict
|
||||
|
||||
|
||||
class WatsonProviderSpec(TypedDict):
|
||||
class WatsonProviderSpec(TypedDict, total=False):
|
||||
"""Watson provider specification."""
|
||||
|
||||
provider: Literal["watson"]
|
||||
provider: Required[Literal["watson"]]
|
||||
config: WatsonProviderConfig
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""IBM Watson embeddings provider."""
|
||||
|
||||
from ibm_watsonx_ai import ( # type: ignore[import-not-found,import-untyped]
|
||||
APIClient,
|
||||
Credentials,
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -28,9 +26,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
)
|
||||
credentials: Credentials | None = Field(
|
||||
default=None, description="Watson credentials"
|
||||
)
|
||||
credentials: Any | None = Field(default=None, description="Watson credentials")
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="Watson project ID",
|
||||
@@ -39,7 +35,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
space_id: str | None = Field(
|
||||
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
|
||||
)
|
||||
api_client: APIClient | None = Field(default=None, description="Watson API client")
|
||||
api_client: Any | None = Field(default=None, description="Watson API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Instructor embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class InstructorProviderConfig(TypedDict, total=False):
|
||||
@@ -11,8 +13,8 @@ class InstructorProviderConfig(TypedDict, total=False):
|
||||
instruction: str
|
||||
|
||||
|
||||
class InstructorProviderSpec(TypedDict):
|
||||
class InstructorProviderSpec(TypedDict, total=False):
|
||||
"""Instructor provider specification."""
|
||||
|
||||
provider: Literal["instructor"]
|
||||
provider: Required[Literal["instructor"]]
|
||||
config: InstructorProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Jina embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class JinaProviderConfig(TypedDict, total=False):
|
||||
@@ -10,8 +12,8 @@ class JinaProviderConfig(TypedDict, total=False):
|
||||
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
|
||||
|
||||
|
||||
class JinaProviderSpec(TypedDict):
|
||||
class JinaProviderSpec(TypedDict, total=False):
|
||||
"""Jina provider specification."""
|
||||
|
||||
provider: Literal["jina"]
|
||||
provider: Required[Literal["jina"]]
|
||||
config: JinaProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Microsoft Azure embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class AzureProviderConfig(TypedDict, total=False):
|
||||
@@ -17,8 +19,8 @@ class AzureProviderConfig(TypedDict, total=False):
|
||||
organization_id: str
|
||||
|
||||
|
||||
class AzureProviderSpec(TypedDict):
|
||||
class AzureProviderSpec(TypedDict, total=False):
|
||||
"""Azure provider specification."""
|
||||
|
||||
provider: Literal["azure"]
|
||||
provider: Required[Literal["azure"]]
|
||||
config: AzureProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Ollama embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OllamaProviderConfig(TypedDict, total=False):
|
||||
@@ -10,8 +12,8 @@ class OllamaProviderConfig(TypedDict, total=False):
|
||||
model_name: str
|
||||
|
||||
|
||||
class OllamaProviderSpec(TypedDict):
|
||||
class OllamaProviderSpec(TypedDict, total=False):
|
||||
"""Ollama provider specification."""
|
||||
|
||||
provider: Literal["ollama"]
|
||||
provider: Required[Literal["ollama"]]
|
||||
config: OllamaProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for ONNX embedding providers."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class ONNXProviderConfig(TypedDict, total=False):
|
||||
@@ -9,8 +11,8 @@ class ONNXProviderConfig(TypedDict, total=False):
|
||||
preferred_providers: list[str]
|
||||
|
||||
|
||||
class ONNXProviderSpec(TypedDict):
|
||||
class ONNXProviderSpec(TypedDict, total=False):
|
||||
"""ONNX provider specification."""
|
||||
|
||||
provider: Literal["onnx"]
|
||||
provider: Required[Literal["onnx"]]
|
||||
config: ONNXProviderConfig
|
||||
|
||||
@@ -17,8 +17,8 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="OpenAI API key", validation_alias="OPENAI_API_KEY"
|
||||
api_key: str | None = Field(
|
||||
default=None, description="OpenAI API key", validation_alias="OPENAI_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for OpenAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenAIProviderConfig(TypedDict, total=False):
|
||||
@@ -17,8 +19,8 @@ class OpenAIProviderConfig(TypedDict, total=False):
|
||||
organization_id: str
|
||||
|
||||
|
||||
class OpenAIProviderSpec(TypedDict):
|
||||
class OpenAIProviderSpec(TypedDict, total=False):
|
||||
"""OpenAI provider specification."""
|
||||
|
||||
provider: Literal["openai"]
|
||||
provider: Required[Literal["openai"]]
|
||||
config: OpenAIProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for OpenCLIP embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenCLIPProviderConfig(TypedDict, total=False):
|
||||
@@ -14,5 +16,5 @@ class OpenCLIPProviderConfig(TypedDict, total=False):
|
||||
class OpenCLIPProviderSpec(TypedDict):
|
||||
"""OpenCLIP provider specification."""
|
||||
|
||||
provider: Literal["openclip"]
|
||||
provider: Required[Literal["openclip"]]
|
||||
config: OpenCLIPProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Roboflow embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class RoboflowProviderConfig(TypedDict, total=False):
|
||||
@@ -13,5 +15,5 @@ class RoboflowProviderConfig(TypedDict, total=False):
|
||||
class RoboflowProviderSpec(TypedDict):
|
||||
"""Roboflow provider specification."""
|
||||
|
||||
provider: Literal["roboflow"]
|
||||
provider: Required[Literal["roboflow"]]
|
||||
config: RoboflowProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for SentenceTransformer embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class SentenceTransformerProviderConfig(TypedDict, total=False):
|
||||
@@ -14,5 +16,5 @@ class SentenceTransformerProviderConfig(TypedDict, total=False):
|
||||
class SentenceTransformerProviderSpec(TypedDict):
|
||||
"""SentenceTransformer provider specification."""
|
||||
|
||||
provider: Literal["sentence-transformer"]
|
||||
provider: Required[Literal["sentence-transformer"]]
|
||||
config: SentenceTransformerProviderConfig
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for Text2Vec embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class Text2VecProviderConfig(TypedDict, total=False):
|
||||
@@ -12,5 +14,5 @@ class Text2VecProviderConfig(TypedDict, total=False):
|
||||
class Text2VecProviderSpec(TypedDict):
|
||||
"""Text2Vec provider specification."""
|
||||
|
||||
provider: Literal["text2vec"]
|
||||
provider: Required[Literal["text2vec"]]
|
||||
config: Text2VecProviderConfig
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
import voyageai
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
@@ -19,6 +18,14 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
Args:
|
||||
**kwargs: Configuration parameters for VoyageAI.
|
||||
"""
|
||||
try:
|
||||
import voyageai # type: ignore[import-not-found]
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"voyageai is required for voyageai embeddings. "
|
||||
"Install it with: uv add voyageai"
|
||||
) from e
|
||||
self._config = kwargs
|
||||
self._client = voyageai.Client(
|
||||
api_key=kwargs["api_key"],
|
||||
@@ -35,6 +42,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Type definitions for VoyageAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class VoyageAIProviderConfig(TypedDict, total=False):
|
||||
@@ -19,5 +21,5 @@ class VoyageAIProviderConfig(TypedDict, total=False):
|
||||
class VoyageAIProviderSpec(TypedDict):
|
||||
"""VoyageAI provider specification."""
|
||||
|
||||
provider: Literal["voyageai"]
|
||||
provider: Required[Literal["voyageai"]]
|
||||
config: VoyageAIProviderConfig
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Type definitions for the embeddings module."""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
@@ -66,3 +67,7 @@ AllowedEmbeddingProviders = Literal[
|
||||
"voyageai",
|
||||
"watson",
|
||||
]
|
||||
|
||||
EmbedderConfig: TypeAlias = (
|
||||
ProviderSpec | BaseEmbeddingsProvider | type[BaseEmbeddingsProvider]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user