chore: add comprehensive typing to AgentConfig and TaskConfig

This commit is contained in:
Greyson Lalonde
2025-10-16 19:53:19 -04:00
parent 30a4b712a3
commit 6081809c76
2 changed files with 85 additions and 10 deletions

View File

@@ -33,12 +33,62 @@ class AgentConfig(TypedDict, total=False):
Fields can be either string references (from YAML) or actual instances (after processing).
"""
# Core agent attributes (from BaseAgent)
role: str
goal: str
backstory: str
cache: bool
verbose: bool
max_rpm: int
allow_delegation: bool
max_iter: int
max_tokens: int
callbacks: list[str]
# LLM configuration
llm: str
tools: list[str] | list[BaseTool]
function_calling_llm: str
use_system_prompt: bool
# Template configuration
system_template: str
prompt_template: str
response_template: str
# Tools and handlers (can be string references or instances)
tools: list[str] | list[BaseTool]
step_callback: str
cache_handler: str | CacheHandler
# Code execution
allow_code_execution: bool
code_execution_mode: Literal["safe", "unsafe"]
# Context and performance
respect_context_window: bool
max_retry_limit: int
# Multimodal and reasoning
multimodal: bool
reasoning: bool
max_reasoning_attempts: int
# Knowledge configuration
knowledge_sources: list[str] | list[Any]
knowledge_storage: str | Any
knowledge_config: dict[str, Any]
embedder: dict[str, Any]
agent_knowledge_context: str
crew_knowledge_context: str
knowledge_search_query: str
# Misc configuration
inject_date: bool
date_format: str
from_repository: str
guardrail: Callable[[Any], tuple[bool, Any]] | str
guardrail_max_retries: int
class TaskConfig(TypedDict, total=False):
"""Type definition for task configuration dictionary.
@@ -47,13 +97,37 @@ class TaskConfig(TypedDict, total=False):
Fields can be either string references (from YAML) or actual instances (after processing).
"""
context: list[str]
tools: list[str] | list[BaseTool]
# Core task attributes
name: str
description: str
expected_output: str
# Agent and context
agent: str
context: list[str]
# Tools and callbacks (can be string references or instances)
tools: list[str] | list[BaseTool]
callback: str
callbacks: list[str]
# Output configuration
output_json: str
output_pydantic: str
callbacks: list[str]
output_file: str
create_directory: bool
# Execution configuration
async_execution: bool
human_input: bool
markdown: bool
# Guardrail configuration
guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str
guardrail_max_retries: int
# Misc configuration
allow_crewai_trigger_context: bool
load_dotenv()
@@ -199,7 +273,7 @@ class CrewBaseMeta(type):
}
after_kickoff_callbacks = _filter_methods(original_methods, "is_after_kickoff")
after_kickoff_callbacks["_close_mcp_server"] = instance._close_mcp_server
after_kickoff_callbacks["close_mcp_server"] = instance.close_mcp_server
instance.__crew_metadata__ = CrewMetadata(
original_methods=original_methods,
@@ -211,7 +285,7 @@ class CrewBaseMeta(type):
)
def _close_mcp_server(
def close_mcp_server(
self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput
) -> CrewOutput:
"""Stop MCP server adapter and return outputs.
@@ -508,7 +582,7 @@ _CLASS_SETUP_FUNCTIONS: tuple[Callable[[type[CrewClass]], None], ...] = (
)
_METHODS_TO_INJECT = (
_close_mcp_server,
close_mcp_server,
get_mcp_tools,
_load_config,
load_configurations,

View File

@@ -12,11 +12,12 @@ from typing import (
Literal,
ParamSpec,
Protocol,
Self,
TypedDict,
TypeVar,
)
from typing_extensions import Self
if TYPE_CHECKING:
from crewai import Agent, Task
from crewai.crews.crew_output import CrewOutput
@@ -81,7 +82,7 @@ class CrewInstance(Protocol):
def load_configurations(self) -> None: ...
def map_all_agent_variables(self) -> None: ...
def map_all_task_variables(self) -> None: ...
def _close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ...
def close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ...
def _load_config(
self, config_path: str | None, config_type: Literal["agent", "task"]
) -> dict[str, Any]: ...
@@ -118,7 +119,7 @@ class CrewClass(Protocol):
original_tasks_config_path: str
mcp_server_params: Any
mcp_connect_timeout: int
_close_mcp_server: Callable[..., Any]
close_mcp_server: Callable[..., Any]
get_mcp_tools: Callable[..., list[BaseTool]]
_load_config: Callable[..., dict[str, Any]]
load_configurations: Callable[..., None]