mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
refactor: drop task paramenter from TaskGuardrail
This parameter was used to get the model from the `task.agent` which is a quite bit redudant since we could propagate the llm directly
This commit is contained in:
@@ -243,7 +243,10 @@ class Task(BaseModel):
|
|||||||
elif isinstance(self.guardrail, str):
|
elif isinstance(self.guardrail, str):
|
||||||
from crewai.tasks.task_guardrail import TaskGuardrail
|
from crewai.tasks.task_guardrail import TaskGuardrail
|
||||||
|
|
||||||
self._guardrail = TaskGuardrail(description=self.guardrail, task=self)
|
assert self.agent is not None
|
||||||
|
self._guardrail = TaskGuardrail(
|
||||||
|
description=self.guardrail, llm=self.agent.llm
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -33,20 +33,11 @@ class TaskGuardrail:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
description: str,
|
description: str,
|
||||||
task: Task | None = None,
|
llm: LLM,
|
||||||
llm: LLM | None = None,
|
|
||||||
):
|
):
|
||||||
self.description = description
|
self.description = description
|
||||||
|
|
||||||
fallback_llm: LLM | None = (
|
self.llm: LLM = llm
|
||||||
task.agent.llm
|
|
||||||
if task is not None
|
|
||||||
and hasattr(task, "agent")
|
|
||||||
and task.agent is not None
|
|
||||||
and hasattr(task.agent, "llm")
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
self.llm: LLM | None = llm or fallback_llm
|
|
||||||
|
|
||||||
def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:
|
def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
|
|||||||
@@ -153,14 +153,18 @@ def task_output():
|
|||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_task_guardrail_process_output(task_output):
|
def test_task_guardrail_process_output(task_output):
|
||||||
guardrail = TaskGuardrail(description="Ensure the result has less than 10 words")
|
guardrail = TaskGuardrail(
|
||||||
|
description="Ensure the result has less than 10 words", llm=LLM(model="gpt-4o")
|
||||||
|
)
|
||||||
|
|
||||||
result = guardrail(task_output)
|
result = guardrail(task_output)
|
||||||
assert result[0] is False
|
assert result[0] is False
|
||||||
|
|
||||||
assert "exceeding the guardrail limit of fewer than" in result[1].lower()
|
assert "exceeding the guardrail limit of fewer than" in result[1].lower()
|
||||||
|
|
||||||
guardrail = TaskGuardrail(description="Ensure the result has less than 500 words")
|
guardrail = TaskGuardrail(
|
||||||
|
description="Ensure the result has less than 500 words", llm=LLM(model="gpt-4o")
|
||||||
|
)
|
||||||
|
|
||||||
result = guardrail(task_output)
|
result = guardrail(task_output)
|
||||||
assert result[0] is True
|
assert result[0] is True
|
||||||
|
|||||||
Reference in New Issue
Block a user