fixing types and adding ability for nor using system prompts

This commit is contained in:
João Moura
2024-09-13 08:56:07 -03:00
parent 91ddab207f
commit 5f46a6de79
7 changed files with 61 additions and 15 deletions

View File

@@ -73,6 +73,10 @@ class Agent(BaseAgent):
default=None,
description="Callback to be executed after each step of the agent execution.",
)
use_system_prompt: Optional[bool] = Field(
default=True,
description="Use system prompt for the agent.",
)
llm: Any = Field(
description="Language model that will run the agent.", default="gpt-4o"
)
@@ -206,6 +210,7 @@ class Agent(BaseAgent):
agent=self,
tools=tools,
i18n=self.i18n,
use_system_prompt=self.use_system_prompt,
system_template=self.system_template,
prompt_template=self.prompt_template,
response_template=self.response_template,

View File

@@ -20,6 +20,7 @@ class CrewAgentExecutorMixin:
task: Optional["Task"]
iterations: int
have_forced_answer: bool
max_iter: int
_i18n: I18N
def _should_force_answer(self) -> bool:

View File

@@ -57,22 +57,22 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.ask_for_human_input = False
self.messages = []
self.messages: List[Dict[str, str]] = []
self.iterations = 0
self.have_forced_answer = False
self.name_to_tool_map = {tool.name: tool for tool in self.tools}
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
if "system" in self.prompt:
system_prompt = self._format_prompt(self.prompt["system"], inputs)
user_prompt = self._format_prompt(self.prompt["user"], inputs)
system_prompt = self._format_prompt(self.prompt.get("system", ""), inputs)
user_prompt = self._format_prompt(self.prompt.get("user", ""), inputs)
self.messages.append(self._format_msg(system_prompt, role="system"))
self.messages.append(self._format_msg(user_prompt))
else:
user_prompt = self._format_prompt(self.prompt["prompt"], inputs)
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
self.messages.append(self._format_msg(user_prompt))
self.ask_for_human_input = inputs.get("ask_for_human_input", False)
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
formatted_answer = self._invoke_loop()
@@ -135,7 +135,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return formatted_answer
def _use_tool(self, agent_action: AgentAction) -> None:
def _use_tool(self, agent_action: AgentAction) -> Any:
tool_usage = ToolUsage(
tools_handler=self.tools_handler,
tools=self.tools,
@@ -143,7 +143,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
tools_description=self.tools_description,
tools_names=self.tools_names,
function_calling_llm=self.function_calling_llm,
task=self.task,
task=self.task, # type: ignore[arg-type]
agent=self.agent,
action=agent_action,
)
@@ -188,7 +188,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
summarized_contents.append(summary)
merged_summary = " ".join(summarized_contents)
merged_summary = " ".join(str(content) for content in summarized_contents)
self.messages = [
self._format_msg(
@@ -226,6 +226,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
):
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
if training_data.get(agent_id):
# type: ignore[union-attr]
training_data[agent_id][self.crew._train_iteration][
"improved_output"
] = result.output
@@ -238,9 +239,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"agent": agent_id,
"agent_role": self.agent.role,
}
CrewTrainingHandler(TRAINING_DATA_FILE).append(
self.crew._train_iteration, agent_id, training_data
)
# type: ignore[union-attr]
if isinstance(self.crew._train_iteration, int):
CrewTrainingHandler(TRAINING_DATA_FILE).append(
self.crew._train_iteration, agent_id, training_data
)
else:
self._logger.log(
"error",
"Invalid train iteration type. Expected int.",
color="red",
)
def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str:
prompt = prompt.replace("{input}", inputs["input"])

View File

@@ -201,7 +201,8 @@ class Crew(BaseModel):
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
self.function_calling_llm = (
self.function_calling_llm.model_name
if hasattr(self.function_calling_llm, "model_name")
if self.function_calling_llm is not None
and hasattr(self.function_calling_llm, "model_name")
else self.function_calling_llm
)
self._telemetry = Telemetry()

View File

@@ -83,7 +83,7 @@ class TaskEvaluator:
instructions = f"{instructions}\n\nReturn only valid JSON with the following schema:\n```json\n{model_schema}\n```"
converter = Converter(
llm=self.function_calling_llm,
llm=self.llm,
text=evaluation_query,
model=TaskEvaluation,
instructions=instructions,

View File

@@ -11,9 +11,10 @@ class Prompts(BaseModel):
system_template: Optional[str] = None
prompt_template: Optional[str] = None
response_template: Optional[str] = None
use_system_prompt: Optional[bool] = False
agent: Any
def task_execution(self) -> str:
def task_execution(self) -> dict[str, str]:
"""Generate a standard prompt for task execution."""
slices = ["role_playing"]
if len(self.tools) > 0:
@@ -23,7 +24,11 @@ class Prompts(BaseModel):
system = self._build_prompt(slices)
slices.append("task")
if not self.system_template and not self.prompt_template:
if (
not self.system_template
and not self.prompt_template
and self.use_system_prompt
):
return {
"system": system,
"user": self._build_prompt(["task"]),

View File

@@ -856,6 +856,31 @@ def test_interpolate_inputs():
assert agent.backstory == "I am the master of nothing"
def test_not_using_system_prompt():
agent = Agent(
role="{topic} specialist",
goal="Figure {goal} out",
backstory="I am the master of {role}",
use_system_prompt=False,
)
agent.create_agent_executor()
assert not agent.agent_executor.prompt.get("user")
assert not agent.agent_executor.prompt.get("system")
def test_using_system_prompt():
agent = Agent(
role="{topic} specialist",
goal="Figure {goal} out",
backstory="I am the master of {role}",
)
agent.create_agent_executor()
assert agent.agent_executor.prompt.get("user")
assert agent.agent_executor.prompt.get("system")
def test_system_and_prompt_template():
agent = Agent(
role="{topic} specialist",