mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-14 18:48:29 +00:00
feat: enhance knowledge and guardrail event handling in Agent class (#3672)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
* feat: enhance knowledge event handling in Agent class - Updated the Agent class to include task context in knowledge retrieval events. - Emitted new events for knowledge retrieval and query processes, capturing task and agent details. - Refactored knowledge event classes to inherit from a base class for better structure and maintainability. - Added tracing for knowledge events in the TraceCollectionListener to improve observability. This change improves the tracking and management of knowledge queries and retrievals, facilitating better debugging and performance monitoring. * refactor: remove task_id from knowledge event emissions in Agent class - Removed the task_id parameter from various knowledge event emissions in the Agent class to streamline event handling. - This change simplifies the event structure and focuses on the essential context of knowledge retrieval and query processes. This refactor enhances the clarity of knowledge events and aligns with the recent improvements in event handling. * surface association for guardrail events * fix: improve LLM selection logic in converter - Updated the logic for selecting the LLM in the convert_with_instructions function to handle cases where the agent may not have a function_calling_llm attribute. - This change ensures that the converter can still function correctly by falling back to the standard LLM if necessary, enhancing robustness and preventing potential errors. This fix improves the reliability of the conversion process when working with different agent configurations. * fix test * fix: enforce valid LLM instance requirement in converter - Updated the convert_with_instructions function to ensure that a valid LLM instance is provided by the agent. - If neither function_calling_llm nor the standard llm is available, a ValueError is raised, enhancing error handling and robustness. - Improved error messaging for conversion failures to provide clearer feedback on issues encountered during the conversion process. This change strengthens the reliability of the conversion process by ensuring that agents are properly configured with a valid LLM.
This commit is contained in:
@@ -53,6 +53,7 @@ from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class Agent(BaseAgent):
|
||||
@@ -347,15 +348,16 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
agent=self,
|
||||
),
|
||||
)
|
||||
try:
|
||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
||||
task_prompt
|
||||
task_prompt, task
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
if self.knowledge_search_query:
|
||||
# Quering agent specific knowledge
|
||||
@@ -385,7 +387,8 @@ class Agent(BaseAgent):
|
||||
self,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=self.knowledge_search_query,
|
||||
agent=self,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
retrieved_knowledge=(
|
||||
(self.agent_knowledge_context or "")
|
||||
+ (
|
||||
@@ -403,8 +406,9 @@ class Agent(BaseAgent):
|
||||
self,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=self.knowledge_search_query or "",
|
||||
agent=self,
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -728,13 +732,14 @@ class Agent(BaseAgent):
|
||||
def set_fingerprint(self, fingerprint: Fingerprint):
|
||||
self.security_config.fingerprint = fingerprint
|
||||
|
||||
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
|
||||
def _get_knowledge_search_query(self, task_prompt: str, task: Task) -> str | None:
|
||||
"""Generate a search query for the knowledge base based on the task description."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeQueryStartedEvent(
|
||||
task_prompt=task_prompt,
|
||||
agent=self,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
query = self.i18n.slice("knowledge_search_query").format(
|
||||
@@ -749,8 +754,9 @@ class Agent(BaseAgent):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeQueryFailedEvent(
|
||||
agent=self,
|
||||
error="LLM is not compatible with knowledge search queries",
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
return None
|
||||
@@ -769,7 +775,8 @@ class Agent(BaseAgent):
|
||||
self,
|
||||
event=KnowledgeQueryCompletedEvent(
|
||||
query=query,
|
||||
agent=self,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
return rewritten_query
|
||||
@@ -777,15 +784,16 @@ class Agent(BaseAgent):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeQueryFailedEvent(
|
||||
agent=self,
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
messages: str | list[LLMMessage],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
@@ -825,7 +833,7 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
messages: str | list[LLMMessage],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
@@ -855,6 +863,7 @@ class Agent(BaseAgent):
|
||||
response_format=response_format,
|
||||
i18n=self.i18n,
|
||||
original_agent=self,
|
||||
guardrail=self.guardrail,
|
||||
)
|
||||
|
||||
return await lite_agent.kickoff_async(messages)
|
||||
|
||||
@@ -32,6 +32,13 @@ from crewai.events.types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
@@ -310,6 +317,26 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def on_agent_reasoning_failed(source, event):
|
||||
self._handle_action_event("agent_reasoning_failed", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeRetrievalStartedEvent)
|
||||
def on_knowledge_retrieval_started(source, event):
|
||||
self._handle_action_event("knowledge_retrieval_started", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeRetrievalCompletedEvent)
|
||||
def on_knowledge_retrieval_completed(source, event):
|
||||
self._handle_action_event("knowledge_retrieval_completed", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeQueryStartedEvent)
|
||||
def on_knowledge_query_started(source, event):
|
||||
self._handle_action_event("knowledge_query_started", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeQueryCompletedEvent)
|
||||
def on_knowledge_query_completed(source, event):
|
||||
self._handle_action_event("knowledge_query_completed", source, event)
|
||||
|
||||
@event_bus.on(KnowledgeQueryFailedEvent)
|
||||
def on_knowledge_query_failed(source, event):
|
||||
self._handle_action_event("knowledge_query_failed", source, event)
|
||||
|
||||
def _initialize_crew_batch(self, source: Any, event: Any):
|
||||
"""Initialize trace batch"""
|
||||
user_context = self._get_user_context()
|
||||
|
||||
@@ -1,51 +1,60 @@
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class KnowledgeRetrievalStartedEvent(BaseEvent):
|
||||
class KnowledgeEventBase(BaseEvent):
|
||||
task_id: str | None = None
|
||||
task_name: str | None = None
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
|
||||
class KnowledgeRetrievalStartedEvent(KnowledgeEventBase):
|
||||
"""Event emitted when a knowledge retrieval is started."""
|
||||
|
||||
type: str = "knowledge_search_query_started"
|
||||
agent: BaseAgent
|
||||
|
||||
|
||||
class KnowledgeRetrievalCompletedEvent(BaseEvent):
|
||||
class KnowledgeRetrievalCompletedEvent(KnowledgeEventBase):
|
||||
"""Event emitted when a knowledge retrieval is completed."""
|
||||
|
||||
query: str
|
||||
type: str = "knowledge_search_query_completed"
|
||||
agent: BaseAgent
|
||||
retrieved_knowledge: str
|
||||
|
||||
|
||||
class KnowledgeQueryStartedEvent(BaseEvent):
|
||||
class KnowledgeQueryStartedEvent(KnowledgeEventBase):
|
||||
"""Event emitted when a knowledge query is started."""
|
||||
|
||||
task_prompt: str
|
||||
type: str = "knowledge_query_started"
|
||||
agent: BaseAgent
|
||||
|
||||
|
||||
class KnowledgeQueryFailedEvent(BaseEvent):
|
||||
class KnowledgeQueryFailedEvent(KnowledgeEventBase):
|
||||
"""Event emitted when a knowledge query fails."""
|
||||
|
||||
type: str = "knowledge_query_failed"
|
||||
agent: BaseAgent
|
||||
error: str
|
||||
|
||||
|
||||
class KnowledgeQueryCompletedEvent(BaseEvent):
|
||||
class KnowledgeQueryCompletedEvent(KnowledgeEventBase):
|
||||
"""Event emitted when a knowledge query is completed."""
|
||||
|
||||
query: str
|
||||
type: str = "knowledge_query_completed"
|
||||
agent: BaseAgent
|
||||
|
||||
|
||||
class KnowledgeSearchQueryFailedEvent(BaseEvent):
|
||||
class KnowledgeSearchQueryFailedEvent(KnowledgeEventBase):
|
||||
"""Event emitted when a knowledge search query fails."""
|
||||
|
||||
query: str
|
||||
type: str = "knowledge_search_query_failed"
|
||||
agent: BaseAgent
|
||||
error: str
|
||||
|
||||
@@ -5,7 +5,21 @@ from typing import Any
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class LLMGuardrailStartedEvent(BaseEvent):
|
||||
class LLMGuardrailBaseEvent(BaseEvent):
|
||||
task_id: str | None = None
|
||||
task_name: str | None = None
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
|
||||
class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
|
||||
"""Event emitted when a guardrail task starts
|
||||
|
||||
Attributes:
|
||||
@@ -29,7 +43,7 @@ class LLMGuardrailStartedEvent(BaseEvent):
|
||||
self.guardrail = getsource(self.guardrail).strip()
|
||||
|
||||
|
||||
class LLMGuardrailCompletedEvent(BaseEvent):
|
||||
class LLMGuardrailCompletedEvent(LLMGuardrailBaseEvent):
|
||||
"""Event emitted when a guardrail task completes
|
||||
|
||||
Attributes:
|
||||
@@ -44,3 +58,16 @@ class LLMGuardrailCompletedEvent(BaseEvent):
|
||||
result: Any
|
||||
error: str | None = None
|
||||
retry_count: int
|
||||
|
||||
|
||||
class LLMGuardrailFailedEvent(LLMGuardrailBaseEvent):
|
||||
"""Event emitted when a guardrail task fails
|
||||
|
||||
Attributes:
|
||||
error: The error message
|
||||
retry_count: The number of times the guardrail has been retried
|
||||
"""
|
||||
|
||||
type: str = "llm_guardrail_failed"
|
||||
error: str
|
||||
retry_count: int
|
||||
|
||||
@@ -4,6 +4,7 @@ import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
@@ -62,6 +63,7 @@ from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class LiteAgentOutput(BaseModel):
|
||||
@@ -180,7 +182,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
_guardrail: Callable | None = PrivateAttr(default=None)
|
||||
@@ -219,7 +221,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
raise TypeError(
|
||||
f"Guardrail requires LLM instance of type BaseLLM, got {type(self.llm).__name__}"
|
||||
)
|
||||
|
||||
self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm)
|
||||
|
||||
return self
|
||||
@@ -276,7 +277,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||
def kickoff(self, messages: str | list[LLMMessage]) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages.
|
||||
|
||||
@@ -368,6 +369,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self._guardrail_retry_count,
|
||||
event_source=self,
|
||||
from_agent=self,
|
||||
)
|
||||
|
||||
if not guardrail_result.success:
|
||||
@@ -414,9 +416,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
return output
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> LiteAgentOutput:
|
||||
async def kickoff_async(self, messages: str | list[LLMMessage]) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages.
|
||||
|
||||
@@ -461,9 +461,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
return base_prompt
|
||||
|
||||
def _format_messages(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
|
||||
"""Format messages for the LLM."""
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -471,7 +469,9 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
system_prompt = self._get_default_system_prompt()
|
||||
|
||||
# Add system message at the beginning
|
||||
formatted_messages = [{"role": "system", "content": system_prompt}]
|
||||
formatted_messages: list[LLMMessage] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
|
||||
# Add the rest of the messages
|
||||
formatted_messages.extend(messages)
|
||||
@@ -583,6 +583,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
def _append_message(self, text: str, role: str = "assistant") -> None:
|
||||
def _append_message(
|
||||
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"
|
||||
) -> None:
|
||||
"""Append a message to the message list with the given role."""
|
||||
self._messages.append(format_message_for_llm(text, role=role))
|
||||
self._messages.append(cast(LLMMessage, format_message_for_llm(text, role=role)))
|
||||
|
||||
@@ -462,6 +462,8 @@ class Task(BaseModel):
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self.retry_count,
|
||||
event_source=self,
|
||||
from_task=self,
|
||||
from_agent=agent,
|
||||
)
|
||||
if not guardrail_result.success:
|
||||
if self.retry_count >= self.guardrail_max_retries:
|
||||
|
||||
@@ -31,6 +31,7 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@@ -222,7 +223,7 @@ def get_llm_response(
|
||||
callbacks: list[Callable[..., Any]],
|
||||
printer: Printer,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
@@ -143,7 +144,7 @@ def convert_to_model(
|
||||
result: str,
|
||||
output_pydantic: type[BaseModel] | None,
|
||||
output_json: type[BaseModel] | None,
|
||||
agent: Agent | None = None,
|
||||
agent: Agent | BaseAgent | None = None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
) -> dict[str, Any] | BaseModel | str:
|
||||
"""Convert a result string to a Pydantic model or JSON.
|
||||
@@ -215,7 +216,7 @@ def handle_partial_json(
|
||||
result: str,
|
||||
model: type[BaseModel],
|
||||
is_json_output: bool,
|
||||
agent: Agent | None,
|
||||
agent: Agent | BaseAgent | None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
) -> dict[str, Any] | BaseModel | str:
|
||||
"""Handle partial JSON in a result string and convert to Pydantic model or dict.
|
||||
@@ -260,7 +261,7 @@ def convert_with_instructions(
|
||||
result: str,
|
||||
model: type[BaseModel],
|
||||
is_json_output: bool,
|
||||
agent: Agent | None,
|
||||
agent: Agent | BaseAgent | None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
) -> dict | BaseModel | str:
|
||||
"""Convert a result string to a Pydantic model or JSON using instructions.
|
||||
@@ -283,7 +284,12 @@ def convert_with_instructions(
|
||||
"""
|
||||
if agent is None:
|
||||
raise TypeError("Agent must be provided if converter_cls is not specified.")
|
||||
llm = agent.function_calling_llm or agent.llm
|
||||
|
||||
llm = getattr(agent, "function_calling_llm", None) or agent.llm
|
||||
|
||||
if llm is None:
|
||||
raise ValueError("Agent must have a valid LLM instance for conversion")
|
||||
|
||||
instructions = get_conversion_instructions(model=model, llm=llm)
|
||||
converter = create_converter(
|
||||
agent=agent,
|
||||
@@ -299,7 +305,7 @@ def convert_with_instructions(
|
||||
|
||||
if isinstance(exported_result, ConverterError):
|
||||
Printer().print(
|
||||
content=f"{exported_result.message} Using raw output instead.",
|
||||
content=f"Failed to convert result to model: {exported_result}",
|
||||
color="red",
|
||||
)
|
||||
return result
|
||||
@@ -308,7 +314,7 @@ def convert_with_instructions(
|
||||
|
||||
|
||||
def get_conversion_instructions(
|
||||
model: type[BaseModel], llm: BaseLLM | LLM | str
|
||||
model: type[BaseModel], llm: BaseLLM | LLM | str | Any
|
||||
) -> str:
|
||||
"""Generate conversion instructions based on the model and LLM capabilities.
|
||||
|
||||
@@ -357,7 +363,7 @@ class CreateConverterKwargs(TypedDict, total=False):
|
||||
|
||||
|
||||
def create_converter(
|
||||
agent: Agent | None = None,
|
||||
agent: Agent | BaseAgent | None = None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
*args: Any,
|
||||
**kwargs: Unpack[CreateConverterKwargs],
|
||||
|
||||
@@ -7,7 +7,9 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.lite_agent import LiteAgentOutput
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
@@ -79,6 +81,8 @@ def process_guardrail(
|
||||
guardrail: Callable[[Any], tuple[bool, Any | str]],
|
||||
retry_count: int,
|
||||
event_source: Any | None = None,
|
||||
from_agent: BaseAgent | LiteAgent | None = None,
|
||||
from_task: Task | None = None,
|
||||
) -> GuardrailResult:
|
||||
"""Process the guardrail for the agent output.
|
||||
|
||||
@@ -95,14 +99,6 @@ def process_guardrail(
|
||||
TypeError: If output is not a TaskOutput or LiteAgentOutput
|
||||
ValueError: If guardrail is None
|
||||
"""
|
||||
from crewai.lite_agent import LiteAgentOutput
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
if not isinstance(output, (TaskOutput, LiteAgentOutput)):
|
||||
raise TypeError("Output must be a TaskOutput or LiteAgentOutput")
|
||||
if guardrail is None:
|
||||
raise ValueError("Guardrail must not be None")
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
@@ -111,7 +107,12 @@ def process_guardrail(
|
||||
|
||||
crewai_event_bus.emit(
|
||||
event_source,
|
||||
LLMGuardrailStartedEvent(guardrail=guardrail, retry_count=retry_count),
|
||||
LLMGuardrailStartedEvent(
|
||||
guardrail=guardrail,
|
||||
retry_count=retry_count,
|
||||
from_agent=from_agent,
|
||||
from_task=from_task,
|
||||
),
|
||||
)
|
||||
|
||||
result = guardrail(output)
|
||||
@@ -124,6 +125,8 @@ def process_guardrail(
|
||||
result=guardrail_result.result,
|
||||
error=guardrail_result.error,
|
||||
retry_count=retry_count,
|
||||
from_agent=from_agent,
|
||||
from_task=from_task,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from crewai.utilities.i18n import I18N
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
@@ -25,7 +26,7 @@ def execute_tool_and_check_finality(
|
||||
agent_role: str | None = None,
|
||||
tools_handler: ToolsHandler | None = None,
|
||||
task: Task | None = None,
|
||||
agent: Agent | None = None,
|
||||
agent: Agent | BaseAgent | None = None,
|
||||
function_calling_llm: BaseLLM | LLM | None = None,
|
||||
fingerprint_context: dict[str, str] | None = None,
|
||||
) -> ToolResult:
|
||||
|
||||
Reference in New Issue
Block a user