Compare commits

...

5 Commits

Author SHA1 Message Date
Brandon Hancock
5cd79b7345 Merge branch 'feat/ibm-memory' of https://github.com/joaomdmoura/crewAI into feat/ibm-memory 2024-11-01 13:02:38 -04:00
Brandon Hancock
377793af42 clean up for PR 2024-11-01 13:02:32 -04:00
Brandon Hancock (bhancock_ai)
56cea8fb93 Merge branch 'main' into feat/ibm-memory 2024-11-01 12:07:45 -04:00
Brandon Hancock
9933d8f880 Update docs as well. 2024-11-01 10:43:47 -04:00
Brandon Hancock
12b0cf6100 Everything looks like its working. Waiting for lorenze review. 2024-11-01 10:36:23 -04:00
3 changed files with 81 additions and 10 deletions

View File

@@ -254,6 +254,31 @@ my_crew = Crew(
)
```
### Using Watson embeddings
```python Code
from crewai import Crew, Agent, Task, Process
# Note: Ensure you have installed and imported `ibm_watsonx_ai` for Watson embeddings to work.
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "watson",
"config": {
"model": "<model_name>",
"api_url": "<api_url>",
"api_key": "<YOUR_API_KEY>",
"project_id": "<YOUR_PROJECT_ID>",
}
}
)
```
### Resetting Memory
```shell

View File

@@ -34,6 +34,7 @@ class ContextualMemory:
formatted_results = "\n".join(
[f"- {result['context']}" for result in stm_results]
)
print("formatted_results stm", formatted_results)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]:
@@ -53,6 +54,8 @@ class ContextualMemory:
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
print("formatted_results ltm", formatted_results)
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:
@@ -64,4 +67,5 @@ class ContextualMemory:
formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
)
print("formatted_results em", formatted_results)
return f"Entities:\n{formatted_results}" if em_results else ""

View File

@@ -4,13 +4,13 @@ import logging
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from typing import Any, Dict, List, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import cast
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
@contextlib.contextmanager
@@ -21,9 +21,11 @@ def suppress_logging(
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
io.StringIO()
), contextlib.suppress(UserWarning):
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)
@@ -113,12 +115,52 @@ class RAGStorage(BaseRAGStorage):
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
elif provider == "watson":
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import (
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=config.get("api_key"), url=config.get("api_url")
),
project_id=config.get("project_id"),
)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print("Error during Watson embedding:", e)
raise e
self.embedder_config = WatsonEmbeddingFunction()
else:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]"
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]"
)
else:
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
validate_embedding_function(self.embedder_config)
self.embedder_config = self.embedder_config
def _initialize_app(self):