Ability to disable cache at agent and crew level

This commit is contained in:
João Moura
2024-03-19 12:47:34 -03:00
parent 6d36f66a00
commit 81d5fe6fc6
6 changed files with 1980 additions and 13 deletions

View File

@@ -66,6 +66,10 @@ class Agent(BaseModel):
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
cache: bool = Field(
default=True,
description="Whether the agent should use a cache for tool usage.",
)
config: Optional[Dict[str, Any]] = Field(
description="Configuration for the agent",
default=None,
@@ -96,7 +100,7 @@ class Agent(BaseModel):
default=None, description="An instance of the ToolsHandler class."
)
cache_handler: InstanceOf[CacheHandler] = Field(
default=CacheHandler(), description="An instance of the CacheHandler class."
default=None, description="An instance of the CacheHandler class."
)
step_callback: Optional[Any] = Field(
default=None,
@@ -158,6 +162,8 @@ class Agent(BaseModel):
TokenCalcHandler(self.llm.model_name, self._token_process)
]
if not self.agent_executor:
if not self.cache_handler:
self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler)
return self
@@ -177,7 +183,8 @@ class Agent(BaseModel):
Returns:
Output of the agent
"""
self.tools_handler.last_used_tool = {}
if self.tools_handler:
self.tools_handler.last_used_tool = {}
task_prompt = task.prompt()
@@ -213,9 +220,10 @@ class Agent(BaseModel):
Args:
cache_handler: An instance of the CacheHandler class.
"""
self.cache_handler = cache_handler
self.tools_handler = ToolsHandler(cache=self.cache_handler)
self.create_agent_executor()
if self.cache:
self.cache_handler = cache_handler
self.tools_handler = ToolsHandler(cache=self.cache_handler)
self.create_agent_executor()
def set_rpm_controller(self, rpm_controller: RPMController) -> None:
"""Set the rpm controller for the agent.

View File

@@ -198,7 +198,7 @@ class CrewAgentExecutor(AgentExecutor):
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
tool_usage = ToolUsage(
tools_handler=self.tools_handler,
tools=self.tools,

View File

@@ -34,8 +34,9 @@ class Crew(BaseModel):
agents: List of agents part of this crew.
manager_llm: The language model that will run manager agent.
manager_callbacks: The callback handlers to be executed by the manager agent when hierarchical process is used
cache: Whether the crew should use a cache to store the results of the tools execution.
function_calling_llm: The language model that will run the tool calling for all the agents.
process: The process flow that the crew will follow (e.g., sequential).
process: The process flow that the crew will follow (e.g., sequential, hierarchical).
verbose: Indicates the verbosity level for logging during execution.
config: Configuration settings for the crew.
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
@@ -50,6 +51,7 @@ class Crew(BaseModel):
_rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
cache: bool = Field(default=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
tasks: List[Task] = Field(default_factory=list)
agents: List[Agent] = Field(default_factory=list)
@@ -150,7 +152,8 @@ class Crew(BaseModel):
if self.agents:
for agent in self.agents:
agent.set_cache_handler(self._cache_handler)
if self.cache:
agent.set_cache_handler(self._cache_handler)
if self.max_rpm:
agent.set_rpm_controller(self._rpm_controller)
return self

View File

@@ -109,9 +109,12 @@ class ToolUsage:
except Exception:
self.task.increment_tools_errors()
result = self.tools_handler.cache.read(
tool=calling.tool_name, input=calling.arguments
)
result = None
if self.tools_handler:
result = self.tools_handler.cache.read(
tool=calling.tool_name, input=calling.arguments
)
if not result:
try:
@@ -155,7 +158,8 @@ class ToolUsage:
self.task.increment_tools_errors()
return self.use(calling=calling, tool_string=tool_string)
self.tools_handler.on_tool_use(calling=calling, output=result)
if self.tools_handler:
self.tools_handler.on_tool_use(calling=calling, output=result)
self._printer.print(content=f"\n\n{result}\n", color="yellow")
self._telemetry.tool_usage(
@@ -185,6 +189,8 @@ class ToolUsage:
def _check_tool_repeated_usage(
self, calling: Union[ToolCalling, InstructorToolCalling]
) -> None:
if not self.tools_handler:
return False
if last_tool_usage := self.tools_handler.last_used_tool:
return (calling.tool_name == last_tool_usage.tool_name) and (
calling.arguments == last_tool_usage.arguments

View File

@@ -208,7 +208,7 @@ def test_cache_hitting():
with patch.object(CacheHandler, "read") as read:
read.return_value = "0"
task = Task(
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool.",
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
agent=agent,
expected_output="The result of the multiplication.",
)
@@ -219,6 +219,70 @@ def test_cache_hitting():
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_disabling_cache_for_agent():
@tool
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
return first_number * second_number
cache_handler = CacheHandler()
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
tools=[multiplier],
allow_delegation=False,
cache_handler=cache_handler,
cache=False,
verbose=True,
)
task1 = Task(
description="What is 2 times 6?",
agent=agent,
expected_output="The result of the multiplication.",
)
task2 = Task(
description="What is 3 times 3?",
agent=agent,
expected_output="The result of the multiplication.",
)
output = agent.execute_task(task1)
output = agent.execute_task(task2)
assert cache_handler._cache != {
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
}
task = Task(
description="What is 2 times 6 times 3? Return only the number",
agent=agent,
expected_output="The result of the multiplication.",
)
output = agent.execute_task(task)
assert output == "36"
assert cache_handler._cache != {
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
"multiplier-{'first_number': 12, 'second_number': 3}": 36,
}
with patch.object(CacheHandler, "read") as read:
read.return_value = "0"
task = Task(
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool.",
agent=agent,
expected_output="The result of the multiplication.",
)
output = agent.execute_task(task)
assert output == "12"
read.assert_not_called()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execution_with_specific_tools():
@tool

File diff suppressed because it is too large Load Diff