mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-18 21:38:29 +00:00
Compare commits
6 Commits
bugfix-pyt
...
pr-2174
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21f4b60754 | ||
|
|
216ff4aa6f | ||
|
|
6f849c0e6d | ||
|
|
276f661e6c | ||
|
|
8f99caf61b | ||
|
|
f4642f11cc |
@@ -20,6 +20,7 @@ from crewai.tools.agent_tools.agent_tools import AgentTools
|
|||||||
from crewai.utilities import Converter, Prompts
|
from crewai.utilities import Converter, Prompts
|
||||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import generate_model_description
|
||||||
|
from crewai.utilities.embedding_configurator import EmbeddingConfig
|
||||||
from crewai.utilities.events.agent_events import (
|
from crewai.utilities.events.agent_events import (
|
||||||
AgentExecutionCompletedEvent,
|
AgentExecutionCompletedEvent,
|
||||||
AgentExecutionErrorEvent,
|
AgentExecutionErrorEvent,
|
||||||
@@ -108,7 +109,7 @@ class Agent(BaseAgent):
|
|||||||
default="safe",
|
default="safe",
|
||||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||||
)
|
)
|
||||||
embedder: Optional[Dict[str, Any]] = Field(
|
embedder: Optional[EmbeddingConfig] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Embedder configuration for the agent.",
|
description="Embedder configuration for the agent.",
|
||||||
)
|
)
|
||||||
@@ -134,7 +135,7 @@ class Agent(BaseAgent):
|
|||||||
self.cache_handler = CacheHandler()
|
self.cache_handler = CacheHandler()
|
||||||
self.set_cache_handler(self.cache_handler)
|
self.set_cache_handler(self.cache_handler)
|
||||||
|
|
||||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
def set_knowledge(self, crew_embedder: Optional[EmbeddingConfig] = None):
|
||||||
try:
|
try:
|
||||||
if self.embedder is None and crew_embedder:
|
if self.embedder is None and crew_embedder:
|
||||||
self.embedder = crew_embedder
|
self.embedder = crew_embedder
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from crewai.tools.base_tool import BaseTool, Tool
|
|||||||
from crewai.utilities import I18N, Logger, RPMController
|
from crewai.utilities import I18N, Logger, RPMController
|
||||||
from crewai.utilities.config import process_config
|
from crewai.utilities.config import process_config
|
||||||
from crewai.utilities.converter import Converter
|
from crewai.utilities.converter import Converter
|
||||||
|
from crewai.utilities.embedding_configurator import EmbeddingConfig
|
||||||
|
|
||||||
T = TypeVar("T", bound="BaseAgent")
|
T = TypeVar("T", bound="BaseAgent")
|
||||||
|
|
||||||
@@ -362,5 +363,5 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
self._rpm_controller = rpm_controller
|
self._rpm_controller = rpm_controller
|
||||||
self.create_agent_executor()
|
self.create_agent_executor()
|
||||||
|
|
||||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
def set_knowledge(self, crew_embedder: Optional[EmbeddingConfig] = None):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from crewai.tools.base_tool import Tool
|
|||||||
from crewai.types.usage_metrics import UsageMetrics
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
|
from crewai.utilities.embedding_configurator import EmbeddingConfig
|
||||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||||
from crewai.utilities.events.crew_events import (
|
from crewai.utilities.events.crew_events import (
|
||||||
@@ -145,7 +146,7 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.",
|
description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.",
|
||||||
)
|
)
|
||||||
embedder: Optional[dict] = Field(
|
embedder: Optional[EmbeddingConfig] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Configuration for the embedder to be used for the crew.",
|
description="Configuration for the embedder to be used for the crew.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
|
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||||
|
from crewai.utilities.embedding_configurator import EmbeddingConfig
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||||
|
|
||||||
@@ -21,14 +22,14 @@ class Knowledge(BaseModel):
|
|||||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||||
embedder: Optional[Dict[str, Any]] = None
|
embedder: Optional[EmbeddingConfig] = None
|
||||||
collection_name: Optional[str] = None
|
collection_name: Optional[str] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
sources: List[BaseKnowledgeSource],
|
sources: List[BaseKnowledgeSource],
|
||||||
embedder: Optional[Dict[str, Any]] = None,
|
embedder: Optional[EmbeddingConfig] = None,
|
||||||
storage: Optional[KnowledgeStorage] = None,
|
storage: Optional[KnowledgeStorage] = None,
|
||||||
**data,
|
**data,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from chromadb.config import Settings
|
|||||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||||
|
from crewai.utilities.embedding_configurator import EmbeddingConfig
|
||||||
from crewai.utilities.logger import Logger
|
from crewai.utilities.logger import Logger
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
@@ -48,7 +49,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedder: Optional[Dict[str, Any]] = None,
|
embedder: Optional[EmbeddingConfig] = None,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
@@ -187,7 +188,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
|
def _set_embedder_config(self, embedder: Optional[EmbeddingConfig] = None) -> None:
|
||||||
"""Set the embedding configuration for the knowledge storage.
|
"""Set the embedding configuration for the knowledge storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,8 +1,84 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional, cast
|
from typing import Any, Callable, Literal, cast
|
||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProviderConfig(BaseModel):
|
||||||
|
"""Configuration model for embedding providers.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
# Core Model Configuration
|
||||||
|
model (str | None): The model identifier for embeddings, used across multiple providers
|
||||||
|
like OpenAI, Azure, Watson, etc.
|
||||||
|
embedder (str | Callable | None): Custom embedding function or callable for custom
|
||||||
|
embedding implementations.
|
||||||
|
|
||||||
|
# API Authentication & Configuration
|
||||||
|
api_key (str | None): Authentication key for various providers (OpenAI, VertexAI,
|
||||||
|
Google, Cohere, VoyageAI, Watson).
|
||||||
|
api_base (str | None): Base API URL override for OpenAI and Azure services.
|
||||||
|
api_type (str | None): API type specification, particularly used for Azure configuration.
|
||||||
|
api_version (str | None): API version for OpenAI and Azure services.
|
||||||
|
api_url (str | None): API endpoint URL, used by HuggingFace and Watson services.
|
||||||
|
url (str | None): Base URL for the embedding service, primarily used for Ollama and
|
||||||
|
HuggingFace endpoints.
|
||||||
|
|
||||||
|
# Service-Specific Configuration
|
||||||
|
project_id (str | None): Project identifier used by VertexAI and Watson services.
|
||||||
|
organization_id (str | None): Organization identifier for OpenAI and Azure services.
|
||||||
|
deployment_id (str | None): Deployment identifier for OpenAI and Azure services.
|
||||||
|
region (str | None): Geographic region for VertexAI services.
|
||||||
|
session (str | None): Session configuration for Amazon Bedrock embeddings.
|
||||||
|
|
||||||
|
# Request Configuration
|
||||||
|
task_type (str | None): Specifies the task type for Google Generative AI embeddings.
|
||||||
|
default_headers (str | None): Custom headers for OpenAI and Azure API requests.
|
||||||
|
dimensions (str | None): Output dimensions specification for OpenAI and Azure embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Core Model Configuration
|
||||||
|
model: str | None = None
|
||||||
|
embedder: str | Callable | None = None
|
||||||
|
|
||||||
|
# API Authentication & Configuration
|
||||||
|
api_key: str | None = None
|
||||||
|
api_base: str | None = None
|
||||||
|
api_type: str | None = None
|
||||||
|
api_version: str | None = None
|
||||||
|
api_url: str | None = None
|
||||||
|
url: str | None = None
|
||||||
|
|
||||||
|
# Service-Specific Configuration
|
||||||
|
project_id: str | None = None
|
||||||
|
organization_id: str | None = None
|
||||||
|
deployment_id: str | None = None
|
||||||
|
region: str | None = None
|
||||||
|
session: str | None = None
|
||||||
|
|
||||||
|
# Request Configuration
|
||||||
|
task_type: str | None = None
|
||||||
|
default_headers: str | None = None
|
||||||
|
dimensions: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingConfig(BaseModel):
|
||||||
|
provider: Literal[
|
||||||
|
"openai",
|
||||||
|
"azure",
|
||||||
|
"ollama",
|
||||||
|
"vertexai",
|
||||||
|
"google",
|
||||||
|
"cohere",
|
||||||
|
"voyageai",
|
||||||
|
"bedrock",
|
||||||
|
"huggingface",
|
||||||
|
"watson",
|
||||||
|
"custom",
|
||||||
|
]
|
||||||
|
config: EmbeddingProviderConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfigurator:
|
class EmbeddingConfigurator:
|
||||||
@@ -23,15 +99,19 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
def configure_embedder(
|
def configure_embedder(
|
||||||
self,
|
self,
|
||||||
embedder_config: Optional[Dict[str, Any]] = None,
|
embedder_config: EmbeddingConfig | None = None,
|
||||||
) -> EmbeddingFunction:
|
) -> EmbeddingFunction:
|
||||||
"""Configures and returns an embedding function based on the provided config."""
|
"""Configures and returns an embedding function based on the provided config."""
|
||||||
if embedder_config is None:
|
if embedder_config is None:
|
||||||
return self._create_default_embedding_function()
|
return self._create_default_embedding_function()
|
||||||
|
|
||||||
provider = embedder_config.get("provider")
|
provider = embedder_config.provider
|
||||||
config = embedder_config.get("config", {})
|
config = (
|
||||||
model_name = config.get("model") if provider != "custom" else None
|
embedder_config.config
|
||||||
|
if embedder_config.config
|
||||||
|
else EmbeddingProviderConfig()
|
||||||
|
)
|
||||||
|
model_name = config.model if provider != "custom" else None
|
||||||
|
|
||||||
if provider not in self.embedding_functions:
|
if provider not in self.embedding_functions:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@@ -56,123 +136,123 @@ class EmbeddingConfigurator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_openai(config, model_name):
|
def _configure_openai(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
api_key=config.api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_base=config.get("api_base", None),
|
api_base=config.api_base,
|
||||||
api_type=config.get("api_type", None),
|
api_type=config.api_type,
|
||||||
api_version=config.get("api_version", None),
|
api_version=config.api_version,
|
||||||
default_headers=config.get("default_headers", None),
|
default_headers=config.default_headers,
|
||||||
dimensions=config.get("dimensions", None),
|
dimensions=config.dimensions,
|
||||||
deployment_id=config.get("deployment_id", None),
|
deployment_id=config.deployment_id,
|
||||||
organization_id=config.get("organization_id", None),
|
organization_id=config.organization_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_azure(config, model_name):
|
def _configure_azure(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
api_base=config.get("api_base"),
|
api_base=config.api_base,
|
||||||
api_type=config.get("api_type", "azure"),
|
api_type=config.api_type if config.api_type else "azure",
|
||||||
api_version=config.get("api_version"),
|
api_version=config.api_version,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
default_headers=config.get("default_headers"),
|
default_headers=config.default_headers,
|
||||||
dimensions=config.get("dimensions"),
|
dimensions=config.dimensions,
|
||||||
deployment_id=config.get("deployment_id"),
|
deployment_id=config.deployment_id,
|
||||||
organization_id=config.get("organization_id"),
|
organization_id=config.organization_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_ollama(config, model_name):
|
def _configure_ollama(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||||
OllamaEmbeddingFunction,
|
OllamaEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return OllamaEmbeddingFunction(
|
return OllamaEmbeddingFunction(
|
||||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
url=config.url if config.url else "http://localhost:11434/api/embeddings",
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_vertexai(config, model_name):
|
def _configure_vertexai(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
GoogleVertexEmbeddingFunction,
|
GoogleVertexEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return GoogleVertexEmbeddingFunction(
|
return GoogleVertexEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
project_id=config.get("project_id"),
|
project_id=config.project_id,
|
||||||
region=config.get("region"),
|
region=config.region,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_google(config, model_name):
|
def _configure_google(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
GoogleGenerativeAiEmbeddingFunction,
|
GoogleGenerativeAiEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return GoogleGenerativeAiEmbeddingFunction(
|
return GoogleGenerativeAiEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
task_type=config.get("task_type"),
|
task_type=config.task_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_cohere(config, model_name):
|
def _configure_cohere(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||||
CohereEmbeddingFunction,
|
CohereEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CohereEmbeddingFunction(
|
return CohereEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_voyageai(config, model_name):
|
def _configure_voyageai(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||||
VoyageAIEmbeddingFunction,
|
VoyageAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return VoyageAIEmbeddingFunction(
|
return VoyageAIEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_bedrock(config, model_name):
|
def _configure_bedrock(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||||
AmazonBedrockEmbeddingFunction,
|
AmazonBedrockEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Allow custom model_name override with backwards compatibility
|
# Allow custom model_name override with backwards compatibility
|
||||||
kwargs = {"session": config.get("session")}
|
kwargs = {"session": config.session}
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
kwargs["model_name"] = model_name
|
kwargs["model_name"] = model_name
|
||||||
return AmazonBedrockEmbeddingFunction(**kwargs)
|
return AmazonBedrockEmbeddingFunction(**kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_huggingface(config, model_name):
|
def _configure_huggingface(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||||
HuggingFaceEmbeddingServer,
|
HuggingFaceEmbeddingServer,
|
||||||
)
|
)
|
||||||
|
|
||||||
return HuggingFaceEmbeddingServer(
|
return HuggingFaceEmbeddingServer(
|
||||||
url=config.get("api_url"),
|
url=config.api_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_watson(config, model_name):
|
def _configure_watson(config: EmbeddingProviderConfig, model_name: str):
|
||||||
try:
|
try:
|
||||||
import ibm_watsonx_ai.foundation_models as watson_models
|
import ibm_watsonx_ai.foundation_models as watson_models
|
||||||
from ibm_watsonx_ai import Credentials
|
from ibm_watsonx_ai import Credentials
|
||||||
@@ -193,12 +273,10 @@ class EmbeddingConfigurator:
|
|||||||
}
|
}
|
||||||
|
|
||||||
embedding = watson_models.Embeddings(
|
embedding = watson_models.Embeddings(
|
||||||
model_id=config.get("model"),
|
model_id=config.model,
|
||||||
params=embed_params,
|
params=embed_params,
|
||||||
credentials=Credentials(
|
credentials=Credentials(api_key=config.api_key, url=config.api_url),
|
||||||
api_key=config.get("api_key"), url=config.get("api_url")
|
project_id=config.project_id,
|
||||||
),
|
|
||||||
project_id=config.get("project_id"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -211,8 +289,8 @@ class EmbeddingConfigurator:
|
|||||||
return WatsonEmbeddingFunction()
|
return WatsonEmbeddingFunction()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_custom(config):
|
def _configure_custom(config: EmbeddingProviderConfig):
|
||||||
custom_embedder = config.get("embedder")
|
custom_embedder = config.embedder
|
||||||
if isinstance(custom_embedder, EmbeddingFunction):
|
if isinstance(custom_embedder, EmbeddingFunction):
|
||||||
try:
|
try:
|
||||||
validate_embedding_function(custom_embedder)
|
validate_embedding_function(custom_embedder)
|
||||||
|
|||||||
Reference in New Issue
Block a user