chore: update typing in task and guardrails

This commit is contained in:
Greyson LaLonde
2025-11-30 22:27:53 -05:00
parent 3ce019b07b
commit 008f906f60
2 changed files with 41 additions and 56 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from collections.abc import Callable
from concurrent.futures import Future
from copy import copy as shallow_copy
import datetime
@@ -10,6 +11,7 @@ import logging
from pathlib import Path
import threading
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
cast,
@@ -17,11 +19,11 @@ from typing import (
get_origin,
)
import uuid
import warnings
from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
PrivateAttr,
field_validator,
@@ -37,6 +39,7 @@ from crewai.events.types.task_events import (
TaskFailedEvent,
TaskStartedEvent,
)
from crewai.llms.base_llm import BaseLLM
from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
@@ -57,6 +60,9 @@ from crewai.utilities.printer import Printer
from crewai.utilities.string_utils import interpolate_only
if TYPE_CHECKING:
from crewai.agent.core import Agent
_printer = Printer()
@@ -101,17 +107,17 @@ class Task(BaseModel):
description="Configuration for the agent",
default=None,
)
callback: Any | None = Field(
callback: Callable[[TaskOutput], None] | None = Field(
description="Callback to be executed after the task is completed.", default=None
)
agent: BaseAgent | None = Field(
agent: Agent | None = Field(
description="Agent responsible for execution the task.", default=None
)
context: list[Task] | None | _NotSpecified = Field(
description="Other tasks that will have their output used as context for this task.",
default=NOT_SPECIFIED,
)
async_execution: bool | None = Field(
async_execution: bool = Field(
description="Whether the task should be executed asynchronously or not.",
default=False,
)
@@ -151,11 +157,11 @@ class Task(BaseModel):
frozen=True,
description="Unique identifier for the object, not set by user.",
)
human_input: bool | None = Field(
human_input: bool = Field(
description="Whether the task should have a human review the final answer of the agent",
default=False,
)
markdown: bool | None = Field(
markdown: bool = Field(
description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
default=False,
)
@@ -172,11 +178,6 @@ class Task(BaseModel):
default=None,
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
)
max_retries: int | None = Field(
default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
)
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
@@ -187,8 +188,8 @@ class Task(BaseModel):
end_time: datetime.datetime | None = Field(
default=None, description="End time of the task execution"
)
allow_crewai_trigger_context: bool | None = Field(
default=None,
allow_crewai_trigger_context: bool = Field(
default=False,
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
)
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
@@ -202,7 +203,9 @@ class Task(BaseModel):
_original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: str | None = PrivateAttr(default=None)
_thread: threading.Thread | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@field_validator("guardrail")
@classmethod
@@ -288,15 +291,16 @@ class Task(BaseModel):
if self.agent is None:
raise ValueError("Agent is required to use LLMGuardrail")
self._guardrail = cast(
GuardrailCallable,
LLMGuardrail(description=self.guardrail, llm=self.agent.llm),
self._guardrail = LLMGuardrail(
description=self.guardrail, llm=cast(BaseLLM, self.agent.llm)
)
return self
@model_validator(mode="after")
def ensure_guardrails_is_list_of_callables(self) -> Task:
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails = []
if self.guardrails is not None:
if isinstance(self.guardrails, (list, tuple)):
@@ -309,14 +313,11 @@ class Task(BaseModel):
raise ValueError(
"Agent is required to use non-programmatic guardrails"
)
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append(
cast(
GuardrailCallable,
LLMGuardrail(
description=guardrail, llm=self.agent.llm
),
LLMGuardrail(
description=guardrail,
llm=cast(BaseLLM, self.agent.llm),
)
)
else:
@@ -329,14 +330,11 @@ class Task(BaseModel):
raise ValueError(
"Agent is required to use non-programmatic guardrails"
)
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append(
cast(
GuardrailCallable,
LLMGuardrail(
description=self.guardrails, llm=self.agent.llm
),
LLMGuardrail(
description=self.guardrails,
llm=cast(BaseLLM, self.agent.llm),
)
)
else:
@@ -436,21 +434,9 @@ class Task(BaseModel):
)
return self
@model_validator(mode="after")
def handle_max_retries_deprecation(self) -> Self:
if self.max_retries is not None:
warnings.warn(
"The 'max_retries' parameter is deprecated and will be removed in CrewAI v1.0.0. "
"Please use 'guardrail_max_retries' instead.",
DeprecationWarning,
stacklevel=2,
)
self.guardrail_max_retries = self.max_retries
return self
def execute_sync(
self,
agent: BaseAgent | None = None,
agent: Agent | None = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> TaskOutput:
@@ -488,9 +474,9 @@ class Task(BaseModel):
def _execute_task_async(
self,
agent: BaseAgent | None,
agent: Agent | None,
context: str | None,
tools: list[Any] | None,
tools: list[BaseTool] | None,
future: Future[TaskOutput],
) -> None:
"""Execute the task asynchronously with context handling."""
@@ -499,9 +485,9 @@ class Task(BaseModel):
def _execute_core(
self,
agent: BaseAgent | None,
agent: Agent | None,
context: str | None,
tools: list[Any] | None,
tools: list[BaseTool] | None,
) -> TaskOutput:
"""Run the core execution logic of the task."""
try:
@@ -611,8 +597,6 @@ class Task(BaseModel):
if trigger_payload is not None:
description += f"\n\nTrigger Payload: {trigger_payload}"
tasks_slices = [description]
output = self.i18n.slice("expected_output").format(
expected_output=self.expected_output
)
@@ -715,7 +699,7 @@ Follow these guidelines:
self.processed_by_agents.add(agent_name)
self.delegations += 1
def copy( # type: ignore
def copy( # type: ignore[override]
self, agents: list[BaseAgent], task_mapping: dict[str, Task]
) -> Task:
"""Creates a deep copy of the Task while preserving its original class type.
@@ -859,7 +843,7 @@ Follow these guidelines:
def _invoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
agent: Agent,
tools: list[BaseTool],
guardrail: GuardrailCallable | None,
guardrail_index: int | None = None,

View File

@@ -2,7 +2,7 @@ from typing import Any
from pydantic import BaseModel, Field
from crewai.agent import Agent
from crewai.agent.core import Agent
from crewai.lite_agent_output import LiteAgentOutput
from crewai.llms.base_llm import BaseLLM
from crewai.tasks.task_output import TaskOutput
@@ -38,7 +38,9 @@ class LLMGuardrail:
self.llm: BaseLLM = llm
def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:
def _validate_output(
self, task_output: TaskOutput | LiteAgentOutput
) -> LiteAgentOutput:
agent = Agent(
role="Guardrail Agent",
goal="Validate the output of the task",
@@ -64,18 +66,17 @@ class LLMGuardrail:
return agent.kickoff(query, response_format=LLMGuardrailResult)
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
def __call__(self, task_output: TaskOutput | LiteAgentOutput) -> tuple[bool, Any]:
"""Validates the output of a task based on specified criteria.
Args:
task_output (TaskOutput): The output to be validated.
task_output: The output to be validated.
Returns:
Tuple[bool, Any]: A tuple containing:
A tuple containing:
- bool: True if validation passed, False otherwise
- Any: The validation result or error message
"""
try:
result = self._validate_output(task_output)
if not isinstance(result.pydantic, LLMGuardrailResult):