diff --git a/src/crewai/project/annotations.py b/src/crewai/project/annotations.py index d7c636ccf..a060a2a84 100644 --- a/src/crewai/project/annotations.py +++ b/src/crewai/project/annotations.py @@ -1,87 +1,95 @@ +"""Decorators for defining crew components and their behaviors.""" + +from collections.abc import Callable from functools import wraps -from typing import Callable +from typing import Any, Concatenate, ParamSpec, TypeVar from crewai import Crew from crewai.project.utils import memoize -"""Decorators for defining crew components and their behaviors.""" +P = ParamSpec("P") +R = TypeVar("R") -def before_kickoff(func): +def before_kickoff(func: Callable[P, R]) -> Callable[P, R]: """Marks a method to execute before crew kickoff.""" - func.is_before_kickoff = True + func.is_before_kickoff = True # type: ignore return func -def after_kickoff(func): +def after_kickoff(func: Callable[P, R]) -> Callable[P, R]: """Marks a method to execute after crew kickoff.""" - func.is_after_kickoff = True + func.is_after_kickoff = True # type: ignore return func -def task(func): +def task(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]: """Marks a method as a crew task.""" - func.is_task = True + func.is_task = True # type: ignore @wraps(func) - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - if not result.name: - result.name = func.__name__ + def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R: + result = func(self, *args, **kwargs) + if not result.name: # type: ignore + result.name = func.__name__ # type: ignore return result return memoize(wrapper) -def agent(func): +def agent(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]: """Marks a method as a crew agent.""" - func.is_agent = True - func = memoize(func) - return func + func.is_agent = True # type: ignore + return memoize(func) -def llm(func): +def llm(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]: """Marks a method as an LLM provider.""" - func.is_llm = True - func = memoize(func) - return func + func.is_llm = True # type: ignore + return memoize(func) -def output_json(cls): +def output_json(cls: type[R]) -> type[R]: """Marks a class as JSON output format.""" - cls.is_output_json = True + cls.is_output_json = True # type: ignore return cls -def output_pydantic(cls): +def output_pydantic(cls: type[R]) -> type[R]: """Marks a class as Pydantic output format.""" - cls.is_output_pydantic = True + cls.is_output_pydantic = True # type: ignore return cls -def tool(func): +def tool(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]: """Marks a method as a crew tool.""" - func.is_tool = True + func.is_tool = True # type: ignore return memoize(func) -def callback(func): +def callback( + func: Callable[Concatenate[Any, P], R], +) -> Callable[Concatenate[Any, P], R]: """Marks a method as a crew callback.""" - func.is_callback = True + func.is_callback = True # type: ignore return memoize(func) -def cache_handler(func): +def cache_handler( + func: Callable[Concatenate[Any, P], R], +) -> Callable[Concatenate[Any, P], R]: """Marks a method as a cache handler.""" - func.is_cache_handler = True + func.is_cache_handler = True # type: ignore return memoize(func) -def crew(func) -> Callable[..., Crew]: +def crew( + func: Callable[Concatenate[Any, P], Crew], +) -> Callable[Concatenate[Any, P], Crew]: """Marks a method as the main crew execution point.""" @wraps(func) - def wrapper(self, *args, **kwargs) -> Crew: + def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> Crew: instantiated_tasks = [] instantiated_agents = [] agent_roles = set() @@ -91,7 +99,7 @@ def crew(func) -> Callable[..., Crew]: agents = self._original_agents.items() # Instantiate tasks in order - for task_name, task_method in tasks: + for _task_name, task_method in tasks: task_instance = task_method(self) instantiated_tasks.append(task_instance) agent_instance = getattr(task_instance, "agent", None) @@ -100,7 +108,7 @@ def crew(func) -> Callable[..., Crew]: agent_roles.add(agent_instance.role) # Instantiate agents not included by tasks - for agent_name, agent_method in agents: + for _agent_name, agent_method in agents: agent_instance = agent_method(self) if agent_instance.role not in agent_roles: instantiated_agents.append(agent_instance) @@ -109,19 +117,23 @@ def crew(func) -> Callable[..., Crew]: self.agents = instantiated_agents self.tasks = instantiated_tasks - crew = func(self, *args, **kwargs) + crew_result = func(self, *args, **kwargs) - def callback_wrapper(callback, instance): - def wrapper(*args, **kwargs): - return callback(instance, *args, **kwargs) + def callback_wrapper(callback_func: Any, instance: Any) -> Callable[..., Any]: + def inner_wrapper(*cb_args: Any, **cb_kwargs: Any) -> Any: + return callback_func(instance, *cb_args, **cb_kwargs) - return wrapper + return inner_wrapper - for _, callback in self._before_kickoff.items(): - crew.before_kickoff_callbacks.append(callback_wrapper(callback, self)) - for _, callback in self._after_kickoff.items(): - crew.after_kickoff_callbacks.append(callback_wrapper(callback, self)) + for callback_func in self._before_kickoff.values(): + crew_result.before_kickoff_callbacks.append( + callback_wrapper(callback_func, self) + ) + for callback_func in self._after_kickoff.values(): + crew_result.after_kickoff_callbacks.append( + callback_wrapper(callback_func, self) + ) - return crew + return crew_result return memoize(wrapper) diff --git a/src/crewai/project/utils.py b/src/crewai/project/utils.py index e8876d941..3cb7c6ca8 100644 --- a/src/crewai/project/utils.py +++ b/src/crewai/project/utils.py @@ -1,11 +1,25 @@ +from collections.abc import Callable from functools import wraps +from typing import ParamSpec, TypeVar + +P = ParamSpec("P") +R = TypeVar("R") -def memoize(func): - cache = {} +def memoize(func: Callable[P, R]) -> Callable[P, R]: + """Decorator that caches function results based on arguments. + + Args: + func: The function to memoize. + + Returns: + The memoized function. + """ + cache: dict[tuple, R] = {} @wraps(func) - def memoized_func(*args, **kwargs): + def memoized_func(*args: P.args, **kwargs: P.kwargs) -> R: + """Memoized wrapper function.""" key = (args, tuple(kwargs.items())) if key not in cache: cache[key] = func(*args, **kwargs)