fix: support to define MPC connection timeout on CrewBase instance (#3465)

* fix: support to define MPC connection timeout on CrewBase instance

* fix: resolve linter issues

* chore: ignore specific rule N802 on CrewBase class

* fix: ignore untyped import
This commit is contained in:
Lucas Gomide
2025-09-10 10:58:46 -03:00
committed by GitHub
parent 1dc4f2e897
commit 260b49c10a
5 changed files with 188 additions and 34 deletions

View File

@@ -1,12 +1,14 @@
import inspect
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Dict, TypeVar, cast, List
from crewai.tools import BaseTool
from typing import Any, TypeVar, cast
import yaml
from dotenv import load_dotenv
from crewai.tools import BaseTool
load_dotenv()
T = TypeVar("T", bound=type)
@@ -14,7 +16,7 @@ T = TypeVar("T", bound=type)
"""Base decorator for creating crew classes with configuration and function management."""
def CrewBase(cls: T) -> T:
def CrewBase(cls: T) -> T: # noqa: N802
"""Wraps a class with crew functionality and configuration management."""
class WrappedClass(cls): # type: ignore
@@ -29,6 +31,7 @@ def CrewBase(cls: T) -> T:
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
mcp_server_params: Any = getattr(cls, "mcp_server_params", None)
mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -86,15 +89,18 @@ def CrewBase(cls: T) -> T:
import types
return types.MethodType(_close_mcp_server, self)
def get_mcp_tools(self, *tool_names: list[str]) -> List[BaseTool]:
def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]:
if not self.mcp_server_params:
return []
from crewai_tools import MCPServerAdapter
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
adapter = getattr(self, '_mcp_server_adapter', None)
if not adapter:
self._mcp_server_adapter = MCPServerAdapter(self.mcp_server_params)
self._mcp_server_adapter = MCPServerAdapter(
self.mcp_server_params,
connect_timeout=self.mcp_connect_timeout
)
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
@@ -154,8 +160,8 @@ def CrewBase(cls: T) -> T:
}
def _filter_functions(
self, functions: Dict[str, Callable], attribute: str
) -> Dict[str, Callable]:
self, functions: dict[str, Callable], attribute: str
) -> dict[str, Callable]:
return {
name: func
for name, func in functions.items()
@@ -184,11 +190,11 @@ def CrewBase(cls: T) -> T:
def _map_agent_variables(
self,
agent_name: str,
agent_info: Dict[str, Any],
llms: Dict[str, Callable],
tool_functions: Dict[str, Callable],
cache_handler_functions: Dict[str, Callable],
callbacks: Dict[str, Callable],
agent_info: dict[str, Any],
llms: dict[str, Callable],
tool_functions: dict[str, Callable],
cache_handler_functions: dict[str, Callable],
callbacks: dict[str, Callable],
) -> None:
if llm := agent_info.get("llm"):
try:
@@ -245,13 +251,13 @@ def CrewBase(cls: T) -> T:
def _map_task_variables(
self,
task_name: str,
task_info: Dict[str, Any],
agents: Dict[str, Callable],
tasks: Dict[str, Callable],
output_json_functions: Dict[str, Callable],
tool_functions: Dict[str, Callable],
callback_functions: Dict[str, Callable],
output_pydantic_functions: Dict[str, Callable],
task_info: dict[str, Any],
agents: dict[str, Callable],
tasks: dict[str, Callable],
output_json_functions: dict[str, Callable],
tool_functions: dict[str, Callable],
callback_functions: dict[str, Callable],
output_pydantic_functions: dict[str, Callable],
) -> None:
if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [