Compare commits

...

6 Commits

Author SHA1 Message Date
Brandon Hancock
21f4b60754 Include embedding type fix 2025-03-20 08:55:30 -04:00
Brandon Hancock (bhancock_ai)
216ff4aa6f Merge branch 'main' into embedding-config-typing 2025-03-20 08:47:33 -04:00
Nick Fujita
6f849c0e6d 'added docs for config based on agent review' 2025-02-20 18:11:16 +09:00
Nick Fujita
276f661e6c 'add specific providers to provider type' 2025-02-20 18:02:36 +09:00
Nick Fujita
8f99caf61b 'type cleanup' 2025-02-20 17:58:46 +09:00
Nick Fujita
f4642f11cc 'add typings to embedding configurator input arg' 2025-02-20 17:52:13 +09:00
6 changed files with 139 additions and 56 deletions

View File

@@ -20,6 +20,7 @@ from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.utilities import Converter, Prompts from crewai.utilities import Converter, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description from crewai.utilities.converter import generate_model_description
from crewai.utilities.embedding_configurator import EmbeddingConfig
from crewai.utilities.events.agent_events import ( from crewai.utilities.events.agent_events import (
AgentExecutionCompletedEvent, AgentExecutionCompletedEvent,
AgentExecutionErrorEvent, AgentExecutionErrorEvent,
@@ -108,7 +109,7 @@ class Agent(BaseAgent):
default="safe", default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).", description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
) )
embedder: Optional[Dict[str, Any]] = Field( embedder: Optional[EmbeddingConfig] = Field(
default=None, default=None,
description="Embedder configuration for the agent.", description="Embedder configuration for the agent.",
) )
@@ -134,7 +135,7 @@ class Agent(BaseAgent):
self.cache_handler = CacheHandler() self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler) self.set_cache_handler(self.cache_handler)
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): def set_knowledge(self, crew_embedder: Optional[EmbeddingConfig] = None):
try: try:
if self.embedder is None and crew_embedder: if self.embedder is None and crew_embedder:
self.embedder = crew_embedder self.embedder = crew_embedder

View File

@@ -25,6 +25,7 @@ from crewai.tools.base_tool import BaseTool, Tool
from crewai.utilities import I18N, Logger, RPMController from crewai.utilities import I18N, Logger, RPMController
from crewai.utilities.config import process_config from crewai.utilities.config import process_config
from crewai.utilities.converter import Converter from crewai.utilities.converter import Converter
from crewai.utilities.embedding_configurator import EmbeddingConfig
T = TypeVar("T", bound="BaseAgent") T = TypeVar("T", bound="BaseAgent")
@@ -362,5 +363,5 @@ class BaseAgent(ABC, BaseModel):
self._rpm_controller = rpm_controller self._rpm_controller = rpm_controller
self.create_agent_executor() self.create_agent_executor()
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): def set_knowledge(self, crew_embedder: Optional[EmbeddingConfig] = None):
pass pass

View File

@@ -41,6 +41,7 @@ from crewai.tools.base_tool import Tool
from crewai.types.usage_metrics import UsageMetrics from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINING_DATA_FILE from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.embedding_configurator import EmbeddingConfig
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.events.crew_events import ( from crewai.utilities.events.crew_events import (
@@ -145,7 +146,7 @@ class Crew(BaseModel):
default=None, default=None,
description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.", description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.",
) )
embedder: Optional[dict] = Field( embedder: Optional[EmbeddingConfig] = Field(
default=None, default=None,
description="Configuration for the embedder to be used for the crew.", description="Configuration for the embedder to be used for the crew.",
) )

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.utilities.embedding_configurator import EmbeddingConfig
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
@@ -21,14 +22,14 @@ class Knowledge(BaseModel):
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: List[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None) storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None embedder: Optional[EmbeddingConfig] = None
collection_name: Optional[str] = None collection_name: Optional[str] = None
def __init__( def __init__(
self, self,
collection_name: str, collection_name: str,
sources: List[BaseKnowledgeSource], sources: List[BaseKnowledgeSource],
embedder: Optional[Dict[str, Any]] = None, embedder: Optional[EmbeddingConfig] = None,
storage: Optional[KnowledgeStorage] = None, storage: Optional[KnowledgeStorage] = None,
**data, **data,
): ):

View File

@@ -15,6 +15,7 @@ from chromadb.config import Settings
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.embedding_configurator import EmbeddingConfig
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
@@ -48,7 +49,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__( def __init__(
self, self,
embedder: Optional[Dict[str, Any]] = None, embedder: Optional[EmbeddingConfig] = None,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
): ):
self.collection_name = collection_name self.collection_name = collection_name
@@ -187,7 +188,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
) )
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None: def _set_embedder_config(self, embedder: Optional[EmbeddingConfig] = None) -> None:
"""Set the embedding configuration for the knowledge storage. """Set the embedding configuration for the knowledge storage.
Args: Args:

View File

@@ -1,8 +1,84 @@
import os import os
from typing import Any, Dict, Optional, cast from typing import Any, Callable, Literal, cast
from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
from pydantic import BaseModel
class EmbeddingProviderConfig(BaseModel):
"""Configuration model for embedding providers.
Attributes:
# Core Model Configuration
model (str | None): The model identifier for embeddings, used across multiple providers
like OpenAI, Azure, Watson, etc.
embedder (str | Callable | None): Custom embedding function or callable for custom
embedding implementations.
# API Authentication & Configuration
api_key (str | None): Authentication key for various providers (OpenAI, VertexAI,
Google, Cohere, VoyageAI, Watson).
api_base (str | None): Base API URL override for OpenAI and Azure services.
api_type (str | None): API type specification, particularly used for Azure configuration.
api_version (str | None): API version for OpenAI and Azure services.
api_url (str | None): API endpoint URL, used by HuggingFace and Watson services.
url (str | None): Base URL for the embedding service, primarily used for Ollama and
HuggingFace endpoints.
# Service-Specific Configuration
project_id (str | None): Project identifier used by VertexAI and Watson services.
organization_id (str | None): Organization identifier for OpenAI and Azure services.
deployment_id (str | None): Deployment identifier for OpenAI and Azure services.
region (str | None): Geographic region for VertexAI services.
session (str | None): Session configuration for Amazon Bedrock embeddings.
# Request Configuration
task_type (str | None): Specifies the task type for Google Generative AI embeddings.
default_headers (str | None): Custom headers for OpenAI and Azure API requests.
dimensions (str | None): Output dimensions specification for OpenAI and Azure embeddings.
"""
# Core Model Configuration
model: str | None = None
embedder: str | Callable | None = None
# API Authentication & Configuration
api_key: str | None = None
api_base: str | None = None
api_type: str | None = None
api_version: str | None = None
api_url: str | None = None
url: str | None = None
# Service-Specific Configuration
project_id: str | None = None
organization_id: str | None = None
deployment_id: str | None = None
region: str | None = None
session: str | None = None
# Request Configuration
task_type: str | None = None
default_headers: str | None = None
dimensions: str | None = None
class EmbeddingConfig(BaseModel):
provider: Literal[
"openai",
"azure",
"ollama",
"vertexai",
"google",
"cohere",
"voyageai",
"bedrock",
"huggingface",
"watson",
"custom",
]
config: EmbeddingProviderConfig | None = None
class EmbeddingConfigurator: class EmbeddingConfigurator:
@@ -23,15 +99,19 @@ class EmbeddingConfigurator:
def configure_embedder( def configure_embedder(
self, self,
embedder_config: Optional[Dict[str, Any]] = None, embedder_config: EmbeddingConfig | None = None,
) -> EmbeddingFunction: ) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config.""" """Configures and returns an embedding function based on the provided config."""
if embedder_config is None: if embedder_config is None:
return self._create_default_embedding_function() return self._create_default_embedding_function()
provider = embedder_config.get("provider") provider = embedder_config.provider
config = embedder_config.get("config", {}) config = (
model_name = config.get("model") if provider != "custom" else None embedder_config.config
if embedder_config.config
else EmbeddingProviderConfig()
)
model_name = config.model if provider != "custom" else None
if provider not in self.embedding_functions: if provider not in self.embedding_functions:
raise Exception( raise Exception(
@@ -56,123 +136,123 @@ class EmbeddingConfigurator:
) )
@staticmethod @staticmethod
def _configure_openai(config, model_name): def _configure_openai(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.openai_embedding_function import ( from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction, OpenAIEmbeddingFunction,
) )
return OpenAIEmbeddingFunction( return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), api_key=config.api_key or os.getenv("OPENAI_API_KEY"),
model_name=model_name, model_name=model_name,
api_base=config.get("api_base", None), api_base=config.api_base,
api_type=config.get("api_type", None), api_type=config.api_type,
api_version=config.get("api_version", None), api_version=config.api_version,
default_headers=config.get("default_headers", None), default_headers=config.default_headers,
dimensions=config.get("dimensions", None), dimensions=config.dimensions,
deployment_id=config.get("deployment_id", None), deployment_id=config.deployment_id,
organization_id=config.get("organization_id", None), organization_id=config.organization_id,
) )
@staticmethod @staticmethod
def _configure_azure(config, model_name): def _configure_azure(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.openai_embedding_function import ( from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction, OpenAIEmbeddingFunction,
) )
return OpenAIEmbeddingFunction( return OpenAIEmbeddingFunction(
api_key=config.get("api_key"), api_key=config.api_key,
api_base=config.get("api_base"), api_base=config.api_base,
api_type=config.get("api_type", "azure"), api_type=config.api_type if config.api_type else "azure",
api_version=config.get("api_version"), api_version=config.api_version,
model_name=model_name, model_name=model_name,
default_headers=config.get("default_headers"), default_headers=config.default_headers,
dimensions=config.get("dimensions"), dimensions=config.dimensions,
deployment_id=config.get("deployment_id"), deployment_id=config.deployment_id,
organization_id=config.get("organization_id"), organization_id=config.organization_id,
) )
@staticmethod @staticmethod
def _configure_ollama(config, model_name): def _configure_ollama(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.ollama_embedding_function import ( from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction, OllamaEmbeddingFunction,
) )
return OllamaEmbeddingFunction( return OllamaEmbeddingFunction(
url=config.get("url", "http://localhost:11434/api/embeddings"), url=config.url if config.url else "http://localhost:11434/api/embeddings",
model_name=model_name, model_name=model_name,
) )
@staticmethod @staticmethod
def _configure_vertexai(config, model_name): def _configure_vertexai(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.google_embedding_function import ( from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction, GoogleVertexEmbeddingFunction,
) )
return GoogleVertexEmbeddingFunction( return GoogleVertexEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.api_key,
project_id=config.get("project_id"), project_id=config.project_id,
region=config.get("region"), region=config.region,
) )
@staticmethod @staticmethod
def _configure_google(config, model_name): def _configure_google(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.google_embedding_function import ( from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction, GoogleGenerativeAiEmbeddingFunction,
) )
return GoogleGenerativeAiEmbeddingFunction( return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.api_key,
task_type=config.get("task_type"), task_type=config.task_type,
) )
@staticmethod @staticmethod
def _configure_cohere(config, model_name): def _configure_cohere(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.cohere_embedding_function import ( from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction, CohereEmbeddingFunction,
) )
return CohereEmbeddingFunction( return CohereEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.api_key,
) )
@staticmethod @staticmethod
def _configure_voyageai(config, model_name): def _configure_voyageai(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( from chromadb.utils.embedding_functions.voyageai_embedding_function import (
VoyageAIEmbeddingFunction, VoyageAIEmbeddingFunction,
) )
return VoyageAIEmbeddingFunction( return VoyageAIEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.api_key,
) )
@staticmethod @staticmethod
def _configure_bedrock(config, model_name): def _configure_bedrock(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction, AmazonBedrockEmbeddingFunction,
) )
# Allow custom model_name override with backwards compatibility # Allow custom model_name override with backwards compatibility
kwargs = {"session": config.get("session")} kwargs = {"session": config.session}
if model_name is not None: if model_name is not None:
kwargs["model_name"] = model_name kwargs["model_name"] = model_name
return AmazonBedrockEmbeddingFunction(**kwargs) return AmazonBedrockEmbeddingFunction(**kwargs)
@staticmethod @staticmethod
def _configure_huggingface(config, model_name): def _configure_huggingface(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.huggingface_embedding_function import ( from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer, HuggingFaceEmbeddingServer,
) )
return HuggingFaceEmbeddingServer( return HuggingFaceEmbeddingServer(
url=config.get("api_url"), url=config.api_url,
) )
@staticmethod @staticmethod
def _configure_watson(config, model_name): def _configure_watson(config: EmbeddingProviderConfig, model_name: str):
try: try:
import ibm_watsonx_ai.foundation_models as watson_models import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials from ibm_watsonx_ai import Credentials
@@ -193,12 +273,10 @@ class EmbeddingConfigurator:
} }
embedding = watson_models.Embeddings( embedding = watson_models.Embeddings(
model_id=config.get("model"), model_id=config.model,
params=embed_params, params=embed_params,
credentials=Credentials( credentials=Credentials(api_key=config.api_key, url=config.api_url),
api_key=config.get("api_key"), url=config.get("api_url") project_id=config.project_id,
),
project_id=config.get("project_id"),
) )
try: try:
@@ -211,8 +289,8 @@ class EmbeddingConfigurator:
return WatsonEmbeddingFunction() return WatsonEmbeddingFunction()
@staticmethod @staticmethod
def _configure_custom(config): def _configure_custom(config: EmbeddingProviderConfig):
custom_embedder = config.get("embedder") custom_embedder = config.embedder
if isinstance(custom_embedder, EmbeddingFunction): if isinstance(custom_embedder, EmbeddingFunction):
try: try:
validate_embedding_function(custom_embedder) validate_embedding_function(custom_embedder)