mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
consolodation and improvements
This commit is contained in:
@@ -121,7 +121,6 @@ class Agent(BaseAgent):
|
|||||||
default="safe",
|
default="safe",
|
||||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||||
)
|
)
|
||||||
# TODO: Lorenze add knowledge_embedder. Support direct class or config dict.
|
|
||||||
_knowledge: Optional[Knowledge] = PrivateAttr(default=None)
|
_knowledge: Optional[Knowledge] = PrivateAttr(default=None)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from .base_embedder import BaseEmbedder
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaEmbedder(BaseEmbedder):
|
|
||||||
"""
|
|
||||||
A wrapper class for text embedding models using Ollama's API
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
base_url: str = "http://localhost:11434/v1",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the embedding model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the model to use
|
|
||||||
api_key: API key (defaults to 'ollama' or environment variable 'OLLAMA_API_KEY')
|
|
||||||
base_url: Base URL for the Ollama API (default is 'http://localhost:11434/v1')
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
|
||||||
self.api_key = api_key or os.getenv("OLLAMA_API_KEY") or "ollama"
|
|
||||||
self.base_url = base_url
|
|
||||||
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
||||||
|
|
||||||
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for a list of text chunks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunks: List of text chunks to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embeddings
|
|
||||||
"""
|
|
||||||
return self.embed_texts(chunks)
|
|
||||||
|
|
||||||
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for a list of texts
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embeddings
|
|
||||||
"""
|
|
||||||
embeddings = []
|
|
||||||
max_batch_size = 2048 # Adjust batch size if necessary
|
|
||||||
for i in range(0, len(texts), max_batch_size):
|
|
||||||
batch = texts[i : i + max_batch_size]
|
|
||||||
response = self.client.embeddings.create(input=batch, model=self.model_name)
|
|
||||||
batch_embeddings = [np.array(item.embedding) for item in response.data]
|
|
||||||
embeddings.extend(batch_embeddings)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def embed_text(self, text: str) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Generate embedding for a single text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Embedding array
|
|
||||||
"""
|
|
||||||
return self.embed_texts([text])[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dimension(self) -> int:
|
|
||||||
"""Get the dimension of the embeddings"""
|
|
||||||
# Embedding dimensions may vary; we'll determine it dynamically
|
|
||||||
test_embed = self.embed_text("test")
|
|
||||||
return len(test_embed)
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from .base_embedder import BaseEmbedder
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEmbedder(BaseEmbedder):
|
|
||||||
"""
|
|
||||||
A wrapper class for text embedding models using OpenAI's Embedding API
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str = "text-embedding-ada-002",
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the embedding model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the model to use
|
|
||||||
api_key: OpenAI API key
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError(
|
|
||||||
"OpenAI API key must be provided or set in the environment variable 'OPENAI_API_KEY'"
|
|
||||||
)
|
|
||||||
self.client = OpenAI(
|
|
||||||
api_key=self.api_key,
|
|
||||||
base_url="http://localhost:11434/v1",
|
|
||||||
)
|
|
||||||
|
|
||||||
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for a list of text chunks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunks: List of text chunks to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embeddings
|
|
||||||
"""
|
|
||||||
return self.embed_texts(chunks)
|
|
||||||
|
|
||||||
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for a list of texts
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embeddings
|
|
||||||
"""
|
|
||||||
embeddings = []
|
|
||||||
max_batch_size = 2048 # OpenAI recommends smaller batch sizes
|
|
||||||
for i in range(0, len(texts), max_batch_size):
|
|
||||||
batch = texts[i : i + max_batch_size]
|
|
||||||
response = self.client.embeddings.create(input=batch, model=self.model_name)
|
|
||||||
batch_embeddings = [np.array(data.embedding) for data in response.data]
|
|
||||||
embeddings.extend(batch_embeddings)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def embed_text(self, text: str) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Generate embedding fors a single text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Embedding array
|
|
||||||
"""
|
|
||||||
return self.embed_texts([text])[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dimension(self) -> int:
|
|
||||||
"""Get the dimension of the embeddings"""
|
|
||||||
# For OpenAI's text-embedding-ada-002, the dimension is 1536
|
|
||||||
return 1536
|
|
||||||
@@ -15,12 +15,13 @@ class Knowledge(BaseModel):
|
|||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
agents: List[str] = Field(default_factory=list)
|
agents: List[str] = Field(default_factory=list)
|
||||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||||
|
embedder_config: Optional[Dict[str, Any]] = Field(default_factory=None)
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
# Call add on all sources during initialization
|
embedder_config = data.get("embedder_config", None)
|
||||||
for source in self.sources:
|
if embedder_config:
|
||||||
source.add(self.embedder)
|
self.storage = KnowledgeStorage(embedder_config=embedder_config)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, query: List[str], limit: int = 3, preference: Optional[str] = None
|
self, query: List[str], limit: int = 3, preference: Optional[str] = None
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from typing import List, Dict, Any
|
from typing import List
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from crewai.knowledge.embedder.base_embedder import BaseEmbedder
|
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
|
|
||||||
|
|
||||||
@@ -20,14 +19,10 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
|||||||
if not isinstance(self.content, str):
|
if not isinstance(self.content, str):
|
||||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||||
|
|
||||||
def add(self, embedder: BaseEmbedder) -> None:
|
def add(self) -> None:
|
||||||
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them."""
|
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them."""
|
||||||
new_chunks = self._chunk_text(self.content)
|
new_chunks = self._chunk_text(self.content)
|
||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
# Compute embeddings for the new chunks
|
|
||||||
new_embeddings = embedder.embed_chunks(new_chunks)
|
|
||||||
# Save the embeddings
|
|
||||||
self.chunk_embeddings.extend(new_embeddings)
|
|
||||||
self.save_documents(metadata=self.metadata)
|
self.save_documents(metadata=self.metadata)
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> List[str]:
|
def _chunk_text(self, text: str) -> List[str]:
|
||||||
|
|||||||
@@ -13,17 +13,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
|||||||
with self.file_path.open("r", encoding="utf-8") as f:
|
with self.file_path.open("r", encoding="utf-8") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
def add(self, embedder: BaseEmbedder) -> None:
|
def add(self) -> None:
|
||||||
"""
|
"""
|
||||||
Add text file content to the knowledge source, chunk it, compute embeddings,
|
Add text file content to the knowledge source, chunk it, compute embeddings,
|
||||||
and save the embeddings.
|
and save the embeddings.
|
||||||
"""
|
"""
|
||||||
new_chunks = self._chunk_text(self.content)
|
new_chunks = self._chunk_text(self.content)
|
||||||
self.chunks.extend(new_chunks)
|
self.chunks.extend(new_chunks)
|
||||||
# Compute embeddings for the new chunks
|
|
||||||
new_embeddings = embedder.embed_chunks(new_chunks)
|
|
||||||
# Save the embeddings
|
|
||||||
self.chunk_embeddings.extend(new_embeddings)
|
|
||||||
self.save_documents(metadata=self.metadata)
|
self.save_documents(metadata=self.metadata)
|
||||||
|
|
||||||
def _chunk_text(self, text: str) -> List[str]:
|
def _chunk_text(self, text: str) -> List[str]:
|
||||||
|
|||||||
29
src/crewai/knowledge/storage/base_knowledge_storage.py
Normal file
29
src/crewai/knowledge/storage/base_knowledge_storage.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BaseKnowledgeStorage(ABC):
|
||||||
|
"""Abstract base class for knowledge storage implementations."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: List[str],
|
||||||
|
limit: int = 3,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
|
score_threshold: float = 0.35,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Search for documents in the knowledge base."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
"""Save documents to the knowledge base."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the knowledge base."""
|
||||||
|
pass
|
||||||
@@ -7,6 +7,8 @@ import chromadb
|
|||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
|
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -26,7 +28,7 @@ def suppress_logging(
|
|||||||
logger.setLevel(original_level)
|
logger.setLevel(original_level)
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeStorage:
|
class KnowledgeStorage(BaseKnowledgeStorage):
|
||||||
"""
|
"""
|
||||||
Extends Storage to handle embeddings for memory entries, improving
|
Extends Storage to handle embeddings for memory entries, improving
|
||||||
search efficiency.
|
search efficiency.
|
||||||
@@ -35,10 +37,7 @@ class KnowledgeStorage:
|
|||||||
collection: Optional[chromadb.Collection] = None
|
collection: Optional[chromadb.Collection] = None
|
||||||
|
|
||||||
def __init__(self, embedder_config=None):
|
def __init__(self, embedder_config=None):
|
||||||
self.embedder_config = (
|
self._initialize_app(embedder_config or {})
|
||||||
embedder_config or self._create_default_embedding_function()
|
|
||||||
)
|
|
||||||
self._initialize_app()
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -54,7 +53,6 @@ class KnowledgeStorage:
|
|||||||
n_results=limit,
|
n_results=limit,
|
||||||
where=filter,
|
where=filter,
|
||||||
)
|
)
|
||||||
print("Fetched", fetched)
|
|
||||||
results = []
|
results = []
|
||||||
for i in range(len(fetched["ids"][0])):
|
for i in range(len(fetched["ids"][0])):
|
||||||
result = {
|
result = {
|
||||||
@@ -69,10 +67,12 @@ class KnowledgeStorage:
|
|||||||
else:
|
else:
|
||||||
raise Exception("Collection not initialized")
|
raise Exception("Collection not initialized")
|
||||||
|
|
||||||
def _initialize_app(self):
|
def _initialize_app(self, embedder_config: Optional[Dict[str, Any]] = None):
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
|
|
||||||
|
self._set_embedder_config(embedder_config)
|
||||||
|
|
||||||
chroma_client = chromadb.PersistentClient(
|
chroma_client = chromadb.PersistentClient(
|
||||||
path=f"{db_storage_path()}/knowledge",
|
path=f"{db_storage_path()}/knowledge",
|
||||||
settings=Settings(allow_reset=True),
|
settings=Settings(allow_reset=True),
|
||||||
@@ -107,3 +107,18 @@ class KnowledgeStorage:
|
|||||||
from crewai.knowledge.embedder.fastembed import FastEmbed
|
from crewai.knowledge.embedder.fastembed import FastEmbed
|
||||||
|
|
||||||
return FastEmbed().embed_texts
|
return FastEmbed().embed_texts
|
||||||
|
|
||||||
|
def _set_embedder_config(
|
||||||
|
self, embedder_config: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
|
"""Set the embedding configuration for the knowledge storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
|
||||||
|
If None or empty, defaults to the default embedding function.
|
||||||
|
"""
|
||||||
|
self.embedder_config = (
|
||||||
|
EmbeddingConfigurator().configure_embedder(embedder_config)
|
||||||
|
if embedder_config
|
||||||
|
else self._create_default_embedding_function()
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,13 +4,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, cast
|
|
||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from typing import Any, Dict, List, Optional
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
from chromadb.api.types import validate_embedding_function
|
|
||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -51,133 +50,8 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
def _set_embedder_config(self):
|
def _set_embedder_config(self):
|
||||||
if self.embedder_config is None:
|
configurator = EmbeddingConfigurator()
|
||||||
self.embedder_config = self._create_default_embedding_function()
|
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||||
|
|
||||||
if isinstance(self.embedder_config, dict):
|
|
||||||
provider = self.embedder_config.get("provider")
|
|
||||||
config = self.embedder_config.get("config", {})
|
|
||||||
model_name = config.get("model")
|
|
||||||
if provider == "openai":
|
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
||||||
OpenAIEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = OpenAIEmbeddingFunction(
|
|
||||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
elif provider == "azure":
|
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
||||||
OpenAIEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = OpenAIEmbeddingFunction(
|
|
||||||
api_key=config.get("api_key"),
|
|
||||||
api_base=config.get("api_base"),
|
|
||||||
api_type=config.get("api_type", "azure"),
|
|
||||||
api_version=config.get("api_version"),
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
elif provider == "ollama":
|
|
||||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
|
||||||
OllamaEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = OllamaEmbeddingFunction(
|
|
||||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
elif provider == "vertexai":
|
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
|
||||||
GoogleVertexEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = GoogleVertexEmbeddingFunction(
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=config.get("api_key"),
|
|
||||||
)
|
|
||||||
elif provider == "google":
|
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
|
||||||
GoogleGenerativeAiEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=config.get("api_key"),
|
|
||||||
)
|
|
||||||
elif provider == "cohere":
|
|
||||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
|
||||||
CohereEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = CohereEmbeddingFunction(
|
|
||||||
model_name=model_name,
|
|
||||||
api_key=config.get("api_key"),
|
|
||||||
)
|
|
||||||
elif provider == "bedrock":
|
|
||||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
|
||||||
AmazonBedrockEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = AmazonBedrockEmbeddingFunction(
|
|
||||||
session=config.get("session"),
|
|
||||||
)
|
|
||||||
elif provider == "huggingface":
|
|
||||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
|
||||||
HuggingFaceEmbeddingServer,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = HuggingFaceEmbeddingServer(
|
|
||||||
url=config.get("api_url"),
|
|
||||||
)
|
|
||||||
elif provider == "watson":
|
|
||||||
try:
|
|
||||||
import ibm_watsonx_ai.foundation_models as watson_models
|
|
||||||
from ibm_watsonx_ai import Credentials
|
|
||||||
from ibm_watsonx_ai.metanames import (
|
|
||||||
EmbedTextParamsMetaNames as EmbedParams,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
|
||||||
def __call__(self, input: Documents) -> Embeddings:
|
|
||||||
if isinstance(input, str):
|
|
||||||
input = [input]
|
|
||||||
|
|
||||||
embed_params = {
|
|
||||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
|
||||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
|
||||||
}
|
|
||||||
|
|
||||||
embedding = watson_models.Embeddings(
|
|
||||||
model_id=config.get("model"),
|
|
||||||
params=embed_params,
|
|
||||||
credentials=Credentials(
|
|
||||||
api_key=config.get("api_key"), url=config.get("api_url")
|
|
||||||
),
|
|
||||||
project_id=config.get("project_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
embeddings = embedding.embed_documents(input)
|
|
||||||
return cast(Embeddings, embeddings)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print("Error during Watson embedding:", e)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
self.embedder_config = WatsonEmbeddingFunction()
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
validate_embedding_function(self.embedder_config)
|
|
||||||
self.embedder_config = self.embedder_config
|
|
||||||
|
|
||||||
def _initialize_app(self):
|
def _initialize_app(self):
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from .rpm_controller import RPMController
|
|||||||
from .exceptions.context_window_exceeding_exception import (
|
from .exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
)
|
)
|
||||||
|
from .embedding_configurator import EmbeddingConfigurator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Converter",
|
"Converter",
|
||||||
@@ -23,4 +24,5 @@ __all__ = [
|
|||||||
"RPMController",
|
"RPMController",
|
||||||
"YamlParser",
|
"YamlParser",
|
||||||
"LLMContextLengthExceededException",
|
"LLMContextLengthExceededException",
|
||||||
|
"EmbeddingConfigurator",
|
||||||
]
|
]
|
||||||
|
|||||||
184
src/crewai/utilities/embedding_configurator.py
Normal file
184
src/crewai/utilities/embedding_configurator.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Dict, cast
|
||||||
|
from chromadb import EmbeddingFunction, Documents, Embeddings
|
||||||
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingConfigurator:
|
||||||
|
def __init__(self):
|
||||||
|
self.embedding_functions = {
|
||||||
|
"openai": self._configure_openai,
|
||||||
|
"azure": self._configure_azure,
|
||||||
|
"ollama": self._configure_ollama,
|
||||||
|
"vertexai": self._configure_vertexai,
|
||||||
|
"google": self._configure_google,
|
||||||
|
"cohere": self._configure_cohere,
|
||||||
|
"bedrock": self._configure_bedrock,
|
||||||
|
"huggingface": self._configure_huggingface,
|
||||||
|
"watson": self._configure_watson,
|
||||||
|
}
|
||||||
|
|
||||||
|
def configure_embedder(
|
||||||
|
self,
|
||||||
|
embedder_config: Dict[str, Any] | None = None,
|
||||||
|
) -> EmbeddingFunction:
|
||||||
|
"""Configures and returns an embedding function based on the provided config."""
|
||||||
|
if embedder_config is None:
|
||||||
|
return self._create_default_embedding_function()
|
||||||
|
|
||||||
|
provider = embedder_config.get("provider")
|
||||||
|
config = embedder_config.get("config", {})
|
||||||
|
model_name = config.get("model")
|
||||||
|
|
||||||
|
if isinstance(provider, EmbeddingFunction):
|
||||||
|
try:
|
||||||
|
validate_embedding_function(provider)
|
||||||
|
print("Valid custom embedding function", provider, config)
|
||||||
|
return provider
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||||
|
|
||||||
|
if provider not in self.embedding_functions:
|
||||||
|
raise Exception(
|
||||||
|
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.embedding_functions[provider](config, model_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_default_embedding_function():
|
||||||
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
|
OpenAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_openai(config, model_name):
|
||||||
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
|
OpenAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIEmbeddingFunction(
|
||||||
|
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_azure(config, model_name):
|
||||||
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
|
OpenAIEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIEmbeddingFunction(
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
api_base=config.get("api_base"),
|
||||||
|
api_type=config.get("api_type", "azure"),
|
||||||
|
api_version=config.get("api_version"),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_ollama(config, model_name):
|
||||||
|
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||||
|
OllamaEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OllamaEmbeddingFunction(
|
||||||
|
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_vertexai(config, model_name):
|
||||||
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
|
GoogleVertexEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
return GoogleVertexEmbeddingFunction(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_google(config, model_name):
|
||||||
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
|
GoogleGenerativeAiEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
return GoogleGenerativeAiEmbeddingFunction(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_cohere(config, model_name):
|
||||||
|
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||||
|
CohereEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CohereEmbeddingFunction(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_bedrock(config, model_name):
|
||||||
|
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||||
|
AmazonBedrockEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AmazonBedrockEmbeddingFunction(
|
||||||
|
session=config.get("session"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_huggingface(config, model_name):
|
||||||
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||||
|
HuggingFaceEmbeddingServer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return HuggingFaceEmbeddingServer(
|
||||||
|
url=config.get("api_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_watson(config, model_name):
|
||||||
|
try:
|
||||||
|
import ibm_watsonx_ai.foundation_models as watson_models
|
||||||
|
from ibm_watsonx_ai import Credentials
|
||||||
|
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||||
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
|
||||||
|
embed_params = {
|
||||||
|
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||||
|
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding = watson_models.Embeddings(
|
||||||
|
model_id=config.get("model"),
|
||||||
|
params=embed_params,
|
||||||
|
credentials=Credentials(
|
||||||
|
api_key=config.get("api_key"), url=config.get("api_url")
|
||||||
|
),
|
||||||
|
project_id=config.get("project_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = embedding.embed_documents(input)
|
||||||
|
return cast(Embeddings, embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error during Watson embedding:", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return WatsonEmbeddingFunction()
|
||||||
Reference in New Issue
Block a user