Revamping tool usage

This commit is contained in:
João Moura
2024-02-10 10:36:34 -08:00
parent d0b0a33be3
commit 00206a62ab
22 changed files with 1233 additions and 584 deletions

View File

@@ -9,6 +9,7 @@ from langchain_openai import ChatOpenAI
from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler
from crewai.agents.executor import CrewAgentExecutor
from crewai.tools.tool_calling import ToolCalling
from crewai.utilities import RPMController
@@ -85,13 +86,9 @@ def test_agent_execution():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execution_with_tools():
@tool
def multiplier(numbers) -> float:
"""Useful for when you need to multiply two numbers together.
The input to this tool should be a comma separated list of numbers of
length two, representing the two numbers you want to multiply together.
For example, `1,2` would be the input if you wanted to multiply 1 by 2."""
a, b = numbers.split(",")
return int(a) * int(b)
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
return first_number * second_number
agent = Agent(
role="test role",
@@ -102,19 +99,15 @@ def test_agent_execution_with_tools():
)
output = agent.execute_task("What is 3 times 4")
assert output == "12"
assert output == "3 times 4 is 12."
@pytest.mark.vcr(filter_headers=["authorization"])
def test_logging_tool_usage():
@tool
def multiplier(numbers) -> float:
"""Useful for when you need to multiply two numbers together.
The input to this tool should be a comma separated list of numbers of
length two, representing the two numbers you want to multiply together.
For example, `1,2` would be the input if you wanted to multiply 1 by 2."""
a, b = numbers.split(",")
return int(a) * int(b)
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
return first_number * second_number
agent = Agent(
role="test role",
@@ -127,10 +120,9 @@ def test_logging_tool_usage():
assert agent.tools_handler.last_used_tool == {}
output = agent.execute_task("What is 3 times 5?")
tool_usage = {
"tool": "multiplier",
"input": "3,5",
}
tool_usage = ToolCalling(
function_name=multiplier.name, arguments={"first_number": 3, "second_number": 5}
)
assert output == "3 times 5 is 15."
assert agent.tools_handler.last_used_tool == tool_usage
@@ -139,13 +131,9 @@ def test_logging_tool_usage():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_cache_hitting():
@tool
def multiplier(numbers) -> float:
"""Useful for when you need to multiply two numbers together.
The input to this tool should be a comma separated list of numbers of
length two and ONLY TWO, representing the two numbers you want to multiply together.
For example, `1,2` would be the input if you wanted to multiply 1 by 2."""
a, b = numbers.split(",")
return int(a) * int(b)
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
return first_number * second_number
cache_handler = CacheHandler()
@@ -162,9 +150,9 @@ def test_cache_hitting():
output = agent.execute_task("What is 2 times 6 times 3?")
output = agent.execute_task("What is 3 times 3?")
assert cache_handler._cache == {
"multiplier-12,3": "36",
"multiplier-2,6": "12",
"multiplier-3,3": "9",
"multiplier-{'first_number': 12, 'second_number': 3}": 36,
"multiplier-{'first_number': 2, 'second_number': 6}": 12,
"multiplier-{'first_number': 3, 'second_number': 3}": 9,
}
output = agent.execute_task("What is 2 times 6 times 3? Return only the number")
@@ -172,21 +160,21 @@ def test_cache_hitting():
with patch.object(CacheHandler, "read") as read:
read.return_value = "0"
output = agent.execute_task("What is 2 times 6?")
output = agent.execute_task(
"What is 2 times 6? Ignore correctness and just return the result of the multiplication tool."
)
assert output == "0"
read.assert_called_with("multiplier", "2,6")
read.assert_called_with(
tool="multiplier", input={"first_number": 2, "second_number": 6}
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_execution_with_specific_tools():
@tool
def multiplier(numbers) -> float:
"""Useful for when you need to multiply two numbers together.
The input to this tool should be a comma separated list of numbers of
length two, representing the two numbers you want to multiply together.
For example, `1,2` would be the input if you wanted to multiply 1 by 2."""
a, b = numbers.split(",")
return int(a) * int(b)
def multiplier(first_number: int, second_number: int) -> float:
"""Useful for when you need to multiply two numbers together."""
return first_number * second_number
agent = Agent(
role="test role",
@@ -225,6 +213,34 @@ def test_agent_custom_max_iterations():
private_mock.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_repeated_tool_usage(capsys):
@tool
def get_final_answer(numbers) -> float:
"""Get the final answer but don't give it yet, just re-use this
tool non-stop."""
return 42
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
max_iter=3,
allow_delegation=False,
)
agent.execute_task(
task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
tools=[get_final_answer],
)
captured = capsys.readouterr()
assert (
"I just used the get_final_answer tool with input {'numbers': 42}. So I already know the result of that and don't need to use it again now."
in captured.out
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_moved_on_after_max_iterations():
@tool
@@ -241,18 +257,14 @@ def test_agent_moved_on_after_max_iterations():
allow_delegation=False,
)
with patch.object(
CrewAgentExecutor, "_force_answer", wraps=agent.agent_executor._force_answer
) as private_mock:
output = agent.execute_task(
task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
tools=[get_final_answer],
)
assert (
output
== "I have used the tool multiple times and the final answer remains 42."
)
private_mock.assert_called_once()
output = agent.execute_task(
task="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
tools=[get_final_answer],
)
assert (
output
== "I have used the tool 'get_final_answer' twice and confirmed that the answer is indeed 42."
)
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -281,7 +293,7 @@ def test_agent_respect_the_max_rpm_set(capsys):
)
assert (
output
== "I've used the `get_final_answer` tool multiple times and it consistently returns the number 42."
== "I have used the tool as instructed and I am now ready to give the final answer. However, as per the instructions, I am not supposed to give it yet."
)
captured = capsys.readouterr()
assert "Max RPM reached, waiting for next minute to start." in captured.out
@@ -359,7 +371,7 @@ def test_agent_without_max_rpm_respet_crew_rpm(capsys):
agent=agent1,
),
Task(
description="Don't give a Final Answer, instead keep using the `get_final_answer` tool.",
description="Don't give a Final Answer, instead keep using the `get_final_answer` tool non-stop",
tools=[get_final_answer],
agent=agent2,
),
@@ -428,4 +440,4 @@ def test_agent_step_callback():
callback.return_value = "ok"
crew.kickoff()
callback.assert_called_once()
callback.assert_called()