chore: update typing in task and guardrails

This commit is contained in:
Greyson LaLonde
2025-11-30 22:27:53 -05:00
parent 3ce019b07b
commit 008f906f60
2 changed files with 41 additions and 56 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy as shallow_copy from copy import copy as shallow_copy
import datetime import datetime
@@ -10,6 +11,7 @@ import logging
from pathlib import Path from pathlib import Path
import threading import threading
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
ClassVar, ClassVar,
cast, cast,
@@ -17,11 +19,11 @@ from typing import (
get_origin, get_origin,
) )
import uuid import uuid
import warnings
from pydantic import ( from pydantic import (
UUID4, UUID4,
BaseModel, BaseModel,
ConfigDict,
Field, Field,
PrivateAttr, PrivateAttr,
field_validator, field_validator,
@@ -37,6 +39,7 @@ from crewai.events.types.task_events import (
TaskFailedEvent, TaskFailedEvent,
TaskStartedEvent, TaskStartedEvent,
) )
from crewai.llms.base_llm import BaseLLM
from crewai.security import Fingerprint, SecurityConfig from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
@@ -57,6 +60,9 @@ from crewai.utilities.printer import Printer
from crewai.utilities.string_utils import interpolate_only from crewai.utilities.string_utils import interpolate_only
if TYPE_CHECKING:
from crewai.agent.core import Agent
_printer = Printer() _printer = Printer()
@@ -101,17 +107,17 @@ class Task(BaseModel):
description="Configuration for the agent", description="Configuration for the agent",
default=None, default=None,
) )
callback: Any | None = Field( callback: Callable[[TaskOutput], None] | None = Field(
description="Callback to be executed after the task is completed.", default=None description="Callback to be executed after the task is completed.", default=None
) )
agent: BaseAgent | None = Field( agent: Agent | None = Field(
description="Agent responsible for execution the task.", default=None description="Agent responsible for execution the task.", default=None
) )
context: list[Task] | None | _NotSpecified = Field( context: list[Task] | None | _NotSpecified = Field(
description="Other tasks that will have their output used as context for this task.", description="Other tasks that will have their output used as context for this task.",
default=NOT_SPECIFIED, default=NOT_SPECIFIED,
) )
async_execution: bool | None = Field( async_execution: bool = Field(
description="Whether the task should be executed asynchronously or not.", description="Whether the task should be executed asynchronously or not.",
default=False, default=False,
) )
@@ -151,11 +157,11 @@ class Task(BaseModel):
frozen=True, frozen=True,
description="Unique identifier for the object, not set by user.", description="Unique identifier for the object, not set by user.",
) )
human_input: bool | None = Field( human_input: bool = Field(
description="Whether the task should have a human review the final answer of the agent", description="Whether the task should have a human review the final answer of the agent",
default=False, default=False,
) )
markdown: bool | None = Field( markdown: bool = Field(
description="Whether the task should instruct the agent to return the final answer formatted in Markdown", description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
default=False, default=False,
) )
@@ -172,11 +178,6 @@ class Task(BaseModel):
default=None, default=None,
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task", description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
) )
max_retries: int | None = Field(
default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
)
guardrail_max_retries: int = Field( guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails" default=3, description="Maximum number of retries when guardrail fails"
) )
@@ -187,8 +188,8 @@ class Task(BaseModel):
end_time: datetime.datetime | None = Field( end_time: datetime.datetime | None = Field(
default=None, description="End time of the task execution" default=None, description="End time of the task execution"
) )
allow_crewai_trigger_context: bool | None = Field( allow_crewai_trigger_context: bool = Field(
default=None, default=False,
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.", description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
) )
_guardrail: GuardrailCallable | None = PrivateAttr(default=None) _guardrail: GuardrailCallable | None = PrivateAttr(default=None)
@@ -202,7 +203,9 @@ class Task(BaseModel):
_original_expected_output: str | None = PrivateAttr(default=None) _original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: str | None = PrivateAttr(default=None) _original_output_file: str | None = PrivateAttr(default=None)
_thread: threading.Thread | None = PrivateAttr(default=None) _thread: threading.Thread | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True} model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@field_validator("guardrail") @field_validator("guardrail")
@classmethod @classmethod
@@ -288,15 +291,16 @@ class Task(BaseModel):
if self.agent is None: if self.agent is None:
raise ValueError("Agent is required to use LLMGuardrail") raise ValueError("Agent is required to use LLMGuardrail")
self._guardrail = cast( self._guardrail = LLMGuardrail(
GuardrailCallable, description=self.guardrail, llm=cast(BaseLLM, self.agent.llm)
LLMGuardrail(description=self.guardrail, llm=self.agent.llm),
) )
return self return self
@model_validator(mode="after") @model_validator(mode="after")
def ensure_guardrails_is_list_of_callables(self) -> Task: def ensure_guardrails_is_list_of_callables(self) -> Task:
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails = [] guardrails = []
if self.guardrails is not None: if self.guardrails is not None:
if isinstance(self.guardrails, (list, tuple)): if isinstance(self.guardrails, (list, tuple)):
@@ -309,14 +313,11 @@ class Task(BaseModel):
raise ValueError( raise ValueError(
"Agent is required to use non-programmatic guardrails" "Agent is required to use non-programmatic guardrails"
) )
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append( guardrails.append(
cast( LLMGuardrail(
GuardrailCallable, description=guardrail,
LLMGuardrail( llm=cast(BaseLLM, self.agent.llm),
description=guardrail, llm=self.agent.llm
),
) )
) )
else: else:
@@ -329,14 +330,11 @@ class Task(BaseModel):
raise ValueError( raise ValueError(
"Agent is required to use non-programmatic guardrails" "Agent is required to use non-programmatic guardrails"
) )
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append( guardrails.append(
cast( LLMGuardrail(
GuardrailCallable, description=self.guardrails,
LLMGuardrail( llm=cast(BaseLLM, self.agent.llm),
description=self.guardrails, llm=self.agent.llm
),
) )
) )
else: else:
@@ -436,21 +434,9 @@ class Task(BaseModel):
) )
return self return self
@model_validator(mode="after")
def handle_max_retries_deprecation(self) -> Self:
if self.max_retries is not None:
warnings.warn(
"The 'max_retries' parameter is deprecated and will be removed in CrewAI v1.0.0. "
"Please use 'guardrail_max_retries' instead.",
DeprecationWarning,
stacklevel=2,
)
self.guardrail_max_retries = self.max_retries
return self
def execute_sync( def execute_sync(
self, self,
agent: BaseAgent | None = None, agent: Agent | None = None,
context: str | None = None, context: str | None = None,
tools: list[BaseTool] | None = None, tools: list[BaseTool] | None = None,
) -> TaskOutput: ) -> TaskOutput:
@@ -488,9 +474,9 @@ class Task(BaseModel):
def _execute_task_async( def _execute_task_async(
self, self,
agent: BaseAgent | None, agent: Agent | None,
context: str | None, context: str | None,
tools: list[Any] | None, tools: list[BaseTool] | None,
future: Future[TaskOutput], future: Future[TaskOutput],
) -> None: ) -> None:
"""Execute the task asynchronously with context handling.""" """Execute the task asynchronously with context handling."""
@@ -499,9 +485,9 @@ class Task(BaseModel):
def _execute_core( def _execute_core(
self, self,
agent: BaseAgent | None, agent: Agent | None,
context: str | None, context: str | None,
tools: list[Any] | None, tools: list[BaseTool] | None,
) -> TaskOutput: ) -> TaskOutput:
"""Run the core execution logic of the task.""" """Run the core execution logic of the task."""
try: try:
@@ -611,8 +597,6 @@ class Task(BaseModel):
if trigger_payload is not None: if trigger_payload is not None:
description += f"\n\nTrigger Payload: {trigger_payload}" description += f"\n\nTrigger Payload: {trigger_payload}"
tasks_slices = [description]
output = self.i18n.slice("expected_output").format( output = self.i18n.slice("expected_output").format(
expected_output=self.expected_output expected_output=self.expected_output
) )
@@ -715,7 +699,7 @@ Follow these guidelines:
self.processed_by_agents.add(agent_name) self.processed_by_agents.add(agent_name)
self.delegations += 1 self.delegations += 1
def copy( # type: ignore def copy( # type: ignore[override]
self, agents: list[BaseAgent], task_mapping: dict[str, Task] self, agents: list[BaseAgent], task_mapping: dict[str, Task]
) -> Task: ) -> Task:
"""Creates a deep copy of the Task while preserving its original class type. """Creates a deep copy of the Task while preserving its original class type.
@@ -859,7 +843,7 @@ Follow these guidelines:
def _invoke_guardrail_function( def _invoke_guardrail_function(
self, self,
task_output: TaskOutput, task_output: TaskOutput,
agent: BaseAgent, agent: Agent,
tools: list[BaseTool], tools: list[BaseTool],
guardrail: GuardrailCallable | None, guardrail: GuardrailCallable | None,
guardrail_index: int | None = None, guardrail_index: int | None = None,

View File

@@ -2,7 +2,7 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.agent import Agent from crewai.agent.core import Agent
from crewai.lite_agent_output import LiteAgentOutput from crewai.lite_agent_output import LiteAgentOutput
from crewai.llms.base_llm import BaseLLM from crewai.llms.base_llm import BaseLLM
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
@@ -38,7 +38,9 @@ class LLMGuardrail:
self.llm: BaseLLM = llm self.llm: BaseLLM = llm
def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput: def _validate_output(
self, task_output: TaskOutput | LiteAgentOutput
) -> LiteAgentOutput:
agent = Agent( agent = Agent(
role="Guardrail Agent", role="Guardrail Agent",
goal="Validate the output of the task", goal="Validate the output of the task",
@@ -64,18 +66,17 @@ class LLMGuardrail:
return agent.kickoff(query, response_format=LLMGuardrailResult) return agent.kickoff(query, response_format=LLMGuardrailResult)
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]: def __call__(self, task_output: TaskOutput | LiteAgentOutput) -> tuple[bool, Any]:
"""Validates the output of a task based on specified criteria. """Validates the output of a task based on specified criteria.
Args: Args:
task_output (TaskOutput): The output to be validated. task_output: The output to be validated.
Returns: Returns:
Tuple[bool, Any]: A tuple containing: A tuple containing:
- bool: True if validation passed, False otherwise - bool: True if validation passed, False otherwise
- Any: The validation result or error message - Any: The validation result or error message
""" """
try: try:
result = self._validate_output(task_output) result = self._validate_output(task_output)
if not isinstance(result.pydantic, LLMGuardrailResult): if not isinstance(result.pydantic, LLMGuardrailResult):