mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
fixes
This commit is contained in:
@@ -16,7 +16,7 @@ class EntityMemory(Memory):
|
|||||||
if storage
|
if storage
|
||||||
else RAGStorage(
|
else RAGStorage(
|
||||||
type="entities",
|
type="entities",
|
||||||
allow_reset=False,
|
allow_reset=True,
|
||||||
embedder_config=embedder_config,
|
embedder_config=embedder_config,
|
||||||
crew=crew,
|
crew=crew,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
|||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -47,64 +49,64 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
def _set_embedder_config(self):
|
def _set_embedder_config(self):
|
||||||
|
import chromadb.utils.embedding_functions as embedding_functions
|
||||||
|
|
||||||
if self.embedder_config is None:
|
if self.embedder_config is None:
|
||||||
self.embedder_config = self._create_default_embedding_function()
|
self.embedder_config = self._create_default_embedding_function()
|
||||||
|
|
||||||
if isinstance(self.embedder_config, dict):
|
if isinstance(self.embedder_config, dict):
|
||||||
provider = self.embedder_config.get("provider")
|
provider = self.embedder_config.get("provider")
|
||||||
config = self.embedder_config.get("config", {})
|
config = self.embedder_config.get("config", {})
|
||||||
model_name = config.get("model")
|
model_name = config.get("model")
|
||||||
if provider == "openai":
|
if provider == "openai":
|
||||||
import chromadb.utils.embedding_functions as embedding_functions
|
|
||||||
|
|
||||||
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
elif provider == "azure":
|
elif provider == "azure":
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
OpenAIEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = OpenAIEmbeddingFunction(
|
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
api_base=config.get("api_base"),
|
api_base=config.get("api_base"),
|
||||||
api_type=config.get("api_type"),
|
api_type=config.get("api_type", "azure"),
|
||||||
api_version=config.get("api_version"),
|
api_version=config.get("api_version"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
elif provider == "ollama":
|
elif provider == "ollama":
|
||||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
print("using this ollama")
|
||||||
OllamaEmbeddingFunction,
|
from openai import OpenAI
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = OllamaEmbeddingFunction(
|
class OllamaEmbeddingFunction(EmbeddingFunction):
|
||||||
model_name=config.get("model"),
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
url=config.get("url") or "http://localhost:11434",
|
client = OpenAI(
|
||||||
)
|
base_url="http://localhost:11434/v1",
|
||||||
|
api_key=config.get("api_key", "ollama"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = client.embeddings.create(
|
||||||
|
input=input, model=model_name
|
||||||
|
)
|
||||||
|
embeddings = [item.embedding for item in response.data]
|
||||||
|
return cast(Embeddings, embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
self.embedder_config = OllamaEmbeddingFunction()
|
||||||
elif provider == "vertexai":
|
elif provider == "vertexai":
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
self.embedder_config = (
|
||||||
GoogleVertexEmbeddingFunction,
|
embedding_functions.GoogleVertexEmbeddingFunction(
|
||||||
)
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
self.embedder_config = GoogleVertexEmbeddingFunction(
|
)
|
||||||
model_name=model_name,
|
|
||||||
api_key=config.get("api_key"),
|
|
||||||
)
|
)
|
||||||
elif provider == "google":
|
elif provider == "google":
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
self.embedder_config = (
|
||||||
GoogleGenerativeAiEmbeddingFunction,
|
embedding_functions.GoogleGenerativeAiEmbeddingFunction(
|
||||||
)
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
|
)
|
||||||
model_name=model_name,
|
|
||||||
api_key=config.get("api_key"),
|
|
||||||
)
|
)
|
||||||
elif provider == "cohere":
|
elif provider == "cohere":
|
||||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
self.embedder_config = embedding_functions.CohereEmbeddingFunction(
|
||||||
CohereEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedder_config = CohereEmbeddingFunction(
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user