From 377793af423465e8a02a0d208c3c5744cd644254 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 1 Nov 2024 13:02:32 -0400 Subject: [PATCH] clean up for PR --- src/crewai/memory/contextual/contextual_memory.py | 4 ++++ src/crewai/memory/storage/rag_storage.py | 8 ++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index 5d91cf47d..3d3a9c6c1 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -34,6 +34,7 @@ class ContextualMemory: formatted_results = "\n".join( [f"- {result['context']}" for result in stm_results] ) + print("formatted_results stm", formatted_results) return f"Recent Insights:\n{formatted_results}" if stm_results else "" def _fetch_ltm_context(self, task) -> Optional[str]: @@ -53,6 +54,8 @@ class ContextualMemory: formatted_results = list(dict.fromkeys(formatted_results)) formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]") + print("formatted_results ltm", formatted_results) + return f"Historical Data:\n{formatted_results}" if ltm_results else "" def _fetch_entity_context(self, query) -> str: @@ -64,4 +67,5 @@ class ContextualMemory: formatted_results = "\n".join( [f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" ) + print("formatted_results em", formatted_results) return f"Entities:\n{formatted_results}" if em_results else "" diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index f07c7e4f7..d0f1cfc64 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -128,7 +128,7 @@ class RAGStorage(BaseRAGStorage): ) from e class WatsonEmbeddingFunction(EmbeddingFunction): - def __call__(self, input: Documents) -> watson_models.Embeddings: + def __call__(self, input: Documents) -> Embeddings: if isinstance(input, str): input = [input] @@ -147,12 +147,8 @@ class RAGStorage(BaseRAGStorage): ) try: - print("Embedding input:", input) embeddings = embedding.embed_documents(input) - print("Embedding output:", embeddings) - casted = cast(Embeddings, embeddings) - print("Casted:", casted) - return casted + return cast(Embeddings, embeddings) except Exception as e: print("Error during Watson embedding:", e)