mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
chore: update typing in task and guardrails
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from copy import copy as shallow_copy
|
from copy import copy as shallow_copy
|
||||||
import datetime
|
import datetime
|
||||||
@@ -10,6 +11,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import threading
|
import threading
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
cast,
|
cast,
|
||||||
@@ -17,11 +19,11 @@ from typing import (
|
|||||||
get_origin,
|
get_origin,
|
||||||
)
|
)
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
field_validator,
|
field_validator,
|
||||||
@@ -37,6 +39,7 @@ from crewai.events.types.task_events import (
|
|||||||
TaskFailedEvent,
|
TaskFailedEvent,
|
||||||
TaskStartedEvent,
|
TaskStartedEvent,
|
||||||
)
|
)
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.security import Fingerprint, SecurityConfig
|
from crewai.security import Fingerprint, SecurityConfig
|
||||||
from crewai.tasks.output_format import OutputFormat
|
from crewai.tasks.output_format import OutputFormat
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
@@ -57,6 +60,9 @@ from crewai.utilities.printer import Printer
|
|||||||
from crewai.utilities.string_utils import interpolate_only
|
from crewai.utilities.string_utils import interpolate_only
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.agent.core import Agent
|
||||||
|
|
||||||
_printer = Printer()
|
_printer = Printer()
|
||||||
|
|
||||||
|
|
||||||
@@ -101,17 +107,17 @@ class Task(BaseModel):
|
|||||||
description="Configuration for the agent",
|
description="Configuration for the agent",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
callback: Any | None = Field(
|
callback: Callable[[TaskOutput], None] | None = Field(
|
||||||
description="Callback to be executed after the task is completed.", default=None
|
description="Callback to be executed after the task is completed.", default=None
|
||||||
)
|
)
|
||||||
agent: BaseAgent | None = Field(
|
agent: Agent | None = Field(
|
||||||
description="Agent responsible for execution the task.", default=None
|
description="Agent responsible for execution the task.", default=None
|
||||||
)
|
)
|
||||||
context: list[Task] | None | _NotSpecified = Field(
|
context: list[Task] | None | _NotSpecified = Field(
|
||||||
description="Other tasks that will have their output used as context for this task.",
|
description="Other tasks that will have their output used as context for this task.",
|
||||||
default=NOT_SPECIFIED,
|
default=NOT_SPECIFIED,
|
||||||
)
|
)
|
||||||
async_execution: bool | None = Field(
|
async_execution: bool = Field(
|
||||||
description="Whether the task should be executed asynchronously or not.",
|
description="Whether the task should be executed asynchronously or not.",
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
@@ -151,11 +157,11 @@ class Task(BaseModel):
|
|||||||
frozen=True,
|
frozen=True,
|
||||||
description="Unique identifier for the object, not set by user.",
|
description="Unique identifier for the object, not set by user.",
|
||||||
)
|
)
|
||||||
human_input: bool | None = Field(
|
human_input: bool = Field(
|
||||||
description="Whether the task should have a human review the final answer of the agent",
|
description="Whether the task should have a human review the final answer of the agent",
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
markdown: bool | None = Field(
|
markdown: bool = Field(
|
||||||
description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
|
description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
@@ -172,11 +178,6 @@ class Task(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
|
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
|
||||||
)
|
)
|
||||||
|
|
||||||
max_retries: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
|
|
||||||
)
|
|
||||||
guardrail_max_retries: int = Field(
|
guardrail_max_retries: int = Field(
|
||||||
default=3, description="Maximum number of retries when guardrail fails"
|
default=3, description="Maximum number of retries when guardrail fails"
|
||||||
)
|
)
|
||||||
@@ -187,8 +188,8 @@ class Task(BaseModel):
|
|||||||
end_time: datetime.datetime | None = Field(
|
end_time: datetime.datetime | None = Field(
|
||||||
default=None, description="End time of the task execution"
|
default=None, description="End time of the task execution"
|
||||||
)
|
)
|
||||||
allow_crewai_trigger_context: bool | None = Field(
|
allow_crewai_trigger_context: bool = Field(
|
||||||
default=None,
|
default=False,
|
||||||
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
|
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
|
||||||
)
|
)
|
||||||
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
|
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
|
||||||
@@ -202,7 +203,9 @@ class Task(BaseModel):
|
|||||||
_original_expected_output: str | None = PrivateAttr(default=None)
|
_original_expected_output: str | None = PrivateAttr(default=None)
|
||||||
_original_output_file: str | None = PrivateAttr(default=None)
|
_original_output_file: str | None = PrivateAttr(default=None)
|
||||||
_thread: threading.Thread | None = PrivateAttr(default=None)
|
_thread: threading.Thread | None = PrivateAttr(default=None)
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = ConfigDict(
|
||||||
|
arbitrary_types_allowed=True,
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("guardrail")
|
@field_validator("guardrail")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -288,15 +291,16 @@ class Task(BaseModel):
|
|||||||
if self.agent is None:
|
if self.agent is None:
|
||||||
raise ValueError("Agent is required to use LLMGuardrail")
|
raise ValueError("Agent is required to use LLMGuardrail")
|
||||||
|
|
||||||
self._guardrail = cast(
|
self._guardrail = LLMGuardrail(
|
||||||
GuardrailCallable,
|
description=self.guardrail, llm=cast(BaseLLM, self.agent.llm)
|
||||||
LLMGuardrail(description=self.guardrail, llm=self.agent.llm),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def ensure_guardrails_is_list_of_callables(self) -> Task:
|
def ensure_guardrails_is_list_of_callables(self) -> Task:
|
||||||
|
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||||
|
|
||||||
guardrails = []
|
guardrails = []
|
||||||
if self.guardrails is not None:
|
if self.guardrails is not None:
|
||||||
if isinstance(self.guardrails, (list, tuple)):
|
if isinstance(self.guardrails, (list, tuple)):
|
||||||
@@ -309,14 +313,11 @@ class Task(BaseModel):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Agent is required to use non-programmatic guardrails"
|
"Agent is required to use non-programmatic guardrails"
|
||||||
)
|
)
|
||||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
|
||||||
|
|
||||||
guardrails.append(
|
guardrails.append(
|
||||||
cast(
|
LLMGuardrail(
|
||||||
GuardrailCallable,
|
description=guardrail,
|
||||||
LLMGuardrail(
|
llm=cast(BaseLLM, self.agent.llm),
|
||||||
description=guardrail, llm=self.agent.llm
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -329,14 +330,11 @@ class Task(BaseModel):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Agent is required to use non-programmatic guardrails"
|
"Agent is required to use non-programmatic guardrails"
|
||||||
)
|
)
|
||||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
|
||||||
|
|
||||||
guardrails.append(
|
guardrails.append(
|
||||||
cast(
|
LLMGuardrail(
|
||||||
GuardrailCallable,
|
description=self.guardrails,
|
||||||
LLMGuardrail(
|
llm=cast(BaseLLM, self.agent.llm),
|
||||||
description=self.guardrails, llm=self.agent.llm
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -436,21 +434,9 @@ class Task(BaseModel):
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def handle_max_retries_deprecation(self) -> Self:
|
|
||||||
if self.max_retries is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"The 'max_retries' parameter is deprecated and will be removed in CrewAI v1.0.0. "
|
|
||||||
"Please use 'guardrail_max_retries' instead.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
self.guardrail_max_retries = self.max_retries
|
|
||||||
return self
|
|
||||||
|
|
||||||
def execute_sync(
|
def execute_sync(
|
||||||
self,
|
self,
|
||||||
agent: BaseAgent | None = None,
|
agent: Agent | None = None,
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
tools: list[BaseTool] | None = None,
|
tools: list[BaseTool] | None = None,
|
||||||
) -> TaskOutput:
|
) -> TaskOutput:
|
||||||
@@ -488,9 +474,9 @@ class Task(BaseModel):
|
|||||||
|
|
||||||
def _execute_task_async(
|
def _execute_task_async(
|
||||||
self,
|
self,
|
||||||
agent: BaseAgent | None,
|
agent: Agent | None,
|
||||||
context: str | None,
|
context: str | None,
|
||||||
tools: list[Any] | None,
|
tools: list[BaseTool] | None,
|
||||||
future: Future[TaskOutput],
|
future: Future[TaskOutput],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute the task asynchronously with context handling."""
|
"""Execute the task asynchronously with context handling."""
|
||||||
@@ -499,9 +485,9 @@ class Task(BaseModel):
|
|||||||
|
|
||||||
def _execute_core(
|
def _execute_core(
|
||||||
self,
|
self,
|
||||||
agent: BaseAgent | None,
|
agent: Agent | None,
|
||||||
context: str | None,
|
context: str | None,
|
||||||
tools: list[Any] | None,
|
tools: list[BaseTool] | None,
|
||||||
) -> TaskOutput:
|
) -> TaskOutput:
|
||||||
"""Run the core execution logic of the task."""
|
"""Run the core execution logic of the task."""
|
||||||
try:
|
try:
|
||||||
@@ -611,8 +597,6 @@ class Task(BaseModel):
|
|||||||
if trigger_payload is not None:
|
if trigger_payload is not None:
|
||||||
description += f"\n\nTrigger Payload: {trigger_payload}"
|
description += f"\n\nTrigger Payload: {trigger_payload}"
|
||||||
|
|
||||||
tasks_slices = [description]
|
|
||||||
|
|
||||||
output = self.i18n.slice("expected_output").format(
|
output = self.i18n.slice("expected_output").format(
|
||||||
expected_output=self.expected_output
|
expected_output=self.expected_output
|
||||||
)
|
)
|
||||||
@@ -715,7 +699,7 @@ Follow these guidelines:
|
|||||||
self.processed_by_agents.add(agent_name)
|
self.processed_by_agents.add(agent_name)
|
||||||
self.delegations += 1
|
self.delegations += 1
|
||||||
|
|
||||||
def copy( # type: ignore
|
def copy( # type: ignore[override]
|
||||||
self, agents: list[BaseAgent], task_mapping: dict[str, Task]
|
self, agents: list[BaseAgent], task_mapping: dict[str, Task]
|
||||||
) -> Task:
|
) -> Task:
|
||||||
"""Creates a deep copy of the Task while preserving its original class type.
|
"""Creates a deep copy of the Task while preserving its original class type.
|
||||||
@@ -859,7 +843,7 @@ Follow these guidelines:
|
|||||||
def _invoke_guardrail_function(
|
def _invoke_guardrail_function(
|
||||||
self,
|
self,
|
||||||
task_output: TaskOutput,
|
task_output: TaskOutput,
|
||||||
agent: BaseAgent,
|
agent: Agent,
|
||||||
tools: list[BaseTool],
|
tools: list[BaseTool],
|
||||||
guardrail: GuardrailCallable | None,
|
guardrail: GuardrailCallable | None,
|
||||||
guardrail_index: int | None = None,
|
guardrail_index: int | None = None,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent.core import Agent
|
||||||
from crewai.lite_agent_output import LiteAgentOutput
|
from crewai.lite_agent_output import LiteAgentOutput
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
@@ -38,7 +38,9 @@ class LLMGuardrail:
|
|||||||
|
|
||||||
self.llm: BaseLLM = llm
|
self.llm: BaseLLM = llm
|
||||||
|
|
||||||
def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:
|
def _validate_output(
|
||||||
|
self, task_output: TaskOutput | LiteAgentOutput
|
||||||
|
) -> LiteAgentOutput:
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
role="Guardrail Agent",
|
role="Guardrail Agent",
|
||||||
goal="Validate the output of the task",
|
goal="Validate the output of the task",
|
||||||
@@ -64,18 +66,17 @@ class LLMGuardrail:
|
|||||||
|
|
||||||
return agent.kickoff(query, response_format=LLMGuardrailResult)
|
return agent.kickoff(query, response_format=LLMGuardrailResult)
|
||||||
|
|
||||||
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
|
def __call__(self, task_output: TaskOutput | LiteAgentOutput) -> tuple[bool, Any]:
|
||||||
"""Validates the output of a task based on specified criteria.
|
"""Validates the output of a task based on specified criteria.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_output (TaskOutput): The output to be validated.
|
task_output: The output to be validated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Any]: A tuple containing:
|
A tuple containing:
|
||||||
- bool: True if validation passed, False otherwise
|
- bool: True if validation passed, False otherwise
|
||||||
- Any: The validation result or error message
|
- Any: The validation result or error message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self._validate_output(task_output)
|
result = self._validate_output(task_output)
|
||||||
if not isinstance(result.pydantic, LLMGuardrailResult):
|
if not isinstance(result.pydantic, LLMGuardrailResult):
|
||||||
|
|||||||
Reference in New Issue
Block a user