Lorenze/fix tool call twice (#3495)

* test: add test to ensure tool is called only once during crew execution

- Introduced a new test case to validate that the counting_tool is executed exactly once during crew execution.
- Created a CountingTool class to track execution counts and log call history.
- Enhanced the test suite with a YAML cassette for consistent tool behavior verification.

* ensure tool function called only once

* refactor: simplify error handling in CrewStructuredTool

- Removed unnecessary try-except block around the tool function call to streamline execution flow.
- Ensured that the tool function is called directly, improving readability and maintainability.

* linted

* need to ignore for now as we cant infer the complex generic type within pydantic create_model_func

* fix tests
This commit is contained in:
Lorenze Jay
2025-09-10 15:20:21 -07:00
committed by GitHub
parent 01be26ce2a
commit 75b916c85a
5 changed files with 583 additions and 233 deletions

View File

@@ -1,11 +1,11 @@
"""Test Agent creation and execution basic functionality."""
import hashlib
import ast
import json
import os
import time
from functools import partial
from typing import Tuple, Union
from hashlib import md5
from unittest.mock import MagicMock, patch
import pytest
@@ -248,7 +248,7 @@ def test_guardrail_type_error():
return (True, x)
@staticmethod
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]:
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, str | TaskOutput]:
return (True, x)
obj = Object()
@@ -271,7 +271,7 @@ def test_guardrail_type_error():
guardrail=Object.guardrail_static_fn,
)
def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]:
def error_fn(x: TaskOutput, y: bool) -> tuple[bool, TaskOutput]:
return (y, x)
Task(
@@ -340,7 +340,7 @@ def test_output_pydantic_hierarchical():
)
result = crew.kickoff()
assert isinstance(result.pydantic, ScoreOutput)
assert result.to_dict() == {"score": 5}
assert result.to_dict() == {"score": 4}
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -401,8 +401,8 @@ def test_output_json_hierarchical():
manager_llm="gpt-4o",
)
result = crew.kickoff()
assert result.json == '{"score": 5}'
assert result.to_dict() == {"score": 5}
assert result.json == '{"score": 4}'
assert result.to_dict() == {"score": 4}
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -560,8 +560,8 @@ def test_output_json_dict_hierarchical():
manager_llm="gpt-4o",
)
result = crew.kickoff()
assert {"score": 5} == result.json_dict
assert result.to_dict() == {"score": 5}
assert {"score": 4} == result.json_dict
assert result.to_dict() == {"score": 4}
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -900,11 +900,11 @@ def test_conditional_task_copy_preserves_type():
assert isinstance(copied_conditional_task, ConditionalTask)
def test_interpolate_inputs():
def test_interpolate_inputs(tmp_path):
task = Task(
description="Give me a list of 5 interesting ideas about {topic} to explore for an article, what makes them unique and interesting.",
expected_output="Bullet point list of 5 interesting ideas about {topic}.",
output_file="/tmp/{topic}/output_{date}.txt",
output_file=str(tmp_path / "{topic}" / "output_{date}.txt"),
)
task.interpolate_inputs_and_add_conversation_history(
@@ -915,7 +915,7 @@ def test_interpolate_inputs():
== "Give me a list of 5 interesting ideas about AI to explore for an article, what makes them unique and interesting."
)
assert task.expected_output == "Bullet point list of 5 interesting ideas about AI."
assert task.output_file == "/tmp/AI/output_2025.txt"
assert task.output_file == str(tmp_path / "AI" / "output_2025.txt")
task.interpolate_inputs_and_add_conversation_history(
inputs={"topic": "ML", "date": "2025"}
@@ -925,7 +925,7 @@ def test_interpolate_inputs():
== "Give me a list of 5 interesting ideas about ML to explore for an article, what makes them unique and interesting."
)
assert task.expected_output == "Bullet point list of 5 interesting ideas about ML."
assert task.output_file == "/tmp/ML/output_2025.txt"
assert task.output_file == str(tmp_path / "ML" / "output_2025.txt")
def test_interpolate_only():
@@ -1074,8 +1074,9 @@ def test_key():
description=original_description,
expected_output=original_expected_output,
)
hash = hashlib.md5(
f"{original_description}|{original_expected_output}".encode()
hash = md5(
f"{original_description}|{original_expected_output}".encode(),
usedforsecurity=False,
).hexdigest()
assert task.key == hash, "The key should be the hash of the description."
@@ -1086,7 +1087,7 @@ def test_key():
)
def test_output_file_validation():
def test_output_file_validation(tmp_path):
"""Test output file path validation."""
# Valid paths
assert (
@@ -1097,13 +1098,15 @@ def test_output_file_validation():
).output_file
== "output.txt"
)
# Use secure temporary path instead of /tmp
temp_file = tmp_path / "output.txt"
assert (
Task(
description="Test task",
expected_output="Test output",
output_file="/tmp/output.txt",
output_file=str(temp_file),
).output_file
== "tmp/output.txt"
== str(temp_file).lstrip("/") # Remove leading slash to match expected behavior
)
assert (
Task(
@@ -1320,7 +1323,7 @@ def test_interpolate_with_list_of_dicts():
}
result = interpolate_only("{people}", input_data)
parsed_result = eval(result)
parsed_result = ast.literal_eval(result)
assert isinstance(parsed_result, list)
assert len(parsed_result) == 2
assert parsed_result[0]["name"] == "Alice"
@@ -1346,7 +1349,7 @@ def test_interpolate_with_nested_structures():
}
}
result = interpolate_only("{company}", input_data)
parsed = eval(result)
parsed = ast.literal_eval(result)
assert parsed["name"] == "TechCorp"
assert len(parsed["departments"]) == 2
@@ -1364,7 +1367,7 @@ def test_interpolate_with_special_characters():
}
}
result = interpolate_only("{special_data}", input_data)
parsed = eval(result)
parsed = ast.literal_eval(result)
assert parsed["quotes"] == """This has "double" and 'single' quotes"""
assert parsed["unicode"] == "文字化けテスト"
@@ -1386,7 +1389,7 @@ def test_interpolate_mixed_types():
}
}
result = interpolate_only("{data}", input_data)
parsed = eval(result)
parsed = ast.literal_eval(result)
assert parsed["name"] == "Test Dataset"
assert parsed["samples"] == 1000
@@ -1409,7 +1412,7 @@ def test_interpolate_complex_combination():
]
}
result = interpolate_only("{report}", input_data)
parsed = eval(result)
parsed = ast.literal_eval(result)
assert len(parsed) == 2
assert parsed[0]["month"] == "January"
@@ -1482,7 +1485,7 @@ def test_interpolate_valid_complex_types():
# Should not raise any errors
result = interpolate_only("{data}", {"data": valid_data})
parsed = eval(result)
parsed = ast.literal_eval(result)
assert parsed["name"] == "Valid Dataset"
assert parsed["stats"]["nested"]["deeper"]["b"] == 2.5
@@ -1512,7 +1515,7 @@ def test_interpolate_valid_types():
}
result = interpolate_only("{data}", {"data": valid_data})
parsed = eval(result)
parsed = ast.literal_eval(result)
assert parsed["active"] is True
assert parsed["deleted"] is False