From a853935cc81ebcef31b7e8f1b85f222f0548ea0d Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 29 Jan 2025 10:08:42 -0500 Subject: [PATCH] more tests --- src/crewai/task.py | 48 ++++---- tests/task_test.py | 271 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 291 insertions(+), 28 deletions(-) diff --git a/src/crewai/task.py b/src/crewai/task.py index 234133aec..d3ffb8a0a 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -431,9 +431,7 @@ class Task(BaseModel): content = ( json_output if json_output - else pydantic_output.model_dump_json() - if pydantic_output - else result + else pydantic_output.model_dump_json() if pydantic_output else result ) self._save_file(content) @@ -528,7 +526,7 @@ class Task(BaseModel): def interpolate_only( self, input_string: Optional[str], - inputs: Dict[str, Union[str, int, float, dict, list]], + inputs: Dict[str, Any], ) -> str: """Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched. @@ -536,17 +534,37 @@ class Task(BaseModel): input_string: The string containing template variables to interpolate. Can be None or empty, in which case an empty string is returned. inputs: Dictionary mapping template variables to their values. - Supported value types are strings, integers, floats, dicts, and lists. - If input_string is empty or has no placeholders, inputs can be empty. + Supported value types are strings, integers, floats, and dicts/lists + containing only these types and other nested dicts/lists. Returns: The interpolated string with all template variables replaced with their values. Empty string if input_string is None or empty. Raises: - ValueError: If a required template variable is missing from inputs. - KeyError: If a template variable is not found in the inputs dictionary. + ValueError: If a value contains unsupported types """ + + # Validation function for recursive type checking + def validate_type(value: Any) -> None: + if isinstance(value, (str, int, float)): + return + if isinstance(value, (dict, list)): + for item in value.values() if isinstance(value, dict) else value: + validate_type(item) + return + raise ValueError( + f"Unsupported type {type(value).__name__} in inputs. " + "Only str, int, float, dict, and list are allowed." + ) + + # Validate all input values + for key, value in inputs.items(): + try: + validate_type(value) + except ValueError as e: + raise ValueError(f"Invalid value for key '{key}': {str(e)}") from e + if input_string is None or not input_string: return "" if "{" not in input_string and "}" not in input_string: @@ -556,20 +574,6 @@ class Task(BaseModel): "Inputs dictionary cannot be empty when interpolating variables" ) try: - # Validate input types - for key, value in inputs.items(): - if not isinstance(value, (str, int, float, dict, list)): - raise ValueError( - f"Value for key '{key}' must be a string, integer, float, dict, or list, got {type(value).__name__}" - ) - if isinstance(value, (dict, list)): - try: - inputs[key] = json.dumps(value, ensure_ascii=False) - except Exception as e: - raise ValueError( - f"Failed to serialize value for key: {key} with value: {value} due to error: {str(e)}" - ) from e - escaped_string = input_string.replace("{", "{{").replace("}", "}}") for key in inputs.keys(): diff --git a/tests/task_test.py b/tests/task_test.py index f31895e24..4e2c04ea3 100644 --- a/tests/task_test.py +++ b/tests/task_test.py @@ -441,9 +441,9 @@ def test_output_pydantic_to_another_task(): crew = Crew(agents=[scorer], tasks=[task1, task2], verbose=True) result = crew.kickoff() pydantic_result = result.pydantic - assert isinstance(pydantic_result, ScoreOutput), ( - "Expected pydantic result to be of type ScoreOutput" - ) + assert isinstance( + pydantic_result, ScoreOutput + ), "Expected pydantic result to be of type ScoreOutput" assert pydantic_result.score == 5 @@ -907,9 +907,9 @@ def test_key(): assert task.key == hash, "The key should be the hash of the description." task.interpolate_inputs_and_add_conversation_history(inputs={"topic": "AI"}) - assert task.key == hash, ( - "The key should be the hash of the non-interpolated description." - ) + assert ( + task.key == hash + ), "The key should be the hash of the non-interpolated description." def test_output_file_validation(): @@ -1003,3 +1003,262 @@ def test_task_execution_times(): assert task.start_time is not None assert task.end_time is not None assert task.execution_duration == (task.end_time - task.start_time).total_seconds() + + +def test_interpolate_with_list_of_strings(): + task = Task( + description="Test list interpolation", + expected_output="List: {items}", + ) + + # Test simple list of strings + input_str = "Available items: {items}" + inputs = {"items": ["apple", "banana", "cherry"]} + result = task.interpolate_only(input_str, inputs) + assert result == 'Available items: ["apple", "banana", "cherry"]' + + # Test empty list + empty_list_input = {"items": []} + result = task.interpolate_only(input_str, empty_list_input) + assert result == "Available items: []" + + +def test_interpolate_with_list_of_dicts(): + task = Task( + description="Test list of dicts interpolation", + expected_output="People: {people}", + ) + + input_data = { + "people": [ + {"name": "Alice", "age": 30, "skills": ["Python", "AI"]}, + {"name": "Bob", "age": 25, "skills": ["Java", "Cloud"]}, + ] + } + result = task.interpolate_only("{people}", input_data) + + assert '"name": "Alice"' in result + assert '"age": 30' in result + assert '"skills": ["Python", "AI"]' in result + assert isinstance(json.loads(result), list) + + +def test_interpolate_with_nested_structures(): + task = Task( + description="Test nested structures", + expected_output="Company: {company}", + ) + + input_data = { + "company": { + "name": "TechCorp", + "departments": [ + { + "name": "Engineering", + "employees": 50, + "tools": ["Git", "Docker", "Kubernetes"], + }, + {"name": "Sales", "employees": 20, "regions": {"north": 5, "south": 3}}, + ], + } + } + result = task.interpolate_only("{company}", input_data) + parsed = json.loads(result) + + assert parsed["name"] == "TechCorp" + assert len(parsed["departments"]) == 2 + assert parsed["departments"][0]["tools"] == ["Git", "Docker", "Kubernetes"] + assert parsed["departments"][1]["regions"]["north"] == 5 + + +def test_interpolate_with_special_characters(): + task = Task( + description="Test special characters in dicts", + expected_output="Data: {special_data}", + ) + + input_data = { + "special_data": { + "quotes": """This has "double" and 'single' quotes""", + "unicode": "文字化けテスト", + "symbols": "!@#$%^&*()", + "empty": "", + } + } + result = task.interpolate_only("{special_data}", input_data) + parsed = json.loads(result) + + assert parsed["quotes"] == """This has "double" and 'single' quotes""" + assert parsed["unicode"] == "文字化けテスト" + assert parsed["symbols"] == "!@#$%^&*()" + assert parsed["empty"] == "" + + +def test_interpolate_mixed_types(): + task = Task( + description="Test mixed type interpolation", + expected_output="Mixed: {data}", + ) + + input_data = { + "data": { + "name": "Test Dataset", + "samples": 1000, + "features": ["age", "income", "location"], + "metadata": { + "source": "public", + "validated": True, + "tags": ["demo", "test", "temp"], + }, + "null_value": None, + } + } + result = task.interpolate_only("{data}", input_data) + parsed = json.loads(result) + + assert parsed["name"] == "Test Dataset" + assert parsed["samples"] == 1000 + assert parsed["metadata"]["tags"] == ["demo", "test", "temp"] + assert "null_value" in parsed + + +def test_interpolate_complex_combination(): + task = Task( + description="Test complex combination", + expected_output="Report: {report}", + ) + + input_data = { + "report": [ + { + "month": "January", + "metrics": {"sales": 15000, "expenses": 8000, "profit": 7000}, + "top_products": ["Product A", "Product B"], + }, + { + "month": "February", + "metrics": {"sales": 18000, "expenses": 8500, "profit": 9500}, + "top_products": ["Product C", "Product D"], + }, + ] + } + result = task.interpolate_only("{report}", input_data) + parsed = json.loads(result) + + assert len(parsed) == 2 + assert parsed[0]["month"] == "January" + assert parsed[1]["metrics"]["profit"] == 9500 + assert "Product D" in parsed[1]["top_products"] + + +def test_interpolate_invalid_type_validation(): + task = Task( + description="Test invalid type validation", + expected_output="Should never reach here", + ) + + # Test with invalid top-level type + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": True}) + assert "Unsupported type bool" in str(excinfo.value) + assert "key 'data'" in str(excinfo.value) + + # Test with invalid nested type + invalid_nested = { + "profile": { + "name": "John", + "age": 30, + "tags": {"a", "b", "c"}, # Set is invalid + "preferences": [None, True], # None and bool are invalid + } + } + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": invalid_nested}) + assert "Unsupported type set" in str(excinfo.value) + assert "key 'tags'" in str(excinfo.value) + assert "Unsupported type NoneType" in str(excinfo.value) + + +def test_interpolate_custom_object_validation(): + task = Task( + description="Test custom object rejection", + expected_output="Should never reach here", + ) + + class CustomObject: + def __init__(self, value): + self.value = value + + # Test with custom object + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{obj}", {"obj": CustomObject(5)}) + assert "Unsupported type CustomObject" in str(excinfo.value) + + # Test with nested custom object + with pytest.raises(ValueError) as excinfo: + task.interpolate_only( + "{data}", {"data": [{"valid": 1, "invalid": CustomObject(5)}]} + ) + assert "Unsupported type CustomObject" in str(excinfo.value) + assert "key 'invalid'" in str(excinfo.value) + + +def test_interpolate_valid_complex_types(): + task = Task( + description="Test valid complex types", + expected_output="Validation should pass", + ) + + # Valid complex structure + valid_data = { + "name": "Valid Dataset", + "stats": { + "count": 1000, + "distribution": [0.2, 0.3, 0.5], + "features": ["age", "income"], + "nested": {"deep": [1, 2, 3], "deeper": {"a": 1, "b": 2.5}}, + }, + } + + # Should not raise any errors + result = task.interpolate_only("{data}", {"data": valid_data}) + parsed = json.loads(result) + assert parsed["name"] == "Valid Dataset" + assert parsed["stats"]["nested"]["deeper"]["b"] == 2.5 + + +def test_interpolate_edge_cases(): + task = Task( + description="Test edge cases", + expected_output="Edge case handling", + ) + + # Test empty dict and list + assert task.interpolate_only("{}", {"data": {}}) == "{}" + assert task.interpolate_only("[]", {"data": []}) == "[]" + + # Test numeric types + assert task.interpolate_only("{num}", {"num": 42}) == "42" + assert task.interpolate_only("{num}", {"num": 3.14}) == "3.14" + + # Test boolean rejection + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{flag}", {"flag": True}) + assert "Unsupported type bool" in str(excinfo.value) + + +def test_interpolate_null_handling(): + task = Task( + description="Test null handling", + expected_output="Null validation", + ) + + # Test null rejection + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": None}) + assert "Unsupported type NoneType" in str(excinfo.value) + + # Test null in nested structure + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": {"valid": 1, "invalid": None}}) + assert "Unsupported type NoneType" in str(excinfo.value)