fixes interpolation issues when inputs are type dict,list specifically when defined on expected_output

This commit is contained in:
Lorenze Jay
2025-01-28 09:56:33 -08:00
parent bcb7fb27d0
commit 99d9151196
2 changed files with 55 additions and 12 deletions

View File

@@ -431,7 +431,9 @@ class Task(BaseModel):
content = (
json_output
if json_output
else pydantic_output.model_dump_json() if pydantic_output else result
else pydantic_output.model_dump_json()
if pydantic_output
else result
)
self._save_file(content)
@@ -452,7 +454,7 @@ class Task(BaseModel):
return "\n".join(tasks_slices)
def interpolate_inputs_and_add_conversation_history(
self, inputs: Dict[str, Union[str, int, float]]
self, inputs: Dict[str, Union[str, int, float, dict, list]]
) -> None:
"""Interpolate inputs into the task description, expected output, and output file path.
Add conversation history if present.
@@ -524,7 +526,9 @@ class Task(BaseModel):
)
def interpolate_only(
self, input_string: Optional[str], inputs: Dict[str, Union[str, int, float]]
self,
input_string: Optional[str],
inputs: Dict[str, Union[str, int, float, dict, list]],
) -> str:
"""Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched.
@@ -551,14 +555,16 @@ class Task(BaseModel):
raise ValueError(
"Inputs dictionary cannot be empty when interpolating variables"
)
try:
# Validate input types
for key, value in inputs.items():
if not isinstance(value, (str, int, float)):
print("key", key, value)
if not isinstance(value, (str, int, float, dict, list)):
raise ValueError(
f"Value for key '{key}' must be a string, integer, or float, got {type(value).__name__}"
f"Value for key '{key}' must be a string, integer, float, dict, or list, got {type(value).__name__}"
)
if isinstance(value, (dict, list)):
inputs[key] = json.dumps(value, ensure_ascii=False)
escaped_string = input_string.replace("{", "{{").replace("}", "}}")

View File

@@ -441,9 +441,9 @@ def test_output_pydantic_to_another_task():
crew = Crew(agents=[scorer], tasks=[task1, task2], verbose=True)
result = crew.kickoff()
pydantic_result = result.pydantic
assert isinstance(
pydantic_result, ScoreOutput
), "Expected pydantic result to be of type ScoreOutput"
assert isinstance(pydantic_result, ScoreOutput), (
"Expected pydantic result to be of type ScoreOutput"
)
assert pydantic_result.score == 5
@@ -779,6 +779,43 @@ def test_interpolate_only():
assert result == no_placeholders
def test_interpolate_only_with_dict_inside_expected_output():
"""Test the interpolate_only method for various scenarios including JSON structure preservation."""
task = Task(
description="Unused in this test",
expected_output="Unused in this test: {questions}",
)
json_string = '{"questions": {"main_question": "What is the user\'s name?", "secondary_question": "What is the user\'s age?"}}'
result = task.interpolate_only(
input_string=json_string,
inputs={
"questions": {
"main_question": "What is the user's name?",
"secondary_question": "What is the user's age?",
}
},
)
assert '"main_question": "What is the user\'s name?"' in result
assert '"secondary_question": "What is the user\'s age?"' in result
assert result == json_string
normal_string = "Hello {name}, welcome to {place}!"
result = task.interpolate_only(
input_string=normal_string, inputs={"name": "John", "place": "CrewAI"}
)
assert result == "Hello John, welcome to CrewAI!"
result = task.interpolate_only(input_string="", inputs={"unused": "value"})
assert result == ""
no_placeholders = "Hello, this is a test"
result = task.interpolate_only(
input_string=no_placeholders, inputs={"unused": "value"}
)
assert result == no_placeholders
def test_task_output_str_with_pydantic():
from crewai.tasks.output_format import OutputFormat
@@ -870,9 +907,9 @@ def test_key():
assert task.key == hash, "The key should be the hash of the description."
task.interpolate_inputs_and_add_conversation_history(inputs={"topic": "AI"})
assert (
task.key == hash
), "The key should be the hash of the non-interpolated description."
assert task.key == hash, (
"The key should be the hash of the non-interpolated description."
)
def test_output_file_validation():