mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-11 05:22:41 +00:00
chore: improve hook typing and registration
Allow hook registration to accept both typed hook types and plain callables by importing and using After*/Before*CallHookCallable types; add explicit LLMCallHookContext and ToolCallHookContext typing in crew_base. Introduce a post-initialize crew hook list and invoke hooks after Crew instance initialization. Refactor filtered hook factory functions to include precise typing and clearer local names (before_llm_hook/after_llm_hook/before_tool_hook/after_tool_hook) and register those with the instance. Update CrewInstance protocol to include _registered_hook_functions and _hooks_being_registered fields.
This commit is contained in:
@@ -3,7 +3,12 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
from crewai.hooks.types import (
|
||||
AfterLLMCallHookCallable,
|
||||
AfterLLMCallHookType,
|
||||
BeforeLLMCallHookCallable,
|
||||
BeforeLLMCallHookType,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
@@ -149,12 +154,12 @@ class LLMCallHookContext:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
|
||||
|
||||
_before_llm_call_hooks: list[BeforeLLMCallHookType] = []
|
||||
_after_llm_call_hooks: list[AfterLLMCallHookType] = []
|
||||
_before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = []
|
||||
_after_llm_call_hooks: list[AfterLLMCallHookType | AfterLLMCallHookCallable] = []
|
||||
|
||||
|
||||
def register_before_llm_call_hook(
|
||||
hook: BeforeLLMCallHookType,
|
||||
hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable,
|
||||
) -> None:
|
||||
"""Register a global before_llm_call hook.
|
||||
|
||||
@@ -190,7 +195,7 @@ def register_before_llm_call_hook(
|
||||
|
||||
|
||||
def register_after_llm_call_hook(
|
||||
hook: AfterLLMCallHookType,
|
||||
hook: AfterLLMCallHookType | AfterLLMCallHookCallable,
|
||||
) -> None:
|
||||
"""Register a global after_llm_call hook.
|
||||
|
||||
@@ -217,7 +222,9 @@ def register_after_llm_call_hook(
|
||||
_after_llm_call_hooks.append(hook)
|
||||
|
||||
|
||||
def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]:
|
||||
def get_before_llm_call_hooks() -> list[
|
||||
BeforeLLMCallHookType | BeforeLLMCallHookCallable
|
||||
]:
|
||||
"""Get all registered global before_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
@@ -226,7 +233,7 @@ def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]:
|
||||
return _before_llm_call_hooks.copy()
|
||||
|
||||
|
||||
def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]:
|
||||
def get_after_llm_call_hooks() -> list[AfterLLMCallHookType | AfterLLMCallHookCallable]:
|
||||
"""Get all registered global after_llm_call hooks.
|
||||
|
||||
Returns:
|
||||
@@ -236,7 +243,7 @@ def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]:
|
||||
|
||||
|
||||
def unregister_before_llm_call_hook(
|
||||
hook: BeforeLLMCallHookType,
|
||||
hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable,
|
||||
) -> bool:
|
||||
"""Unregister a specific global before_llm_call hook.
|
||||
|
||||
@@ -262,7 +269,7 @@ def unregister_before_llm_call_hook(
|
||||
|
||||
|
||||
def unregister_after_llm_call_hook(
|
||||
hook: AfterLLMCallHookType,
|
||||
hook: AfterLLMCallHookType | AfterLLMCallHookCallable,
|
||||
) -> bool:
|
||||
"""Unregister a specific global after_llm_call hook.
|
||||
|
||||
|
||||
@@ -3,7 +3,12 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterToolCallHookType, BeforeToolCallHookType
|
||||
from crewai.hooks.types import (
|
||||
AfterToolCallHookCallable,
|
||||
AfterToolCallHookType,
|
||||
BeforeToolCallHookCallable,
|
||||
BeforeToolCallHookType,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
@@ -112,12 +117,12 @@ class ToolCallHookContext:
|
||||
|
||||
|
||||
# Global hook registries
|
||||
_before_tool_call_hooks: list[BeforeToolCallHookType] = []
|
||||
_after_tool_call_hooks: list[AfterToolCallHookType] = []
|
||||
_before_tool_call_hooks: list[BeforeToolCallHookType | BeforeToolCallHookCallable] = []
|
||||
_after_tool_call_hooks: list[AfterToolCallHookType | AfterToolCallHookCallable] = []
|
||||
|
||||
|
||||
def register_before_tool_call_hook(
|
||||
hook: BeforeToolCallHookType,
|
||||
hook: BeforeToolCallHookType | BeforeToolCallHookCallable,
|
||||
) -> None:
|
||||
"""Register a global before_tool_call hook.
|
||||
|
||||
@@ -154,7 +159,7 @@ def register_before_tool_call_hook(
|
||||
|
||||
|
||||
def register_after_tool_call_hook(
|
||||
hook: AfterToolCallHookType,
|
||||
hook: AfterToolCallHookType | AfterToolCallHookCallable,
|
||||
) -> None:
|
||||
"""Register a global after_tool_call hook.
|
||||
|
||||
@@ -184,7 +189,9 @@ def register_after_tool_call_hook(
|
||||
_after_tool_call_hooks.append(hook)
|
||||
|
||||
|
||||
def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]:
|
||||
def get_before_tool_call_hooks() -> list[
|
||||
BeforeToolCallHookType | BeforeToolCallHookCallable
|
||||
]:
|
||||
"""Get all registered global before_tool_call hooks.
|
||||
|
||||
Returns:
|
||||
@@ -193,7 +200,9 @@ def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]:
|
||||
return _before_tool_call_hooks.copy()
|
||||
|
||||
|
||||
def get_after_tool_call_hooks() -> list[AfterToolCallHookType]:
|
||||
def get_after_tool_call_hooks() -> list[
|
||||
AfterToolCallHookType | AfterToolCallHookCallable
|
||||
]:
|
||||
"""Get all registered global after_tool_call hooks.
|
||||
|
||||
Returns:
|
||||
@@ -203,7 +212,7 @@ def get_after_tool_call_hooks() -> list[AfterToolCallHookType]:
|
||||
|
||||
|
||||
def unregister_before_tool_call_hook(
|
||||
hook: BeforeToolCallHookType,
|
||||
hook: BeforeToolCallHookType | BeforeToolCallHookCallable,
|
||||
) -> bool:
|
||||
"""Unregister a specific global before_tool_call hook.
|
||||
|
||||
@@ -229,7 +238,7 @@ def unregister_before_tool_call_hook(
|
||||
|
||||
|
||||
def unregister_after_tool_call_hook(
|
||||
hook: AfterToolCallHookType,
|
||||
hook: AfterToolCallHookType | AfterToolCallHookCallable,
|
||||
) -> bool:
|
||||
"""Unregister a specific global after_tool_call hook.
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ if TYPE_CHECKING:
|
||||
from crewai import Agent, Task
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
from crewai.project.wrappers import (
|
||||
CrewInstance,
|
||||
OutputJsonClass,
|
||||
@@ -34,6 +36,8 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
_post_initialize_crew_hooks: list[Callable[[Any], None]] = []
|
||||
|
||||
|
||||
class AgentConfig(TypedDict, total=False):
|
||||
"""Type definition for agent configuration dictionary.
|
||||
@@ -266,6 +270,9 @@ class CrewBaseMeta(type):
|
||||
instance.map_all_agent_variables()
|
||||
instance.map_all_task_variables()
|
||||
|
||||
for hook in _post_initialize_crew_hooks:
|
||||
hook(instance)
|
||||
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in cls.__dict__.items()
|
||||
@@ -485,47 +492,61 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
|
||||
if has_agent_filter:
|
||||
agents_filter = hook_method._filter_agents
|
||||
|
||||
def make_filtered_before_llm(bound_fn, agents_list):
|
||||
def filtered(context):
|
||||
def make_filtered_before_llm(
|
||||
bound_fn: Callable[[LLMCallHookContext], bool | None],
|
||||
agents_list: list[str],
|
||||
) -> Callable[[LLMCallHookContext], bool | None]:
|
||||
def filtered(context: LLMCallHookContext) -> bool | None:
|
||||
if context.agent and context.agent.role not in agents_list:
|
||||
return None
|
||||
return bound_fn(context)
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_before_llm(bound_hook, agents_filter)
|
||||
before_llm_hook = make_filtered_before_llm(bound_hook, agents_filter)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
before_llm_hook = bound_hook
|
||||
|
||||
register_before_llm_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("before_llm_call", final_hook))
|
||||
register_before_llm_call_hook(before_llm_hook)
|
||||
instance._registered_hook_functions.append(
|
||||
("before_llm_call", before_llm_hook)
|
||||
)
|
||||
|
||||
if hasattr(hook_method, "is_after_llm_call_hook"):
|
||||
if has_agent_filter:
|
||||
agents_filter = hook_method._filter_agents
|
||||
|
||||
def make_filtered_after_llm(bound_fn, agents_list):
|
||||
def filtered(context):
|
||||
def make_filtered_after_llm(
|
||||
bound_fn: Callable[[LLMCallHookContext], str | None],
|
||||
agents_list: list[str],
|
||||
) -> Callable[[LLMCallHookContext], str | None]:
|
||||
def filtered(context: LLMCallHookContext) -> str | None:
|
||||
if context.agent and context.agent.role not in agents_list:
|
||||
return None
|
||||
return bound_fn(context)
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_after_llm(bound_hook, agents_filter)
|
||||
after_llm_hook = make_filtered_after_llm(bound_hook, agents_filter)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
after_llm_hook = bound_hook
|
||||
|
||||
register_after_llm_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("after_llm_call", final_hook))
|
||||
register_after_llm_call_hook(after_llm_hook)
|
||||
instance._registered_hook_functions.append(
|
||||
("after_llm_call", after_llm_hook)
|
||||
)
|
||||
|
||||
if hasattr(hook_method, "is_before_tool_call_hook"):
|
||||
if has_tool_filter or has_agent_filter:
|
||||
tools_filter = getattr(hook_method, "_filter_tools", None)
|
||||
agents_filter = getattr(hook_method, "_filter_agents", None)
|
||||
|
||||
def make_filtered_before_tool(bound_fn, tools_list, agents_list):
|
||||
def filtered(context):
|
||||
def make_filtered_before_tool(
|
||||
bound_fn: Callable[[ToolCallHookContext], bool | None],
|
||||
tools_list: list[str] | None,
|
||||
agents_list: list[str] | None,
|
||||
) -> Callable[[ToolCallHookContext], bool | None]:
|
||||
def filtered(context: ToolCallHookContext) -> bool | None:
|
||||
if tools_list and context.tool_name not in tools_list:
|
||||
return None
|
||||
if (
|
||||
@@ -538,22 +559,28 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_before_tool(
|
||||
before_tool_hook = make_filtered_before_tool(
|
||||
bound_hook, tools_filter, agents_filter
|
||||
)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
before_tool_hook = bound_hook
|
||||
|
||||
register_before_tool_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("before_tool_call", final_hook))
|
||||
register_before_tool_call_hook(before_tool_hook)
|
||||
instance._registered_hook_functions.append(
|
||||
("before_tool_call", before_tool_hook)
|
||||
)
|
||||
|
||||
if hasattr(hook_method, "is_after_tool_call_hook"):
|
||||
if has_tool_filter or has_agent_filter:
|
||||
tools_filter = getattr(hook_method, "_filter_tools", None)
|
||||
agents_filter = getattr(hook_method, "_filter_agents", None)
|
||||
|
||||
def make_filtered_after_tool(bound_fn, tools_list, agents_list):
|
||||
def filtered(context):
|
||||
def make_filtered_after_tool(
|
||||
bound_fn: Callable[[ToolCallHookContext], str | None],
|
||||
tools_list: list[str] | None,
|
||||
agents_list: list[str] | None,
|
||||
) -> Callable[[ToolCallHookContext], str | None]:
|
||||
def filtered(context: ToolCallHookContext) -> str | None:
|
||||
if tools_list and context.tool_name not in tools_list:
|
||||
return None
|
||||
if (
|
||||
@@ -566,14 +593,16 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None:
|
||||
|
||||
return filtered
|
||||
|
||||
final_hook = make_filtered_after_tool(
|
||||
after_tool_hook = make_filtered_after_tool(
|
||||
bound_hook, tools_filter, agents_filter
|
||||
)
|
||||
else:
|
||||
final_hook = bound_hook
|
||||
after_tool_hook = bound_hook
|
||||
|
||||
register_after_tool_call_hook(final_hook)
|
||||
instance._registered_hook_functions.append(("after_tool_call", final_hook))
|
||||
register_after_tool_call_hook(after_tool_hook)
|
||||
instance._registered_hook_functions.append(
|
||||
("after_tool_call", after_tool_hook)
|
||||
)
|
||||
|
||||
instance._hooks_being_registered = False
|
||||
|
||||
|
||||
@@ -72,6 +72,8 @@ class CrewInstance(Protocol):
|
||||
__crew_metadata__: CrewMetadata
|
||||
_mcp_server_adapter: Any
|
||||
_all_methods: dict[str, Callable[..., Any]]
|
||||
_registered_hook_functions: list[tuple[str, Callable[..., Any]]]
|
||||
_hooks_being_registered: bool
|
||||
agents: list[Agent]
|
||||
tasks: list[Task]
|
||||
base_directory: Path
|
||||
|
||||
Reference in New Issue
Block a user