mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
chore: update typing in task and guardrails
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
import datetime
|
||||
@@ -10,6 +11,7 @@ import logging
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
cast,
|
||||
@@ -17,11 +19,11 @@ from typing import (
|
||||
get_origin,
|
||||
)
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
@@ -37,6 +39,7 @@ from crewai.events.types.task_events import (
|
||||
TaskFailedEvent,
|
||||
TaskStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
@@ -57,6 +60,9 @@ from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
_printer = Printer()
|
||||
|
||||
|
||||
@@ -101,17 +107,17 @@ class Task(BaseModel):
|
||||
description="Configuration for the agent",
|
||||
default=None,
|
||||
)
|
||||
callback: Any | None = Field(
|
||||
callback: Callable[[TaskOutput], None] | None = Field(
|
||||
description="Callback to be executed after the task is completed.", default=None
|
||||
)
|
||||
agent: BaseAgent | None = Field(
|
||||
agent: Agent | None = Field(
|
||||
description="Agent responsible for execution the task.", default=None
|
||||
)
|
||||
context: list[Task] | None | _NotSpecified = Field(
|
||||
description="Other tasks that will have their output used as context for this task.",
|
||||
default=NOT_SPECIFIED,
|
||||
)
|
||||
async_execution: bool | None = Field(
|
||||
async_execution: bool = Field(
|
||||
description="Whether the task should be executed asynchronously or not.",
|
||||
default=False,
|
||||
)
|
||||
@@ -151,11 +157,11 @@ class Task(BaseModel):
|
||||
frozen=True,
|
||||
description="Unique identifier for the object, not set by user.",
|
||||
)
|
||||
human_input: bool | None = Field(
|
||||
human_input: bool = Field(
|
||||
description="Whether the task should have a human review the final answer of the agent",
|
||||
default=False,
|
||||
)
|
||||
markdown: bool | None = Field(
|
||||
markdown: bool = Field(
|
||||
description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
|
||||
default=False,
|
||||
)
|
||||
@@ -172,11 +178,6 @@ class Task(BaseModel):
|
||||
default=None,
|
||||
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
|
||||
)
|
||||
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
@@ -187,8 +188,8 @@ class Task(BaseModel):
|
||||
end_time: datetime.datetime | None = Field(
|
||||
default=None, description="End time of the task execution"
|
||||
)
|
||||
allow_crewai_trigger_context: bool | None = Field(
|
||||
default=None,
|
||||
allow_crewai_trigger_context: bool = Field(
|
||||
default=False,
|
||||
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
|
||||
)
|
||||
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
|
||||
@@ -202,7 +203,9 @@ class Task(BaseModel):
|
||||
_original_expected_output: str | None = PrivateAttr(default=None)
|
||||
_original_output_file: str | None = PrivateAttr(default=None)
|
||||
_thread: threading.Thread | None = PrivateAttr(default=None)
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
@field_validator("guardrail")
|
||||
@classmethod
|
||||
@@ -288,15 +291,16 @@ class Task(BaseModel):
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent is required to use LLMGuardrail")
|
||||
|
||||
self._guardrail = cast(
|
||||
GuardrailCallable,
|
||||
LLMGuardrail(description=self.guardrail, llm=self.agent.llm),
|
||||
self._guardrail = LLMGuardrail(
|
||||
description=self.guardrail, llm=cast(BaseLLM, self.agent.llm)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrails_is_list_of_callables(self) -> Task:
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
guardrails = []
|
||||
if self.guardrails is not None:
|
||||
if isinstance(self.guardrails, (list, tuple)):
|
||||
@@ -309,14 +313,11 @@ class Task(BaseModel):
|
||||
raise ValueError(
|
||||
"Agent is required to use non-programmatic guardrails"
|
||||
)
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
guardrails.append(
|
||||
cast(
|
||||
GuardrailCallable,
|
||||
LLMGuardrail(
|
||||
description=guardrail, llm=self.agent.llm
|
||||
),
|
||||
LLMGuardrail(
|
||||
description=guardrail,
|
||||
llm=cast(BaseLLM, self.agent.llm),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -329,14 +330,11 @@ class Task(BaseModel):
|
||||
raise ValueError(
|
||||
"Agent is required to use non-programmatic guardrails"
|
||||
)
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
guardrails.append(
|
||||
cast(
|
||||
GuardrailCallable,
|
||||
LLMGuardrail(
|
||||
description=self.guardrails, llm=self.agent.llm
|
||||
),
|
||||
LLMGuardrail(
|
||||
description=self.guardrails,
|
||||
llm=cast(BaseLLM, self.agent.llm),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -436,21 +434,9 @@ class Task(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def handle_max_retries_deprecation(self) -> Self:
|
||||
if self.max_retries is not None:
|
||||
warnings.warn(
|
||||
"The 'max_retries' parameter is deprecated and will be removed in CrewAI v1.0.0. "
|
||||
"Please use 'guardrail_max_retries' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.guardrail_max_retries = self.max_retries
|
||||
return self
|
||||
|
||||
def execute_sync(
|
||||
self,
|
||||
agent: BaseAgent | None = None,
|
||||
agent: Agent | None = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> TaskOutput:
|
||||
@@ -488,9 +474,9 @@ class Task(BaseModel):
|
||||
|
||||
def _execute_task_async(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
agent: Agent | None,
|
||||
context: str | None,
|
||||
tools: list[Any] | None,
|
||||
tools: list[BaseTool] | None,
|
||||
future: Future[TaskOutput],
|
||||
) -> None:
|
||||
"""Execute the task asynchronously with context handling."""
|
||||
@@ -499,9 +485,9 @@ class Task(BaseModel):
|
||||
|
||||
def _execute_core(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
agent: Agent | None,
|
||||
context: str | None,
|
||||
tools: list[Any] | None,
|
||||
tools: list[BaseTool] | None,
|
||||
) -> TaskOutput:
|
||||
"""Run the core execution logic of the task."""
|
||||
try:
|
||||
@@ -611,8 +597,6 @@ class Task(BaseModel):
|
||||
if trigger_payload is not None:
|
||||
description += f"\n\nTrigger Payload: {trigger_payload}"
|
||||
|
||||
tasks_slices = [description]
|
||||
|
||||
output = self.i18n.slice("expected_output").format(
|
||||
expected_output=self.expected_output
|
||||
)
|
||||
@@ -715,7 +699,7 @@ Follow these guidelines:
|
||||
self.processed_by_agents.add(agent_name)
|
||||
self.delegations += 1
|
||||
|
||||
def copy( # type: ignore
|
||||
def copy( # type: ignore[override]
|
||||
self, agents: list[BaseAgent], task_mapping: dict[str, Task]
|
||||
) -> Task:
|
||||
"""Creates a deep copy of the Task while preserving its original class type.
|
||||
@@ -859,7 +843,7 @@ Follow these guidelines:
|
||||
def _invoke_guardrail_function(
|
||||
self,
|
||||
task_output: TaskOutput,
|
||||
agent: BaseAgent,
|
||||
agent: Agent,
|
||||
tools: list[BaseTool],
|
||||
guardrail: GuardrailCallable | None,
|
||||
guardrail_index: int | None = None,
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
@@ -38,7 +38,9 @@ class LLMGuardrail:
|
||||
|
||||
self.llm: BaseLLM = llm
|
||||
|
||||
def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:
|
||||
def _validate_output(
|
||||
self, task_output: TaskOutput | LiteAgentOutput
|
||||
) -> LiteAgentOutput:
|
||||
agent = Agent(
|
||||
role="Guardrail Agent",
|
||||
goal="Validate the output of the task",
|
||||
@@ -64,18 +66,17 @@ class LLMGuardrail:
|
||||
|
||||
return agent.kickoff(query, response_format=LLMGuardrailResult)
|
||||
|
||||
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
|
||||
def __call__(self, task_output: TaskOutput | LiteAgentOutput) -> tuple[bool, Any]:
|
||||
"""Validates the output of a task based on specified criteria.
|
||||
|
||||
Args:
|
||||
task_output (TaskOutput): The output to be validated.
|
||||
task_output: The output to be validated.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Any]: A tuple containing:
|
||||
A tuple containing:
|
||||
- bool: True if validation passed, False otherwise
|
||||
- Any: The validation result or error message
|
||||
"""
|
||||
|
||||
try:
|
||||
result = self._validate_output(task_output)
|
||||
if not isinstance(result.pydantic, LLMGuardrailResult):
|
||||
|
||||
Reference in New Issue
Block a user