chore: improve hook typing and registration

Allow hook registration to accept both typed hook types and plain callables by importing and using After*/Before*CallHookCallable types; add explicit LLMCallHookContext and ToolCallHookContext typing in crew_base. Introduce a post-initialize crew hook list and invoke hooks after Crew instance initialization. Refactor filtered hook factory functions to include precise typing and clearer local names (before_llm_hook/after_llm_hook/before_tool_hook/after_tool_hook) and register those with the instance. Update CrewInstance protocol to include _registered_hook_functions and _hooks_being_registered fields.
This commit is contained in:
Greyson LaLonde
2026-02-04 20:12:58 -05:00
parent 76b5f72e81
commit d943567c76
4 changed files with 89 additions and 42 deletions

View File

@@ -3,7 +3,12 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, cast
from crewai.events.event_listener import event_listener
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
from crewai.hooks.types import (
AfterLLMCallHookCallable,
AfterLLMCallHookType,
BeforeLLMCallHookCallable,
BeforeLLMCallHookType,
)
from crewai.utilities.printer import Printer
@@ -149,12 +154,12 @@ class LLMCallHookContext:
event_listener.formatter.resume_live_updates()
_before_llm_call_hooks: list[BeforeLLMCallHookType] = []
_after_llm_call_hooks: list[AfterLLMCallHookType] = []
_before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = []
_after_llm_call_hooks: list[AfterLLMCallHookType | AfterLLMCallHookCallable] = []
def register_before_llm_call_hook(
hook: BeforeLLMCallHookType,
hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable,
) -> None:
"""Register a global before_llm_call hook.
@@ -190,7 +195,7 @@ def register_before_llm_call_hook(
def register_after_llm_call_hook(
hook: AfterLLMCallHookType,
hook: AfterLLMCallHookType | AfterLLMCallHookCallable,
) -> None:
"""Register a global after_llm_call hook.
@@ -217,7 +222,9 @@ def register_after_llm_call_hook(
_after_llm_call_hooks.append(hook)
def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]:
def get_before_llm_call_hooks() -> list[
BeforeLLMCallHookType | BeforeLLMCallHookCallable
]:
"""Get all registered global before_llm_call hooks.
Returns:
@@ -226,7 +233,7 @@ def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]:
return _before_llm_call_hooks.copy()
def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]:
def get_after_llm_call_hooks() -> list[AfterLLMCallHookType | AfterLLMCallHookCallable]:
"""Get all registered global after_llm_call hooks.
Returns:
@@ -236,7 +243,7 @@ def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]:
def unregister_before_llm_call_hook(
hook: BeforeLLMCallHookType,
hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable,
) -> bool:
"""Unregister a specific global before_llm_call hook.
@@ -262,7 +269,7 @@ def unregister_before_llm_call_hook(
def unregister_after_llm_call_hook(
hook: AfterLLMCallHookType,
hook: AfterLLMCallHookType | AfterLLMCallHookCallable,
) -> bool:
"""Unregister a specific global after_llm_call hook.

View File

@@ -3,7 +3,12 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any
from crewai.events.event_listener import event_listener
from crewai.hooks.types import AfterToolCallHookType, BeforeToolCallHookType
from crewai.hooks.types import (
AfterToolCallHookCallable,
AfterToolCallHookType,
BeforeToolCallHookCallable,
BeforeToolCallHookType,
)
from crewai.utilities.printer import Printer
@@ -112,12 +117,12 @@ class ToolCallHookContext:
# Global hook registries
_before_tool_call_hooks: list[BeforeToolCallHookType] = []
_after_tool_call_hooks: list[AfterToolCallHookType] = []
_before_tool_call_hooks: list[BeforeToolCallHookType | BeforeToolCallHookCallable] = []
_after_tool_call_hooks: list[AfterToolCallHookType | AfterToolCallHookCallable] = []
def register_before_tool_call_hook(
hook: BeforeToolCallHookType,
hook: BeforeToolCallHookType | BeforeToolCallHookCallable,
) -> None:
"""Register a global before_tool_call hook.
@@ -154,7 +159,7 @@ def register_before_tool_call_hook(
def register_after_tool_call_hook(
hook: AfterToolCallHookType,
hook: AfterToolCallHookType | AfterToolCallHookCallable,
) -> None:
"""Register a global after_tool_call hook.
@@ -184,7 +189,9 @@ def register_after_tool_call_hook(
_after_tool_call_hooks.append(hook)
def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]:
def get_before_tool_call_hooks() -> list[
BeforeToolCallHookType | BeforeToolCallHookCallable
]:
"""Get all registered global before_tool_call hooks.
Returns:
@@ -193,7 +200,9 @@ def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]:
return _before_tool_call_hooks.copy()
def get_after_tool_call_hooks() -> list[AfterToolCallHookType]:
def get_after_tool_call_hooks() -> list[
AfterToolCallHookType | AfterToolCallHookCallable
]:
"""Get all registered global after_tool_call hooks.
Returns:
@@ -203,7 +212,7 @@ def get_after_tool_call_hooks() -> list[AfterToolCallHookType]:
def unregister_before_tool_call_hook(
hook: BeforeToolCallHookType,
hook: BeforeToolCallHookType | BeforeToolCallHookCallable,
) -> bool:
"""Unregister a specific global before_tool_call hook.
@@ -229,7 +238,7 @@ def unregister_before_tool_call_hook(
def unregister_after_tool_call_hook(
hook: AfterToolCallHookType,
hook: AfterToolCallHookType | AfterToolCallHookCallable,
) -> bool:
"""Unregister a specific global after_tool_call hook.

View File

@@ -27,6 +27,8 @@ if TYPE_CHECKING:
from crewai import Agent, Task
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.hooks.llm_hooks import LLMCallHookContext
from crewai.hooks.tool_hooks import ToolCallHookContext
from crewai.project.wrappers import (
CrewInstance,
OutputJsonClass,
@@ -34,6 +36,8 @@ if TYPE_CHECKING:
)
from crewai.tasks.task_output import TaskOutput
_post_initialize_crew_hooks: list[Callable[[Any], None]] = []
class AgentConfig(TypedDict, total=False):
"""Type definition for agent configuration dictionary.
@@ -266,6 +270,9 @@ class CrewBaseMeta(type):
instance.map_all_agent_variables()
instance.map_all_task_variables()
for hook in _post_initialize_crew_hooks:
hook(instance)
original_methods = {
name: method
for name, method in cls.__dict__.items()
@@ -485,47 +492,61 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
if has_agent_filter:
agents_filter = hook_method._filter_agents
def make_filtered_before_llm(bound_fn, agents_list):
def filtered(context):
def make_filtered_before_llm(
bound_fn: Callable[[LLMCallHookContext], bool | None],
agents_list: list[str],
) -> Callable[[LLMCallHookContext], bool | None]:
def filtered(context: LLMCallHookContext) -> bool | None:
if context.agent and context.agent.role not in agents_list:
return None
return bound_fn(context)
return filtered
final_hook = make_filtered_before_llm(bound_hook, agents_filter)
before_llm_hook = make_filtered_before_llm(bound_hook, agents_filter)
else:
final_hook = bound_hook
before_llm_hook = bound_hook
register_before_llm_call_hook(final_hook)
instance._registered_hook_functions.append(("before_llm_call", final_hook))
register_before_llm_call_hook(before_llm_hook)
instance._registered_hook_functions.append(
("before_llm_call", before_llm_hook)
)
if hasattr(hook_method, "is_after_llm_call_hook"):
if has_agent_filter:
agents_filter = hook_method._filter_agents
def make_filtered_after_llm(bound_fn, agents_list):
def filtered(context):
def make_filtered_after_llm(
bound_fn: Callable[[LLMCallHookContext], str | None],
agents_list: list[str],
) -> Callable[[LLMCallHookContext], str | None]:
def filtered(context: LLMCallHookContext) -> str | None:
if context.agent and context.agent.role not in agents_list:
return None
return bound_fn(context)
return filtered
final_hook = make_filtered_after_llm(bound_hook, agents_filter)
after_llm_hook = make_filtered_after_llm(bound_hook, agents_filter)
else:
final_hook = bound_hook
after_llm_hook = bound_hook
register_after_llm_call_hook(final_hook)
instance._registered_hook_functions.append(("after_llm_call", final_hook))
register_after_llm_call_hook(after_llm_hook)
instance._registered_hook_functions.append(
("after_llm_call", after_llm_hook)
)
if hasattr(hook_method, "is_before_tool_call_hook"):
if has_tool_filter or has_agent_filter:
tools_filter = getattr(hook_method, "_filter_tools", None)
agents_filter = getattr(hook_method, "_filter_agents", None)
def make_filtered_before_tool(bound_fn, tools_list, agents_list):
def filtered(context):
def make_filtered_before_tool(
bound_fn: Callable[[ToolCallHookContext], bool | None],
tools_list: list[str] | None,
agents_list: list[str] | None,
) -> Callable[[ToolCallHookContext], bool | None]:
def filtered(context: ToolCallHookContext) -> bool | None:
if tools_list and context.tool_name not in tools_list:
return None
if (
@@ -538,22 +559,28 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
return filtered
final_hook = make_filtered_before_tool(
before_tool_hook = make_filtered_before_tool(
bound_hook, tools_filter, agents_filter
)
else:
final_hook = bound_hook
before_tool_hook = bound_hook
register_before_tool_call_hook(final_hook)
instance._registered_hook_functions.append(("before_tool_call", final_hook))
register_before_tool_call_hook(before_tool_hook)
instance._registered_hook_functions.append(
("before_tool_call", before_tool_hook)
)
if hasattr(hook_method, "is_after_tool_call_hook"):
if has_tool_filter or has_agent_filter:
tools_filter = getattr(hook_method, "_filter_tools", None)
agents_filter = getattr(hook_method, "_filter_agents", None)
def make_filtered_after_tool(bound_fn, tools_list, agents_list):
def filtered(context):
def make_filtered_after_tool(
bound_fn: Callable[[ToolCallHookContext], str | None],
tools_list: list[str] | None,
agents_list: list[str] | None,
) -> Callable[[ToolCallHookContext], str | None]:
def filtered(context: ToolCallHookContext) -> str | None:
if tools_list and context.tool_name not in tools_list:
return None
if (
@@ -566,14 +593,16 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
return filtered
final_hook = make_filtered_after_tool(
after_tool_hook = make_filtered_after_tool(
bound_hook, tools_filter, agents_filter
)
else:
final_hook = bound_hook
after_tool_hook = bound_hook
register_after_tool_call_hook(final_hook)
instance._registered_hook_functions.append(("after_tool_call", final_hook))
register_after_tool_call_hook(after_tool_hook)
instance._registered_hook_functions.append(
("after_tool_call", after_tool_hook)
)
instance._hooks_being_registered = False

View File

@@ -72,6 +72,8 @@ class CrewInstance(Protocol):
__crew_metadata__: CrewMetadata
_mcp_server_adapter: Any
_all_methods: dict[str, Callable[..., Any]]
_registered_hook_functions: list[tuple[str, Callable[..., Any]]]
_hooks_being_registered: bool
agents: list[Agent]
tasks: list[Task]
base_directory: Path