mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Lorenze/fix tool call twice (#3495)
* test: add test to ensure tool is called only once during crew execution - Introduced a new test case to validate that the counting_tool is executed exactly once during crew execution. - Created a CountingTool class to track execution counts and log call history. - Enhanced the test suite with a YAML cassette for consistent tool behavior verification. * ensure tool function called only once * refactor: simplify error handling in CrewStructuredTool - Removed unnecessary try-except block around the tool function call to streamline execution flow. - Ensured that the tool function is called directly, improving readability and maintainability. * linted * need to ignore for now as we cant infer the complex generic type within pydantic create_model_func * fix tests
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import hashlib
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Tuple, Union
|
||||
from hashlib import md5
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -248,7 +248,7 @@ def test_guardrail_type_error():
|
||||
return (True, x)
|
||||
|
||||
@staticmethod
|
||||
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]:
|
||||
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, str | TaskOutput]:
|
||||
return (True, x)
|
||||
|
||||
obj = Object()
|
||||
@@ -271,7 +271,7 @@ def test_guardrail_type_error():
|
||||
guardrail=Object.guardrail_static_fn,
|
||||
)
|
||||
|
||||
def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]:
|
||||
def error_fn(x: TaskOutput, y: bool) -> tuple[bool, TaskOutput]:
|
||||
return (y, x)
|
||||
|
||||
Task(
|
||||
@@ -340,7 +340,7 @@ def test_output_pydantic_hierarchical():
|
||||
)
|
||||
result = crew.kickoff()
|
||||
assert isinstance(result.pydantic, ScoreOutput)
|
||||
assert result.to_dict() == {"score": 5}
|
||||
assert result.to_dict() == {"score": 4}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -401,8 +401,8 @@ def test_output_json_hierarchical():
|
||||
manager_llm="gpt-4o",
|
||||
)
|
||||
result = crew.kickoff()
|
||||
assert result.json == '{"score": 5}'
|
||||
assert result.to_dict() == {"score": 5}
|
||||
assert result.json == '{"score": 4}'
|
||||
assert result.to_dict() == {"score": 4}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -560,8 +560,8 @@ def test_output_json_dict_hierarchical():
|
||||
manager_llm="gpt-4o",
|
||||
)
|
||||
result = crew.kickoff()
|
||||
assert {"score": 5} == result.json_dict
|
||||
assert result.to_dict() == {"score": 5}
|
||||
assert {"score": 4} == result.json_dict
|
||||
assert result.to_dict() == {"score": 4}
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -900,11 +900,11 @@ def test_conditional_task_copy_preserves_type():
|
||||
assert isinstance(copied_conditional_task, ConditionalTask)
|
||||
|
||||
|
||||
def test_interpolate_inputs():
|
||||
def test_interpolate_inputs(tmp_path):
|
||||
task = Task(
|
||||
description="Give me a list of 5 interesting ideas about {topic} to explore for an article, what makes them unique and interesting.",
|
||||
expected_output="Bullet point list of 5 interesting ideas about {topic}.",
|
||||
output_file="/tmp/{topic}/output_{date}.txt",
|
||||
output_file=str(tmp_path / "{topic}" / "output_{date}.txt"),
|
||||
)
|
||||
|
||||
task.interpolate_inputs_and_add_conversation_history(
|
||||
@@ -915,7 +915,7 @@ def test_interpolate_inputs():
|
||||
== "Give me a list of 5 interesting ideas about AI to explore for an article, what makes them unique and interesting."
|
||||
)
|
||||
assert task.expected_output == "Bullet point list of 5 interesting ideas about AI."
|
||||
assert task.output_file == "/tmp/AI/output_2025.txt"
|
||||
assert task.output_file == str(tmp_path / "AI" / "output_2025.txt")
|
||||
|
||||
task.interpolate_inputs_and_add_conversation_history(
|
||||
inputs={"topic": "ML", "date": "2025"}
|
||||
@@ -925,7 +925,7 @@ def test_interpolate_inputs():
|
||||
== "Give me a list of 5 interesting ideas about ML to explore for an article, what makes them unique and interesting."
|
||||
)
|
||||
assert task.expected_output == "Bullet point list of 5 interesting ideas about ML."
|
||||
assert task.output_file == "/tmp/ML/output_2025.txt"
|
||||
assert task.output_file == str(tmp_path / "ML" / "output_2025.txt")
|
||||
|
||||
|
||||
def test_interpolate_only():
|
||||
@@ -1074,8 +1074,9 @@ def test_key():
|
||||
description=original_description,
|
||||
expected_output=original_expected_output,
|
||||
)
|
||||
hash = hashlib.md5(
|
||||
f"{original_description}|{original_expected_output}".encode()
|
||||
hash = md5(
|
||||
f"{original_description}|{original_expected_output}".encode(),
|
||||
usedforsecurity=False,
|
||||
).hexdigest()
|
||||
|
||||
assert task.key == hash, "The key should be the hash of the description."
|
||||
@@ -1086,7 +1087,7 @@ def test_key():
|
||||
)
|
||||
|
||||
|
||||
def test_output_file_validation():
|
||||
def test_output_file_validation(tmp_path):
|
||||
"""Test output file path validation."""
|
||||
# Valid paths
|
||||
assert (
|
||||
@@ -1097,13 +1098,15 @@ def test_output_file_validation():
|
||||
).output_file
|
||||
== "output.txt"
|
||||
)
|
||||
# Use secure temporary path instead of /tmp
|
||||
temp_file = tmp_path / "output.txt"
|
||||
assert (
|
||||
Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
output_file="/tmp/output.txt",
|
||||
output_file=str(temp_file),
|
||||
).output_file
|
||||
== "tmp/output.txt"
|
||||
== str(temp_file).lstrip("/") # Remove leading slash to match expected behavior
|
||||
)
|
||||
assert (
|
||||
Task(
|
||||
@@ -1320,7 +1323,7 @@ def test_interpolate_with_list_of_dicts():
|
||||
}
|
||||
result = interpolate_only("{people}", input_data)
|
||||
|
||||
parsed_result = eval(result)
|
||||
parsed_result = ast.literal_eval(result)
|
||||
assert isinstance(parsed_result, list)
|
||||
assert len(parsed_result) == 2
|
||||
assert parsed_result[0]["name"] == "Alice"
|
||||
@@ -1346,7 +1349,7 @@ def test_interpolate_with_nested_structures():
|
||||
}
|
||||
}
|
||||
result = interpolate_only("{company}", input_data)
|
||||
parsed = eval(result)
|
||||
parsed = ast.literal_eval(result)
|
||||
|
||||
assert parsed["name"] == "TechCorp"
|
||||
assert len(parsed["departments"]) == 2
|
||||
@@ -1364,7 +1367,7 @@ def test_interpolate_with_special_characters():
|
||||
}
|
||||
}
|
||||
result = interpolate_only("{special_data}", input_data)
|
||||
parsed = eval(result)
|
||||
parsed = ast.literal_eval(result)
|
||||
|
||||
assert parsed["quotes"] == """This has "double" and 'single' quotes"""
|
||||
assert parsed["unicode"] == "文字化けテスト"
|
||||
@@ -1386,7 +1389,7 @@ def test_interpolate_mixed_types():
|
||||
}
|
||||
}
|
||||
result = interpolate_only("{data}", input_data)
|
||||
parsed = eval(result)
|
||||
parsed = ast.literal_eval(result)
|
||||
|
||||
assert parsed["name"] == "Test Dataset"
|
||||
assert parsed["samples"] == 1000
|
||||
@@ -1409,7 +1412,7 @@ def test_interpolate_complex_combination():
|
||||
]
|
||||
}
|
||||
result = interpolate_only("{report}", input_data)
|
||||
parsed = eval(result)
|
||||
parsed = ast.literal_eval(result)
|
||||
|
||||
assert len(parsed) == 2
|
||||
assert parsed[0]["month"] == "January"
|
||||
@@ -1482,7 +1485,7 @@ def test_interpolate_valid_complex_types():
|
||||
|
||||
# Should not raise any errors
|
||||
result = interpolate_only("{data}", {"data": valid_data})
|
||||
parsed = eval(result)
|
||||
parsed = ast.literal_eval(result)
|
||||
assert parsed["name"] == "Valid Dataset"
|
||||
assert parsed["stats"]["nested"]["deeper"]["b"] == 2.5
|
||||
|
||||
@@ -1512,7 +1515,7 @@ def test_interpolate_valid_types():
|
||||
}
|
||||
|
||||
result = interpolate_only("{data}", {"data": valid_data})
|
||||
parsed = eval(result)
|
||||
parsed = ast.literal_eval(result)
|
||||
|
||||
assert parsed["active"] is True
|
||||
assert parsed["deleted"] is False
|
||||
|
||||
Reference in New Issue
Block a user