Compare commits

...

5 Commits

Author SHA1 Message Date
Brandon Hancock
5cd79b7345 Merge branch 'feat/ibm-memory' of https://github.com/joaomdmoura/crewAI into feat/ibm-memory 2024-11-01 13:02:38 -04:00
Brandon Hancock
377793af42 clean up for PR 2024-11-01 13:02:32 -04:00
Brandon Hancock (bhancock_ai)
56cea8fb93 Merge branch 'main' into feat/ibm-memory 2024-11-01 12:07:45 -04:00
Brandon Hancock
9933d8f880 Update docs as well. 2024-11-01 10:43:47 -04:00
Brandon Hancock
12b0cf6100 Everything looks like its working. Waiting for lorenze review. 2024-11-01 10:36:23 -04:00
3 changed files with 81 additions and 10 deletions

View File

@@ -254,6 +254,31 @@ my_crew = Crew(
) )
``` ```
### Using Watson embeddings
```python Code
from crewai import Crew, Agent, Task, Process
# Note: Ensure you have installed and imported `ibm_watsonx_ai` for Watson embeddings to work.
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "watson",
"config": {
"model": "<model_name>",
"api_url": "<api_url>",
"api_key": "<YOUR_API_KEY>",
"project_id": "<YOUR_PROJECT_ID>",
}
}
)
```
### Resetting Memory ### Resetting Memory
```shell ```shell

View File

@@ -34,6 +34,7 @@ class ContextualMemory:
formatted_results = "\n".join( formatted_results = "\n".join(
[f"- {result['context']}" for result in stm_results] [f"- {result['context']}" for result in stm_results]
) )
print("formatted_results stm", formatted_results)
return f"Recent Insights:\n{formatted_results}" if stm_results else "" return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]: def _fetch_ltm_context(self, task) -> Optional[str]:
@@ -53,6 +54,8 @@ class ContextualMemory:
formatted_results = list(dict.fromkeys(formatted_results)) formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]") formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
print("formatted_results ltm", formatted_results)
return f"Historical Data:\n{formatted_results}" if ltm_results else "" return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str: def _fetch_entity_context(self, query) -> str:
@@ -64,4 +67,5 @@ class ContextualMemory:
formatted_results = "\n".join( formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" [f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
) )
print("formatted_results em", formatted_results)
return f"Entities:\n{formatted_results}" if em_results else "" return f"Entities:\n{formatted_results}" if em_results else ""

View File

@@ -4,13 +4,13 @@ import logging
import os import os
import shutil import shutil
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, cast
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
from chromadb import Documents, EmbeddingFunction, Embeddings from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from typing import cast from crewai.utilities.paths import db_storage_path
@contextlib.contextmanager @contextlib.contextmanager
@@ -21,9 +21,11 @@ def suppress_logging(
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel() original_level = logger.getEffectiveLevel()
logger.setLevel(level) logger.setLevel(level)
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( with (
io.StringIO() contextlib.redirect_stdout(io.StringIO()),
), contextlib.suppress(UserWarning): contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield yield
logger.setLevel(original_level) logger.setLevel(original_level)
@@ -113,12 +115,52 @@ class RAGStorage(BaseRAGStorage):
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer( self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
url=config.get("api_url"), url=config.get("api_url"),
) )
elif provider == "watson":
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import (
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=config.get("api_key"), url=config.get("api_url")
),
project_id=config.get("project_id"),
)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print("Error during Watson embedding:", e)
raise e
self.embedder_config = WatsonEmbeddingFunction()
else: else:
raise Exception( raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]" f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]"
) )
else: else:
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class validate_embedding_function(self.embedder_config)
self.embedder_config = self.embedder_config self.embedder_config = self.embedder_config
def _initialize_app(self): def _initialize_app(self):