From 128ce919514bcac76062f36fe14b49eb283eb7dc Mon Sep 17 00:00:00 2001 From: Gui Vieira Date: Fri, 22 Mar 2024 03:08:54 -0300 Subject: [PATCH] Fix input interpolation bug (#369) --- src/crewai/agent.py | 17 ++++++++++++++--- src/crewai/task.py | 12 ++++++++++-- tests/agent_test.py | 18 ++++++++++++++++++ tests/task_test.py | 21 +++++++++++++++++++++ 4 files changed, 63 insertions(+), 5 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 1fc189b09..3a2ec9c97 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -116,6 +116,10 @@ class Agent(BaseModel): default=None, description="Callback to be executed" ) + _original_role: str | None = None + _original_goal: str | None = None + _original_backstory: str | None = None + def __init__(__pydantic_self__, **data): config = data.pop("config", {}) super().__init__(**config, **data) @@ -282,10 +286,17 @@ class Agent(BaseModel): def interpolate_inputs(self, inputs: Dict[str, Any]) -> None: """Interpolate inputs into the agent description and backstory.""" + if self._original_role is None: + self._original_role = self.role + if self._original_goal is None: + self._original_goal = self.goal + if self._original_backstory is None: + self._original_backstory = self.backstory + if inputs: - self.role = self.role.format(**inputs) - self.goal = self.goal.format(**inputs) - self.backstory = self.backstory.format(**inputs) + self.role = self._original_role.format(**inputs) + self.goal = self._original_goal.format(**inputs) + self.backstory = self._original_backstory.format(**inputs) def increment_formatting_errors(self) -> None: """Count the formatting errors of the agent.""" diff --git a/src/crewai/task.py b/src/crewai/task.py index ff7af1b89..779e21e5b 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -71,6 +71,9 @@ class Task(BaseModel): description="Unique identifier for the object, not set by user.", ) + _original_description: str | None = None + _original_expected_output: str | None = None + def __init__(__pydantic_self__, **data): config = data.pop("config", {}) super().__init__(**config, **data) @@ -189,9 +192,14 @@ class Task(BaseModel): def interpolate_inputs(self, inputs: Dict[str, Any]) -> None: """Interpolate inputs into the task description and expected output.""" + if self._original_description is None: + self._original_description = self.description + if self._original_expected_output is None: + self._original_expected_output = self.expected_output + if inputs: - self.description = self.description.format(**inputs) - self.expected_output = self.expected_output.format(**inputs) + self.description = self._original_description.format(**inputs) + self.expected_output = self._original_expected_output.format(**inputs) def increment_tools_errors(self) -> None: """Increment the tools errors counter.""" diff --git a/tests/agent_test.py b/tests/agent_test.py index 3d837efb6..505d3a51d 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -680,3 +680,21 @@ def test_agent_definition_based_on_dict(): assert agent.backstory == "test backstory" assert agent.verbose == True assert agent.tools == [] + + +def test_interpolate_inputs(): + agent = Agent( + role="{topic} specialist", + goal="Figure {goal} out", + backstory="I am the master of {role}", + ) + + agent.interpolate_inputs({"topic": "AI", "goal": "life", "role": "all things"}) + assert agent.role == "AI specialist" + assert agent.goal == "Figure life out" + assert agent.backstory == "I am the master of all things" + + agent.interpolate_inputs({"topic": "Sales", "goal": "stuff", "role": "nothing"}) + assert agent.role == "Sales specialist" + assert agent.goal == "Figure stuff out" + assert agent.backstory == "I am the master of nothing" diff --git a/tests/task_test.py b/tests/task_test.py index 3ef520cdf..ffb401a09 100644 --- a/tests/task_test.py +++ b/tests/task_test.py @@ -462,3 +462,24 @@ def test_task_definition_based_on_dict(): assert task.description == config["description"] assert task.expected_output == config["expected_output"] assert task.agent is None + + +def test_interpolate_inputs(): + task = Task( + description="Give me a list of 5 interesting ideas about {topic} to explore for an article, what makes them unique and interesting.", + expected_output="Bullet point list of 5 interesting ideas about {topic}.", + ) + + task.interpolate_inputs(inputs={"topic": "AI"}) + assert ( + task.description + == "Give me a list of 5 interesting ideas about AI to explore for an article, what makes them unique and interesting." + ) + assert task.expected_output == "Bullet point list of 5 interesting ideas about AI." + + task.interpolate_inputs(inputs={"topic": "ML"}) + assert ( + task.description + == "Give me a list of 5 interesting ideas about ML to explore for an article, what makes them unique and interesting." + ) + assert task.expected_output == "Bullet point list of 5 interesting ideas about ML."