* initial knowledge

* WIP

* Adding core knowledge sources

* Improve types and better support for file paths

* added additional sources

* fix linting

* update yaml to include optional deps

* adding in lorenze feedback

* ensure embeddings are persisted

* improvements all around Knowledge class

* return this

* properly reset memory

* properly reset memory+knowledge

* consolodation and improvements

* linted

* cleanup rm unused embedder

* fix test

* fix duplicate

* generating cassettes for knowledge test

* updated default embedder

* None embedder to use default on pipeline cloning

* improvements

* fixed text_file_knowledge

* mypysrc fixes

* type check fixes

* added extra cassette

* just mocks

* linted

* mock knowledge query to not spin up db

* linted

* verbose run

* put a flag

* fix

* adding docs

* better docs

* improvements from review

* more docs

* linted

* rm print

* more fixes

* clearer docs

* added docstrings and type hints for cli

---------

Co-authored-by: João Moura <joaomdmoura@gmail.com>
Co-authored-by: Lorenze Jay <lorenzejaytech@gmail.com>
This commit is contained in:
Brandon Hancock (bhancock_ai)
2024-11-20 18:40:08 -05:00
committed by GitHub
parent fde1ee45f9
commit 14a36d3f5e
37 changed files with 2302 additions and 266 deletions

View File

@@ -4,13 +4,12 @@ import logging
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from crewai.utilities import EmbeddingConfigurator
@contextlib.contextmanager
@@ -51,133 +50,8 @@ class RAGStorage(BaseRAGStorage):
self._initialize_app()
def _set_embedder_config(self):
if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function()
if isinstance(self.embedder_config, dict):
provider = self.embedder_config.get("provider")
config = self.embedder_config.get("config", {})
model_name = config.get("model")
if provider == "openai":
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
elif provider == "azure":
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
)
elif provider == "ollama":
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
self.embedder_config = OllamaEmbeddingFunction(
url=config.get("url", "http://localhost:11434/api/embeddings"),
model_name=model_name,
)
elif provider == "vertexai":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
self.embedder_config = GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "google":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "cohere":
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
self.embedder_config = CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "bedrock":
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
self.embedder_config = AmazonBedrockEmbeddingFunction(
session=config.get("session"),
)
elif provider == "huggingface":
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
self.embedder_config = HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
elif provider == "watson":
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import (
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=config.get("api_key"), url=config.get("api_url")
),
project_id=config.get("project_id"),
)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print("Error during Watson embedding:", e)
raise e
self.embedder_config = WatsonEmbeddingFunction()
else:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]"
)
else:
validate_embedding_function(self.embedder_config)
self.embedder_config = self.embedder_config
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
import chromadb