Adding custom caching

This commit is contained in:
João Moura
2024-03-22 23:26:53 -03:00
parent 2c359ca926
commit d73dd08ef4
10 changed files with 3273 additions and 439 deletions

View File

@@ -733,3 +733,78 @@ def test_crew_inputs_interpolate_both_agents_and_tasks():
crew.kickoff(inputs={"topic": "AI", "points": 5})
interpolate_agent_inputs.assert_called()
interpolate_task_inputs.assert_called()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_tools_with_custom_caching():
from unittest.mock import patch
from crewai_tools import tool
@tool
def multiplcation_tool(first_number: int, second_number: int) -> str:
"""Useful for when you need to multiply two numbers together."""
return first_number * second_number
def cache_func(args, result):
cache = result % 2 == 0
print(f"cache?: {cache}")
return cache
multiplcation_tool.cache_function = cache_func
writer1 = Agent(
role="Writer",
goal="You write lesssons of math for kids.",
backstory="You're an expert in writting and you love to teach kids but you know nothing of math.",
tools=[multiplcation_tool],
allow_delegation=False,
)
writer2 = Agent(
role="Writer",
goal="You write lesssons of math for kids.",
backstory="You're an expert in writting and you love to teach kids but you know nothing of math.",
tools=[multiplcation_tool],
allow_delegation=False,
)
task1 = Task(
description="What is 2 times 6? Return only the number after using the multiplication tool.",
expected_output="the result of multiplication",
agent=writer1,
)
task2 = Task(
description="What is 3 times 1? Return only the number after using the multiplication tool.",
expected_output="the result of multiplication",
agent=writer1,
)
task3 = Task(
description="What is 2 times 6? Return only the number after using the multiplication tool.",
expected_output="the result of multiplication",
agent=writer2,
)
task4 = Task(
description="What is 3 times 1? Return only the number after using the multiplication tool.",
expected_output="the result of multiplication",
agent=writer2,
)
crew = Crew(agents=[writer1, writer2], tasks=[task1, task2, task3, task4])
with patch.object(
CacheHandler, "add", wraps=crew._cache_handler.add
) as add_to_cache:
with patch.object(
CacheHandler, "read", wraps=crew._cache_handler.read
) as read_from_cache:
result = crew.kickoff()
add_to_cache.assert_called_once_with(
tool="multiplcation_tool",
input={"first_number": 2, "second_number": 6},
output=12,
)
assert result == "3"