chore: modernize project module typing

This commit is contained in:
Greyson Lalonde
2025-09-07 01:16:55 -04:00
parent 1a96ed7b00
commit d0641a8084
2 changed files with 73 additions and 47 deletions

View File

@@ -1,87 +1,95 @@
"""Decorators for defining crew components and their behaviors."""
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Callable from typing import Any, Concatenate, ParamSpec, TypeVar
from crewai import Crew from crewai import Crew
from crewai.project.utils import memoize from crewai.project.utils import memoize
"""Decorators for defining crew components and their behaviors.""" P = ParamSpec("P")
R = TypeVar("R")
def before_kickoff(func): def before_kickoff(func: Callable[P, R]) -> Callable[P, R]:
"""Marks a method to execute before crew kickoff.""" """Marks a method to execute before crew kickoff."""
func.is_before_kickoff = True func.is_before_kickoff = True # type: ignore
return func return func
def after_kickoff(func): def after_kickoff(func: Callable[P, R]) -> Callable[P, R]:
"""Marks a method to execute after crew kickoff.""" """Marks a method to execute after crew kickoff."""
func.is_after_kickoff = True func.is_after_kickoff = True # type: ignore
return func return func
def task(func): def task(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]:
"""Marks a method as a crew task.""" """Marks a method as a crew task."""
func.is_task = True func.is_task = True # type: ignore
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R:
result = func(*args, **kwargs) result = func(self, *args, **kwargs)
if not result.name: if not result.name: # type: ignore
result.name = func.__name__ result.name = func.__name__ # type: ignore
return result return result
return memoize(wrapper) return memoize(wrapper)
def agent(func): def agent(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]:
"""Marks a method as a crew agent.""" """Marks a method as a crew agent."""
func.is_agent = True func.is_agent = True # type: ignore
func = memoize(func) return memoize(func)
return func
def llm(func): def llm(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]:
"""Marks a method as an LLM provider.""" """Marks a method as an LLM provider."""
func.is_llm = True func.is_llm = True # type: ignore
func = memoize(func) return memoize(func)
return func
def output_json(cls): def output_json(cls: type[R]) -> type[R]:
"""Marks a class as JSON output format.""" """Marks a class as JSON output format."""
cls.is_output_json = True cls.is_output_json = True # type: ignore
return cls return cls
def output_pydantic(cls): def output_pydantic(cls: type[R]) -> type[R]:
"""Marks a class as Pydantic output format.""" """Marks a class as Pydantic output format."""
cls.is_output_pydantic = True cls.is_output_pydantic = True # type: ignore
return cls return cls
def tool(func): def tool(func: Callable[Concatenate[Any, P], R]) -> Callable[Concatenate[Any, P], R]:
"""Marks a method as a crew tool.""" """Marks a method as a crew tool."""
func.is_tool = True func.is_tool = True # type: ignore
return memoize(func) return memoize(func)
def callback(func): def callback(
func: Callable[Concatenate[Any, P], R],
) -> Callable[Concatenate[Any, P], R]:
"""Marks a method as a crew callback.""" """Marks a method as a crew callback."""
func.is_callback = True func.is_callback = True # type: ignore
return memoize(func) return memoize(func)
def cache_handler(func): def cache_handler(
func: Callable[Concatenate[Any, P], R],
) -> Callable[Concatenate[Any, P], R]:
"""Marks a method as a cache handler.""" """Marks a method as a cache handler."""
func.is_cache_handler = True func.is_cache_handler = True # type: ignore
return memoize(func) return memoize(func)
def crew(func) -> Callable[..., Crew]: def crew(
func: Callable[Concatenate[Any, P], Crew],
) -> Callable[Concatenate[Any, P], Crew]:
"""Marks a method as the main crew execution point.""" """Marks a method as the main crew execution point."""
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs) -> Crew: def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> Crew:
instantiated_tasks = [] instantiated_tasks = []
instantiated_agents = [] instantiated_agents = []
agent_roles = set() agent_roles = set()
@@ -91,7 +99,7 @@ def crew(func) -> Callable[..., Crew]:
agents = self._original_agents.items() agents = self._original_agents.items()
# Instantiate tasks in order # Instantiate tasks in order
for task_name, task_method in tasks: for _task_name, task_method in tasks:
task_instance = task_method(self) task_instance = task_method(self)
instantiated_tasks.append(task_instance) instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None) agent_instance = getattr(task_instance, "agent", None)
@@ -100,7 +108,7 @@ def crew(func) -> Callable[..., Crew]:
agent_roles.add(agent_instance.role) agent_roles.add(agent_instance.role)
# Instantiate agents not included by tasks # Instantiate agents not included by tasks
for agent_name, agent_method in agents: for _agent_name, agent_method in agents:
agent_instance = agent_method(self) agent_instance = agent_method(self)
if agent_instance.role not in agent_roles: if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance) instantiated_agents.append(agent_instance)
@@ -109,19 +117,23 @@ def crew(func) -> Callable[..., Crew]:
self.agents = instantiated_agents self.agents = instantiated_agents
self.tasks = instantiated_tasks self.tasks = instantiated_tasks
crew = func(self, *args, **kwargs) crew_result = func(self, *args, **kwargs)
def callback_wrapper(callback, instance): def callback_wrapper(callback_func: Any, instance: Any) -> Callable[..., Any]:
def wrapper(*args, **kwargs): def inner_wrapper(*cb_args: Any, **cb_kwargs: Any) -> Any:
return callback(instance, *args, **kwargs) return callback_func(instance, *cb_args, **cb_kwargs)
return wrapper return inner_wrapper
for _, callback in self._before_kickoff.items(): for callback_func in self._before_kickoff.values():
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self)) crew_result.before_kickoff_callbacks.append(
for _, callback in self._after_kickoff.items(): callback_wrapper(callback_func, self)
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self)) )
for callback_func in self._after_kickoff.values():
crew_result.after_kickoff_callbacks.append(
callback_wrapper(callback_func, self)
)
return crew return crew_result
return memoize(wrapper) return memoize(wrapper)

View File

@@ -1,11 +1,25 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def memoize(func): def memoize(func: Callable[P, R]) -> Callable[P, R]:
cache = {} """Decorator that caches function results based on arguments.
Args:
func: The function to memoize.
Returns:
The memoized function.
"""
cache: dict[tuple, R] = {}
@wraps(func) @wraps(func)
def memoized_func(*args, **kwargs): def memoized_func(*args: P.args, **kwargs: P.kwargs) -> R:
"""Memoized wrapper function."""
key = (args, tuple(kwargs.items())) key = (args, tuple(kwargs.items()))
if key not in cache: if key not in cache:
cache[key] = func(*args, **kwargs) cache[key] = func(*args, **kwargs)