Compare commits

...

2 Commits

Author SHA1 Message Date
Greyson LaLonde
948c90c52e Merge branch 'main' into chore/mypy-strict-cleanup 2026-05-22 03:45:34 +08:00
Greyson LaLonde
b7cf1f0148 chore: tighten mypy strict mode and remove dead code
Enable warn_unreachable, extra_checks, local_partial_types in pyproject.
Remove dead defensive branches and AI-slop union members; replace narrow
band-aid type:ignore with proper signature widening or targeted ignores
for genuine runtime-defensive paths (double-checked locking, hook misuse,
unfollowed-import boundaries).
2026-05-22 03:38:39 +08:00
39 changed files with 242 additions and 484 deletions

View File

@@ -232,9 +232,6 @@ class A2UIClientExtension:
continue
data = root.data
if not isinstance(data, dict):
continue
surface_id = _get_surface_id(data)
if not surface_id:
continue

View File

@@ -258,8 +258,6 @@ def validate_catalog_components_v09(message: A2UIMessageV09) -> None:
errors: list[Any] = []
for entry in message.update_components.components:
if not isinstance(entry, dict):
continue
type_name = entry.get("component")
if not isinstance(type_name, str):
continue

View File

@@ -178,7 +178,7 @@ class StreamingHandler:
is_final=is_final_update,
)
elif isinstance(event, Message):
elif isinstance(event, Message): # type: ignore[unreachable]
new_messages.append(event)
result_parts.extend(
part.root.text

View File

@@ -317,9 +317,7 @@ def get_part_content_type(part: Part) -> str:
if mime == APPLICATION_A2UI_JSON:
return APPLICATION_A2UI_JSON
return APPLICATION_JSON
if root.kind == "file":
return root.file.mime_type or APPLICATION_OCTET_STREAM
return APPLICATION_OCTET_STREAM
return root.file.mime_type or APPLICATION_OCTET_STREAM
def validate_message_parts(

View File

@@ -112,9 +112,6 @@ class BaseConverterAdapter(ABC):
Returns:
Extracted JSON string if found and valid, otherwise the original result.
"""
if not isinstance(result, str):
return str(result)
if valid := BaseConverterAdapter._validate_json(result):
return valid

View File

@@ -46,8 +46,8 @@ class LangGraphToolAdapter(BaseToolAdapter):
else:
all_tools = tools
for tool in all_tools:
if isinstance(tool, LangChainBaseTool):
converted_tools.append(tool)
if isinstance(tool, LangChainBaseTool): # type: ignore[unreachable]
converted_tools.append(tool) # type: ignore[unreachable]
continue
sanitized_name: str = self.sanitize_tool_name(tool.name)

View File

@@ -4,7 +4,6 @@ This module contains the OpenAIAgentToolAdapter class that converts CrewAI tools
to OpenAI Assistant-compatible format using the agents library.
"""
from collections.abc import Awaitable
import inspect
import json
from typing import Any, cast
@@ -114,12 +113,8 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
else:
args_dict = {param_name: str(arguments)}
output: Any | Awaitable[Any] = tool._run(**args_dict)
if inspect.isawaitable(output):
result: Any = await output
else:
result = output
output: Any = tool._run(**args_dict)
result: Any = await output if inspect.isawaitable(output) else output
if isinstance(result, (dict, list, str, int, float, bool, type(None))):
return result

View File

@@ -569,9 +569,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
if not self._token_process:
self._token_process = TokenProcess()
if self.security_config is None:
self.security_config = SecurityConfig()
return self
@field_validator("id", mode="before")
@@ -621,7 +618,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
) -> str | BaseModel:
pass
@abstractmethod
@@ -630,7 +627,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
) -> str | BaseModel:
"""Execute a task asynchronously."""
@abstractmethod

View File

@@ -334,7 +334,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
Returns:
Final answer from the agent.
"""
formatted_answer = None
formatted_answer: AgentAction | AgentFinish | None = None
while not isinstance(formatted_answer, AgentFinish):
try:
if has_reached_max_iterations(self.iterations, self.max_iter):
@@ -385,12 +385,12 @@ class CrewAgentExecutor(BaseAgentExecutor):
)
formatted_answer = process_llm_response(
answer_str, self.use_stop_words
) # type: ignore[assignment]
)
else:
answer_str = str(answer) if not isinstance(answer, str) else answer
formatted_answer = process_llm_response(
answer_str, self.use_stop_words
) # type: ignore[assignment]
)
if isinstance(formatted_answer, AgentAction):
fingerprint_context = {}
@@ -425,7 +425,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
self._append_message(formatted_answer.text)
except OutputParserError as e:
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
formatted_answer = handle_output_parser_exception(
e=e,
messages=self.messages,
iterations=self.iterations,
@@ -1145,7 +1145,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
Returns:
Final answer from the agent.
"""
formatted_answer = None
formatted_answer: AgentAction | AgentFinish | None = None
while not isinstance(formatted_answer, AgentFinish):
try:
if has_reached_max_iterations(self.iterations, self.max_iter):
@@ -1197,12 +1197,12 @@ class CrewAgentExecutor(BaseAgentExecutor):
)
formatted_answer = process_llm_response(
answer_str, self.use_stop_words
) # type: ignore[assignment]
)
else:
answer_str = str(answer) if not isinstance(answer, str) else answer
formatted_answer = process_llm_response(
answer_str, self.use_stop_words
) # type: ignore[assignment]
)
if isinstance(formatted_answer, AgentAction):
fingerprint_context = {}
@@ -1237,7 +1237,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
self._append_message(formatted_answer.text)
except OutputParserError as e:
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
formatted_answer = handle_output_parser_exception(
e=e,
messages=self.messages,
iterations=self.iterations,

View File

@@ -308,15 +308,11 @@ class StepExecutor:
if isinstance(formatted, AgentFinish):
return str(formatted.output)
if isinstance(formatted, AgentAction):
tool_calls_made.append(formatted.tool)
tool_result = self._execute_text_tool_with_events(formatted)
last_tool_result = tool_result
messages.append({"role": "assistant", "content": answer_str})
messages.append(self._build_observation_message(tool_result))
continue
return answer_str
tool_calls_made.append(formatted.tool)
tool_result = self._execute_text_tool_with_events(formatted)
last_tool_result = tool_result
messages.append({"role": "assistant", "content": answer_str})
messages.append(self._build_observation_message(tool_result))
return last_tool_result

View File

@@ -39,10 +39,7 @@ class ToolsHandler(BaseModel):
if self.cache and should_cache and calling.tool_name != CacheTools().name:
input_str = ""
if calling.arguments:
if isinstance(calling.arguments, dict):
input_str = json.dumps(calling.arguments)
else:
input_str = str(calling.arguments)
input_str = json.dumps(calling.arguments)
self.cache.add(
tool=calling.tool_name,

View File

@@ -166,7 +166,7 @@ class CrewAIEventsBus:
with self._instance_lock:
if self._executor_initialized:
return
return # type: ignore[unreachable]
self._sync_executor = ThreadPoolExecutor(
max_workers=10,
@@ -304,7 +304,7 @@ class CrewAIEventsBus:
from crewai import RuntimeState
if RuntimeState is None:
logger.warning(
logger.warning( # type: ignore[unreachable]
"RuntimeState unavailable; skipping entity registration."
)
return
@@ -428,7 +428,7 @@ class CrewAIEventsBus:
if cached_plan is None:
with self._rwlock.w_locked():
if self._shutting_down:
return
return # type: ignore[unreachable]
cached_plan = self._execution_plan_cache.get(event_type)
if cached_plan is None:
sync_handlers = self._sync_handlers.get(event_type, frozenset())

View File

@@ -291,7 +291,9 @@ class TraceBatchManager:
)
if response is None:
logger.warning("Failed to send trace events. Events will be lost.")
logger.warning( # type: ignore[unreachable]
"Failed to send trace events. Events will be lost."
)
return 500
if response.status_code in [200, 201]:

View File

@@ -2616,7 +2616,9 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
with _llm_stop_words_applied(self.llm, self):
self.kickoff()
formatted_answer = self.state.current_answer
formatted_answer: AgentAction | AgentFinish | None = (
self.state.current_answer
)
if not isinstance(formatted_answer, AgentFinish):
raise RuntimeError(
@@ -2717,7 +2719,9 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
with _llm_stop_words_applied(self.llm, self):
await self.kickoff_async()
formatted_answer = self.state.current_answer
formatted_answer: AgentAction | AgentFinish | None = (
self.state.current_answer
)
if not isinstance(formatted_answer, AgentFinish):
raise RuntimeError(

View File

@@ -962,7 +962,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
model_config = ConfigDict(
arbitrary_types_allowed=True,
ignored_types=(StartMethod, ListenMethod, RouterMethod),
ignored_types=(FlowMethod,),
revalidate_instances="never",
)
__hash__ = object.__hash__
@@ -3009,8 +3009,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self._pending_and_listeners.pop(pending_key, None)
return True
return False
return False
def _find_triggered_methods(

View File

@@ -28,7 +28,7 @@ import inspect
import logging
import re
import textwrap
from typing import Any, TypedDict, get_args, get_origin
from typing import Any, Literal, TypeAlias, TypedDict, get_args, get_origin
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
@@ -44,6 +44,8 @@ from crewai.flow.flow_wrappers import (
logger = logging.getLogger(__name__)
MethodType: TypeAlias = Literal["start", "listen", "router", "start_router"]
class MethodInfo(TypedDict, total=False):
"""Information about a single flow method.
@@ -59,7 +61,7 @@ class MethodInfo(TypedDict, total=False):
"""
name: str
type: str
type: MethodType
trigger_methods: list[str]
condition_type: str | None
router_paths: list[str]
@@ -132,7 +134,7 @@ def _get_method_type(
method: Any,
start_methods: list[str],
routers: set[str],
) -> str:
) -> MethodType:
"""Determine the type of a flow method.
Args:
@@ -191,7 +193,6 @@ def _detect_crew_reference(method: Any) -> bool:
True if crew reference detected, False otherwise.
"""
try:
# Get the underlying function from wrapper
func = method
if hasattr(method, "_meth"):
func = method._meth
@@ -201,12 +202,11 @@ def _detect_crew_reference(method: Any) -> bool:
source = inspect.getsource(func)
source = textwrap.dedent(source)
# Patterns that indicate Crew usage
crew_patterns = [
r"\.crew\(\)", # .crew() method call
r"Crew\s*\(", # Crew( instantiation
r":\s*Crew\b", # Type hint with Crew
r"->.*Crew", # Return type hint with Crew
r"\.crew\(\)",
r"Crew\s*\(",
r":\s*Crew\b",
r"->.*Crew",
]
for pattern in crew_patterns:
@@ -215,7 +215,6 @@ def _detect_crew_reference(method: Any) -> bool:
return False
except (OSError, TypeError):
# Can't get source code - assume no crew reference
return False
@@ -231,11 +230,9 @@ def _extract_trigger_methods(method: Any) -> tuple[list[str], str | None]:
trigger_methods: list[str] = []
condition_type: str | None = None
# First try __trigger_methods__ (populated for simple conditions)
if hasattr(method, "__trigger_methods__") and method.__trigger_methods__:
trigger_methods = [str(m) for m in method.__trigger_methods__]
# For complex conditions (or_/and_ combinators), extract from __trigger_condition__
if (
not trigger_methods
and hasattr(method, "__trigger_condition__")
@@ -264,11 +261,9 @@ def _extract_router_paths(
"""
method_name = getattr(method, "__name__", "")
# First check if there are __router_paths__ on the method itself
if hasattr(method, "__router_paths__") and method.__router_paths__:
return [str(p) for p in method.__router_paths__]
# Then check the class-level registry
if method_name in router_paths_registry:
return [str(p) for p in router_paths_registry[method_name]]
@@ -276,39 +271,15 @@ def _extract_router_paths(
def _extract_all_methods_from_condition(
condition: str | FlowCondition | dict[str, Any] | list[Any],
condition: str | FlowCondition,
) -> list[str]:
"""Extract all method names from a condition tree recursively.
Args:
condition: Can be a string, FlowCondition tuple, dict, or list.
Returns:
List of all method names found in the condition.
"""
"""Extract all method names from a condition tree recursively."""
if isinstance(condition, str):
return [condition]
if isinstance(condition, tuple) and len(condition) == 2:
# FlowCondition: (condition_type, methods_list)
_, methods = condition
if isinstance(methods, list):
result: list[str] = []
for m in methods:
result.extend(_extract_all_methods_from_condition(m))
return result
return []
if isinstance(condition, dict):
conditions_list = condition.get("conditions", [])
dict_methods: list[str] = []
for sub_cond in conditions_list:
dict_methods.extend(_extract_all_methods_from_condition(sub_cond))
return dict_methods
if isinstance(condition, list):
list_methods: list[str] = []
for item in condition:
list_methods.extend(_extract_all_methods_from_condition(item))
return list_methods
return []
methods: list[str] = []
for sub_cond in condition.get("conditions", []):
methods.extend(_extract_all_methods_from_condition(sub_cond))
return methods
def _generate_edges(
@@ -330,7 +301,6 @@ def _generate_edges(
"""
edges: list[EdgeInfo] = []
# Generate edges from listeners (listen edges)
for listener_name, condition_data in listeners.items():
trigger_methods: list[str] = []
@@ -340,7 +310,6 @@ def _generate_edges(
elif isinstance(condition_data, dict):
trigger_methods = _extract_all_methods_from_condition(condition_data)
# Create edges from each trigger to the listener
edges.extend(
EdgeInfo(
from_method=trigger,
@@ -352,10 +321,8 @@ def _generate_edges(
if trigger in all_methods
)
# Generate edges from routers (route edges)
for router_name, paths in router_paths.items():
for path in paths:
# Find listeners that listen to this path
for listener_name, condition_data in listeners.items():
path_triggers: list[str] = []
@@ -393,11 +360,9 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
"""
state_type: type | None = None
# Check for _initial_state_t set by __class_getitem__
if hasattr(flow_class, "_initial_state_t"):
state_type = flow_class._initial_state_t
# Check initial_state class attribute
if state_type is None and hasattr(flow_class, "initial_state"):
initial_state = flow_class.initial_state
if isinstance(initial_state, type) and issubclass(initial_state, BaseModel):
@@ -405,7 +370,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
elif isinstance(initial_state, BaseModel):
state_type = type(initial_state)
# Check __orig_bases__ for generic parameters
if state_type is None and hasattr(flow_class, "__orig_bases__"):
for base in flow_class.__orig_bases__:
origin = get_origin(base)
@@ -420,7 +384,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
if state_type is None or not issubclass(state_type, BaseModel):
return None
# Extract fields from the Pydantic model
fields: list[StateFieldInfo] = []
try:
model_fields = state_type.model_fields
@@ -428,7 +391,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
field_type_str = "Any"
if field_info.annotation is not None:
field_type_str = str(field_info.annotation)
# Clean up the type string
field_type_str = field_type_str.replace("typing.", "")
field_type_str = field_type_str.replace("<class '", "").replace(
"'>", ""
@@ -441,7 +403,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
and not callable(field_info.default)
):
try:
# Try to serialize the default value
default_value = field_info.default
except Exception:
default_value = str(field_info.default)
@@ -474,7 +435,6 @@ def _detect_flow_inputs(flow_class: type) -> list[str]:
"""
inputs: list[str] = []
# Check for inputs in __init__ signature beyond standard Flow params
try:
init_method = flow_class.__init__ # type: ignore[misc]
init_sig = inspect.signature(init_method)
@@ -533,7 +493,6 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
f"Got {type(flow_class).__name__}"
)
# Get class-level metadata set by FlowMeta
start_methods: list[str] = getattr(flow_class, "_start_methods", [])
listeners: dict[str, Any] = getattr(flow_class, "_listeners", {})
routers: set[str] = getattr(flow_class, "_routers", set())
@@ -541,7 +500,6 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
flow_class, "_router_paths", {}
)
# Collect all flow methods
methods: list[MethodInfo] = []
all_method_names: set[str] = set()
@@ -554,7 +512,6 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
except AttributeError:
continue
# Check if it's a flow method
is_flow_method = (
isinstance(attr, (FlowMethod, StartMethod, ListenMethod, RouterMethod))
or hasattr(attr, "__is_flow_method__")
@@ -568,21 +525,16 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
all_method_names.add(attr_name)
# Get method type
method_type = _get_method_type(attr_name, attr, start_methods, routers)
# Get trigger methods and condition type
trigger_methods, condition_type = _extract_trigger_methods(attr)
# Get router paths if applicable
router_paths_list: list[str] = []
if method_type in ("router", "start_router"):
router_paths_list = _extract_router_paths(attr, router_paths_registry)
# Check for human feedback
has_hf = _has_human_feedback(attr)
# Check for crew reference
has_crew = _detect_crew_reference(attr)
method_info = MethodInfo(
@@ -596,16 +548,10 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
)
methods.append(method_info)
# Generate edges
edges = _generate_edges(listeners, routers, router_paths_registry, all_method_names)
# Extract state schema
state_schema = _extract_state_schema(flow_class)
# Detect inputs
inputs = _detect_flow_inputs(flow_class)
# Get flow description from docstring
description: str | None = None
if flow_class.__doc__:
description = flow_class.__doc__.strip()

View File

@@ -46,6 +46,8 @@ class FlowMethod(Generic[P, R]):
both bound (instance) and unbound (class) method states.
"""
__is_flow_method__: bool = True
def __init__(self, meth: Callable[P, R], instance: Any = None) -> None:
"""Initialize the flow method wrapper.
@@ -53,9 +55,9 @@ class FlowMethod(Generic[P, R]):
meth: The method to wrap.
instance: The instance to bind to (None for unbound).
"""
functools.update_wrapper(self, meth)
self._meth = meth
self._instance = instance
functools.update_wrapper(self, meth, updated=[])
self.__name__: FlowMethodName = FlowMethodName(self.__name__)
self.__signature__ = inspect.signature(meth)
@@ -70,16 +72,6 @@ class FlowMethod(Generic[P, R]):
self._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined]
# Preserve flow-related attributes from wrapped method (e.g., from @human_feedback)
for attr in [
"__is_router__",
"__router_paths__",
"__human_feedback_config__",
"_hf_llm", # Live LLM object for HITL resume
]:
if hasattr(meth, attr):
setattr(self, attr, getattr(meth, attr))
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Call the wrapped method.
@@ -102,6 +94,19 @@ class FlowMethod(Generic[P, R]):
"""
return self._meth
def __getattr__(self, name: str) -> Any:
"""Delegate missing attributes to the wrapped method.
Lets flow flags like ``__is_start_method__`` defined on the wrapped
method's class flow through transparently when this wrapper itself
wraps another FlowMethod.
"""
try:
meth = object.__getattribute__(self, "_meth")
except AttributeError:
raise AttributeError(name) from None
return getattr(meth, name)
def __get__(self, instance: Any, owner: type | None = None) -> Self:
"""Support the descriptor protocol for method binding.

View File

@@ -118,13 +118,11 @@ def _deserialize_llm_from_context(
if isinstance(llm_data, str):
return LLM(model=llm_data)
if isinstance(llm_data, dict):
data = dict(llm_data)
model = data.pop("model", None)
if not model:
return None
return LLM(model=model, **data)
return None
data = dict(llm_data)
model = data.pop("model", None)
if not model:
return None
return LLM(model=model, **data)
@dataclass
@@ -706,6 +704,6 @@ def human_feedback(
# instead of creating a bare LLM from just the model string.
wrapper._hf_llm = llm
return wrapper # type: ignore[no-any-return]
return HumanFeedbackMethod(wrapper) # type: ignore[return-value]
return decorator

View File

@@ -24,15 +24,16 @@ Example:
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
import functools
import inspect
import logging
from typing import TYPE_CHECKING, Any, Final, TypeVar, cast
from typing import TYPE_CHECKING, Any, Final, ParamSpec, TypeVar, cast
from crewai_core.printer import PRINTER
from pydantic import BaseModel
from crewai.flow.flow_wrappers import FlowMethod
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
@@ -42,6 +43,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
# Constants for log messages
@@ -134,9 +137,71 @@ class PersistenceDecorator:
raise ValueError(error_msg) from e
class PersistedFlowMethod(FlowMethod[P, R]):
"""FlowMethod variant that persists state after each invocation.
Wrapping the original method directly (rather than copying its attributes
onto a closure) lets ``FlowMethod.__getattr__`` delegate flow flags like
``__is_start_method__`` to the wrapped object transparently.
For async wrapped methods, ``R`` is the ``Coroutine`` returned by calling
them, so ``__call__``'s declared return type stays accurate in both cases.
"""
def __init__(
self,
meth: Callable[P, R],
instance: Any = None,
*,
persistence: FlowPersistence | None = None,
verbose: bool = False,
) -> None:
super().__init__(meth, instance)
self._persistence = persistence
self._verbose = verbose
def _resolve_flow_instance(self, args: tuple[Any, ...]) -> Any:
return (
self._instance
if self._instance is not None
else (args[0] if args else None)
)
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
if inspect.iscoroutinefunction(self._meth):
return cast(R, self._call_async(*args, **kwargs))
flow_instance = self._resolve_flow_instance(args)
result = super().__call__(*args, **kwargs)
PersistenceDecorator.persist_state(
flow_instance,
self.__name__,
cast(FlowPersistence, self._persistence),
self._verbose,
)
return result
async def _call_async(self, *args: Any, **kwargs: Any) -> Any:
flow_instance = self._resolve_flow_instance(args)
meth = cast(Callable[..., Awaitable[Any]], self._meth)
if self._instance is not None:
result = await meth(self._instance, *args, **kwargs)
else:
result = await meth(*args, **kwargs)
PersistenceDecorator.persist_state(
flow_instance,
self.__name__,
cast(FlowPersistence, self._persistence),
self._verbose,
)
return result
def persist(
persistence: FlowPersistence | None = None, verbose: bool = False
) -> Callable[[type | Callable[..., T]], type | Callable[..., T]]:
) -> Callable[
[type[Flow[Any]] | Callable[..., T]],
type[Flow[Any]] | Callable[..., T],
]:
"""Decorator to persist flow state.
This decorator can be applied at either the class level or method level.
@@ -164,150 +229,41 @@ def persist(
pass
"""
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
def decorator(
target: type[Flow[Any]] | Callable[..., T],
) -> type[Flow[Any]] | Callable[..., T]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()
if isinstance(target, type):
# Class decoration
original_init = target.__init__ # type: ignore[misc]
@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
def new_init(self: Flow[Any], *args: Any, **kwargs: Any) -> None:
if "persistence" not in kwargs:
kwargs["persistence"] = actual_persistence
original_init(self, *args, **kwargs)
target.__init__ = new_init # type: ignore[misc]
# Store original methods to preserve their decorators
original_methods = {
name: method
for name, method in target.__dict__.items()
if callable(method)
and (
hasattr(method, "__is_start_method__")
or hasattr(method, "__trigger_methods__")
or hasattr(method, "__condition_type__")
or hasattr(method, "__is_flow_method__")
or hasattr(method, "__is_router__")
for name, method in list(target.__dict__.items()):
if not isinstance(method, FlowMethod):
continue
setattr(
target,
name,
PersistedFlowMethod(
method, persistence=actual_persistence, verbose=verbose
),
)
}
# Create wrapped versions of the methods that include persistence
for name, method in original_methods.items():
if asyncio.iscoroutinefunction(method):
# Create a closure to capture the current name and method
def create_async_wrapper(
method_name: str, original_method: Callable[..., Any]
) -> Callable[..., Any]:
@functools.wraps(original_method)
async def method_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> Any:
result = await original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(
self, method_name, actual_persistence, verbose
)
return result
return method_wrapper
wrapped = create_async_wrapper(name, method)
# Preserve all original decorators and attributes
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
# Update the class with the wrapped method
setattr(target, name, wrapped)
else:
# Create a closure to capture the current name and method
def create_sync_wrapper(
method_name: str, original_method: Callable[..., Any]
) -> Callable[..., Any]:
@functools.wraps(original_method)
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(
self, method_name, actual_persistence, verbose
)
return result
return method_wrapper
wrapped = create_sync_wrapper(name, method)
# Preserve all original decorators and attributes
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
# Update the class with the wrapped method
setattr(target, name, wrapped)
return target
# Method decoration
method = target
method.__is_flow_method__ = True # type: ignore[attr-defined]
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(
flow_instance: Any, *args: Any, **kwargs: Any
) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(
flow_instance, method.__name__, actual_persistence, verbose
)
return cast(T, result)
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
return cast(Callable[..., T], method_async_wrapper)
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(
flow_instance, method.__name__, actual_persistence, verbose
)
return result
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
return cast(Callable[..., T], method_sync_wrapper)
return cast(
Callable[..., T],
PersistedFlowMethod(
target, persistence=actual_persistence, verbose=verbose
),
)
return decorator

View File

@@ -26,7 +26,7 @@ if TYPE_CHECKING:
def _extract_direct_or_triggers(
condition: str | dict[str, Any] | list[Any] | FlowCondition,
condition: str | FlowCondition,
) -> list[str]:
"""Extract direct OR-level trigger strings from a condition.
@@ -39,36 +39,22 @@ def _extract_direct_or_triggers(
- and_("a", "b") -> [] (neither are direct triggers, both required)
- or_(and_("a", "b"), "c") -> ["c"] (only "c" is a direct trigger)
Args:
condition: Can be a string, dict, or list.
Returns:
List of direct OR-level trigger strings.
"""
if isinstance(condition, str):
return [condition]
if isinstance(condition, dict):
cond_type = condition.get("type", OR_CONDITION)
conditions_list = condition.get("conditions", [])
if cond_type == OR_CONDITION:
strings = []
for sub_cond in conditions_list:
strings.extend(_extract_direct_or_triggers(sub_cond))
return strings
cond_type = condition.get("type", OR_CONDITION)
if cond_type != OR_CONDITION:
return []
if isinstance(condition, list):
strings = []
for item in condition:
strings.extend(_extract_direct_or_triggers(item))
return strings
if callable(condition) and hasattr(condition, "__name__"):
return [condition.__name__]
return []
strings: list[str] = []
for sub_cond in condition.get("conditions", []):
strings.extend(_extract_direct_or_triggers(sub_cond))
return strings
def _extract_all_trigger_names(
condition: str | dict[str, Any] | list[Any] | FlowCondition,
condition: str | FlowCondition,
) -> list[str]:
"""Extract ALL trigger names from a condition for display purposes.
@@ -81,50 +67,26 @@ def _extract_all_trigger_names(
- and_("a", "b") -> ["a", "b"]
- or_(and_("a", method_6), method_4) -> ["a", "method_6", "method_4"]
Args:
condition: Can be a string, dict, or list.
Returns:
List of all trigger names found in the condition.
"""
if isinstance(condition, str):
return [condition]
if isinstance(condition, dict):
conditions_list = condition.get("conditions", [])
strings = []
for sub_cond in conditions_list:
strings.extend(_extract_all_trigger_names(sub_cond))
return strings
if isinstance(condition, list):
strings = []
for item in condition:
strings.extend(_extract_all_trigger_names(item))
return strings
if callable(condition) and hasattr(condition, "__name__"):
return [condition.__name__]
return []
strings: list[str] = []
for sub_cond in condition.get("conditions", []):
strings.extend(_extract_all_trigger_names(sub_cond))
return strings
def _create_edges_from_condition(
condition: str | dict[str, Any] | list[Any] | FlowCondition,
condition: str | FlowCondition,
target: str,
nodes: dict[str, NodeMetadata],
) -> list[StructureEdge]:
"""Create edges from a condition tree, preserving AND/OR semantics.
This function recursively processes the condition tree and creates edges
with the appropriate condition_type for each trigger.
For AND conditions, all triggers get edges with condition_type="AND".
For OR conditions, triggers get edges with condition_type="OR".
Args:
condition: The condition tree (string, dict, or list).
target: The target node name.
nodes: Dictionary of all nodes for validation.
Returns:
List of StructureEdge objects representing the condition.
"""
edges: list[StructureEdge] = []
@@ -138,39 +100,24 @@ def _create_edges_from_condition(
is_router_path=False,
)
)
elif callable(condition) and hasattr(condition, "__name__"):
method_name = condition.__name__
if method_name in nodes:
edges.append(
StructureEdge(
source=method_name,
target=target,
condition_type=OR_CONDITION,
is_router_path=False,
)
)
elif isinstance(condition, dict):
cond_type = condition.get("type", OR_CONDITION)
conditions_list = condition.get("conditions", [])
return edges
if cond_type == AND_CONDITION:
triggers = _extract_all_trigger_names(condition)
edges.extend(
StructureEdge(
source=trigger,
target=target,
condition_type=AND_CONDITION,
is_router_path=False,
)
for trigger in triggers
if trigger in nodes
cond_type = condition.get("type", OR_CONDITION)
if cond_type == AND_CONDITION:
triggers = _extract_all_trigger_names(condition)
edges.extend(
StructureEdge(
source=trigger,
target=target,
condition_type=AND_CONDITION,
is_router_path=False,
)
else:
for sub_cond in conditions_list:
edges.extend(_create_edges_from_condition(sub_cond, target, nodes))
elif isinstance(condition, list):
for item in condition:
edges.extend(_create_edges_from_condition(item, target, nodes))
for trigger in triggers
if trigger in nodes
)
else:
for sub_cond in condition.get("conditions", []):
edges.extend(_create_edges_from_condition(sub_cond, target, nodes))
return edges

View File

@@ -92,8 +92,6 @@ class CrewDoclingSource(BaseKnowledgeSource):
raise e
def add(self) -> None:
if self.content is None:
return
for doc in self.content:
new_chunks_iterable = self._chunk_doc(doc)
self.chunks.extend(list(new_chunks_iterable))
@@ -101,8 +99,6 @@ class CrewDoclingSource(BaseKnowledgeSource):
async def aadd(self) -> None:
"""Add docling content asynchronously."""
if self.content is None:
return
for doc in self.content:
new_chunks_iterable = self._chunk_doc(doc)
self.chunks.extend(list(new_chunks_iterable))

View File

@@ -155,11 +155,8 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
# Updated to account for .xlsx workbooks with multiple tabs/sheets
content_str = ""
for value in self.content.values():
if isinstance(value, dict):
for sheet_value in value.values():
content_str += str(sheet_value) + "\n"
else:
content_str += str(value) + "\n"
for sheet_value in value.values():
content_str += str(sheet_value) + "\n"
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
@@ -169,11 +166,8 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
"""Add Excel file content asynchronously."""
content_str = ""
for value in self.content.values():
if isinstance(value, dict):
for sheet_value in value.values():
content_str += str(sheet_value) + "\n"
else:
content_str += str(value) + "\n"
for sheet_value in value.values():
content_str += str(sheet_value) + "\n"
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)

View File

@@ -484,10 +484,7 @@ class AnthropicCompletion(BaseLLM):
params["tool_choice"] = {"type": "tool", "name": tool_name}
if self.thinking:
if isinstance(self.thinking, AnthropicThinkingConfig):
params["thinking"] = self.thinking.model_dump()
else:
params["thinking"] = self.thinking
params["thinking"] = self.thinking.model_dump()
return params

View File

@@ -582,19 +582,16 @@ class GeminiCompletion(BaseLLM):
parts: list[types.Part] = []
if isinstance(content, list):
for item in content:
if isinstance(item, dict):
if "text" in item:
parts.append(types.Part.from_text(text=str(item["text"])))
elif "inlineData" in item:
inline = item["inlineData"]
parts.append(
types.Part.from_bytes(
data=base64.b64decode(inline["data"]),
mime_type=inline["mimeType"],
)
if "text" in item:
parts.append(types.Part.from_text(text=str(item["text"])))
elif "inlineData" in item:
inline = item["inlineData"]
parts.append(
types.Part.from_bytes(
data=base64.b64decode(inline["data"]),
mime_type=inline["mimeType"],
)
else:
parts.append(types.Part.from_text(text=str(item)))
)
else:
parts.append(types.Part.from_text(text=str(content) if content else ""))

View File

@@ -798,10 +798,7 @@ class OpenAICompletion(BaseLLM):
}
if parameters:
if isinstance(parameters, dict):
responses_tool["parameters"] = parameters
else:
responses_tool["parameters"] = dict(parameters)
responses_tool["parameters"] = parameters
responses_tools.append(responses_tool)

View File

@@ -383,22 +383,19 @@ class MCPToolResolver:
if mcp_config.tool_filter:
filtered_tools = []
for tool in tools_list:
if callable(mcp_config.tool_filter):
try:
from crewai.mcp.filters import ToolFilterContext
try:
from crewai.mcp.filters import ToolFilterContext
context = ToolFilterContext(
agent=self._agent,
server_name=server_name,
run_context=None,
)
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
except (TypeError, AttributeError):
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
else:
filtered_tools.append(tool)
context = ToolFilterContext(
agent=self._agent,
server_name=server_name,
run_context=None,
)
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
except (TypeError, AttributeError): # noqa: PERF203
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
tools_list = filtered_tools
if not tools_list:

View File

@@ -194,9 +194,6 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
if self._use_legacy:
return self._call_legacy(input)
return self._call_genai(input)

View File

@@ -54,9 +54,6 @@ class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
"Install it with: uv add ibm-watsonx-ai"
) from e
if isinstance(input, str):
input = [input]
embeddings_config: dict[str, Any] = {
"model_id": self._config["model_id"],
}

View File

@@ -47,9 +47,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
result = self._client.embed(
texts=input,
model=self._config.get("model", "voyage-2"),

View File

@@ -47,7 +47,7 @@ def _ensure_handlers_registered() -> None:
return
with _register_lock:
if _handlers_registered:
return
return # type: ignore[unreachable]
_register_all_handlers(crewai_event_bus)
_handlers_registered = True

View File

@@ -1159,12 +1159,10 @@ Follow these guidelines:
return model_output, None
if isinstance(model_output, dict):
return None, model_output
if isinstance(model_output, str):
try:
return None, json.loads(model_output)
except json.JSONDecodeError:
return None, None
return None, None
try:
return None, json.loads(model_output)
except json.JSONDecodeError:
return None, None
def _get_output_format(self) -> OutputFormat:
if self.output_json:

View File

@@ -97,7 +97,7 @@ class Telemetry:
provider: OpenTelemetry tracer provider.
"""
_instance = None
_instance: Self | None = None
_lock = threading.Lock()
def __new__(cls) -> Self:
@@ -937,9 +937,6 @@ class Telemetry:
value: The attribute value.
"""
if span is None:
return
def _operation() -> None:
return span.set_attribute(key, value)

View File

@@ -122,7 +122,8 @@ class BaseAgentTool(BaseTool):
logger.debug(
f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}"
)
return selected_agent.execute_task(task_with_assigned_agent, context)
result = selected_agent.execute_task(task_with_assigned_agent, context)
return result if isinstance(result, str) else result.model_dump_json()
except Exception as e:
# Handle task creation or execution errors
return I18N_DEFAULT.errors("agent_tool_execution_error").format(

View File

@@ -421,28 +421,10 @@ class BaseTool(BaseModel, ABC):
)
def _set_args_schema(self) -> None:
if self.args_schema is None:
run_sig = signature(self._run)
fields: dict[str, Any] = {}
"""No-op retained for backward compatibility.
for param_name, param in run_sig.parameters.items():
if param_name in ("self", "return"):
continue
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
continue
annotation = (
param.annotation if param.annotation != param.empty else Any
)
if param.default is param.empty:
fields[param_name] = (annotation, ...)
else:
fields[param_name] = (annotation, param.default)
self.args_schema = create_model(
f"{self.__class__.__name__}Schema", **fields
)
Schema generation is performed by the ``args_schema`` field validator.
"""
def _generate_description(self) -> None:
"""Generate the tool description with a JSON schema for arguments."""

View File

@@ -274,10 +274,7 @@ class ToolUsage:
if self.tools_handler and self.tools_handler.cache:
input_str = ""
if calling.arguments:
if isinstance(calling.arguments, dict):
input_str = json.dumps(calling.arguments)
else:
input_str = str(calling.arguments)
input_str = json.dumps(calling.arguments)
result = self.tools_handler.cache.read(
tool=sanitize_tool_name(calling.tool_name), input=input_str
@@ -303,7 +300,7 @@ class ToolUsage:
result = self._format_result(result=result)
# Don't return early - fall through to finally block
elif result is None:
try:
try: # type: ignore[unreachable]
if sanitize_tool_name(calling.tool_name) in [
sanitize_tool_name("Delegate work to coworker"),
sanitize_tool_name("Ask question to coworker"),
@@ -507,10 +504,7 @@ class ToolUsage:
if self.tools_handler and self.tools_handler.cache:
input_str = ""
if calling.arguments:
if isinstance(calling.arguments, dict):
input_str = json.dumps(calling.arguments)
else:
input_str = str(calling.arguments)
input_str = json.dumps(calling.arguments)
result = self.tools_handler.cache.read(
tool=sanitize_tool_name(calling.tool_name), input=input_str
@@ -536,7 +530,7 @@ class ToolUsage:
result = self._format_result(result=result)
# Don't return early - fall through to finally block
elif result is None:
try:
try: # type: ignore[unreachable]
if sanitize_tool_name(calling.tool_name) in [
sanitize_tool_name("Delegate work to coworker"),
sanitize_tool_name("Ask question to coworker"),
@@ -826,11 +820,6 @@ class ToolUsage:
raise
return ToolUsageError(f"{I18N_DEFAULT.errors('tool_arguments_error')}")
if not isinstance(arguments, dict):
if raise_error:
raise
return ToolUsageError(f"{I18N_DEFAULT.errors('tool_arguments_error')}")
return ToolCalling(
tool_name=sanitize_tool_name(tool.name),
arguments=arguments,

View File

@@ -1679,7 +1679,7 @@ def _setup_before_llm_call_hooks(
)
if not isinstance(executor_context.messages, list):
if verbose:
if verbose: # type: ignore[unreachable]
printer.print(
content=(
"Warning: before_llm_call hook replaced messages with non-list. "
@@ -1742,7 +1742,7 @@ def _setup_after_llm_call_hooks(
)
if not isinstance(executor_context.messages, list):
if verbose:
if verbose: # type: ignore[unreachable]
printer.print(
content=(
"Warning: after_llm_call hook replaced messages with non-list. "

View File

@@ -161,15 +161,10 @@ def _llm_via_environment_or_fallback() -> LLM | None:
# Map environment variable names to recognized parameters
param_key = _normalize_key_name(key_name.lower())
llm_params[param_key] = env_value
elif isinstance(env_var, dict):
if env_var.get("default", False):
for key, value in env_var.items():
if key not in ["prompt", "key_name", "default"]:
llm_params[key.lower()] = value
else:
logger.debug(
f"Expected env_var to be a dictionary, but got {type(env_var)}"
)
elif env_var.get("default", False):
for key, value in env_var.items():
if key not in ["prompt", "key_name", "default"]:
llm_params[key.lower()] = value
llm_params = {k: v for k, v in llm_params.items() if v is not None}

View File

@@ -141,12 +141,12 @@ def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
def add_key_in_dict_recursively(
d: dict[str, Any],
d: Any,
key: str,
value: Any,
criteria: Callable[[dict[str, Any]], bool],
_seen: set[int] | None = None,
) -> dict[str, Any]:
) -> Any:
"""Recursively adds a key/value pair to all nested dicts matching `criteria`.
Args:
@@ -338,9 +338,6 @@ def add_const_to_oneof_variants(schema: dict[str, Any]) -> dict[str, Any]:
def _process_oneof(node: dict[str, Any]) -> dict[str, Any]:
"""Process a single node that might contain a oneOf with discriminator."""
if not isinstance(node, dict):
return node
if "oneOf" in node and "discriminator" in node:
discriminator = node["discriminator"]
property_name = discriminator.get("propertyName")
@@ -606,8 +603,6 @@ def sanitize_tool_params_for_openai_strict(
params: dict[str, Any],
) -> dict[str, Any]:
"""Sanitize a JSON schema for OpenAI strict function calling."""
if not isinstance(params, dict):
return params
return cast(
dict[str, Any], strip_unsupported_formats(_common_strict_pipeline(params))
)
@@ -617,8 +612,6 @@ def sanitize_tool_params_for_anthropic_strict(
params: dict[str, Any],
) -> dict[str, Any]:
"""Sanitize a JSON schema for Anthropic strict tool use."""
if not isinstance(params, dict):
return params
sanitized = lift_top_level_anyof(_common_strict_pipeline(params))
sanitized = _strip_keys_recursive(sanitized, _CLAUDE_STRICT_UNSUPPORTED)
return cast(dict[str, Any], strip_unsupported_formats(sanitized))

View File

@@ -124,8 +124,11 @@ disallow_any_unimported = true
no_implicit_optional = true
check_untyped_defs = true
warn_return_any = true
warn_unreachable = true
show_error_codes = true
warn_unused_ignores = true
local_partial_types = true
extra_checks = true
python_version = "3.12"
exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/|^lib/cli/src/crewai_cli/templates/|^lib/crewai/tests/|^lib/crewai-tools/tests/|^lib/crewai-files/tests/|^lib/cli/tests/|^lib/devtools/tests/)"
plugins = ["pydantic.mypy"]