mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: renaming TaskGuardrail to LLMGuardrail
This commit is contained in:
@@ -322,9 +322,9 @@ blog_task = Task(
|
|||||||
- On success: it returns a tuple of `(bool, Any)`. For example: `(True, validated_result)`
|
- On success: it returns a tuple of `(bool, Any)`. For example: `(True, validated_result)`
|
||||||
- On Failure: it returns a tuple of `(bool, str)`. For example: `(False, "Error message explain the failure")`
|
- On Failure: it returns a tuple of `(bool, str)`. For example: `(False, "Error message explain the failure")`
|
||||||
|
|
||||||
### TaskGuardrail
|
### LLMGuardrail
|
||||||
|
|
||||||
The `TaskGuardrail` class offers a robust mechanism for validating task outputs
|
The `LLMGuardrail` class offers a robust mechanism for validating task outputs.
|
||||||
|
|
||||||
### Error Handling Best Practices
|
### Error Handling Best Practices
|
||||||
|
|
||||||
@@ -819,7 +819,7 @@ from crewai.llm import LLM
|
|||||||
task = Task(
|
task = Task(
|
||||||
description="Generate JSON data",
|
description="Generate JSON data",
|
||||||
expected_output="Valid JSON object",
|
expected_output="Valid JSON object",
|
||||||
guardrail=TaskGuardrail(
|
guardrail=LLMGuardrail(
|
||||||
description="Ensure the response is a valid JSON object",
|
description="Ensure the response is a valid JSON object",
|
||||||
llm=LLM(model="gpt-4o-mini"),
|
llm=LLM(model="gpt-4o-mini"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -241,10 +241,10 @@ class Task(BaseModel):
|
|||||||
if callable(self.guardrail):
|
if callable(self.guardrail):
|
||||||
self._guardrail = self.guardrail
|
self._guardrail = self.guardrail
|
||||||
elif isinstance(self.guardrail, str):
|
elif isinstance(self.guardrail, str):
|
||||||
from crewai.tasks.task_guardrail import TaskGuardrail
|
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||||
|
|
||||||
assert self.agent is not None
|
assert self.agent is not None
|
||||||
self._guardrail = TaskGuardrail(
|
self._guardrail = LLMGuardrail(
|
||||||
description=self.guardrail, llm=self.agent.llm
|
description=self.guardrail, llm=self.agent.llm
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -494,8 +494,8 @@ class Task(BaseModel):
|
|||||||
assert self._guardrail is not None
|
assert self._guardrail is not None
|
||||||
|
|
||||||
from crewai.utilities.events import (
|
from crewai.utilities.events import (
|
||||||
TaskGuardrailCompletedEvent,
|
LLMGuardrailCompletedEvent,
|
||||||
TaskGuardrailStartedEvent,
|
LLMGuardrailStartedEvent,
|
||||||
)
|
)
|
||||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
|
||||||
@@ -503,7 +503,7 @@ class Task(BaseModel):
|
|||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
TaskGuardrailStartedEvent(
|
LLMGuardrailStartedEvent(
|
||||||
guardrail=self._guardrail, retry_count=self.retry_count
|
guardrail=self._guardrail, retry_count=self.retry_count
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -512,7 +512,7 @@ class Task(BaseModel):
|
|||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
TaskGuardrailCompletedEvent(
|
LLMGuardrailCompletedEvent(
|
||||||
success=guardrail_result.success,
|
success=guardrail_result.success,
|
||||||
result=guardrail_result.result,
|
result=guardrail_result.result,
|
||||||
error=guardrail_result.error,
|
error=guardrail_result.error,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from crewai.task import Task
|
|||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
|
||||||
|
|
||||||
class TaskGuardrailResult(BaseModel):
|
class LLMGuardrailResult(BaseModel):
|
||||||
valid: bool = Field(
|
valid: bool = Field(
|
||||||
description="Whether the task output complies with the guardrail"
|
description="Whether the task output complies with the guardrail"
|
||||||
)
|
)
|
||||||
@@ -18,7 +18,7 @@ class TaskGuardrailResult(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TaskGuardrail:
|
class LLMGuardrail:
|
||||||
"""It validates the output of another task using an LLM.
|
"""It validates the output of another task using an LLM.
|
||||||
|
|
||||||
This class is used to validate the output from a Task based on specified criteria.
|
This class is used to validate the output from a Task based on specified criteria.
|
||||||
@@ -62,7 +62,7 @@ class TaskGuardrail:
|
|||||||
- If the Task result complies with the guardrail, saying that is valid
|
- If the Task result complies with the guardrail, saying that is valid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = agent.kickoff(query, response_format=TaskGuardrailResult)
|
result = agent.kickoff(query, response_format=LLMGuardrailResult)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ class TaskGuardrail:
|
|||||||
try:
|
try:
|
||||||
result = self._validate_output(task_output)
|
result = self._validate_output(task_output)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
result.pydantic, TaskGuardrailResult
|
result.pydantic, LLMGuardrailResult
|
||||||
), "The guardrail result is not a valid pydantic model"
|
), "The guardrail result is not a valid pydantic model"
|
||||||
|
|
||||||
if result.pydantic.valid:
|
if result.pydantic.valid:
|
||||||
@@ -9,9 +9,9 @@ from .crew_events import (
|
|||||||
CrewTestCompletedEvent,
|
CrewTestCompletedEvent,
|
||||||
CrewTestFailedEvent,
|
CrewTestFailedEvent,
|
||||||
)
|
)
|
||||||
from .task_guardrail_events import (
|
from .llm_guardrail_events import (
|
||||||
TaskGuardrailCompletedEvent,
|
LLMGuardrailCompletedEvent,
|
||||||
TaskGuardrailStartedEvent,
|
LLMGuardrailStartedEvent,
|
||||||
)
|
)
|
||||||
from .agent_events import (
|
from .agent_events import (
|
||||||
AgentExecutionStartedEvent,
|
AgentExecutionStartedEvent,
|
||||||
|
|||||||
@@ -29,15 +29,15 @@ from .llm_events import (
|
|||||||
LLMCallStartedEvent,
|
LLMCallStartedEvent,
|
||||||
LLMStreamChunkEvent,
|
LLMStreamChunkEvent,
|
||||||
)
|
)
|
||||||
|
from .llm_guardrail_events import (
|
||||||
|
LLMGuardrailCompletedEvent,
|
||||||
|
LLMGuardrailStartedEvent,
|
||||||
|
)
|
||||||
from .task_events import (
|
from .task_events import (
|
||||||
TaskCompletedEvent,
|
TaskCompletedEvent,
|
||||||
TaskFailedEvent,
|
TaskFailedEvent,
|
||||||
TaskStartedEvent,
|
TaskStartedEvent,
|
||||||
)
|
)
|
||||||
from .task_guardrail_events import (
|
|
||||||
TaskGuardrailCompletedEvent,
|
|
||||||
TaskGuardrailStartedEvent,
|
|
||||||
)
|
|
||||||
from .tool_usage_events import (
|
from .tool_usage_events import (
|
||||||
ToolUsageErrorEvent,
|
ToolUsageErrorEvent,
|
||||||
ToolUsageFinishedEvent,
|
ToolUsageFinishedEvent,
|
||||||
@@ -72,6 +72,6 @@ EventTypes = Union[
|
|||||||
LLMCallCompletedEvent,
|
LLMCallCompletedEvent,
|
||||||
LLMCallFailedEvent,
|
LLMCallFailedEvent,
|
||||||
LLMStreamChunkEvent,
|
LLMStreamChunkEvent,
|
||||||
TaskGuardrailStartedEvent,
|
LLMGuardrailStartedEvent,
|
||||||
TaskGuardrailCompletedEvent,
|
LLMGuardrailCompletedEvent,
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,35 +3,35 @@ from typing import Any, Callable, Optional, Union
|
|||||||
from crewai.utilities.events.base_events import BaseEvent
|
from crewai.utilities.events.base_events import BaseEvent
|
||||||
|
|
||||||
|
|
||||||
class TaskGuardrailStartedEvent(BaseEvent):
|
class LLMGuardrailStartedEvent(BaseEvent):
|
||||||
"""Event emitted when a guardrail task starts
|
"""Event emitted when a guardrail task starts
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
guardrail: The guardrail callable or TaskGuardrail instance
|
guardrail: The guardrail callable or LLMGuardrail instance
|
||||||
retry_count: The number of times the guardrail has been retried
|
retry_count: The number of times the guardrail has been retried
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "task_guardrail_started"
|
type: str = "llm_guardrail_started"
|
||||||
guardrail: Union[str, Callable]
|
guardrail: Union[str, Callable]
|
||||||
retry_count: int
|
retry_count: int
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
from inspect import getsource
|
from inspect import getsource
|
||||||
|
|
||||||
from crewai.tasks.task_guardrail import TaskGuardrail
|
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||||
|
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
if isinstance(self.guardrail, TaskGuardrail):
|
if isinstance(self.guardrail, LLMGuardrail):
|
||||||
self.guardrail = self.guardrail.description.strip()
|
self.guardrail = self.guardrail.description.strip()
|
||||||
elif isinstance(self.guardrail, Callable):
|
elif isinstance(self.guardrail, Callable):
|
||||||
self.guardrail = getsource(self.guardrail).strip()
|
self.guardrail = getsource(self.guardrail).strip()
|
||||||
|
|
||||||
|
|
||||||
class TaskGuardrailCompletedEvent(BaseEvent):
|
class LLMGuardrailCompletedEvent(BaseEvent):
|
||||||
"""Event emitted when a guardrail task completes"""
|
"""Event emitted when a guardrail task completes"""
|
||||||
|
|
||||||
type: str = "task_guardrail_completed"
|
type: str = "llm_guardrail_completed"
|
||||||
success: bool
|
success: bool
|
||||||
result: Any
|
result: Any
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
@@ -4,11 +4,11 @@ import pytest
|
|||||||
|
|
||||||
from crewai import Agent, Task
|
from crewai import Agent, Task
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
from crewai.tasks.task_guardrail import TaskGuardrail
|
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
from crewai.utilities.events import (
|
from crewai.utilities.events import (
|
||||||
TaskGuardrailCompletedEvent,
|
LLMGuardrailCompletedEvent,
|
||||||
TaskGuardrailStartedEvent,
|
LLMGuardrailStartedEvent,
|
||||||
)
|
)
|
||||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ def task_output():
|
|||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_task_guardrail_process_output(task_output):
|
def test_task_guardrail_process_output(task_output):
|
||||||
guardrail = TaskGuardrail(
|
guardrail = LLMGuardrail(
|
||||||
description="Ensure the result has less than 10 words", llm=LLM(model="gpt-4o")
|
description="Ensure the result has less than 10 words", llm=LLM(model="gpt-4o")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -162,7 +162,7 @@ def test_task_guardrail_process_output(task_output):
|
|||||||
|
|
||||||
assert "exceeding the guardrail limit of fewer than" in result[1].lower()
|
assert "exceeding the guardrail limit of fewer than" in result[1].lower()
|
||||||
|
|
||||||
guardrail = TaskGuardrail(
|
guardrail = LLMGuardrail(
|
||||||
description="Ensure the result has less than 500 words", llm=LLM(model="gpt-4o")
|
description="Ensure the result has less than 500 words", llm=LLM(model="gpt-4o")
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -178,13 +178,13 @@ def test_guardrail_emits_events(sample_agent):
|
|||||||
|
|
||||||
with crewai_event_bus.scoped_handlers():
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
|
||||||
@crewai_event_bus.on(TaskGuardrailStartedEvent)
|
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||||
def handle_guardrail_started(source, event):
|
def handle_guardrail_started(source, event):
|
||||||
started_guardrail.append(
|
started_guardrail.append(
|
||||||
{"guardrail": event.guardrail, "retry_count": event.retry_count}
|
{"guardrail": event.guardrail, "retry_count": event.retry_count}
|
||||||
)
|
)
|
||||||
|
|
||||||
@crewai_event_bus.on(TaskGuardrailCompletedEvent)
|
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||||
def handle_guardrail_completed(source, event):
|
def handle_guardrail_completed(source, event):
|
||||||
completed_guardrail.append(
|
completed_guardrail.append(
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user