mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Adding long term, short term, entity and contextual memory
This commit is contained in:
@@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain.agents.agent import RunnableAgent
|
||||
from langchain.agents.tools import tool as LangChainTool
|
||||
from langchain.memory import ConversationSummaryMemory
|
||||
from langchain.tools.render import render_text_description
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
@@ -22,6 +21,7 @@ from pydantic import (
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser, ToolsHandler
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.utilities import I18N, Logger, Prompts, RPMController
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler, TokenProcess
|
||||
|
||||
@@ -96,6 +96,7 @@ class Agent(BaseModel):
|
||||
agent_executor: InstanceOf[CrewAgentExecutor] = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||
default=None, description="An instance of the ToolsHandler class."
|
||||
)
|
||||
@@ -193,6 +194,15 @@ class Agent(BaseModel):
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
|
||||
if self.crew and self.memory:
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
tools = tools or self.tools
|
||||
parsed_tools = self._parse_tools(tools)
|
||||
|
||||
@@ -258,6 +268,8 @@ class Agent(BaseModel):
|
||||
executor_args = {
|
||||
"llm": self.llm,
|
||||
"i18n": self.i18n,
|
||||
"crew": self.crew,
|
||||
"crew_agent": self,
|
||||
"tools": self._parse_tools(tools),
|
||||
"verbose": self.verbose,
|
||||
"original_tools": tools,
|
||||
@@ -274,15 +286,7 @@ class Agent(BaseModel):
|
||||
"request_within_rpm_limit"
|
||||
] = self._rpm_controller.check_or_wait
|
||||
|
||||
if self.memory:
|
||||
summary_memory = ConversationSummaryMemory(
|
||||
llm=self.llm, input_key="input", memory_key="chat_history"
|
||||
)
|
||||
executor_args["memory"] = summary_memory
|
||||
agent_args["chat_history"] = lambda x: x["chat_history"]
|
||||
prompt = Prompts(i18n=self.i18n, tools=tools).task_execution_with_memory()
|
||||
else:
|
||||
prompt = Prompts(i18n=self.i18n, tools=tools).task_execution()
|
||||
prompt = Prompts(i18n=self.i18n, tools=tools).task_execution()
|
||||
|
||||
execution_prompt = prompt.partial(
|
||||
goal=self.goal,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
@@ -12,8 +13,13 @@ from langchain_core.utils.input import get_color_mapping
|
||||
from pydantic import InstanceOf
|
||||
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
|
||||
|
||||
class CrewAgentExecutor(AgentExecutor):
|
||||
@@ -25,6 +31,8 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
tools_description: str = ""
|
||||
tools_names: str = ""
|
||||
original_tools: List[Any] = []
|
||||
crew_agent: Any = None
|
||||
crew: Any = None
|
||||
function_calling_llm: Any = None
|
||||
request_within_rpm_limit: Any = None
|
||||
tools_handler: InstanceOf[ToolsHandler] = None
|
||||
@@ -43,6 +51,52 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
self.iterations == self.force_answer_max_iterations
|
||||
) and not self.have_forced_answer
|
||||
|
||||
def _create_short_term_memory(self, output) -> None:
|
||||
if (
|
||||
self.crew_agent.memory
|
||||
and "Action: Delegate work to co-worker" not in output.log
|
||||
):
|
||||
memory = ShortTermMemoryItem(
|
||||
data=output.log,
|
||||
agent=self.crew_agent.role,
|
||||
metadata={
|
||||
"observation": self.task.description,
|
||||
},
|
||||
)
|
||||
self.crew._short_term_memory.save(memory)
|
||||
|
||||
def _create_long_term_memory(self, output) -> None:
|
||||
if self.crew_agent.memory:
|
||||
ltm_agent = TaskEvaluator(self.crew_agent)
|
||||
evaluation = ltm_agent.evaluate(self.task, output.log)
|
||||
|
||||
if isinstance(evaluation, ConverterError):
|
||||
return
|
||||
|
||||
long_term_memory = LongTermMemoryItem(
|
||||
task=self.task.description,
|
||||
agent=self.crew_agent.role,
|
||||
quality=evaluation.quality,
|
||||
datetime=str(time.time()),
|
||||
expected_output=self.task.expected_output,
|
||||
metadata={
|
||||
"suggestions": "\n".join(
|
||||
[f"- {s}" for s in evaluation.suggestions]
|
||||
),
|
||||
"quality": evaluation.quality,
|
||||
},
|
||||
)
|
||||
self.crew._long_term_memory.save(long_term_memory)
|
||||
|
||||
for entity in evaluation.entities:
|
||||
entity_memory = EntityMemoryItem(
|
||||
name=entity.name,
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
relationships="\n".join([f"- {r}" for r in entity.relationships]),
|
||||
)
|
||||
self.crew._entity_memory.save(entity_memory)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
@@ -53,7 +107,8 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
# We construct a mapping from each tool to a color, used for logging.
|
||||
color_mapping = get_color_mapping(
|
||||
[tool.name for tool in self.tools], excluded_colors=["green", "red"]
|
||||
[tool.name.casefold() for tool in self.tools],
|
||||
excluded_colors=["green", "red"],
|
||||
)
|
||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||
# Allowing human input given task setting
|
||||
@@ -63,6 +118,7 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
self.iterations = 0
|
||||
time_elapsed = 0.0
|
||||
start_time = time.time()
|
||||
|
||||
# We now enter the agent loop (until it returns something).
|
||||
while self._should_continue(self.iterations, time_elapsed):
|
||||
if not self.request_within_rpm_limit or self.request_within_rpm_limit():
|
||||
@@ -73,16 +129,21 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
|
||||
if self.step_callback:
|
||||
self.step_callback(next_step_output)
|
||||
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
# Creating long term memory
|
||||
create_long_term_memory = threading.Thread(
|
||||
target=self._create_long_term_memory, args=(next_step_output,)
|
||||
)
|
||||
create_long_term_memory.start()
|
||||
|
||||
return self._return(
|
||||
next_step_output, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
|
||||
intermediate_steps.extend(next_step_output)
|
||||
|
||||
if len(next_step_output) == 1:
|
||||
next_step_action = next_step_output[0]
|
||||
# See if tool should return directly
|
||||
@@ -91,11 +152,13 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
return self._return(
|
||||
tool_return, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
|
||||
self.iterations += 1
|
||||
time_elapsed = time.time() - start_time
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
|
||||
return self._return(output, intermediate_steps, run_manager=run_manager)
|
||||
|
||||
def _iter_next_step(
|
||||
@@ -119,6 +182,7 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
return
|
||||
|
||||
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
||||
|
||||
# Call the LLM to see what to do.
|
||||
output = self.agent.plan(
|
||||
intermediate_steps,
|
||||
@@ -152,8 +216,10 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
else:
|
||||
raise ValueError("Got unexpected type of `handle_parsing_errors`")
|
||||
output = AgentAction("_Exception", observation, "")
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(output, color="green")
|
||||
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = ExceptionTool().run(
|
||||
output.tool_input,
|
||||
@@ -193,9 +259,12 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
yield output
|
||||
return
|
||||
|
||||
self._create_short_term_memory(output)
|
||||
|
||||
actions: List[AgentAction]
|
||||
actions = [output] if isinstance(output, AgentAction) else output
|
||||
yield from actions
|
||||
|
||||
for agent_action in actions:
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(agent_action, color="green")
|
||||
@@ -215,15 +284,16 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
if isinstance(tool_calling, ToolUsageErrorException):
|
||||
observation = tool_calling.message
|
||||
else:
|
||||
if tool_calling.tool_name.lower().strip() in [
|
||||
name.lower().strip() for name in name_to_tool_map
|
||||
if tool_calling.tool_name.casefold().strip() in [
|
||||
name.casefold().strip() for name in name_to_tool_map
|
||||
]:
|
||||
observation = tool_usage.use(tool_calling, agent_action.log)
|
||||
else:
|
||||
observation = self._i18n.errors("wrong_tool_name").format(
|
||||
tool=tool_calling.tool_name,
|
||||
tools=", ".join([tool.name for tool in self.tools]),
|
||||
tools=", ".join([tool.name.casefold() for tool in self.tools]),
|
||||
)
|
||||
|
||||
yield AgentStep(action=agent_action, observation=observation)
|
||||
|
||||
def _ask_human_input(self, final_answer: dict) -> str:
|
||||
|
||||
1
src/crewai/cli/templates/.gitignore
vendored
1
src/crewai/cli/templates/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
.env
|
||||
.db
|
||||
__pycache__/
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
@@ -18,6 +21,9 @@ from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry import Telemetry
|
||||
@@ -33,6 +39,7 @@ class Crew(BaseModel):
|
||||
tasks: List of tasks assigned to the crew.
|
||||
agents: List of agents part of this crew.
|
||||
manager_llm: The language model that will run manager agent.
|
||||
memory: Whether the crew should use memory to store memories of it's execution.
|
||||
manager_callbacks: The callback handlers to be executed by the manager agent when hierarchical process is used
|
||||
cache: Whether the crew should use a cache to store the results of the tools execution.
|
||||
function_calling_llm: The language model that will run the tool calling for all the agents.
|
||||
@@ -42,6 +49,7 @@ class Crew(BaseModel):
|
||||
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
|
||||
id: A unique identifier for the crew instance.
|
||||
full_output: Whether the crew should return the full output with all tasks outputs or just the final output.
|
||||
task_callback: Callback to be executed after each task for every agents execution.
|
||||
step_callback: Callback to be executed after each step for every agents execution.
|
||||
share_crew: Whether you want to share the complete crew infromation and execution with crewAI to make the library better, and allow us to train models.
|
||||
"""
|
||||
@@ -51,12 +59,24 @@ class Crew(BaseModel):
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
||||
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
|
||||
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
|
||||
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()
|
||||
|
||||
cache: bool = Field(default=True)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
tasks: List[Task] = Field(default_factory=list)
|
||||
agents: List[Agent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: Union[int, bool] = Field(default=0)
|
||||
memory: bool = Field(
|
||||
default=True,
|
||||
description="Whether the crew should use memory to store memories of it's execution",
|
||||
)
|
||||
embedder: Optional[dict] = Field(
|
||||
default={"provider": "openai"},
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
usage_metrics: Optional[dict] = Field(
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
@@ -82,6 +102,10 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="Callback to be executed after each step for all agents execution.",
|
||||
)
|
||||
task_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each task for all agents execution.",
|
||||
)
|
||||
max_rpm: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Maximum number of requests per minute for the crew execution to be respected.",
|
||||
@@ -90,6 +114,10 @@ class Crew(BaseModel):
|
||||
default="en",
|
||||
description="Language used for the crew, defaults to English.",
|
||||
)
|
||||
language_file: str = Field(
|
||||
default=None,
|
||||
description="Path to the language file to be used for the crew.",
|
||||
)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
@@ -126,6 +154,19 @@ class Crew(BaseModel):
|
||||
self._telemetry.crew_creation(self)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_memory(self) -> "Crew":
|
||||
"""Set private attributes."""
|
||||
if self.memory:
|
||||
storage_dir = Path(".db")
|
||||
storage_dir.mkdir(exist_ok=True)
|
||||
if sys.platform.startswith("win"):
|
||||
subprocess.call(["attrib", "+H", str(storage_dir)])
|
||||
self._long_term_memory = LongTermMemory()
|
||||
self._short_term_memory = ShortTermMemory(embedder_config=self.embedder)
|
||||
self._entity_memory = EntityMemory(embedder_config=self.embedder)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_manager_llm(self):
|
||||
"""Validates that the language model is set when using hierarchical process."""
|
||||
@@ -190,16 +231,20 @@ class Crew(BaseModel):
|
||||
"""Starts the crew to work on its assigned tasks."""
|
||||
self._execution_span = self._telemetry.crew_execution_span(self)
|
||||
self._interpolate_inputs(inputs)
|
||||
self._set_tasks_callbacks()
|
||||
|
||||
i18n = I18N(language=self.language, language_file=self.language_file)
|
||||
|
||||
for agent in self.agents:
|
||||
agent.i18n = I18N(language=self.language)
|
||||
agent.i18n = i18n
|
||||
agent.crew = self
|
||||
|
||||
if not agent.function_calling_llm:
|
||||
agent.function_calling_llm = self.function_calling_llm
|
||||
agent.create_agent_executor()
|
||||
if not agent.step_callback:
|
||||
agent.step_callback = self.step_callback
|
||||
agent.create_agent_executor()
|
||||
|
||||
agent.create_agent_executor()
|
||||
|
||||
metrics = []
|
||||
|
||||
@@ -253,7 +298,7 @@ class Crew(BaseModel):
|
||||
def _run_hierarchical_process(self) -> str:
|
||||
"""Creates and assigns a manager agent to make sure the crew completes the tasks."""
|
||||
|
||||
i18n = I18N(language=self.language)
|
||||
i18n = I18N(language=self.language, language_file=self.language_file)
|
||||
manager = Agent(
|
||||
role=i18n.retrieve("hierarchical_manager_agent", "role"),
|
||||
goal=i18n.retrieve("hierarchical_manager_agent", "goal"),
|
||||
@@ -277,6 +322,11 @@ class Crew(BaseModel):
|
||||
self._finish_execution(task_output)
|
||||
return self._format_output(task_output), manager._token_process.get_summary()
|
||||
|
||||
def _set_tasks_callbacks(self) -> str:
|
||||
"""Sets callback for every task suing task_callback"""
|
||||
for task in self.tasks:
|
||||
task.callback = self.task_callback
|
||||
|
||||
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> str:
|
||||
"""Interpolates the inputs in the tasks and agents."""
|
||||
[task.interpolate_inputs(inputs) for task in self.tasks]
|
||||
|
||||
3
src/crewai/memory/__init__.py
Normal file
3
src/crewai/memory/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .entity.entity_memory import EntityMemory
|
||||
from .long_term.long_term_memory import LongTermMemory
|
||||
from .short_term.short_term_memory import ShortTermMemory
|
||||
0
src/crewai/memory/contextual/__init__.py
Normal file
0
src/crewai/memory/contextual/__init__.py
Normal file
58
src/crewai/memory/contextual/contextual_memory.py
Normal file
58
src/crewai/memory/contextual/contextual_memory.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
def __init__(self, stm: ShortTermMemory, ltm: LongTermMemory, em: EntityMemory):
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
self.em = em
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
context = []
|
||||
context.append(self._fetch_ltm_context(task.description))
|
||||
context.append(self._fetch_stm_context(query))
|
||||
context.append(self._fetch_entity_context(query))
|
||||
return "\n".join(filter(None, context))
|
||||
|
||||
def _fetch_stm_context(self, query) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
stm_results = self.stm.search(query)
|
||||
formatted_results = "\n".join([f"- {result}" for result in stm_results])
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task) -> str:
|
||||
"""
|
||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
ltm_results = self.ltm.search(task)
|
||||
if not ltm_results:
|
||||
return None
|
||||
formatted_results = "\n".join(
|
||||
[f"{result['metadata']['suggestions']}" for result in ltm_results]
|
||||
)
|
||||
formatted_results = list(set(formatted_results))
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
"""
|
||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||
formatted as bullet points.
|
||||
"""
|
||||
em_results = self.em.search(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in em_results]
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
0
src/crewai/memory/entity/__init__.py
Normal file
0
src/crewai/memory/entity/__init__.py
Normal file
22
src/crewai/memory/entity/entity_memory.py
Normal file
22
src/crewai/memory/entity/entity_memory.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class EntityMemory(Memory):
|
||||
"""
|
||||
EntityMemory class for managing structured information about entities
|
||||
and their relationships using SQLite storage.
|
||||
Inherits from the Memory class.
|
||||
"""
|
||||
|
||||
def __init__(self, embedder_config=None):
|
||||
storage = RAGStorage(
|
||||
type="entities", allow_reset=False, embedder_config=embedder_config
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None:
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
super().save(data, item.metadata)
|
||||
12
src/crewai/memory/entity/entity_memory_item.py
Normal file
12
src/crewai/memory/entity/entity_memory_item.py
Normal file
@@ -0,0 +1,12 @@
|
||||
class EntityMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
type: str,
|
||||
description: str,
|
||||
relationships: str,
|
||||
):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.description = description
|
||||
self.metadata = {"relationships": relationships}
|
||||
0
src/crewai/memory/long_term/__init__.py
Normal file
0
src/crewai/memory/long_term/__init__.py
Normal file
32
src/crewai/memory/long_term/long_term_memory.py
Normal file
32
src/crewai/memory/long_term/long_term_memory.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
class LongTermMemory(Memory):
|
||||
"""
|
||||
LongTermMemory class for managing cross runs data related to overall crew's
|
||||
execution and performance.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
storage = LTMSQLiteStorage()
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None:
|
||||
metadata = item.metadata
|
||||
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def search(self, task: str) -> Dict[str, Any]:
|
||||
return self.storage.load(task)
|
||||
19
src/crewai/memory/long_term/long_term_memory_item.py
Normal file
19
src/crewai/memory/long_term/long_term_memory_item.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
|
||||
class LongTermMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
agent: str,
|
||||
task: str,
|
||||
expected_output: str,
|
||||
datetime: str,
|
||||
quality: Union[int, float] = None,
|
||||
metadata: Dict[str, Any] = None,
|
||||
):
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.quality = quality
|
||||
self.datetime = datetime
|
||||
self.expected_output = expected_output
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
23
src/crewai/memory/memory.py
Normal file
23
src/crewai/memory/memory.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
|
||||
|
||||
class Memory:
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: Storage):
|
||||
self.storage = storage
|
||||
|
||||
def save(
|
||||
self, value: Any, metadata: Dict[str, Any] = None, agent: str = None
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
if agent:
|
||||
metadata["agent"] = agent
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
def search(self, query: str) -> Dict[str, Any]:
|
||||
return self.storage.search(query)
|
||||
0
src/crewai/memory/short_term/__init__.py
Normal file
0
src/crewai/memory/short_term/__init__.py
Normal file
23
src/crewai/memory/short_term/short_term_memory.py
Normal file
23
src/crewai/memory/short_term/short_term_memory.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
class ShortTermMemory(Memory):
|
||||
"""
|
||||
ShortTermMemory class for managing transient data related to immediate tasks
|
||||
and interactions.
|
||||
Inherits from the Memory class and utilizes an instance of a class that
|
||||
adheres to the Storage for data storage, specifically working with
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, embedder_config=None):
|
||||
storage = RAGStorage(type="short_term", embedder_config=embedder_config)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: ShortTermMemoryItem) -> None:
|
||||
super().save(item.data, item.metadata, item.agent)
|
||||
|
||||
def search(self, query: str, score_threshold: float = 0.35):
|
||||
return self.storage.search(query=query, score_threshold=score_threshold)
|
||||
8
src/crewai/memory/short_term/short_term_memory_item.py
Normal file
8
src/crewai/memory/short_term/short_term_memory_item.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class ShortTermMemoryItem:
|
||||
def __init__(self, data: Any, agent: str, metadata: Dict[str, Any] = None):
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
self.metadata = metadata if metadata is not None else {}
|
||||
11
src/crewai/memory/storage/interface.py
Normal file
11
src/crewai/memory/storage/interface.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Abstract base class defining the storage interface"""
|
||||
|
||||
def save(self, key: str, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(self, key: str) -> Dict[str, Any]:
|
||||
pass
|
||||
100
src/crewai/memory/storage/ltm_sqlite_storage.py
Normal file
100
src/crewai/memory/storage/ltm_sqlite_storage.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from crewai.utilities import Printer
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path=".db/long_term_memory_storage.db"):
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS long_term_memories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_description TEXT,
|
||||
metadata TEXT,
|
||||
datetime TEXT,
|
||||
score REAL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred during database initialization: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def save(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: Dict[str, Any],
|
||||
datetime: str,
|
||||
score: Union[int, float],
|
||||
) -> None:
|
||||
"""Saves data to the LTM table with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(task_description, json.dumps(metadata), datetime, score),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def load(self, task_description: str) -> Dict[str, Any]:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT 2
|
||||
""",
|
||||
(task_description,),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
{
|
||||
"metadata": json.loads(row[0]),
|
||||
"datetime": row[1],
|
||||
"score": row[2],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
87
src/crewai/memory/storage/rag_storage.py
Normal file
87
src/crewai/memory/storage/rag_storage.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
from crewai.memory.storage.interface import Storage
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(
|
||||
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
|
||||
level=logging.ERROR,
|
||||
):
|
||||
logger = logging.getLogger(logger_name)
|
||||
original_level = logger.getEffectiveLevel()
|
||||
logger.setLevel(level)
|
||||
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
||||
io.StringIO()
|
||||
), contextlib.suppress(UserWarning):
|
||||
yield
|
||||
logger.setLevel(original_level)
|
||||
|
||||
|
||||
class FakeLLM(BaseLlm):
|
||||
pass
|
||||
|
||||
|
||||
class RAGStorage(Storage):
|
||||
"""
|
||||
Extends Storage to handle embeddings for memory entries, improving
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
def __init__(self, type, allow_reset=True, embedder_config=None):
|
||||
super().__init__()
|
||||
config = {
|
||||
"app": {
|
||||
"config": {"name": type, "collect_metrics": False, "log_level": "ERROR"}
|
||||
},
|
||||
"chunker": {
|
||||
"chunk_size": 5000,
|
||||
"chunk_overlap": 100,
|
||||
"length_function": "len",
|
||||
"min_chunk_size": 150,
|
||||
},
|
||||
"vectordb": {
|
||||
"provider": "chroma",
|
||||
"config": {
|
||||
"collection_name": type,
|
||||
"dir": f".db/{type}",
|
||||
"allow_reset": allow_reset,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if embedder_config:
|
||||
config["embedder"] = embedder_config
|
||||
|
||||
self.app = App.from_config(config=config)
|
||||
self.app.llm = FakeLLM()
|
||||
if allow_reset:
|
||||
self.app.reset()
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
self._generate_embedding(value, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
filter: dict = None,
|
||||
score_threshold: float = 0.35,
|
||||
) -> Dict[str, Any]:
|
||||
with suppress_logging():
|
||||
results = (
|
||||
self.app.search(query, limit, where=filter)
|
||||
if filter
|
||||
else self.app.search(query, limit)
|
||||
)
|
||||
return [r for r in results if r["metadata"]["score"] >= score_threshold]
|
||||
|
||||
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
|
||||
with suppress_logging():
|
||||
self.app.add(text, data_type="text", metadata=metadata)
|
||||
@@ -24,6 +24,7 @@ class Task(BaseModel):
|
||||
delegations: int = 0
|
||||
i18n: I18N = I18N()
|
||||
thread: threading.Thread = None
|
||||
prompt_context: Optional[str] = None
|
||||
description: str = Field(description="Description of the actual task.")
|
||||
expected_output: str = Field(
|
||||
description="Clear definition of expected output for the task."
|
||||
@@ -144,6 +145,7 @@ class Task(BaseModel):
|
||||
context.append(task.output.raw_output)
|
||||
context = "\n".join(context)
|
||||
|
||||
self.prompt_context = context
|
||||
tools = tools or self.tools
|
||||
|
||||
if self.async_execution:
|
||||
|
||||
@@ -3,7 +3,6 @@ import importlib.resources
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
import pkg_resources
|
||||
@@ -48,21 +47,17 @@ class Telemetry:
|
||||
attributes={SERVICE_NAME: "crewAI-telemetry"},
|
||||
)
|
||||
self.provider = TracerProvider(resource=self.resource)
|
||||
|
||||
cert_file = importlib.resources.files("crewai.telemetry").joinpath(
|
||||
"STAR_crewai_com_bundle.pem"
|
||||
)
|
||||
ssl_context = ssl.create_default_context()
|
||||
with cert_file.open("rb") as cert:
|
||||
ssl_context.load_verify_locations(cadata=cert.read())
|
||||
|
||||
processor = BatchSpanProcessor(
|
||||
OTLPSpanExporter(
|
||||
endpoint=f"{telemetry_endpoint}/v1/traces",
|
||||
ssl_context=ssl_context,
|
||||
certificate_file=cert_file,
|
||||
timeout=30,
|
||||
)
|
||||
)
|
||||
|
||||
self.provider.add_span_processor(processor)
|
||||
self.ready = True
|
||||
except BaseException as e:
|
||||
@@ -114,7 +109,9 @@ class Telemetry:
|
||||
"i18n": agent.i18n.language,
|
||||
"llm": json.dumps(self._safe_llm_attributes(agent.llm)),
|
||||
"delegation_enabled?": agent.allow_delegation,
|
||||
"tools_names": [tool.name for tool in agent.tools],
|
||||
"tools_names": [
|
||||
tool.name.casefold() for tool in agent.tools
|
||||
],
|
||||
}
|
||||
for agent in crew.agents
|
||||
]
|
||||
@@ -129,7 +126,9 @@ class Telemetry:
|
||||
"id": str(task.id),
|
||||
"async_execution?": task.async_execution,
|
||||
"agent_role": task.agent.role if task.agent else "None",
|
||||
"tools_names": [tool.name for tool in task.tools],
|
||||
"tools_names": [
|
||||
tool.name.casefold() for tool in task.tools
|
||||
],
|
||||
}
|
||||
for task in crew.tasks
|
||||
]
|
||||
@@ -217,7 +216,9 @@ class Telemetry:
|
||||
"i18n": agent.i18n.language,
|
||||
"llm": json.dumps(self._safe_llm_attributes(agent.llm)),
|
||||
"delegation_enabled?": agent.allow_delegation,
|
||||
"tools_names": [tool.name for tool in agent.tools],
|
||||
"tools_names": [
|
||||
tool.name.casefold() for tool in agent.tools
|
||||
],
|
||||
}
|
||||
for agent in crew.agents
|
||||
]
|
||||
@@ -237,7 +238,9 @@ class Telemetry:
|
||||
"context": [task.description for task in task.context]
|
||||
if task.context
|
||||
else "None",
|
||||
"tools_names": [tool.name for tool in task.tools],
|
||||
"tools_names": [
|
||||
tool.name.casefold() for tool in task.tools
|
||||
],
|
||||
}
|
||||
for task in crew.tasks
|
||||
]
|
||||
|
||||
@@ -163,8 +163,6 @@ class ToolUsage:
|
||||
|
||||
if self.tools_handler:
|
||||
should_cache = True
|
||||
print("FORA")
|
||||
print(tool)
|
||||
original_tool = next(
|
||||
(ot for ot in self.original_tools if ot.name == tool.name), None
|
||||
)
|
||||
@@ -172,8 +170,6 @@ class ToolUsage:
|
||||
hasattr(original_tool, "cache_function")
|
||||
and original_tool.cache_function
|
||||
):
|
||||
print("CARALHOOOO")
|
||||
print(original_tool.cache_function)
|
||||
should_cache = original_tool.cache_function(
|
||||
calling.arguments, result
|
||||
)
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
},
|
||||
"slices": {
|
||||
"observation": "\nObservation",
|
||||
"task": "Current Task: {input}\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought: ",
|
||||
"memory": "This is the summary of your work so far:\n{chat_history}",
|
||||
"task": "\nCurrent Task: {input}\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought: ",
|
||||
"memory": "\n\n# Useful context: \n{memory}",
|
||||
"role_playing": "You are {role}. {backstory}\nYour personal goal is: {goal}",
|
||||
"tools": "\nYou ONLY have access to the following tools, and should NEVER make up tools that are not listed here:\n\n{tools}\n\nUse the following format:\n\nThought: you should always think about what to do\nAction: the action to take, only one name of [{tool_names}], just the name, exactly as it's written.\nAction Input: the input to the action, just a simple a python dictionary using \" to wrap keys and values.\nObservation: the result of the action\n\nOnce all necessary information is gathered:\n\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\n",
|
||||
"no_tools": "To give my best complete final answer to the task use the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: my best complete final answer to the task.\nYour final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!",
|
||||
|
||||
61
src/crewai/utilities/evaluators/task_evaluator.py
Normal file
61
src/crewai/utilities/evaluators/task_evaluator.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities import Converter
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
name: str = Field(description="The name of the entity.")
|
||||
type: str = Field(description="The type of the entity.")
|
||||
description: str = Field(description="Description of the entity.")
|
||||
relationships: List[str] = Field(description="Relationships of the entity.")
|
||||
|
||||
|
||||
class TaskEvaluation(BaseModel):
|
||||
suggestions: List[str] = Field(
|
||||
description="Suggestions to improve future similar tasks."
|
||||
)
|
||||
quality: float = Field(
|
||||
description="A score from 0 to 10 evaluating on completion, quality, and overall performance, all taking into account the task description, expected output, and the result of the task."
|
||||
)
|
||||
entities: List[Entity] = Field(
|
||||
description="Entities extracted from the task output."
|
||||
)
|
||||
|
||||
|
||||
class TaskEvaluator:
|
||||
def __init__(self, original_agent):
|
||||
self.llm = original_agent.llm
|
||||
|
||||
def evaluate(self, task, ouput) -> TaskEvaluation:
|
||||
evaluation_query = (
|
||||
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
||||
f"Task Description:\n{task.description}\n\n"
|
||||
f"Expected Output:\n{task.expected_output}\n\n"
|
||||
f"Actual Output:\n{ouput}\n\n"
|
||||
"Please provide:\n"
|
||||
"- Bullet points suggestions to improve future similar tasks\n"
|
||||
"- A score from 0 to 10 evaluating on completion, quality, and overall performance"
|
||||
"- Entities extracted from the task output, if any, their type, description, and relationships"
|
||||
)
|
||||
|
||||
instructions = "I'm gonna convert this raw text into valid JSON."
|
||||
|
||||
if not self._is_gpt(self.llm):
|
||||
model_schema = PydanticSchemaParser(model=TaskEvaluation).get_schema()
|
||||
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
|
||||
|
||||
converter = Converter(
|
||||
llm=self.llm,
|
||||
text=evaluation_query,
|
||||
model=TaskEvaluation,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
return converter.to_pydantic()
|
||||
|
||||
def _is_gpt(self, llm) -> bool:
|
||||
return isinstance(llm, ChatOpenAI) and llm.openai_api_base == None
|
||||
@@ -13,16 +13,6 @@ class Prompts(BaseModel):
|
||||
tools: list[Any] = Field(default=[])
|
||||
SCRATCHPAD_SLICE: ClassVar[str] = "\n{agent_scratchpad}"
|
||||
|
||||
def task_execution_with_memory(self) -> BasePromptTemplate:
|
||||
"""Generate a prompt for task execution with memory components."""
|
||||
slices = ["role_playing"]
|
||||
if len(self.tools) > 0:
|
||||
slices.append("tools")
|
||||
else:
|
||||
slices.append("no_tools")
|
||||
slices.extend(["memory", "task"])
|
||||
return self._build_prompt(slices)
|
||||
|
||||
def task_execution_without_tools(self) -> BasePromptTemplate:
|
||||
"""Generate a prompt for task execution without tools components."""
|
||||
return self._build_prompt(["role_playing", "task"])
|
||||
|
||||
@@ -48,36 +48,6 @@ def test_custom_llm():
|
||||
assert agent.llm.temperature == 0
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_without_memory():
|
||||
no_memory_agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
memory=False,
|
||||
llm=ChatOpenAI(temperature=0, model="gpt-4"),
|
||||
)
|
||||
|
||||
memory_agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
memory=True,
|
||||
llm=ChatOpenAI(temperature=0, model="gpt-4"),
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="How much is 1 + 1?",
|
||||
agent=no_memory_agent,
|
||||
expected_output="the result of the math operation.",
|
||||
)
|
||||
result = no_memory_agent.execute_task(task)
|
||||
|
||||
assert result == "The result of the math operation 1 + 1 is 2."
|
||||
assert no_memory_agent.agent_executor.memory is None
|
||||
assert memory_agent.agent_executor.memory is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execution():
|
||||
agent = Agent(
|
||||
@@ -403,7 +373,6 @@ def test_agent_repeated_tool_usage_check_even_with_disabled_cache(capsys):
|
||||
)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
print(captured.out)
|
||||
assert (
|
||||
"I tried reusing the same input, I must stop using this action input. I'll try something else instead."
|
||||
in captured.out
|
||||
|
||||
@@ -648,9 +648,9 @@ def test_agent_usage_metrics_are_captured_for_sequential_process():
|
||||
assert result == "Howdy!"
|
||||
assert crew.usage_metrics == {
|
||||
"completion_tokens": 56,
|
||||
"prompt_tokens": 164,
|
||||
"prompt_tokens": 161,
|
||||
"successful_requests": 1,
|
||||
"total_tokens": 220,
|
||||
"total_tokens": 217,
|
||||
}
|
||||
|
||||
|
||||
@@ -677,8 +677,8 @@ def test_agent_usage_metrics_are_captured_for_hierarchical_process():
|
||||
result = crew.kickoff()
|
||||
assert result == "Howdy!"
|
||||
assert crew.usage_metrics == {
|
||||
"total_tokens": 1513,
|
||||
"prompt_tokens": 1299,
|
||||
"total_tokens": 1510,
|
||||
"prompt_tokens": 1296,
|
||||
"completion_tokens": 214,
|
||||
"successful_requests": 3,
|
||||
}
|
||||
@@ -735,6 +735,36 @@ def test_crew_inputs_interpolate_both_agents_and_tasks():
|
||||
interpolate_task_inputs.assert_called()
|
||||
|
||||
|
||||
def test_task_callback_on_crew():
|
||||
from unittest.mock import patch
|
||||
|
||||
researcher_agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Make the best research and analysis on content about AI and AI agents",
|
||||
backstory="You're an expert researcher, specialized in technology, software engineering, AI and startups. You work as a freelancer and is now working on doing research and analysis for a new customer.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
list_ideas = Task(
|
||||
description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.",
|
||||
expected_output="Bullet point list of 5 important events.",
|
||||
agent=researcher_agent,
|
||||
async_execution=True,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher_agent],
|
||||
process=Process.sequential,
|
||||
tasks=[list_ideas],
|
||||
task_callback=lambda: None,
|
||||
)
|
||||
|
||||
with patch.object(Agent, "execute_task") as execute:
|
||||
execute.return_value = "ok"
|
||||
crew.kickoff()
|
||||
assert list_ideas.callback is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_tools_with_custom_caching():
|
||||
from unittest.mock import patch
|
||||
@@ -748,7 +778,6 @@ def test_tools_with_custom_caching():
|
||||
|
||||
def cache_func(args, result):
|
||||
cache = result % 2 == 0
|
||||
print(f"cache?: {cache}")
|
||||
return cache
|
||||
|
||||
multiplcation_tool.cache_function = cache_func
|
||||
|
||||
29
tests/memory/long_term_memory_test.py
Normal file
29
tests/memory/long_term_memory_test.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import pytest
|
||||
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory():
|
||||
"""Fixture to create a LongTermMemory instance"""
|
||||
return LongTermMemory()
|
||||
|
||||
|
||||
def test_save_and_search(long_term_memory):
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
find = long_term_memory.search("test_task")[0]
|
||||
assert find["score"] == 0.5
|
||||
assert find["datetime"] == "test_datetime"
|
||||
assert find["metadata"]["agent"] == "test_agent"
|
||||
assert find["metadata"]["quality"] == 0.5
|
||||
assert find["metadata"]["task"] == "test_task"
|
||||
assert find["metadata"]["expected_output"] == "test_output"
|
||||
24
tests/memory/short_term_memory_test.py
Normal file
24
tests/memory/short_term_memory_test.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory():
|
||||
"""Fixture to create a ShortTermMemory instance"""
|
||||
return ShortTermMemory()
|
||||
|
||||
|
||||
def test_save_and_search(short_term_memory):
|
||||
memory = ShortTermMemoryItem(
|
||||
data="""test value test value test value test value test value test value
|
||||
test value test value test value test value test value test value
|
||||
test value test value test value test value test value test value""",
|
||||
agent="test_agent",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
short_term_memory.save(memory)
|
||||
find = short_term_memory.search("test value", score_threshold=0.01)[0]
|
||||
assert find["context"] == memory.data, "Data value mismatch."
|
||||
assert find["metadata"]["agent"] == "test_agent", "Agent value mismatch."
|
||||
Reference in New Issue
Block a user