mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-17 12:58:31 +00:00
Compare commits
5 Commits
tm-add-tas
...
feat/ibm-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cd79b7345 | ||
|
|
377793af42 | ||
|
|
56cea8fb93 | ||
|
|
9933d8f880 | ||
|
|
12b0cf6100 |
@@ -254,6 +254,31 @@ my_crew = Crew(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Using Watson embeddings
|
||||||
|
|
||||||
|
```python Code
|
||||||
|
from crewai import Crew, Agent, Task, Process
|
||||||
|
|
||||||
|
# Note: Ensure you have installed and imported `ibm_watsonx_ai` for Watson embeddings to work.
|
||||||
|
|
||||||
|
my_crew = Crew(
|
||||||
|
agents=[...],
|
||||||
|
tasks=[...],
|
||||||
|
process=Process.sequential,
|
||||||
|
memory=True,
|
||||||
|
verbose=True,
|
||||||
|
embedder={
|
||||||
|
"provider": "watson",
|
||||||
|
"config": {
|
||||||
|
"model": "<model_name>",
|
||||||
|
"api_url": "<api_url>",
|
||||||
|
"api_key": "<YOUR_API_KEY>",
|
||||||
|
"project_id": "<YOUR_PROJECT_ID>",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### Resetting Memory
|
### Resetting Memory
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class ContextualMemory:
|
|||||||
formatted_results = "\n".join(
|
formatted_results = "\n".join(
|
||||||
[f"- {result['context']}" for result in stm_results]
|
[f"- {result['context']}" for result in stm_results]
|
||||||
)
|
)
|
||||||
|
print("formatted_results stm", formatted_results)
|
||||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||||
|
|
||||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||||
@@ -53,6 +54,8 @@ class ContextualMemory:
|
|||||||
formatted_results = list(dict.fromkeys(formatted_results))
|
formatted_results = list(dict.fromkeys(formatted_results))
|
||||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||||
|
|
||||||
|
print("formatted_results ltm", formatted_results)
|
||||||
|
|
||||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||||
|
|
||||||
def _fetch_entity_context(self, query) -> str:
|
def _fetch_entity_context(self, query) -> str:
|
||||||
@@ -64,4 +67,5 @@ class ContextualMemory:
|
|||||||
formatted_results = "\n".join(
|
formatted_results = "\n".join(
|
||||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||||
)
|
)
|
||||||
|
print("formatted_results em", formatted_results)
|
||||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, cast
|
||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
|
||||||
from crewai.utilities.paths import db_storage_path
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from typing import cast
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -21,9 +21,11 @@ def suppress_logging(
|
|||||||
logger = logging.getLogger(logger_name)
|
logger = logging.getLogger(logger_name)
|
||||||
original_level = logger.getEffectiveLevel()
|
original_level = logger.getEffectiveLevel()
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
with (
|
||||||
io.StringIO()
|
contextlib.redirect_stdout(io.StringIO()),
|
||||||
), contextlib.suppress(UserWarning):
|
contextlib.redirect_stderr(io.StringIO()),
|
||||||
|
contextlib.suppress(UserWarning),
|
||||||
|
):
|
||||||
yield
|
yield
|
||||||
logger.setLevel(original_level)
|
logger.setLevel(original_level)
|
||||||
|
|
||||||
@@ -113,12 +115,52 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
|
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
|
||||||
url=config.get("api_url"),
|
url=config.get("api_url"),
|
||||||
)
|
)
|
||||||
|
elif provider == "watson":
|
||||||
|
try:
|
||||||
|
import ibm_watsonx_ai.foundation_models as watson_models
|
||||||
|
from ibm_watsonx_ai import Credentials
|
||||||
|
from ibm_watsonx_ai.metanames import (
|
||||||
|
EmbedTextParamsMetaNames as EmbedParams,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||||
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
|
||||||
|
embed_params = {
|
||||||
|
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||||
|
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding = watson_models.Embeddings(
|
||||||
|
model_id=config.get("model"),
|
||||||
|
params=embed_params,
|
||||||
|
credentials=Credentials(
|
||||||
|
api_key=config.get("api_key"), url=config.get("api_url")
|
||||||
|
),
|
||||||
|
project_id=config.get("project_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = embedding.embed_documents(input)
|
||||||
|
return cast(Embeddings, embeddings)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Error during Watson embedding:", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
self.embedder_config = WatsonEmbeddingFunction()
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]"
|
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
|
validate_embedding_function(self.embedder_config)
|
||||||
self.embedder_config = self.embedder_config
|
self.embedder_config = self.embedder_config
|
||||||
|
|
||||||
def _initialize_app(self):
|
def _initialize_app(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user