mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Ability to disable cache at agent and crew level
This commit is contained in:
@@ -66,6 +66,10 @@ class Agent(BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Objective of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
cache: bool = Field(
|
||||
default=True,
|
||||
description="Whether the agent should use a cache for tool usage.",
|
||||
)
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
description="Configuration for the agent",
|
||||
default=None,
|
||||
@@ -96,7 +100,7 @@ class Agent(BaseModel):
|
||||
default=None, description="An instance of the ToolsHandler class."
|
||||
)
|
||||
cache_handler: InstanceOf[CacheHandler] = Field(
|
||||
default=CacheHandler(), description="An instance of the CacheHandler class."
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
)
|
||||
step_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
@@ -158,6 +162,8 @@ class Agent(BaseModel):
|
||||
TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||
]
|
||||
if not self.agent_executor:
|
||||
if not self.cache_handler:
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
return self
|
||||
|
||||
@@ -177,7 +183,8 @@ class Agent(BaseModel):
|
||||
Returns:
|
||||
Output of the agent
|
||||
"""
|
||||
self.tools_handler.last_used_tool = {}
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = {}
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
@@ -213,9 +220,10 @@ class Agent(BaseModel):
|
||||
Args:
|
||||
cache_handler: An instance of the CacheHandler class.
|
||||
"""
|
||||
self.cache_handler = cache_handler
|
||||
self.tools_handler = ToolsHandler(cache=self.cache_handler)
|
||||
self.create_agent_executor()
|
||||
if self.cache:
|
||||
self.cache_handler = cache_handler
|
||||
self.tools_handler = ToolsHandler(cache=self.cache_handler)
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_rpm_controller(self, rpm_controller: RPMController) -> None:
|
||||
"""Set the rpm controller for the agent.
|
||||
|
||||
@@ -198,7 +198,7 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
for agent_action in actions:
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(agent_action, color="green")
|
||||
# Otherwise we lookup the tool
|
||||
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=self.tools_handler,
|
||||
tools=self.tools,
|
||||
|
||||
@@ -34,8 +34,9 @@ class Crew(BaseModel):
|
||||
agents: List of agents part of this crew.
|
||||
manager_llm: The language model that will run manager agent.
|
||||
manager_callbacks: The callback handlers to be executed by the manager agent when hierarchical process is used
|
||||
cache: Whether the crew should use a cache to store the results of the tools execution.
|
||||
function_calling_llm: The language model that will run the tool calling for all the agents.
|
||||
process: The process flow that the crew will follow (e.g., sequential).
|
||||
process: The process flow that the crew will follow (e.g., sequential, hierarchical).
|
||||
verbose: Indicates the verbosity level for logging during execution.
|
||||
config: Configuration settings for the crew.
|
||||
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
|
||||
@@ -50,6 +51,7 @@ class Crew(BaseModel):
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
||||
cache: bool = Field(default=True)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
tasks: List[Task] = Field(default_factory=list)
|
||||
agents: List[Agent] = Field(default_factory=list)
|
||||
@@ -150,7 +152,8 @@ class Crew(BaseModel):
|
||||
|
||||
if self.agents:
|
||||
for agent in self.agents:
|
||||
agent.set_cache_handler(self._cache_handler)
|
||||
if self.cache:
|
||||
agent.set_cache_handler(self._cache_handler)
|
||||
if self.max_rpm:
|
||||
agent.set_rpm_controller(self._rpm_controller)
|
||||
return self
|
||||
|
||||
@@ -109,9 +109,12 @@ class ToolUsage:
|
||||
except Exception:
|
||||
self.task.increment_tools_errors()
|
||||
|
||||
result = self.tools_handler.cache.read(
|
||||
tool=calling.tool_name, input=calling.arguments
|
||||
)
|
||||
result = None
|
||||
|
||||
if self.tools_handler:
|
||||
result = self.tools_handler.cache.read(
|
||||
tool=calling.tool_name, input=calling.arguments
|
||||
)
|
||||
|
||||
if not result:
|
||||
try:
|
||||
@@ -155,7 +158,8 @@ class ToolUsage:
|
||||
self.task.increment_tools_errors()
|
||||
return self.use(calling=calling, tool_string=tool_string)
|
||||
|
||||
self.tools_handler.on_tool_use(calling=calling, output=result)
|
||||
if self.tools_handler:
|
||||
self.tools_handler.on_tool_use(calling=calling, output=result)
|
||||
|
||||
self._printer.print(content=f"\n\n{result}\n", color="yellow")
|
||||
self._telemetry.tool_usage(
|
||||
@@ -185,6 +189,8 @@ class ToolUsage:
|
||||
def _check_tool_repeated_usage(
|
||||
self, calling: Union[ToolCalling, InstructorToolCalling]
|
||||
) -> None:
|
||||
if not self.tools_handler:
|
||||
return False
|
||||
if last_tool_usage := self.tools_handler.last_used_tool:
|
||||
return (calling.tool_name == last_tool_usage.tool_name) and (
|
||||
calling.arguments == last_tool_usage.arguments
|
||||
|
||||
@@ -208,7 +208,7 @@ def test_cache_hitting():
|
||||
with patch.object(CacheHandler, "read") as read:
|
||||
read.return_value = "0"
|
||||
task = Task(
|
||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool.",
|
||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||
agent=agent,
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
@@ -219,6 +219,70 @@ def test_cache_hitting():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_disabling_cache_for_agent():
|
||||
@tool
|
||||
def multiplier(first_number: int, second_number: int) -> float:
|
||||
"""Useful for when you need to multiply two numbers together."""
|
||||
return first_number * second_number
|
||||
|
||||
cache_handler = CacheHandler()
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[multiplier],
|
||||
allow_delegation=False,
|
||||
cache_handler=cache_handler,
|
||||
cache=False,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="What is 2 times 6?",
|
||||
agent=agent,
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
task2 = Task(
|
||||
description="What is 3 times 3?",
|
||||
agent=agent,
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
|
||||
output = agent.execute_task(task1)
|
||||
output = agent.execute_task(task2)
|
||||
assert cache_handler._cache != {
|
||||
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
|
||||
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
|
||||
}
|
||||
|
||||
task = Task(
|
||||
description="What is 2 times 6 times 3? Return only the number",
|
||||
agent=agent,
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
output = agent.execute_task(task)
|
||||
assert output == "36"
|
||||
|
||||
assert cache_handler._cache != {
|
||||
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
|
||||
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
|
||||
"multiplier-{'first_number': 12, 'second_number': 3}": 36,
|
||||
}
|
||||
|
||||
with patch.object(CacheHandler, "read") as read:
|
||||
read.return_value = "0"
|
||||
task = Task(
|
||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool.",
|
||||
agent=agent,
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
output = agent.execute_task(task)
|
||||
assert output == "12"
|
||||
read.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execution_with_specific_tools():
|
||||
@tool
|
||||
|
||||
1886
tests/cassettes/test_disabling_cache_for_agent.yaml
Normal file
1886
tests/cassettes/test_disabling_cache_for_agent.yaml
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user