diff --git a/lib/crewai/src/crewai/hooks/llm_hooks.py b/lib/crewai/src/crewai/hooks/llm_hooks.py index 2f5462fe0..3a6abbedf 100644 --- a/lib/crewai/src/crewai/hooks/llm_hooks.py +++ b/lib/crewai/src/crewai/hooks/llm_hooks.py @@ -3,7 +3,12 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any, cast from crewai.events.event_listener import event_listener -from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType +from crewai.hooks.types import ( + AfterLLMCallHookCallable, + AfterLLMCallHookType, + BeforeLLMCallHookCallable, + BeforeLLMCallHookType, +) from crewai.utilities.printer import Printer @@ -149,12 +154,12 @@ class LLMCallHookContext: event_listener.formatter.resume_live_updates() -_before_llm_call_hooks: list[BeforeLLMCallHookType] = [] -_after_llm_call_hooks: list[AfterLLMCallHookType] = [] +_before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = [] +_after_llm_call_hooks: list[AfterLLMCallHookType | AfterLLMCallHookCallable] = [] def register_before_llm_call_hook( - hook: BeforeLLMCallHookType, + hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable, ) -> None: """Register a global before_llm_call hook. @@ -190,7 +195,7 @@ def register_before_llm_call_hook( def register_after_llm_call_hook( - hook: AfterLLMCallHookType, + hook: AfterLLMCallHookType | AfterLLMCallHookCallable, ) -> None: """Register a global after_llm_call hook. @@ -217,7 +222,9 @@ def register_after_llm_call_hook( _after_llm_call_hooks.append(hook) -def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]: +def get_before_llm_call_hooks() -> list[ + BeforeLLMCallHookType | BeforeLLMCallHookCallable +]: """Get all registered global before_llm_call hooks. Returns: @@ -226,7 +233,7 @@ def get_before_llm_call_hooks() -> list[BeforeLLMCallHookType]: return _before_llm_call_hooks.copy() -def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]: +def get_after_llm_call_hooks() -> list[AfterLLMCallHookType | AfterLLMCallHookCallable]: """Get all registered global after_llm_call hooks. Returns: @@ -236,7 +243,7 @@ def get_after_llm_call_hooks() -> list[AfterLLMCallHookType]: def unregister_before_llm_call_hook( - hook: BeforeLLMCallHookType, + hook: BeforeLLMCallHookType | BeforeLLMCallHookCallable, ) -> bool: """Unregister a specific global before_llm_call hook. @@ -262,7 +269,7 @@ def unregister_before_llm_call_hook( def unregister_after_llm_call_hook( - hook: AfterLLMCallHookType, + hook: AfterLLMCallHookType | AfterLLMCallHookCallable, ) -> bool: """Unregister a specific global after_llm_call hook. diff --git a/lib/crewai/src/crewai/hooks/tool_hooks.py b/lib/crewai/src/crewai/hooks/tool_hooks.py index 6ee0ab033..ac7f5c362 100644 --- a/lib/crewai/src/crewai/hooks/tool_hooks.py +++ b/lib/crewai/src/crewai/hooks/tool_hooks.py @@ -3,7 +3,12 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any from crewai.events.event_listener import event_listener -from crewai.hooks.types import AfterToolCallHookType, BeforeToolCallHookType +from crewai.hooks.types import ( + AfterToolCallHookCallable, + AfterToolCallHookType, + BeforeToolCallHookCallable, + BeforeToolCallHookType, +) from crewai.utilities.printer import Printer @@ -112,12 +117,12 @@ class ToolCallHookContext: # Global hook registries -_before_tool_call_hooks: list[BeforeToolCallHookType] = [] -_after_tool_call_hooks: list[AfterToolCallHookType] = [] +_before_tool_call_hooks: list[BeforeToolCallHookType | BeforeToolCallHookCallable] = [] +_after_tool_call_hooks: list[AfterToolCallHookType | AfterToolCallHookCallable] = [] def register_before_tool_call_hook( - hook: BeforeToolCallHookType, + hook: BeforeToolCallHookType | BeforeToolCallHookCallable, ) -> None: """Register a global before_tool_call hook. @@ -154,7 +159,7 @@ def register_before_tool_call_hook( def register_after_tool_call_hook( - hook: AfterToolCallHookType, + hook: AfterToolCallHookType | AfterToolCallHookCallable, ) -> None: """Register a global after_tool_call hook. @@ -184,7 +189,9 @@ def register_after_tool_call_hook( _after_tool_call_hooks.append(hook) -def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]: +def get_before_tool_call_hooks() -> list[ + BeforeToolCallHookType | BeforeToolCallHookCallable +]: """Get all registered global before_tool_call hooks. Returns: @@ -193,7 +200,9 @@ def get_before_tool_call_hooks() -> list[BeforeToolCallHookType]: return _before_tool_call_hooks.copy() -def get_after_tool_call_hooks() -> list[AfterToolCallHookType]: +def get_after_tool_call_hooks() -> list[ + AfterToolCallHookType | AfterToolCallHookCallable +]: """Get all registered global after_tool_call hooks. Returns: @@ -203,7 +212,7 @@ def get_after_tool_call_hooks() -> list[AfterToolCallHookType]: def unregister_before_tool_call_hook( - hook: BeforeToolCallHookType, + hook: BeforeToolCallHookType | BeforeToolCallHookCallable, ) -> bool: """Unregister a specific global before_tool_call hook. @@ -229,7 +238,7 @@ def unregister_before_tool_call_hook( def unregister_after_tool_call_hook( - hook: AfterToolCallHookType, + hook: AfterToolCallHookType | AfterToolCallHookCallable, ) -> bool: """Unregister a specific global after_tool_call hook. diff --git a/lib/crewai/src/crewai/project/crew_base.py b/lib/crewai/src/crewai/project/crew_base.py index 202d98898..323450b13 100644 --- a/lib/crewai/src/crewai/project/crew_base.py +++ b/lib/crewai/src/crewai/project/crew_base.py @@ -27,6 +27,8 @@ if TYPE_CHECKING: from crewai import Agent, Task from crewai.agents.cache.cache_handler import CacheHandler from crewai.crews.crew_output import CrewOutput + from crewai.hooks.llm_hooks import LLMCallHookContext + from crewai.hooks.tool_hooks import ToolCallHookContext from crewai.project.wrappers import ( CrewInstance, OutputJsonClass, @@ -34,6 +36,8 @@ if TYPE_CHECKING: ) from crewai.tasks.task_output import TaskOutput +_post_initialize_crew_hooks: list[Callable[[Any], None]] = [] + class AgentConfig(TypedDict, total=False): """Type definition for agent configuration dictionary. @@ -266,6 +270,9 @@ class CrewBaseMeta(type): instance.map_all_agent_variables() instance.map_all_task_variables() + for hook in _post_initialize_crew_hooks: + hook(instance) + original_methods = { name: method for name, method in cls.__dict__.items() @@ -485,47 +492,61 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None: if has_agent_filter: agents_filter = hook_method._filter_agents - def make_filtered_before_llm(bound_fn, agents_list): - def filtered(context): + def make_filtered_before_llm( + bound_fn: Callable[[LLMCallHookContext], bool | None], + agents_list: list[str], + ) -> Callable[[LLMCallHookContext], bool | None]: + def filtered(context: LLMCallHookContext) -> bool | None: if context.agent and context.agent.role not in agents_list: return None return bound_fn(context) return filtered - final_hook = make_filtered_before_llm(bound_hook, agents_filter) + before_llm_hook = make_filtered_before_llm(bound_hook, agents_filter) else: - final_hook = bound_hook + before_llm_hook = bound_hook - register_before_llm_call_hook(final_hook) - instance._registered_hook_functions.append(("before_llm_call", final_hook)) + register_before_llm_call_hook(before_llm_hook) + instance._registered_hook_functions.append( + ("before_llm_call", before_llm_hook) + ) if hasattr(hook_method, "is_after_llm_call_hook"): if has_agent_filter: agents_filter = hook_method._filter_agents - def make_filtered_after_llm(bound_fn, agents_list): - def filtered(context): + def make_filtered_after_llm( + bound_fn: Callable[[LLMCallHookContext], str | None], + agents_list: list[str], + ) -> Callable[[LLMCallHookContext], str | None]: + def filtered(context: LLMCallHookContext) -> str | None: if context.agent and context.agent.role not in agents_list: return None return bound_fn(context) return filtered - final_hook = make_filtered_after_llm(bound_hook, agents_filter) + after_llm_hook = make_filtered_after_llm(bound_hook, agents_filter) else: - final_hook = bound_hook + after_llm_hook = bound_hook - register_after_llm_call_hook(final_hook) - instance._registered_hook_functions.append(("after_llm_call", final_hook)) + register_after_llm_call_hook(after_llm_hook) + instance._registered_hook_functions.append( + ("after_llm_call", after_llm_hook) + ) if hasattr(hook_method, "is_before_tool_call_hook"): if has_tool_filter or has_agent_filter: tools_filter = getattr(hook_method, "_filter_tools", None) agents_filter = getattr(hook_method, "_filter_agents", None) - def make_filtered_before_tool(bound_fn, tools_list, agents_list): - def filtered(context): + def make_filtered_before_tool( + bound_fn: Callable[[ToolCallHookContext], bool | None], + tools_list: list[str] | None, + agents_list: list[str] | None, + ) -> Callable[[ToolCallHookContext], bool | None]: + def filtered(context: ToolCallHookContext) -> bool | None: if tools_list and context.tool_name not in tools_list: return None if ( @@ -538,22 +559,28 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None: return filtered - final_hook = make_filtered_before_tool( + before_tool_hook = make_filtered_before_tool( bound_hook, tools_filter, agents_filter ) else: - final_hook = bound_hook + before_tool_hook = bound_hook - register_before_tool_call_hook(final_hook) - instance._registered_hook_functions.append(("before_tool_call", final_hook)) + register_before_tool_call_hook(before_tool_hook) + instance._registered_hook_functions.append( + ("before_tool_call", before_tool_hook) + ) if hasattr(hook_method, "is_after_tool_call_hook"): if has_tool_filter or has_agent_filter: tools_filter = getattr(hook_method, "_filter_tools", None) agents_filter = getattr(hook_method, "_filter_agents", None) - def make_filtered_after_tool(bound_fn, tools_list, agents_list): - def filtered(context): + def make_filtered_after_tool( + bound_fn: Callable[[ToolCallHookContext], str | None], + tools_list: list[str] | None, + agents_list: list[str] | None, + ) -> Callable[[ToolCallHookContext], str | None]: + def filtered(context: ToolCallHookContext) -> str | None: if tools_list and context.tool_name not in tools_list: return None if ( @@ -566,14 +593,16 @@ def _register_crew_hooks(instance: CrewInstance, cls: type) -> None: return filtered - final_hook = make_filtered_after_tool( + after_tool_hook = make_filtered_after_tool( bound_hook, tools_filter, agents_filter ) else: - final_hook = bound_hook + after_tool_hook = bound_hook - register_after_tool_call_hook(final_hook) - instance._registered_hook_functions.append(("after_tool_call", final_hook)) + register_after_tool_call_hook(after_tool_hook) + instance._registered_hook_functions.append( + ("after_tool_call", after_tool_hook) + ) instance._hooks_being_registered = False diff --git a/lib/crewai/src/crewai/project/wrappers.py b/lib/crewai/src/crewai/project/wrappers.py index 28cd39525..3d570b6f0 100644 --- a/lib/crewai/src/crewai/project/wrappers.py +++ b/lib/crewai/src/crewai/project/wrappers.py @@ -72,6 +72,8 @@ class CrewInstance(Protocol): __crew_metadata__: CrewMetadata _mcp_server_adapter: Any _all_methods: dict[str, Callable[..., Any]] + _registered_hook_functions: list[tuple[str, Callable[..., Any]]] + _hooks_being_registered: bool agents: list[Agent] tasks: list[Task] base_directory: Path