Feat/memory base (#1444)

* byom - short/entity memory

* better

* rm uneeded

* fix text

* use context

* rm dep and sync

* type check fix

* fixed test using new cassete

* fixing types

* fixed types

* fix types

* fixed types

* fixing types

* fix type

* cassette update

* just mock the return of short term mem

* remove print

* try catch block

* added docs

* dding error handling here
This commit is contained in:
Lorenze Jay
2024-10-17 09:19:33 -07:00
committed by GitHub
parent 67f55bae2c
commit 6d20ba70a1
14 changed files with 241 additions and 558 deletions

View File

@@ -17,7 +17,7 @@ if TYPE_CHECKING:
class CrewAgentExecutorMixin:
crew: Optional["Crew"]
crew_agent: Optional["BaseAgent"]
agent: Optional["BaseAgent"]
task: Optional["Task"]
iterations: int
have_forced_answer: bool
@@ -33,9 +33,9 @@ class CrewAgentExecutorMixin:
"""Create and save a short-term memory item if conditions are met."""
if (
self.crew
and self.crew_agent
and self.agent
and self.task
and "Action: Delegate work to coworker" not in output.log
and "Action: Delegate work to coworker" not in output.text
):
try:
if (
@@ -43,11 +43,11 @@ class CrewAgentExecutorMixin:
and self.crew._short_term_memory
):
self.crew._short_term_memory.save(
value=output.log,
value=output.text,
metadata={
"observation": self.task.description,
},
agent=self.crew_agent.role,
agent=self.agent.role,
)
except Exception as e:
print(f"Failed to add to short term memory: {e}")
@@ -61,18 +61,18 @@ class CrewAgentExecutorMixin:
and self.crew._long_term_memory
and self.crew._entity_memory
and self.task
and self.crew_agent
and self.agent
):
try:
ltm_agent = TaskEvaluator(self.crew_agent)
evaluation = ltm_agent.evaluate(self.task, output.log)
ltm_agent = TaskEvaluator(self.agent)
evaluation = ltm_agent.evaluate(self.task, output.text)
if isinstance(evaluation, ConverterError):
return
long_term_memory = LongTermMemoryItem(
task=self.task.description,
agent=self.crew_agent.role,
agent=self.agent.role,
quality=evaluation.quality,
datetime=str(time.time()),
expected_output=self.task.expected_output,

View File

@@ -19,6 +19,7 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
)
from crewai.utilities.logger import Logger
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
class CrewAgentExecutor(CrewAgentExecutorMixin):
@@ -29,7 +30,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
llm: Any,
task: Any,
crew: Any,
agent: Any,
agent: BaseAgent,
prompt: dict[str, str],
max_iter: int,
tools: List[Any],
@@ -103,7 +104,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.crew and self.crew._train:
self._handle_crew_training_output(formatted_answer)
self._create_short_term_memory(formatted_answer)
self._create_long_term_memory(formatted_answer)
return {"output": formatted_answer.output}
def _invoke_loop(self, formatted_answer=None):
@@ -176,6 +178,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return formatted_answer
def _show_start_logs(self):
if self.agent is None:
raise ValueError("Agent cannot be None")
if self.agent.verbose or (
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
):
@@ -188,6 +192,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
if self.agent is None:
raise ValueError("Agent cannot be None")
if self.agent.verbose or (
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
):
@@ -306,7 +312,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self, result: AgentFinish, human_feedback: str | None = None
) -> None:
"""Function to handle the process of the training data."""
agent_id = str(self.agent.id)
agent_id = str(self.agent.id) # type: ignore
# Load training data
training_handler = CrewTrainingHandler(TRAINING_DATA_FILE)
@@ -339,7 +345,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"initial_output": result.output,
"human_feedback": human_feedback,
"agent": agent_id,
"agent_role": self.agent.role,
"agent_role": self.agent.role, # type: ignore
}
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
train_iteration = self.crew._train_iteration

View File

@@ -126,8 +126,8 @@ class Crew(BaseModel):
default=None,
description="An Instance of the EntityMemory to be used by the Crew",
)
embedder: Optional[dict] = Field(
default={"provider": "openai"},
embedder: Optional[Any] = Field(
default=None,
description="Configuration for the embedder to be used for the crew.",
)
usage_metrics: Optional[UsageMetrics] = Field(
@@ -774,7 +774,9 @@ class Crew(BaseModel):
def _log_task_start(self, task: Task, role: str = "None"):
if self.output_log_file:
self._file_handler.log(task_name=task.name, task=task.description, agent=role, status="started")
self._file_handler.log(
task_name=task.name, task=task.description, agent=role, status="started"
)
def _update_manager_tools(self, task: Task):
if self.manager_agent:
@@ -796,7 +798,13 @@ class Crew(BaseModel):
def _process_task_result(self, task: Task, output: TaskOutput) -> None:
role = task.agent.role if task.agent is not None else "None"
if self.output_log_file:
self._file_handler.log(task_name=task.name, task=task.description, agent=role, status="completed", output=output.raw)
self._file_handler.log(
task_name=task.name,
task=task.description,
agent=role,
status="completed",
output=output.raw,
)
def _create_crew_output(self, task_outputs: List[TaskOutput]) -> CrewOutput:
if len(task_outputs) != 1:

View File

@@ -31,7 +31,9 @@ class ContextualMemory:
formatted as bullet points.
"""
stm_results = self.stm.search(query)
formatted_results = "\n".join([f"- {result}" for result in stm_results])
formatted_results = "\n".join(
[f"- {result['context']}" for result in stm_results]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@@ -28,7 +28,7 @@ class LongTermMemory(Memory):
datetime=item.datetime,
)
def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]:
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None:

View File

@@ -1,6 +1,6 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List
from crewai.memory.storage.interface import Storage
from crewai.memory.storage.rag_storage import RAGStorage
class Memory:
@@ -8,7 +8,7 @@ class Memory:
Base class for memory, now supporting agent tags and generic metadata.
"""
def __init__(self, storage: Storage):
def __init__(self, storage: RAGStorage):
self.storage = storage
def save(
@@ -23,5 +23,5 @@ class Memory:
self.storage.save(value, metadata)
def search(self, query: str) -> Dict[str, Any]:
def search(self, query: str) -> List[Dict[str, Any]]:
return self.storage.search(query)

View File

@@ -0,0 +1,76 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class BaseRAGStorage(ABC):
"""
Base class for RAG-based Storage implementations.
"""
app: Any | None = None
def __init__(
self,
type: str,
allow_reset: bool = True,
embedder_config: Optional[Any] = None,
crew: Any = None,
):
self.type = type
self.allow_reset = allow_reset
self.embedder_config = embedder_config
self.crew = crew
self.agents = self._initialize_agents()
def _initialize_agents(self) -> str:
if self.crew:
return "_".join(
[self._sanitize_role(agent.role) for agent in self.crew.agents]
)
return ""
@abstractmethod
def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names."""
pass
@abstractmethod
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
"""Save a value with metadata to the storage."""
pass
@abstractmethod
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
"""Search for entries in the storage."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the storage."""
pass
@abstractmethod
def _generate_embedding(
self, text: str, metadata: Optional[Dict[str, Any]] = None
) -> Any:
"""Generate an embedding for the given text and metadata."""
pass
@abstractmethod
def _initialize_app(self):
"""Initialize the vector db."""
pass
def setup_config(self, config: Dict[str, Any]):
"""Setup the config of the storage."""
pass
def initialize_client(self):
"""Initialize the client of the storage. This should setup the app and the db collection"""
pass

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List
class Storage:
@@ -7,7 +7,7 @@ class Storage:
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
pass
def search(self, key: str) -> Dict[str, Any]: # type: ignore
def search(self, key: str) -> List[Dict[str, Any]]: # type: ignore
pass
def reset(self) -> None:

View File

@@ -3,10 +3,11 @@ import io
import logging
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional
from crewai.memory.storage.interface import Storage
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from chromadb.api import ClientAPI
@contextlib.contextmanager
@@ -24,61 +25,42 @@ def suppress_logging(
logger.setLevel(original_level)
class RAGStorage(Storage):
class RAGStorage(BaseRAGStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
search efficiency.
"""
def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
super().__init__()
if (
not os.getenv("OPENAI_API_KEY")
and not os.getenv("OPENAI_BASE_URL") == "https://api.openai.com/v1"
):
os.environ["OPENAI_API_KEY"] = "fake"
app: ClientAPI | None = None
def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
self.agents = agents
config = {
"app": {
"config": {"name": type, "collect_metrics": False, "log_level": "ERROR"}
},
"chunker": {
"chunk_size": 5000,
"chunk_overlap": 100,
"length_function": "len",
"min_chunk_size": 150,
},
"vectordb": {
"provider": "chroma",
"config": {
"collection_name": type,
"dir": f"{db_storage_path()}/{type}/{agents}",
"allow_reset": allow_reset,
},
},
}
if embedder_config:
config["embedder"] = embedder_config
self.type = type
self.config = config
self.embedder_config = embedder_config or self._create_embedding_function()
self.allow_reset = allow_reset
self._initialize_app()
def _initialize_app(self):
from embedchain import App
from embedchain.llm.base import BaseLlm
import chromadb
class FakeLLM(BaseLlm):
pass
chroma_client = chromadb.PersistentClient(
path=f"{db_storage_path()}/{self.type}/{self.agents}"
)
self.app = chroma_client
self.app = App.from_config(config=self.config)
self.app.llm = FakeLLM()
if self.allow_reset:
self.app.reset()
try:
self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception:
self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config
)
def _sanitize_role(self, role: str) -> str:
"""
@@ -87,11 +69,14 @@ class RAGStorage(Storage):
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
if not hasattr(self, "app"):
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
self._generate_embedding(value, metadata)
try:
self._generate_embedding(value, metadata)
except Exception as e:
logging.error(f"Error during {self.type} save: {str(e)}")
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
def search(
self,
query: str,
limit: int = 3,
@@ -100,31 +85,50 @@ class RAGStorage(Storage):
) -> List[Any]:
if not hasattr(self, "app"):
self._initialize_app()
from embedchain.vectordb.chroma import InvalidDimensionException
with suppress_logging():
try:
results = (
self.app.search(query, limit, where=filter)
if filter
else self.app.search(query, limit)
)
except InvalidDimensionException:
self.app.reset()
return []
return [r for r in results if r["metadata"]["score"] >= score_threshold]
try:
with suppress_logging():
response = self.collection.query(query_texts=query, n_results=limit)
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
if not hasattr(self, "app"):
results = []
for i in range(len(response["ids"][0])):
result = {
"id": response["ids"][0][i],
"metadata": response["metadatas"][0][i],
"context": response["documents"][0][i],
"score": response["distances"][0][i],
}
if result["score"] >= score_threshold:
results.append(result)
return results
except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}")
return []
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> None: # type: ignore
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
from embedchain.models.data_type import DataType
self.app.add(text, data_type=DataType.TEXT, metadata=metadata)
self.collection.add(
documents=[text],
metadatas=[metadata or {}],
ids=[str(uuid.uuid4())],
)
def reset(self) -> None:
try:
shutil.rmtree(f"{db_storage_path()}/{self.type}")
if self.app:
self.app.reset()
except Exception as e:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_embedding_function(self):
import chromadb.utils.embedding_functions as embedding_functions
return embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)