mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 20:38:29 +00:00
Compare commits
5 Commits
devin/1750
...
feat/ibm-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cd79b7345 | ||
|
|
377793af42 | ||
|
|
56cea8fb93 | ||
|
|
9933d8f880 | ||
|
|
12b0cf6100 |
@@ -254,6 +254,31 @@ my_crew = Crew(
|
||||
)
|
||||
```
|
||||
|
||||
### Using Watson embeddings
|
||||
|
||||
```python Code
|
||||
from crewai import Crew, Agent, Task, Process
|
||||
|
||||
# Note: Ensure you have installed and imported `ibm_watsonx_ai` for Watson embeddings to work.
|
||||
|
||||
my_crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder={
|
||||
"provider": "watson",
|
||||
"config": {
|
||||
"model": "<model_name>",
|
||||
"api_url": "<api_url>",
|
||||
"api_key": "<YOUR_API_KEY>",
|
||||
"project_id": "<YOUR_PROJECT_ID>",
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Resetting Memory
|
||||
|
||||
```shell
|
||||
|
||||
@@ -34,6 +34,7 @@ class ContextualMemory:
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in stm_results]
|
||||
)
|
||||
print("formatted_results stm", formatted_results)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
@@ -53,6 +54,8 @@ class ContextualMemory:
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
print("formatted_results ltm", formatted_results)
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
@@ -64,4 +67,5 @@ class ContextualMemory:
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
print("formatted_results em", formatted_results)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
@@ -4,13 +4,13 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from typing import cast
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -21,9 +21,11 @@ def suppress_logging(
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
||||
io.StringIO()
|
||||
), contextlib.suppress(UserWarning):
|
||||
with (
|
||||
contextlib.redirect_stdout(io.StringIO()),
|
||||
contextlib.redirect_stderr(io.StringIO()),
|
||||
contextlib.suppress(UserWarning),
|
||||
):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
@@ -113,12 +115,52 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
|
||||
url=config.get("api_url"),
|
||||
)
|
||||
elif provider == "watson":
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models
|
||||
from ibm_watsonx_ai import Credentials
|
||||
from ibm_watsonx_ai.metanames import (
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
) from e
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embed_params = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(
|
||||
model_id=config.get("model"),
|
||||
params=embed_params,
|
||||
credentials=Credentials(
|
||||
api_key=config.get("api_key"), url=config.get("api_url")
|
||||
),
|
||||
project_id=config.get("project_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
|
||||
except Exception as e:
|
||||
print("Error during Watson embedding:", e)
|
||||
raise e
|
||||
|
||||
self.embedder_config = WatsonEmbeddingFunction()
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]"
|
||||
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]"
|
||||
)
|
||||
else:
|
||||
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
|
||||
validate_embedding_function(self.embedder_config)
|
||||
self.embedder_config = self.embedder_config
|
||||
|
||||
def _initialize_app(self):
|
||||
|
||||
Reference in New Issue
Block a user