mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
feat: add guardrail_type and name to distinguish traces (#5303)
* feat: add guardrail_type to distinguish between hallucination, function, and LLM * feat: introduce guardrail_name into guardrail events * feat: propagate guardrail type and name on guardrail completed event * feat: remove unused LLMGuardrailFailedEvent * fix: handle running event loop in LLMGuardrail._validate_output When agent.kickoff() returns a coroutine inside an already-running event loop, asyncio.run() fails
This commit is contained in:
@@ -12,6 +12,8 @@ class LLMGuardrailBaseEvent(BaseEvent):
|
||||
from_agent: Any | None = None
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
guardrail_type: str | None = None
|
||||
guardrail_name: str | None = None
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
@@ -37,9 +39,17 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
|
||||
if isinstance(self.guardrail, HallucinationGuardrail):
|
||||
self.guardrail_type = "hallucination"
|
||||
self.guardrail_name = self.guardrail.description.strip()
|
||||
self.guardrail = self.guardrail.description.strip()
|
||||
elif isinstance(self.guardrail, LLMGuardrail):
|
||||
self.guardrail_type = "llm"
|
||||
self.guardrail_name = self.guardrail.description.strip()
|
||||
self.guardrail = self.guardrail.description.strip()
|
||||
elif callable(self.guardrail):
|
||||
self.guardrail_type = "function"
|
||||
self.guardrail_name = getattr(self.guardrail, "__name__", None)
|
||||
self.guardrail = getsource(self.guardrail).strip()
|
||||
|
||||
|
||||
@@ -58,16 +68,3 @@ class LLMGuardrailCompletedEvent(LLMGuardrailBaseEvent):
|
||||
result: Any
|
||||
error: str | None = None
|
||||
retry_count: int
|
||||
|
||||
|
||||
class LLMGuardrailFailedEvent(LLMGuardrailBaseEvent):
|
||||
"""Event emitted when a guardrail task fails
|
||||
|
||||
Attributes:
|
||||
error: The error message
|
||||
retry_count: The number of times the guardrail has been retried
|
||||
"""
|
||||
|
||||
type: Literal["llm_guardrail_failed"] = "llm_guardrail_failed"
|
||||
error: str
|
||||
retry_count: int
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from collections.abc import Coroutine
|
||||
import contextvars
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
@@ -19,6 +21,21 @@ def _is_coroutine(
|
||||
return inspect.iscoroutine(obj)
|
||||
|
||||
|
||||
def _run_coroutine_sync(coro: Coroutine[Any, Any, LiteAgentOutput]) -> LiteAgentOutput:
|
||||
"""Run a coroutine synchronously, handling an already-running event loop."""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
has_running_loop = True
|
||||
except RuntimeError:
|
||||
has_running_loop = False
|
||||
|
||||
if has_running_loop:
|
||||
ctx = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(ctx.run, asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
class LLMGuardrailResult(BaseModel):
|
||||
valid: bool = Field(
|
||||
description="Whether the task output complies with the guardrail"
|
||||
@@ -75,7 +92,7 @@ class LLMGuardrail:
|
||||
|
||||
kickoff_result = agent.kickoff(query, response_format=LLMGuardrailResult)
|
||||
if _is_coroutine(kickoff_result):
|
||||
return asyncio.run(kickoff_result)
|
||||
return _run_coroutine_sync(kickoff_result)
|
||||
return kickoff_result
|
||||
|
||||
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
|
||||
|
||||
@@ -118,15 +118,13 @@ def process_guardrail(
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
event_source,
|
||||
LLMGuardrailStartedEvent(
|
||||
guardrail=guardrail,
|
||||
retry_count=retry_count,
|
||||
from_agent=from_agent,
|
||||
from_task=from_task,
|
||||
),
|
||||
started_event = LLMGuardrailStartedEvent(
|
||||
guardrail=guardrail,
|
||||
retry_count=retry_count,
|
||||
from_agent=from_agent,
|
||||
from_task=from_task,
|
||||
)
|
||||
crewai_event_bus.emit(event_source, started_event)
|
||||
|
||||
result = guardrail(output)
|
||||
guardrail_result = GuardrailResult.from_tuple(result)
|
||||
@@ -138,6 +136,8 @@ def process_guardrail(
|
||||
result=guardrail_result.result,
|
||||
error=guardrail_result.error,
|
||||
retry_count=retry_count,
|
||||
guardrail_type=started_event.guardrail_type,
|
||||
guardrail_name=started_event.guardrail_name,
|
||||
from_agent=from_agent,
|
||||
from_task=from_task,
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user