From 3878daffd68d14ee0379d11c19bef63c4dc02edc Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Fri, 1 Nov 2024 16:42:46 -0400 Subject: [PATCH] Feat/ibm memory (#1549) * Everything looks like its working. Waiting for lorenze review. * Update docs as well. * clean up for PR --- docs/concepts/memory.mdx | 25 ++++++++ .../memory/contextual/contextual_memory.py | 4 ++ src/crewai/memory/storage/rag_storage.py | 62 ++++++++++++++++--- 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/docs/concepts/memory.mdx b/docs/concepts/memory.mdx index 735ed861e..bda9f3401 100644 --- a/docs/concepts/memory.mdx +++ b/docs/concepts/memory.mdx @@ -254,6 +254,31 @@ my_crew = Crew( ) ``` +### Using Watson embeddings + +```python Code +from crewai import Crew, Agent, Task, Process + +# Note: Ensure you have installed and imported `ibm_watsonx_ai` for Watson embeddings to work. + +my_crew = Crew( + agents=[...], + tasks=[...], + process=Process.sequential, + memory=True, + verbose=True, + embedder={ + "provider": "watson", + "config": { + "model": "", + "api_url": "", + "api_key": "", + "project_id": "", + } + } +) +``` + ### Resetting Memory ```shell diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index 5d91cf47d..3d3a9c6c1 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -34,6 +34,7 @@ class ContextualMemory: formatted_results = "\n".join( [f"- {result['context']}" for result in stm_results] ) + print("formatted_results stm", formatted_results) return f"Recent Insights:\n{formatted_results}" if stm_results else "" def _fetch_ltm_context(self, task) -> Optional[str]: @@ -53,6 +54,8 @@ class ContextualMemory: formatted_results = list(dict.fromkeys(formatted_results)) formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]") + print("formatted_results ltm", formatted_results) + return f"Historical Data:\n{formatted_results}" if ltm_results else "" def _fetch_entity_context(self, query) -> str: @@ -64,4 +67,5 @@ class ContextualMemory: formatted_results = "\n".join( [f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" ) + print("formatted_results em", formatted_results) return f"Entities:\n{formatted_results}" if em_results else "" diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index db98c0036..d0f1cfc64 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -4,13 +4,13 @@ import logging import os import shutil import uuid -from typing import Any, Dict, List, Optional -from crewai.memory.storage.base_rag_storage import BaseRAGStorage -from crewai.utilities.paths import db_storage_path +from typing import Any, Dict, List, Optional, cast + +from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api import ClientAPI from chromadb.api.types import validate_embedding_function -from chromadb import Documents, EmbeddingFunction, Embeddings -from typing import cast +from crewai.memory.storage.base_rag_storage import BaseRAGStorage +from crewai.utilities.paths import db_storage_path @contextlib.contextmanager @@ -21,9 +21,11 @@ def suppress_logging( logger = logging.getLogger(logger_name) original_level = logger.getEffectiveLevel() logger.setLevel(level) - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( - io.StringIO() - ), contextlib.suppress(UserWarning): + with ( + contextlib.redirect_stdout(io.StringIO()), + contextlib.redirect_stderr(io.StringIO()), + contextlib.suppress(UserWarning), + ): yield logger.setLevel(original_level) @@ -113,12 +115,52 @@ class RAGStorage(BaseRAGStorage): self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer( url=config.get("api_url"), ) + elif provider == "watson": + try: + import ibm_watsonx_ai.foundation_models as watson_models + from ibm_watsonx_ai import Credentials + from ibm_watsonx_ai.metanames import ( + EmbedTextParamsMetaNames as EmbedParams, + ) + except ImportError as e: + raise ImportError( + "IBM Watson dependencies are not installed. Please install them to use Watson embedding." + ) from e + + class WatsonEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + if isinstance(input, str): + input = [input] + + embed_params = { + EmbedParams.TRUNCATE_INPUT_TOKENS: 3, + EmbedParams.RETURN_OPTIONS: {"input_text": True}, + } + + embedding = watson_models.Embeddings( + model_id=config.get("model"), + params=embed_params, + credentials=Credentials( + api_key=config.get("api_key"), url=config.get("api_url") + ), + project_id=config.get("project_id"), + ) + + try: + embeddings = embedding.embed_documents(input) + return cast(Embeddings, embeddings) + + except Exception as e: + print("Error during Watson embedding:", e) + raise e + + self.embedder_config = WatsonEmbeddingFunction() else: raise Exception( - f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]" + f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]" ) else: - validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class + validate_embedding_function(self.embedder_config) self.embedder_config = self.embedder_config def _initialize_app(self):