From 27952cfb7ad5fce94337c7e0da214ccb15a1b677 Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Tue, 29 Apr 2025 18:42:24 -0300 Subject: [PATCH] refactor: drop task paramenter from TaskGuardrail This parameter was used to get the model from the `task.agent` which is a quite bit redudant since we could propagate the llm directly --- src/crewai/task.py | 5 ++++- src/crewai/tasks/task_guardrail.py | 13 ++----------- tests/test_task_guardrails.py | 8 ++++++-- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/crewai/task.py b/src/crewai/task.py index 0214ff8da..8a1091935 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -243,7 +243,10 @@ class Task(BaseModel): elif isinstance(self.guardrail, str): from crewai.tasks.task_guardrail import TaskGuardrail - self._guardrail = TaskGuardrail(description=self.guardrail, task=self) + assert self.agent is not None + self._guardrail = TaskGuardrail( + description=self.guardrail, llm=self.agent.llm + ) return self diff --git a/src/crewai/tasks/task_guardrail.py b/src/crewai/tasks/task_guardrail.py index 19879167c..dcaa47f5c 100644 --- a/src/crewai/tasks/task_guardrail.py +++ b/src/crewai/tasks/task_guardrail.py @@ -33,20 +33,11 @@ class TaskGuardrail: def __init__( self, description: str, - task: Task | None = None, - llm: LLM | None = None, + llm: LLM, ): self.description = description - fallback_llm: LLM | None = ( - task.agent.llm - if task is not None - and hasattr(task, "agent") - and task.agent is not None - and hasattr(task.agent, "llm") - else None - ) - self.llm: LLM | None = llm or fallback_llm + self.llm: LLM = llm def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput: agent = Agent( diff --git a/tests/test_task_guardrails.py b/tests/test_task_guardrails.py index df52e9a93..c4224f9cc 100644 --- a/tests/test_task_guardrails.py +++ b/tests/test_task_guardrails.py @@ -153,14 +153,18 @@ def task_output(): @pytest.mark.vcr(filter_headers=["authorization"]) def test_task_guardrail_process_output(task_output): - guardrail = TaskGuardrail(description="Ensure the result has less than 10 words") + guardrail = TaskGuardrail( + description="Ensure the result has less than 10 words", llm=LLM(model="gpt-4o") + ) result = guardrail(task_output) assert result[0] is False assert "exceeding the guardrail limit of fewer than" in result[1].lower() - guardrail = TaskGuardrail(description="Ensure the result has less than 500 words") + guardrail = TaskGuardrail( + description="Ensure the result has less than 500 words", llm=LLM(model="gpt-4o") + ) result = guardrail(task_output) assert result[0] is True