fix: renaming TaskGuardrail to LLMGuardrail (#2731)

This commit is contained in:
Lucas Gomide
2025-04-30 14:11:35 -03:00
committed by GitHub
parent bc24bc64cd
commit d348d5f20e
7 changed files with 36 additions and 36 deletions

View File

@@ -241,10 +241,10 @@ class Task(BaseModel):
if callable(self.guardrail):
self._guardrail = self.guardrail
elif isinstance(self.guardrail, str):
from crewai.tasks.task_guardrail import TaskGuardrail
from crewai.tasks.llm_guardrail import LLMGuardrail
assert self.agent is not None
self._guardrail = TaskGuardrail(
self._guardrail = LLMGuardrail(
description=self.guardrail, llm=self.agent.llm
)
@@ -494,8 +494,8 @@ class Task(BaseModel):
assert self._guardrail is not None
from crewai.utilities.events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
@@ -503,7 +503,7 @@ class Task(BaseModel):
crewai_event_bus.emit(
self,
TaskGuardrailStartedEvent(
LLMGuardrailStartedEvent(
guardrail=self._guardrail, retry_count=self.retry_count
),
)
@@ -512,7 +512,7 @@ class Task(BaseModel):
crewai_event_bus.emit(
self,
TaskGuardrailCompletedEvent(
LLMGuardrailCompletedEvent(
success=guardrail_result.success,
result=guardrail_result.result,
error=guardrail_result.error,

View File

@@ -8,7 +8,7 @@ from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
class TaskGuardrailResult(BaseModel):
class LLMGuardrailResult(BaseModel):
valid: bool = Field(
description="Whether the task output complies with the guardrail"
)
@@ -18,7 +18,7 @@ class TaskGuardrailResult(BaseModel):
)
class TaskGuardrail:
class LLMGuardrail:
"""It validates the output of another task using an LLM.
This class is used to validate the output from a Task based on specified criteria.
@@ -62,7 +62,7 @@ class TaskGuardrail:
- If the Task result complies with the guardrail, saying that is valid
"""
result = agent.kickoff(query, response_format=TaskGuardrailResult)
result = agent.kickoff(query, response_format=LLMGuardrailResult)
return result
@@ -81,7 +81,7 @@ class TaskGuardrail:
try:
result = self._validate_output(task_output)
assert isinstance(
result.pydantic, TaskGuardrailResult
result.pydantic, LLMGuardrailResult
), "The guardrail result is not a valid pydantic model"
if result.pydantic.valid:

View File

@@ -9,9 +9,9 @@ from .crew_events import (
CrewTestCompletedEvent,
CrewTestFailedEvent,
)
from .task_guardrail_events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
from .llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from .agent_events import (
AgentExecutionStartedEvent,

View File

@@ -29,15 +29,15 @@ from .llm_events import (
LLMCallStartedEvent,
LLMStreamChunkEvent,
)
from .llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from .task_events import (
TaskCompletedEvent,
TaskFailedEvent,
TaskStartedEvent,
)
from .task_guardrail_events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
)
from .tool_usage_events import (
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
@@ -72,6 +72,6 @@ EventTypes = Union[
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMStreamChunkEvent,
TaskGuardrailStartedEvent,
TaskGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
]

View File

@@ -3,35 +3,35 @@ from typing import Any, Callable, Optional, Union
from crewai.utilities.events.base_events import BaseEvent
class TaskGuardrailStartedEvent(BaseEvent):
class LLMGuardrailStartedEvent(BaseEvent):
"""Event emitted when a guardrail task starts
Attributes:
guardrail: The guardrail callable or TaskGuardrail instance
guardrail: The guardrail callable or LLMGuardrail instance
retry_count: The number of times the guardrail has been retried
"""
type: str = "task_guardrail_started"
type: str = "llm_guardrail_started"
guardrail: Union[str, Callable]
retry_count: int
def __init__(self, **data):
from inspect import getsource
from crewai.tasks.task_guardrail import TaskGuardrail
from crewai.tasks.llm_guardrail import LLMGuardrail
super().__init__(**data)
if isinstance(self.guardrail, TaskGuardrail):
if isinstance(self.guardrail, LLMGuardrail):
self.guardrail = self.guardrail.description.strip()
elif isinstance(self.guardrail, Callable):
self.guardrail = getsource(self.guardrail).strip()
class TaskGuardrailCompletedEvent(BaseEvent):
class LLMGuardrailCompletedEvent(BaseEvent):
"""Event emitted when a guardrail task completes"""
type: str = "task_guardrail_completed"
type: str = "llm_guardrail_completed"
success: bool
result: Any
error: Optional[str] = None