feat: enhance knowledge and guardrail event handling in Agent class (#3672)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled

* feat: enhance knowledge event handling in Agent class

- Updated the Agent class to include task context in knowledge retrieval events.
- Emitted new events for knowledge retrieval and query processes, capturing task and agent details.
- Refactored knowledge event classes to inherit from a base class for better structure and maintainability.
- Added tracing for knowledge events in the TraceCollectionListener to improve observability.

This change improves the tracking and management of knowledge queries and retrievals, facilitating better debugging and performance monitoring.

* refactor: remove task_id from knowledge event emissions in Agent class

- Removed the task_id parameter from various knowledge event emissions in the Agent class to streamline event handling.
- This change simplifies the event structure and focuses on the essential context of knowledge retrieval and query processes.

This refactor enhances the clarity of knowledge events and aligns with the recent improvements in event handling.

* surface association for guardrail events

* fix: improve LLM selection logic in converter

- Updated the logic for selecting the LLM in the convert_with_instructions function to handle cases where the agent may not have a function_calling_llm attribute.
- This change ensures that the converter can still function correctly by falling back to the standard LLM if necessary, enhancing robustness and preventing potential errors.

This fix improves the reliability of the conversion process when working with different agent configurations.

* fix test

* fix: enforce valid LLM instance requirement in converter

- Updated the convert_with_instructions function to ensure that a valid LLM instance is provided by the agent.
- If neither function_calling_llm nor the standard llm is available, a ValueError is raised, enhancing error handling and robustness.
- Improved error messaging for conversion failures to provide clearer feedback on issues encountered during the conversion process.

This change strengthens the reliability of the conversion process by ensuring that agents are properly configured with a valid LLM.
This commit is contained in:
Lorenze Jay
2025-10-08 11:53:13 -07:00
committed by GitHub
parent 8d93361cb3
commit 6f2e39c0dd
26 changed files with 6547 additions and 64 deletions

View File

@@ -53,6 +53,7 @@ from crewai.utilities.converter import generate_model_description
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.types import LLMMessage
class Agent(BaseAgent):
@@ -347,15 +348,16 @@ class Agent(BaseAgent):
)
if self.knowledge or (self.crew and self.crew.knowledge):
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalStartedEvent(
agent=self,
),
)
try:
self.knowledge_search_query = self._get_knowledge_search_query(
task_prompt
task_prompt, task
)
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=self,
),
)
if self.knowledge_search_query:
# Quering agent specific knowledge
@@ -385,7 +387,8 @@ class Agent(BaseAgent):
self,
event=KnowledgeRetrievalCompletedEvent(
query=self.knowledge_search_query,
agent=self,
from_task=task,
from_agent=self,
retrieved_knowledge=(
(self.agent_knowledge_context or "")
+ (
@@ -403,8 +406,9 @@ class Agent(BaseAgent):
self,
event=KnowledgeSearchQueryFailedEvent(
query=self.knowledge_search_query or "",
agent=self,
error=str(e),
from_task=task,
from_agent=self,
),
)
@@ -728,13 +732,14 @@ class Agent(BaseAgent):
def set_fingerprint(self, fingerprint: Fingerprint):
self.security_config.fingerprint = fingerprint
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
def _get_knowledge_search_query(self, task_prompt: str, task: Task) -> str | None:
"""Generate a search query for the knowledge base based on the task description."""
crewai_event_bus.emit(
self,
event=KnowledgeQueryStartedEvent(
task_prompt=task_prompt,
agent=self,
from_task=task,
from_agent=self,
),
)
query = self.i18n.slice("knowledge_search_query").format(
@@ -749,8 +754,9 @@ class Agent(BaseAgent):
crewai_event_bus.emit(
self,
event=KnowledgeQueryFailedEvent(
agent=self,
error="LLM is not compatible with knowledge search queries",
from_task=task,
from_agent=self,
),
)
return None
@@ -769,7 +775,8 @@ class Agent(BaseAgent):
self,
event=KnowledgeQueryCompletedEvent(
query=query,
agent=self,
from_task=task,
from_agent=self,
),
)
return rewritten_query
@@ -777,15 +784,16 @@ class Agent(BaseAgent):
crewai_event_bus.emit(
self,
event=KnowledgeQueryFailedEvent(
agent=self,
error=str(e),
from_task=task,
from_agent=self,
),
)
return None
def kickoff(
self,
messages: str | list[dict[str, str]],
messages: str | list[LLMMessage],
response_format: type[Any] | None = None,
) -> LiteAgentOutput:
"""
@@ -825,7 +833,7 @@ class Agent(BaseAgent):
async def kickoff_async(
self,
messages: str | list[dict[str, str]],
messages: str | list[LLMMessage],
response_format: type[Any] | None = None,
) -> LiteAgentOutput:
"""
@@ -855,6 +863,7 @@ class Agent(BaseAgent):
response_format=response_format,
i18n=self.i18n,
original_agent=self,
guardrail=self.guardrail,
)
return await lite_agent.kickoff_async(messages)

View File

@@ -32,6 +32,13 @@ from crewai.events.types.flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
)
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
@@ -310,6 +317,26 @@ class TraceCollectionListener(BaseEventListener):
def on_agent_reasoning_failed(source, event):
self._handle_action_event("agent_reasoning_failed", source, event)
@event_bus.on(KnowledgeRetrievalStartedEvent)
def on_knowledge_retrieval_started(source, event):
self._handle_action_event("knowledge_retrieval_started", source, event)
@event_bus.on(KnowledgeRetrievalCompletedEvent)
def on_knowledge_retrieval_completed(source, event):
self._handle_action_event("knowledge_retrieval_completed", source, event)
@event_bus.on(KnowledgeQueryStartedEvent)
def on_knowledge_query_started(source, event):
self._handle_action_event("knowledge_query_started", source, event)
@event_bus.on(KnowledgeQueryCompletedEvent)
def on_knowledge_query_completed(source, event):
self._handle_action_event("knowledge_query_completed", source, event)
@event_bus.on(KnowledgeQueryFailedEvent)
def on_knowledge_query_failed(source, event):
self._handle_action_event("knowledge_query_failed", source, event)
def _initialize_crew_batch(self, source: Any, event: Any):
"""Initialize trace batch"""
user_context = self._get_user_context()

View File

@@ -1,51 +1,60 @@
from crewai.agents.agent_builder.base_agent import BaseAgent
from typing import Any
from crewai.events.base_events import BaseEvent
class KnowledgeRetrievalStartedEvent(BaseEvent):
class KnowledgeEventBase(BaseEvent):
task_id: str | None = None
task_name: str | None = None
from_task: Any | None = None
from_agent: Any | None = None
agent_role: str | None = None
agent_id: str | None = None
def __init__(self, **data):
super().__init__(**data)
self._set_agent_params(data)
self._set_task_params(data)
class KnowledgeRetrievalStartedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge retrieval is started."""
type: str = "knowledge_search_query_started"
agent: BaseAgent
class KnowledgeRetrievalCompletedEvent(BaseEvent):
class KnowledgeRetrievalCompletedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge retrieval is completed."""
query: str
type: str = "knowledge_search_query_completed"
agent: BaseAgent
retrieved_knowledge: str
class KnowledgeQueryStartedEvent(BaseEvent):
class KnowledgeQueryStartedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query is started."""
task_prompt: str
type: str = "knowledge_query_started"
agent: BaseAgent
class KnowledgeQueryFailedEvent(BaseEvent):
class KnowledgeQueryFailedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query fails."""
type: str = "knowledge_query_failed"
agent: BaseAgent
error: str
class KnowledgeQueryCompletedEvent(BaseEvent):
class KnowledgeQueryCompletedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge query is completed."""
query: str
type: str = "knowledge_query_completed"
agent: BaseAgent
class KnowledgeSearchQueryFailedEvent(BaseEvent):
class KnowledgeSearchQueryFailedEvent(KnowledgeEventBase):
"""Event emitted when a knowledge search query fails."""
query: str
type: str = "knowledge_search_query_failed"
agent: BaseAgent
error: str

View File

@@ -5,7 +5,21 @@ from typing import Any
from crewai.events.base_events import BaseEvent
class LLMGuardrailStartedEvent(BaseEvent):
class LLMGuardrailBaseEvent(BaseEvent):
task_id: str | None = None
task_name: str | None = None
from_task: Any | None = None
from_agent: Any | None = None
agent_role: str | None = None
agent_id: str | None = None
def __init__(self, **data):
super().__init__(**data)
self._set_agent_params(data)
self._set_task_params(data)
class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
"""Event emitted when a guardrail task starts
Attributes:
@@ -29,7 +43,7 @@ class LLMGuardrailStartedEvent(BaseEvent):
self.guardrail = getsource(self.guardrail).strip()
class LLMGuardrailCompletedEvent(BaseEvent):
class LLMGuardrailCompletedEvent(LLMGuardrailBaseEvent):
"""Event emitted when a guardrail task completes
Attributes:
@@ -44,3 +58,16 @@ class LLMGuardrailCompletedEvent(BaseEvent):
result: Any
error: str | None = None
retry_count: int
class LLMGuardrailFailedEvent(LLMGuardrailBaseEvent):
"""Event emitted when a guardrail task fails
Attributes:
error: The error message
retry_count: The number of times the guardrail has been retried
"""
type: str = "llm_guardrail_failed"
error: str
retry_count: int

View File

@@ -4,6 +4,7 @@ import uuid
from collections.abc import Callable
from typing import (
Any,
Literal,
cast,
get_args,
get_origin,
@@ -62,6 +63,7 @@ from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.types import LLMMessage
class LiteAgentOutput(BaseModel):
@@ -180,7 +182,7 @@ class LiteAgent(FlowTrackable, BaseModel):
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
_messages: list[LLMMessage] = PrivateAttr(default_factory=list)
_iterations: int = PrivateAttr(default=0)
_printer: Printer = PrivateAttr(default_factory=Printer)
_guardrail: Callable | None = PrivateAttr(default=None)
@@ -219,7 +221,6 @@ class LiteAgent(FlowTrackable, BaseModel):
raise TypeError(
f"Guardrail requires LLM instance of type BaseLLM, got {type(self.llm).__name__}"
)
self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm)
return self
@@ -276,7 +277,7 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Return the original role for compatibility with tool interfaces."""
return self.role
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
def kickoff(self, messages: str | list[LLMMessage]) -> LiteAgentOutput:
"""
Execute the agent with the given messages.
@@ -368,6 +369,7 @@ class LiteAgent(FlowTrackable, BaseModel):
guardrail=self._guardrail,
retry_count=self._guardrail_retry_count,
event_source=self,
from_agent=self,
)
if not guardrail_result.success:
@@ -414,9 +416,7 @@ class LiteAgent(FlowTrackable, BaseModel):
return output
async def kickoff_async(
self, messages: str | list[dict[str, str]]
) -> LiteAgentOutput:
async def kickoff_async(self, messages: str | list[LLMMessage]) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages.
@@ -461,9 +461,7 @@ class LiteAgent(FlowTrackable, BaseModel):
return base_prompt
def _format_messages(
self, messages: str | list[dict[str, str]]
) -> list[dict[str, str]]:
def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
"""Format messages for the LLM."""
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
@@ -471,7 +469,9 @@ class LiteAgent(FlowTrackable, BaseModel):
system_prompt = self._get_default_system_prompt()
# Add system message at the beginning
formatted_messages = [{"role": "system", "content": system_prompt}]
formatted_messages: list[LLMMessage] = [
{"role": "system", "content": system_prompt}
]
# Add the rest of the messages
formatted_messages.extend(messages)
@@ -583,6 +583,8 @@ class LiteAgent(FlowTrackable, BaseModel):
),
)
def _append_message(self, text: str, role: str = "assistant") -> None:
def _append_message(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"
) -> None:
"""Append a message to the message list with the given role."""
self._messages.append(format_message_for_llm(text, role=role))
self._messages.append(cast(LLMMessage, format_message_for_llm(text, role=role)))

View File

@@ -462,6 +462,8 @@ class Task(BaseModel):
guardrail=self._guardrail,
retry_count=self.retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries:

View File

@@ -31,6 +31,7 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.lite_agent import LiteAgent
from crewai.task import Task
@@ -222,7 +223,7 @@ def get_llm_response(
callbacks: list[Callable[..., Any]],
printer: Printer,
from_task: Task | None = None,
from_agent: Agent | None = None,
from_agent: Agent | LiteAgent | None = None,
) -> str:
"""Call the LLM and return the response, handling any invalid responses.

View File

@@ -14,6 +14,7 @@ from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
@@ -143,7 +144,7 @@ def convert_to_model(
result: str,
output_pydantic: type[BaseModel] | None,
output_json: type[BaseModel] | None,
agent: Agent | None = None,
agent: Agent | BaseAgent | None = None,
converter_cls: type[Converter] | None = None,
) -> dict[str, Any] | BaseModel | str:
"""Convert a result string to a Pydantic model or JSON.
@@ -215,7 +216,7 @@ def handle_partial_json(
result: str,
model: type[BaseModel],
is_json_output: bool,
agent: Agent | None,
agent: Agent | BaseAgent | None,
converter_cls: type[Converter] | None = None,
) -> dict[str, Any] | BaseModel | str:
"""Handle partial JSON in a result string and convert to Pydantic model or dict.
@@ -260,7 +261,7 @@ def convert_with_instructions(
result: str,
model: type[BaseModel],
is_json_output: bool,
agent: Agent | None,
agent: Agent | BaseAgent | None,
converter_cls: type[Converter] | None = None,
) -> dict | BaseModel | str:
"""Convert a result string to a Pydantic model or JSON using instructions.
@@ -283,7 +284,12 @@ def convert_with_instructions(
"""
if agent is None:
raise TypeError("Agent must be provided if converter_cls is not specified.")
llm = agent.function_calling_llm or agent.llm
llm = getattr(agent, "function_calling_llm", None) or agent.llm
if llm is None:
raise ValueError("Agent must have a valid LLM instance for conversion")
instructions = get_conversion_instructions(model=model, llm=llm)
converter = create_converter(
agent=agent,
@@ -299,7 +305,7 @@ def convert_with_instructions(
if isinstance(exported_result, ConverterError):
Printer().print(
content=f"{exported_result.message} Using raw output instead.",
content=f"Failed to convert result to model: {exported_result}",
color="red",
)
return result
@@ -308,7 +314,7 @@ def convert_with_instructions(
def get_conversion_instructions(
model: type[BaseModel], llm: BaseLLM | LLM | str
model: type[BaseModel], llm: BaseLLM | LLM | str | Any
) -> str:
"""Generate conversion instructions based on the model and LLM capabilities.
@@ -357,7 +363,7 @@ class CreateConverterKwargs(TypedDict, total=False):
def create_converter(
agent: Agent | None = None,
agent: Agent | BaseAgent | None = None,
converter_cls: type[Converter] | None = None,
*args: Any,
**kwargs: Unpack[CreateConverterKwargs],

View File

@@ -7,7 +7,9 @@ from pydantic import BaseModel, Field, field_validator
from typing_extensions import Self
if TYPE_CHECKING:
from crewai.lite_agent import LiteAgentOutput
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.lite_agent import LiteAgent, LiteAgentOutput
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
@@ -79,6 +81,8 @@ def process_guardrail(
guardrail: Callable[[Any], tuple[bool, Any | str]],
retry_count: int,
event_source: Any | None = None,
from_agent: BaseAgent | LiteAgent | None = None,
from_task: Task | None = None,
) -> GuardrailResult:
"""Process the guardrail for the agent output.
@@ -95,14 +99,6 @@ def process_guardrail(
TypeError: If output is not a TaskOutput or LiteAgentOutput
ValueError: If guardrail is None
"""
from crewai.lite_agent import LiteAgentOutput
from crewai.tasks.task_output import TaskOutput
if not isinstance(output, (TaskOutput, LiteAgentOutput)):
raise TypeError("Output must be a TaskOutput or LiteAgentOutput")
if guardrail is None:
raise ValueError("Guardrail must not be None")
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent,
@@ -111,7 +107,12 @@ def process_guardrail(
crewai_event_bus.emit(
event_source,
LLMGuardrailStartedEvent(guardrail=guardrail, retry_count=retry_count),
LLMGuardrailStartedEvent(
guardrail=guardrail,
retry_count=retry_count,
from_agent=from_agent,
from_task=from_task,
),
)
result = guardrail(output)
@@ -124,6 +125,8 @@ def process_guardrail(
result=guardrail_result.result,
error=guardrail_result.error,
retry_count=retry_count,
from_agent=from_agent,
from_task=from_task,
),
)

View File

@@ -12,6 +12,7 @@ from crewai.utilities.i18n import I18N
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
@@ -25,7 +26,7 @@ def execute_tool_and_check_finality(
agent_role: str | None = None,
tools_handler: ToolsHandler | None = None,
task: Task | None = None,
agent: Agent | None = None,
agent: Agent | BaseAgent | None = None,
function_calling_llm: BaseLLM | LLM | None = None,
fingerprint_context: dict[str, str] | None = None,
) -> ToolResult: