Lorenze/ensure hooks work with lite agents flows (#3981)

* liteagent support hooks

* wip llm.call hooks work - needs tests for this

* fix tests

* fixed more

* more tool hooks test cassettes
This commit is contained in:
Lorenze Jay
2025-12-04 09:38:39 -08:00
committed by GitHub
parent 633e279b51
commit c456e5c5fa
17 changed files with 1640 additions and 53 deletions

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast
from crewai.events.event_listener import event_listener
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
@@ -9,17 +9,22 @@ from crewai.utilities.printer import Printer
if TYPE_CHECKING:
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.lite_agent import LiteAgent
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.types import LLMMessage
class LLMCallHookContext:
"""Context object passed to LLM call hooks with full executor access.
"""Context object passed to LLM call hooks.
Provides hooks with complete access to the executor state, allowing
Provides hooks with complete access to the execution state, allowing
modification of messages, responses, and executor attributes.
Supports both executor-based calls (agents in crews/flows) and direct LLM calls.
Attributes:
executor: Full reference to the CrewAgentExecutor instance
messages: Direct reference to executor.messages (mutable list).
executor: Reference to the executor (CrewAgentExecutor/LiteAgent) or None for direct calls
messages: Direct reference to messages (mutable list).
Can be modified in both before_llm_call and after_llm_call hooks.
Modifications in after_llm_call hooks persist to the next iteration,
allowing hooks to modify conversation history for subsequent LLM calls.
@@ -27,33 +32,75 @@ class LLMCallHookContext:
Do NOT replace the list (e.g., context.messages = []), as this will break
the executor. Use context.messages.append() or context.messages.extend()
instead of assignment.
agent: Reference to the agent executing the task
task: Reference to the task being executed
crew: Reference to the crew instance
agent: Reference to the agent executing the task (None for direct LLM calls)
task: Reference to the task being executed (None for direct LLM calls or LiteAgent)
crew: Reference to the crew instance (None for direct LLM calls or LiteAgent)
llm: Reference to the LLM instance
iterations: Current iteration count
iterations: Current iteration count (0 for direct LLM calls)
response: LLM response string (only set for after_llm_call hooks).
Can be modified by returning a new string from after_llm_call hook.
"""
executor: CrewAgentExecutor | LiteAgent | None
messages: list[LLMMessage]
agent: Any
task: Any
crew: Any
llm: BaseLLM | None | str | Any
iterations: int
response: str | None
def __init__(
self,
executor: CrewAgentExecutor,
executor: CrewAgentExecutor | LiteAgent | None = None,
response: str | None = None,
messages: list[LLMMessage] | None = None,
llm: BaseLLM | str | Any | None = None, # TODO: look into
agent: Any | None = None,
task: Any | None = None,
crew: Any | None = None,
) -> None:
"""Initialize hook context with executor reference.
"""Initialize hook context with executor reference or direct parameters.
Args:
executor: The CrewAgentExecutor instance
executor: The CrewAgentExecutor or LiteAgent instance (None for direct LLM calls)
response: Optional response string (for after_llm_call hooks)
messages: Optional messages list (for direct LLM calls when executor is None)
llm: Optional LLM instance (for direct LLM calls when executor is None)
agent: Optional agent reference (for direct LLM calls when executor is None)
task: Optional task reference (for direct LLM calls when executor is None)
crew: Optional crew reference (for direct LLM calls when executor is None)
"""
self.executor = executor
self.messages = executor.messages
self.agent = executor.agent
self.task = executor.task
self.crew = executor.crew
self.llm = executor.llm
self.iterations = executor.iterations
if executor is not None:
# Existing path: extract from executor
self.executor = executor
self.messages = executor.messages
self.llm = executor.llm
self.iterations = executor.iterations
# Handle CrewAgentExecutor vs LiteAgent differences
if hasattr(executor, "agent"):
self.agent = executor.agent
self.task = cast("CrewAgentExecutor", executor).task
self.crew = cast("CrewAgentExecutor", executor).crew
else:
# LiteAgent case - is the agent itself, doesn't have task/crew
self.agent = (
executor.original_agent
if hasattr(executor, "original_agent")
else executor
)
self.task = None
self.crew = None
else:
# New path: direct LLM call with explicit parameters
self.executor = None
self.messages = messages or []
self.llm = llm
self.agent = agent
self.task = task
self.crew = crew
self.iterations = 0
self.response = response
def request_human_input(

View File

@@ -38,6 +38,8 @@ from crewai.events.types.agent_events import (
)
from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.flow.flow_trackable import FlowTrackable
from crewai.hooks.llm_hooks import get_after_llm_call_hooks, get_before_llm_call_hooks
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
from crewai.lite_agent_output import LiteAgentOutput
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
@@ -155,6 +157,12 @@ class LiteAgent(FlowTrackable, BaseModel):
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
_guardrail_retry_count: int = PrivateAttr(default=0)
_callbacks: list[TokenCalcHandler] = PrivateAttr(default_factory=list)
_before_llm_call_hooks: list[BeforeLLMCallHookType] = PrivateAttr(
default_factory=get_before_llm_call_hooks
)
_after_llm_call_hooks: list[AfterLLMCallHookType] = PrivateAttr(
default_factory=get_after_llm_call_hooks
)
@model_validator(mode="after")
def setup_llm(self) -> Self:
@@ -246,6 +254,26 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Return the original role for compatibility with tool interfaces."""
return self.role
@property
def before_llm_call_hooks(self) -> list[BeforeLLMCallHookType]:
"""Get the before_llm_call hooks for this agent."""
return self._before_llm_call_hooks
@property
def after_llm_call_hooks(self) -> list[AfterLLMCallHookType]:
"""Get the after_llm_call hooks for this agent."""
return self._after_llm_call_hooks
@property
def messages(self) -> list[LLMMessage]:
"""Get the messages list for hook context compatibility."""
return self._messages
@property
def iterations(self) -> int:
"""Get the current iteration count for hook context compatibility."""
return self._iterations
def kickoff(
self,
messages: str | list[LLMMessage],
@@ -504,7 +532,7 @@ class LiteAgent(FlowTrackable, BaseModel):
AgentFinish: The final result of the agent execution.
"""
# Execute the agent loop
formatted_answer = None
formatted_answer: AgentAction | AgentFinish | None = None
while not isinstance(formatted_answer, AgentFinish):
try:
if has_reached_max_iterations(self._iterations, self.max_iterations):
@@ -526,6 +554,7 @@ class LiteAgent(FlowTrackable, BaseModel):
callbacks=self._callbacks,
printer=self._printer,
from_agent=self,
executor_context=self,
)
except Exception as e:

View File

@@ -1642,6 +1642,10 @@ class LLM(BaseLLM):
if message.get("role") == "system":
msg_role: Literal["assistant"] = "assistant"
message["role"] = msg_role
if not self._invoke_before_llm_call_hooks(messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
# --- 5) Set up callbacks if provided
with suppress_warnings():
if callbacks and len(callbacks) > 0:
@@ -1651,7 +1655,16 @@ class LLM(BaseLLM):
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
return self._handle_streaming_response(
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
else:
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
@@ -1660,14 +1673,12 @@ class LLM(BaseLLM):
response_model=response_model,
)
return self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide

View File

@@ -314,7 +314,7 @@ class BaseLLM(ABC):
call_type: LLMCallType,
from_task: Task | None = None,
from_agent: Agent | None = None,
messages: str | list[dict[str, Any]] | None = None,
messages: str | list[LLMMessage] | None = None,
) -> None:
"""Emit LLM call completed event."""
crewai_event_bus.emit(
@@ -586,3 +586,134 @@ class BaseLLM(ABC):
Dictionary with token usage totals
"""
return UsageMetrics(**self._token_usage)
def _invoke_before_llm_call_hooks(
self,
messages: list[LLMMessage],
from_agent: Agent | None = None,
) -> bool:
"""Invoke before_llm_call hooks for direct LLM calls (no agent context).
This method should be called by native provider implementations before
making the actual LLM call when from_agent is None (direct calls).
Args:
messages: The messages being sent to the LLM
from_agent: The agent making the call (None for direct calls)
Returns:
True if LLM call should proceed, False if blocked by hook
Example:
>>> # In a native provider's call() method:
>>> if from_agent is None and not self._invoke_before_llm_call_hooks(
... messages, from_agent
... ):
... raise ValueError("LLM call blocked by hook")
"""
# Only invoke hooks for direct calls (no agent context)
if from_agent is not None:
return True
from crewai.hooks.llm_hooks import (
LLMCallHookContext,
get_before_llm_call_hooks,
)
from crewai.utilities.printer import Printer
before_hooks = get_before_llm_call_hooks()
if not before_hooks:
return True
hook_context = LLMCallHookContext(
executor=None,
messages=messages,
llm=self,
agent=None,
task=None,
crew=None,
)
printer = Printer()
try:
for hook in before_hooks:
result = hook(hook_context)
if result is False:
printer.print(
content="LLM call blocked by before_llm_call hook",
color="yellow",
)
return False
except Exception as e:
printer.print(
content=f"Error in before_llm_call hook: {e}",
color="yellow",
)
return True
def _invoke_after_llm_call_hooks(
self,
messages: list[LLMMessage],
response: str,
from_agent: Agent | None = None,
) -> str:
"""Invoke after_llm_call hooks for direct LLM calls (no agent context).
This method should be called by native provider implementations after
receiving the LLM response when from_agent is None (direct calls).
Args:
messages: The messages that were sent to the LLM
response: The response from the LLM
from_agent: The agent that made the call (None for direct calls)
Returns:
The potentially modified response string
Example:
>>> # In a native provider's call() method:
>>> if from_agent is None and isinstance(result, str):
... result = self._invoke_after_llm_call_hooks(
... messages, result, from_agent
... )
"""
# Only invoke hooks for direct calls (no agent context)
if from_agent is not None or not isinstance(response, str):
return response
from crewai.hooks.llm_hooks import (
LLMCallHookContext,
get_after_llm_call_hooks,
)
from crewai.utilities.printer import Printer
after_hooks = get_after_llm_call_hooks()
if not after_hooks:
return response
hook_context = LLMCallHookContext(
executor=None,
messages=messages,
llm=self,
agent=None,
task=None,
crew=None,
response=response,
)
printer = Printer()
modified_response = response
try:
for hook in after_hooks:
result = hook(hook_context)
if result is not None and isinstance(result, str):
modified_response = result
hook_context.response = modified_response
except Exception as e:
printer.print(
content=f"Error in after_llm_call hook: {e}",
color="yellow",
)
return modified_response

View File

@@ -187,6 +187,9 @@ class AnthropicCompletion(BaseLLM):
messages
)
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
# Prepare completion parameters
completion_params = self._prepare_completion_params(
formatted_messages, system_message, tools
@@ -494,7 +497,9 @@ class AnthropicCompletion(BaseLLM):
if usage.get("total_tokens", 0) > 0:
logging.info(f"Anthropic API usage: {usage}")
return content
return self._invoke_after_llm_call_hooks(
params["messages"], content, from_agent
)
def _handle_streaming_completion(
self,
@@ -588,7 +593,9 @@ class AnthropicCompletion(BaseLLM):
messages=params["messages"],
)
return full_response
return self._invoke_after_llm_call_hooks(
params["messages"], full_response, from_agent
)
def _handle_tool_use_conversation(
self,

View File

@@ -216,6 +216,9 @@ class AzureCompletion(BaseLLM):
# Format messages for Azure
formatted_messages = self._format_messages_for_azure(messages)
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
# Prepare completion parameters
completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model
@@ -550,6 +553,10 @@ class AzureCompletion(BaseLLM):
messages=params["messages"],
)
content = self._invoke_after_llm_call_hooks(
params["messages"], content, from_agent
)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
@@ -642,7 +649,9 @@ class AzureCompletion(BaseLLM):
messages=params["messages"],
)
return full_response
return self._invoke_after_llm_call_hooks(
params["messages"], full_response, from_agent
)
async def _ahandle_completion(
self,

View File

@@ -312,9 +312,14 @@ class BedrockCompletion(BaseLLM):
# Format messages for Converse API
formatted_messages, system_message = self._format_messages_for_converse(
messages # type: ignore[arg-type]
messages
)
if not self._invoke_before_llm_call_hooks(
cast(list[LLMMessage], formatted_messages), from_agent
):
raise ValueError("LLM call blocked by before_llm_call hook")
# Prepare request body
body: BedrockConverseRequestBody = {
"inferenceConfig": self._get_inference_config(),
@@ -356,11 +361,19 @@ class BedrockCompletion(BaseLLM):
if self.stream:
return self._handle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent
cast(list[LLMMessage], formatted_messages),
body,
available_functions,
from_task,
from_agent,
)
return self._handle_converse(
formatted_messages, body, available_functions, from_task, from_agent
cast(list[LLMMessage], formatted_messages),
body,
available_functions,
from_task,
from_agent,
)
except Exception as e:
@@ -481,7 +494,7 @@ class BedrockCompletion(BaseLLM):
def _handle_converse(
self,
messages: list[dict[str, Any]],
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
@@ -605,7 +618,11 @@ class BedrockCompletion(BaseLLM):
messages=messages,
)
return text_content
return self._invoke_after_llm_call_hooks(
messages,
text_content,
from_agent,
)
except ClientError as e:
# Handle all AWS ClientError exceptions as per documentation
@@ -662,7 +679,7 @@ class BedrockCompletion(BaseLLM):
def _handle_streaming_converse(
self,
messages: list[dict[str, Any]],
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -1149,16 +1166,25 @@ class BedrockCompletion(BaseLLM):
messages=messages,
)
return full_response
return self._invoke_after_llm_call_hooks(
messages,
full_response,
from_agent,
)
def _format_messages_for_converse(
self, messages: str | list[dict[str, str]]
self, messages: str | list[LLMMessage]
) -> tuple[list[dict[str, Any]], str | None]:
"""Format messages for Converse API following AWS documentation."""
# Use base class formatting first
formatted_messages = self._format_messages(messages) # type: ignore[arg-type]
"""Format messages for Converse API following AWS documentation.
converse_messages = []
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
a different content structure: {"role": str, "content": [{"text": str}]}
rather than the standard {"role": str, "content": str}.
"""
# Use base class formatting first
formatted_messages = self._format_messages(messages)
converse_messages: list[dict[str, Any]] = []
system_message: str | None = None
for message in formatted_messages:

View File

@@ -246,6 +246,11 @@ class GeminiCompletion(BaseLLM):
messages
)
messages_for_hooks = self._convert_contents_to_dict(formatted_content)
if not self._invoke_before_llm_call_hooks(messages_for_hooks, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
config = self._prepare_generation_config(
system_instruction, tools, response_model
)
@@ -559,7 +564,9 @@ class GeminiCompletion(BaseLLM):
messages=messages_for_event,
)
return content
return self._invoke_after_llm_call_hooks(
messages_for_event, content, from_agent
)
def _handle_streaming_completion(
self,
@@ -639,7 +646,9 @@ class GeminiCompletion(BaseLLM):
messages=messages_for_event,
)
return full_response
return self._invoke_after_llm_call_hooks(
messages_for_event, full_response, from_agent
)
async def _ahandle_completion(
self,
@@ -787,7 +796,159 @@ class GeminiCompletion(BaseLLM):
messages=messages_for_event,
)
return full_response
return self._invoke_after_llm_call_hooks(
messages_for_event, full_response, from_agent
)
async def _ahandle_completion(
self,
contents: list[types.Content],
system_instruction: str | None,
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async non-streaming content generation."""
try:
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
response = await self.client.aio.models.generate_content(
model=self.model,
contents=contents_for_api,
config=config,
)
usage = self._extract_token_usage(response)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
raise e from e
self._track_token_usage_internal(usage)
if response.candidates and (self.tools or available_functions):
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
function_name = part.function_call.name
if function_name is None:
continue
function_args = (
dict(part.function_call.args)
if part.function_call.args
else {}
)
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions or {},
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
content = response.text or ""
content = self._apply_stop_words(content)
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return content
async def _ahandle_streaming_completion(
self,
contents: list[types.Content],
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle async streaming content generation."""
full_response = ""
function_calls: dict[str, dict[str, Any]] = {}
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
stream = await self.client.aio.models.generate_content_stream(
model=self.model,
contents=contents_for_api,
config=config,
)
async for chunk in stream:
if chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(
chunk=chunk.text,
from_task=from_task,
from_agent=from_agent,
)
if chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
if call_id not in function_calls:
function_calls[call_id] = {
"name": part.function_call.name,
"args": dict(part.function_call.args)
if part.function_call.args
else {},
}
if function_calls and available_functions:
for call_data in function_calls.values():
function_name = call_data["name"]
function_args = call_data["args"]
# Skip if function_name is None
if not isinstance(function_name, str):
continue
# Ensure function_args is a dict
if not isinstance(function_args, dict):
function_args = {}
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return self._invoke_after_llm_call_hooks(
messages_for_event, full_response, from_agent
)
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
@@ -851,7 +1012,7 @@ class GeminiCompletion(BaseLLM):
def _convert_contents_to_dict(
self,
contents: list[types.Content],
) -> list[dict[str, str]]:
) -> list[LLMMessage]:
"""Convert contents to dict format."""
result: list[dict[str, str]] = []
for content_obj in contents:

View File

@@ -190,6 +190,9 @@ class OpenAICompletion(BaseLLM):
formatted_messages = self._format_messages(messages)
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
completion_params = self._prepare_completion_params(
messages=formatted_messages, tools=tools
)
@@ -474,6 +477,10 @@ class OpenAICompletion(BaseLLM):
if usage.get("total_tokens", 0) > 0:
logging.info(f"OpenAI API usage: {usage}")
content = self._invoke_after_llm_call_hooks(
params["messages"], content, from_agent
)
except NotFoundError as e:
error_msg = f"Model {self.model} not found: {e}"
logging.error(error_msg)
@@ -629,7 +636,9 @@ class OpenAICompletion(BaseLLM):
messages=params["messages"],
)
return full_response
return self._invoke_after_llm_call_hooks(
params["messages"], full_response, from_agent
)
async def _ahandle_completion(
self,

View File

@@ -237,7 +237,7 @@ def get_llm_response(
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | None = None,
executor_context: CrewAgentExecutor | LiteAgent | None = None,
) -> str:
"""Call the LLM and return the response, handling any invalid responses.
@@ -727,7 +727,7 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
def _setup_before_llm_call_hooks(
executor_context: CrewAgentExecutor | None, printer: Printer
executor_context: CrewAgentExecutor | LiteAgent | None, printer: Printer
) -> bool:
"""Setup and invoke before_llm_call hooks for the executor context.
@@ -777,7 +777,7 @@ def _setup_before_llm_call_hooks(
def _setup_after_llm_call_hooks(
executor_context: CrewAgentExecutor | None,
executor_context: CrewAgentExecutor | LiteAgent | None,
answer: str,
printer: Printer,
) -> str: