mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Ability to disable cache at agent and crew level
This commit is contained in:
@@ -66,6 +66,10 @@ class Agent(BaseModel):
|
|||||||
role: str = Field(description="Role of the agent")
|
role: str = Field(description="Role of the agent")
|
||||||
goal: str = Field(description="Objective of the agent")
|
goal: str = Field(description="Objective of the agent")
|
||||||
backstory: str = Field(description="Backstory of the agent")
|
backstory: str = Field(description="Backstory of the agent")
|
||||||
|
cache: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether the agent should use a cache for tool usage.",
|
||||||
|
)
|
||||||
config: Optional[Dict[str, Any]] = Field(
|
config: Optional[Dict[str, Any]] = Field(
|
||||||
description="Configuration for the agent",
|
description="Configuration for the agent",
|
||||||
default=None,
|
default=None,
|
||||||
@@ -96,7 +100,7 @@ class Agent(BaseModel):
|
|||||||
default=None, description="An instance of the ToolsHandler class."
|
default=None, description="An instance of the ToolsHandler class."
|
||||||
)
|
)
|
||||||
cache_handler: InstanceOf[CacheHandler] = Field(
|
cache_handler: InstanceOf[CacheHandler] = Field(
|
||||||
default=CacheHandler(), description="An instance of the CacheHandler class."
|
default=None, description="An instance of the CacheHandler class."
|
||||||
)
|
)
|
||||||
step_callback: Optional[Any] = Field(
|
step_callback: Optional[Any] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -158,6 +162,8 @@ class Agent(BaseModel):
|
|||||||
TokenCalcHandler(self.llm.model_name, self._token_process)
|
TokenCalcHandler(self.llm.model_name, self._token_process)
|
||||||
]
|
]
|
||||||
if not self.agent_executor:
|
if not self.agent_executor:
|
||||||
|
if not self.cache_handler:
|
||||||
|
self.cache_handler = CacheHandler()
|
||||||
self.set_cache_handler(self.cache_handler)
|
self.set_cache_handler(self.cache_handler)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -177,6 +183,7 @@ class Agent(BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
Output of the agent
|
Output of the agent
|
||||||
"""
|
"""
|
||||||
|
if self.tools_handler:
|
||||||
self.tools_handler.last_used_tool = {}
|
self.tools_handler.last_used_tool = {}
|
||||||
|
|
||||||
task_prompt = task.prompt()
|
task_prompt = task.prompt()
|
||||||
@@ -213,6 +220,7 @@ class Agent(BaseModel):
|
|||||||
Args:
|
Args:
|
||||||
cache_handler: An instance of the CacheHandler class.
|
cache_handler: An instance of the CacheHandler class.
|
||||||
"""
|
"""
|
||||||
|
if self.cache:
|
||||||
self.cache_handler = cache_handler
|
self.cache_handler = cache_handler
|
||||||
self.tools_handler = ToolsHandler(cache=self.cache_handler)
|
self.tools_handler = ToolsHandler(cache=self.cache_handler)
|
||||||
self.create_agent_executor()
|
self.create_agent_executor()
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ class CrewAgentExecutor(AgentExecutor):
|
|||||||
for agent_action in actions:
|
for agent_action in actions:
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_agent_action(agent_action, color="green")
|
run_manager.on_agent_action(agent_action, color="green")
|
||||||
# Otherwise we lookup the tool
|
|
||||||
tool_usage = ToolUsage(
|
tool_usage = ToolUsage(
|
||||||
tools_handler=self.tools_handler,
|
tools_handler=self.tools_handler,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
|
|||||||
@@ -34,8 +34,9 @@ class Crew(BaseModel):
|
|||||||
agents: List of agents part of this crew.
|
agents: List of agents part of this crew.
|
||||||
manager_llm: The language model that will run manager agent.
|
manager_llm: The language model that will run manager agent.
|
||||||
manager_callbacks: The callback handlers to be executed by the manager agent when hierarchical process is used
|
manager_callbacks: The callback handlers to be executed by the manager agent when hierarchical process is used
|
||||||
|
cache: Whether the crew should use a cache to store the results of the tools execution.
|
||||||
function_calling_llm: The language model that will run the tool calling for all the agents.
|
function_calling_llm: The language model that will run the tool calling for all the agents.
|
||||||
process: The process flow that the crew will follow (e.g., sequential).
|
process: The process flow that the crew will follow (e.g., sequential, hierarchical).
|
||||||
verbose: Indicates the verbosity level for logging during execution.
|
verbose: Indicates the verbosity level for logging during execution.
|
||||||
config: Configuration settings for the crew.
|
config: Configuration settings for the crew.
|
||||||
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
|
max_rpm: Maximum number of requests per minute for the crew execution to be respected.
|
||||||
@@ -50,6 +51,7 @@ class Crew(BaseModel):
|
|||||||
_rpm_controller: RPMController = PrivateAttr()
|
_rpm_controller: RPMController = PrivateAttr()
|
||||||
_logger: Logger = PrivateAttr()
|
_logger: Logger = PrivateAttr()
|
||||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
||||||
|
cache: bool = Field(default=True)
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
tasks: List[Task] = Field(default_factory=list)
|
tasks: List[Task] = Field(default_factory=list)
|
||||||
agents: List[Agent] = Field(default_factory=list)
|
agents: List[Agent] = Field(default_factory=list)
|
||||||
@@ -150,6 +152,7 @@ class Crew(BaseModel):
|
|||||||
|
|
||||||
if self.agents:
|
if self.agents:
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
|
if self.cache:
|
||||||
agent.set_cache_handler(self._cache_handler)
|
agent.set_cache_handler(self._cache_handler)
|
||||||
if self.max_rpm:
|
if self.max_rpm:
|
||||||
agent.set_rpm_controller(self._rpm_controller)
|
agent.set_rpm_controller(self._rpm_controller)
|
||||||
|
|||||||
@@ -109,6 +109,9 @@ class ToolUsage:
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.task.increment_tools_errors()
|
self.task.increment_tools_errors()
|
||||||
|
|
||||||
|
result = None
|
||||||
|
|
||||||
|
if self.tools_handler:
|
||||||
result = self.tools_handler.cache.read(
|
result = self.tools_handler.cache.read(
|
||||||
tool=calling.tool_name, input=calling.arguments
|
tool=calling.tool_name, input=calling.arguments
|
||||||
)
|
)
|
||||||
@@ -155,6 +158,7 @@ class ToolUsage:
|
|||||||
self.task.increment_tools_errors()
|
self.task.increment_tools_errors()
|
||||||
return self.use(calling=calling, tool_string=tool_string)
|
return self.use(calling=calling, tool_string=tool_string)
|
||||||
|
|
||||||
|
if self.tools_handler:
|
||||||
self.tools_handler.on_tool_use(calling=calling, output=result)
|
self.tools_handler.on_tool_use(calling=calling, output=result)
|
||||||
|
|
||||||
self._printer.print(content=f"\n\n{result}\n", color="yellow")
|
self._printer.print(content=f"\n\n{result}\n", color="yellow")
|
||||||
@@ -185,6 +189,8 @@ class ToolUsage:
|
|||||||
def _check_tool_repeated_usage(
|
def _check_tool_repeated_usage(
|
||||||
self, calling: Union[ToolCalling, InstructorToolCalling]
|
self, calling: Union[ToolCalling, InstructorToolCalling]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not self.tools_handler:
|
||||||
|
return False
|
||||||
if last_tool_usage := self.tools_handler.last_used_tool:
|
if last_tool_usage := self.tools_handler.last_used_tool:
|
||||||
return (calling.tool_name == last_tool_usage.tool_name) and (
|
return (calling.tool_name == last_tool_usage.tool_name) and (
|
||||||
calling.arguments == last_tool_usage.arguments
|
calling.arguments == last_tool_usage.arguments
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ def test_cache_hitting():
|
|||||||
with patch.object(CacheHandler, "read") as read:
|
with patch.object(CacheHandler, "read") as read:
|
||||||
read.return_value = "0"
|
read.return_value = "0"
|
||||||
task = Task(
|
task = Task(
|
||||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool.",
|
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||||
agent=agent,
|
agent=agent,
|
||||||
expected_output="The result of the multiplication.",
|
expected_output="The result of the multiplication.",
|
||||||
)
|
)
|
||||||
@@ -219,6 +219,70 @@ def test_cache_hitting():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_disabling_cache_for_agent():
|
||||||
|
@tool
|
||||||
|
def multiplier(first_number: int, second_number: int) -> float:
|
||||||
|
"""Useful for when you need to multiply two numbers together."""
|
||||||
|
return first_number * second_number
|
||||||
|
|
||||||
|
cache_handler = CacheHandler()
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
tools=[multiplier],
|
||||||
|
allow_delegation=False,
|
||||||
|
cache_handler=cache_handler,
|
||||||
|
cache=False,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
task1 = Task(
|
||||||
|
description="What is 2 times 6?",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="The result of the multiplication.",
|
||||||
|
)
|
||||||
|
task2 = Task(
|
||||||
|
description="What is 3 times 3?",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="The result of the multiplication.",
|
||||||
|
)
|
||||||
|
|
||||||
|
output = agent.execute_task(task1)
|
||||||
|
output = agent.execute_task(task2)
|
||||||
|
assert cache_handler._cache != {
|
||||||
|
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
|
||||||
|
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
|
||||||
|
}
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="What is 2 times 6 times 3? Return only the number",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="The result of the multiplication.",
|
||||||
|
)
|
||||||
|
output = agent.execute_task(task)
|
||||||
|
assert output == "36"
|
||||||
|
|
||||||
|
assert cache_handler._cache != {
|
||||||
|
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
|
||||||
|
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
|
||||||
|
"multiplier-{'first_number': 12, 'second_number': 3}": 36,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(CacheHandler, "read") as read:
|
||||||
|
read.return_value = "0"
|
||||||
|
task = Task(
|
||||||
|
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool.",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="The result of the multiplication.",
|
||||||
|
)
|
||||||
|
output = agent.execute_task(task)
|
||||||
|
assert output == "12"
|
||||||
|
read.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_agent_execution_with_specific_tools():
|
def test_agent_execution_with_specific_tools():
|
||||||
@tool
|
@tool
|
||||||
|
|||||||
1886
tests/cassettes/test_disabling_cache_for_agent.yaml
Normal file
1886
tests/cassettes/test_disabling_cache_for_agent.yaml
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user