From 38ceb9d409d882c0cd0f339077c69837788607d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Moura?= Date: Wed, 28 Feb 2024 02:59:58 -0300 Subject: [PATCH] adding support for input interpolation for tasks and agents --- src/crewai/agent.py | 8 +++++++- src/crewai/crew.py | 14 ++++++++++++++ src/crewai/task.py | 8 +++++++- tests/crew_test.py | 21 +++++++++++++++++++++ 4 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 1deaa1582..e204d639b 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -1,6 +1,6 @@ import os import uuid -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from crewai_tools import BaseTool as CrewAITool from langchain.agents.agent import RunnableAgent @@ -255,6 +255,12 @@ class Agent(BaseModel): agent=RunnableAgent(runnable=inner_agent), **executor_args ) + def interpolate_inputs(self, inputs: Dict[str, Any]) -> None: + """Interpolate inputs into the agent description and backstory.""" + self.role = self.role.format(**inputs) + self.goal = self.goal.format(**inputs) + self.backstory = self.backstory.format(**inputs) + def increment_formatting_errors(self) -> None: """Count the formatting errors of the agent.""" self.formatting_errors += 1 diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 1ed28b308..79a93176a 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -41,6 +41,7 @@ class Crew(BaseModel): full_output: Whether the crew should return the full output with all tasks outputs or just the final output. step_callback: Callback to be executed after each step for every agents execution. share_crew: Whether you want to share the complete crew infromation and execution with crewAI to make the library better, and allow us to train models. + inputs: Any inputs that the crew will use in tasks or agents, it will be interpolated in promtps. """ __hash__ = object.__hash__ # type: ignore @@ -67,6 +68,10 @@ class Crew(BaseModel): function_calling_llm: Optional[Any] = Field( description="Language model that will run the agent.", default=None ) + inputs: Optional[Dict[str, Any]] = Field( + description="Any inputs that the crew will use in tasks or agents, it will be interpolated in promtps.", + default={}, + ) config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) share_crew: Optional[bool] = Field(default=False) @@ -129,6 +134,15 @@ class Crew(BaseModel): ) return self + @model_validator(mode="after") + def interpolate_inputs(self): + """Interpolates the inputs in the tasks and agents.""" + for task in self.tasks: + task.interpolate_inputs(self.inputs) + for agent in self.agents: + agent.interpolate_inputs(self.inputs) + return self + @model_validator(mode="after") def check_config(self): """Validates that the crew is properly configured with agents and tasks.""" diff --git a/src/crewai/task.py b/src/crewai/task.py index 98325e81f..bdb552f4d 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -1,6 +1,6 @@ import threading import uuid -from typing import Any, List, Optional, Type +from typing import Any, Dict, List, Optional, Type from langchain_openai import ChatOpenAI from pydantic import UUID4, BaseModel, Field, field_validator, model_validator @@ -173,6 +173,12 @@ class Task(BaseModel): tasks_slices = [self.description, output] return "\n".join(tasks_slices) + def interpolate_inputs(self, inputs: Dict[str, Any]) -> None: + """Interpolate inputs into the task description and expected output.""" + self.description = self.description.format(**inputs) + if self.expected_output: + self.expected_output = self.expected_output.format(**inputs) + def increment_tools_errors(self) -> None: """Increment the tools errors counter.""" self.tools_errors += 1 diff --git a/tests/crew_test.py b/tests/crew_test.py index 1adf3b4e4..c440a43cb 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -650,3 +650,24 @@ def test_agent_usage_metrics_are_captured_for_hierarchical_process(): "completion_tokens": 109, "successful_requests": 3, } + + +def test_crew_inputs_interpolate_both_agents_and_tasks(): + agent = Agent( + role="{topic} Researcher", + goal="Express hot takes on {topic}.", + backstory="You have a lot of experience with {topic}.", + ) + + task = Task( + description="Give me an analysis around {topic}.", + expected_output="{points} bullet points about {topic}.", + ) + + crew = Crew(agents=[agent], tasks=[task], inputs={"topic": "AI", "points": 5}) + + assert crew.tasks[0].description == "Give me an analysis around AI." + assert crew.tasks[0].expected_output == "5 bullet points about AI." + assert crew.agents[0].role == "AI Researcher" + assert crew.agents[0].goal == "Express hot takes on AI." + assert crew.agents[0].backstory == "You have a lot of experience with AI."