mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 23:32:39 +00:00
feat: implement before and after LLM call hooks in CrewAgentExecutor (#3893)
- Added support for before and after LLM call hooks to allow modification of messages and responses during LLM interactions. - Introduced LLMCallHookContext to provide hooks with access to the executor state, enabling in-place modifications of messages. - Updated get_llm_response function to utilize the new hooks, ensuring that modifications persist across iterations. - Enhanced tests to verify the functionality of the hooks and their error handling capabilities, ensuring robust execution flow.
This commit is contained in:
@@ -38,6 +38,10 @@ from crewai.utilities.agent_utils import (
|
||||
)
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.llm_call_hooks import (
|
||||
get_after_llm_call_hooks,
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -130,6 +134,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable] = []
|
||||
self.after_llm_call_hooks: list[Callable] = []
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
# This may be mutating the shared llm object and needs further evaluation
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
@@ -226,6 +234,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
@@ -236,6 +237,7 @@ def get_llm_response(
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
@@ -247,6 +249,7 @@ def get_llm_response(
|
||||
from_task: Optional task context for the LLM call
|
||||
from_agent: Optional agent context for the LLM call
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
executor_context: Optional executor context for hook invocation
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string
|
||||
@@ -255,6 +258,11 @@ def get_llm_response(
|
||||
Exception: If an error occurs.
|
||||
ValueError: If the response is None or empty.
|
||||
"""
|
||||
|
||||
if executor_context is not None:
|
||||
_setup_before_llm_call_hooks(executor_context, printer)
|
||||
messages = executor_context.messages
|
||||
|
||||
try:
|
||||
answer = llm.call(
|
||||
messages,
|
||||
@@ -272,7 +280,7 @@ def get_llm_response(
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
return answer
|
||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||
|
||||
|
||||
def process_llm_response(
|
||||
@@ -661,3 +669,92 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
|
||||
else:
|
||||
attributes[key] = value
|
||||
return attributes
|
||||
|
||||
|
||||
def _setup_before_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None, printer: Printer
|
||||
) -> None:
|
||||
"""Setup and invoke before_llm_call hooks for the executor context.
|
||||
|
||||
Args:
|
||||
executor_context: The executor context to setup the hooks for.
|
||||
printer: Printer instance for error logging.
|
||||
"""
|
||||
if executor_context and executor_context.before_llm_call_hooks:
|
||||
from crewai.utilities.llm_call_hooks import LLMCallHookContext
|
||||
|
||||
original_messages = executor_context.messages
|
||||
|
||||
hook_context = LLMCallHookContext(executor_context)
|
||||
try:
|
||||
for hook in executor_context.before_llm_call_hooks:
|
||||
hook(hook_context)
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in before_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
if not isinstance(executor_context.messages, list):
|
||||
printer.print(
|
||||
content=(
|
||||
"Warning: before_llm_call hook replaced messages with non-list. "
|
||||
"Restoring original messages list. Hooks should modify messages in-place, "
|
||||
"not replace the list (e.g., use context.messages.append() not context.messages = [])."
|
||||
),
|
||||
color="yellow",
|
||||
)
|
||||
if isinstance(original_messages, list):
|
||||
executor_context.messages = original_messages
|
||||
else:
|
||||
executor_context.messages = []
|
||||
|
||||
|
||||
def _setup_after_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None,
|
||||
answer: str,
|
||||
printer: Printer,
|
||||
) -> str:
|
||||
"""Setup and invoke after_llm_call hooks for the executor context.
|
||||
|
||||
Args:
|
||||
executor_context: The executor context to setup the hooks for.
|
||||
answer: The LLM response string.
|
||||
printer: Printer instance for error logging.
|
||||
|
||||
Returns:
|
||||
The potentially modified response string.
|
||||
"""
|
||||
if executor_context and executor_context.after_llm_call_hooks:
|
||||
from crewai.utilities.llm_call_hooks import LLMCallHookContext
|
||||
|
||||
original_messages = executor_context.messages
|
||||
|
||||
hook_context = LLMCallHookContext(executor_context, response=answer)
|
||||
try:
|
||||
for hook in executor_context.after_llm_call_hooks:
|
||||
modified_response = hook(hook_context)
|
||||
if modified_response is not None and isinstance(modified_response, str):
|
||||
answer = modified_response
|
||||
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in after_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
if not isinstance(executor_context.messages, list):
|
||||
printer.print(
|
||||
content=(
|
||||
"Warning: after_llm_call hook replaced messages with non-list. "
|
||||
"Restoring original messages list. Hooks should modify messages in-place, "
|
||||
"not replace the list (e.g., use context.messages.append() not context.messages = [])."
|
||||
),
|
||||
color="yellow",
|
||||
)
|
||||
if isinstance(original_messages, list):
|
||||
executor_context.messages = original_messages
|
||||
else:
|
||||
executor_context.messages = []
|
||||
|
||||
return answer
|
||||
|
||||
115
lib/crewai/src/crewai/utilities/llm_call_hooks.py
Normal file
115
lib/crewai/src/crewai/utilities/llm_call_hooks.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
|
||||
|
||||
class LLMCallHookContext:
|
||||
"""Context object passed to LLM call hooks with full executor access.
|
||||
|
||||
Provides hooks with complete access to the executor state, allowing
|
||||
modification of messages, responses, and executor attributes.
|
||||
|
||||
Attributes:
|
||||
executor: Full reference to the CrewAgentExecutor instance
|
||||
messages: Direct reference to executor.messages (mutable list).
|
||||
Can be modified in both before_llm_call and after_llm_call hooks.
|
||||
Modifications in after_llm_call hooks persist to the next iteration,
|
||||
allowing hooks to modify conversation history for subsequent LLM calls.
|
||||
IMPORTANT: Modify messages in-place (e.g., append, extend, remove items).
|
||||
Do NOT replace the list (e.g., context.messages = []), as this will break
|
||||
the executor. Use context.messages.append() or context.messages.extend()
|
||||
instead of assignment.
|
||||
agent: Reference to the agent executing the task
|
||||
task: Reference to the task being executed
|
||||
crew: Reference to the crew instance
|
||||
llm: Reference to the LLM instance
|
||||
iterations: Current iteration count
|
||||
response: LLM response string (only set for after_llm_call hooks).
|
||||
Can be modified by returning a new string from after_llm_call hook.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: CrewAgentExecutor,
|
||||
response: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize hook context with executor reference.
|
||||
|
||||
Args:
|
||||
executor: The CrewAgentExecutor instance
|
||||
response: Optional response string (for after_llm_call hooks)
|
||||
"""
|
||||
self.executor = executor
|
||||
self.messages = executor.messages
|
||||
self.agent = executor.agent
|
||||
self.task = executor.task
|
||||
self.crew = executor.crew
|
||||
self.llm = executor.llm
|
||||
self.iterations = executor.iterations
|
||||
self.response = response
|
||||
|
||||
|
||||
# Global hook registries (optional convenience feature)
|
||||
_before_llm_call_hooks: list[Callable[[LLMCallHookContext], None]] = []
|
||||
_after_llm_call_hooks: list[Callable[[LLMCallHookContext], str | None]] = []
|
||||
|
||||
|
||||
def register_before_llm_call_hook(
|
||||
hook: Callable[[LLMCallHookContext], None],
|
||||
) -> None:
|
||||
"""Register a global before_llm_call hook.
|
||||
|
||||
Global hooks are added to all executors automatically.
|
||||
This is a convenience function for registering hooks that should
|
||||
apply to all LLM calls across all executors.
|
||||
|
||||
Args:
|
||||
hook: Function that receives LLMCallHookContext and can modify
|
||||
context.messages directly. Should return None.
|
||||
IMPORTANT: Modify messages in-place (append, extend, remove items).
|
||||
Do NOT replace the list (context.messages = []), as this will break execution.
|
||||
"""
|
||||
_before_llm_call_hooks.append(hook)
|
||||
|
||||
|
||||
def register_after_llm_call_hook(
|
||||
hook: Callable[[LLMCallHookContext], str | None],
|
||||
) -> None:
|
||||
"""Register a global after_llm_call hook.
|
||||
|
||||
Global hooks are added to all executors automatically.
|
||||
This is a convenience function for registering hooks that should
|
||||
apply to all LLM calls across all executors.
|
||||
|
||||
Args:
|
||||
hook: Function that receives LLMCallHookContext and can modify:
|
||||
- The response: Return modified response string or None to keep original
|
||||
- The messages: Modify context.messages directly (mutable reference)
|
||||
Both modifications are supported and can be used together.
|
||||
IMPORTANT: Modify messages in-place (append, extend, remove items).
|
||||
Do NOT replace the list (context.messages = []), as this will break execution.
|
||||
"""
|
||||
_after_llm_call_hooks.append(hook)
|
||||
|
||||
|
||||
def get_before_llm_call_hooks() -> list[Callable[[LLMCallHookContext], None]]:
|
||||
"""Get all registered global before_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
List of registered before hooks
|
||||
"""
|
||||
return _before_llm_call_hooks.copy()
|
||||
|
||||
|
||||
def get_after_llm_call_hooks() -> list[Callable[[LLMCallHookContext], str | None]]:
|
||||
"""Get all registered global after_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
List of registered after hooks
|
||||
"""
|
||||
return _after_llm_call_hooks.copy()
|
||||
Reference in New Issue
Block a user