mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-22 17:38:10 +00:00
Compare commits
2 Commits
docs/check
...
chore/mypy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
948c90c52e | ||
|
|
b7cf1f0148 |
@@ -232,9 +232,6 @@ class A2UIClientExtension:
|
||||
continue
|
||||
|
||||
data = root.data
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
|
||||
surface_id = _get_surface_id(data)
|
||||
if not surface_id:
|
||||
continue
|
||||
|
||||
@@ -258,8 +258,6 @@ def validate_catalog_components_v09(message: A2UIMessageV09) -> None:
|
||||
|
||||
errors: list[Any] = []
|
||||
for entry in message.update_components.components:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
type_name = entry.get("component")
|
||||
if not isinstance(type_name, str):
|
||||
continue
|
||||
|
||||
@@ -178,7 +178,7 @@ class StreamingHandler:
|
||||
is_final=is_final_update,
|
||||
)
|
||||
|
||||
elif isinstance(event, Message):
|
||||
elif isinstance(event, Message): # type: ignore[unreachable]
|
||||
new_messages.append(event)
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
|
||||
@@ -317,9 +317,7 @@ def get_part_content_type(part: Part) -> str:
|
||||
if mime == APPLICATION_A2UI_JSON:
|
||||
return APPLICATION_A2UI_JSON
|
||||
return APPLICATION_JSON
|
||||
if root.kind == "file":
|
||||
return root.file.mime_type or APPLICATION_OCTET_STREAM
|
||||
return APPLICATION_OCTET_STREAM
|
||||
return root.file.mime_type or APPLICATION_OCTET_STREAM
|
||||
|
||||
|
||||
def validate_message_parts(
|
||||
|
||||
@@ -112,9 +112,6 @@ class BaseConverterAdapter(ABC):
|
||||
Returns:
|
||||
Extracted JSON string if found and valid, otherwise the original result.
|
||||
"""
|
||||
if not isinstance(result, str):
|
||||
return str(result)
|
||||
|
||||
if valid := BaseConverterAdapter._validate_json(result):
|
||||
return valid
|
||||
|
||||
|
||||
@@ -46,8 +46,8 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
else:
|
||||
all_tools = tools
|
||||
for tool in all_tools:
|
||||
if isinstance(tool, LangChainBaseTool):
|
||||
converted_tools.append(tool)
|
||||
if isinstance(tool, LangChainBaseTool): # type: ignore[unreachable]
|
||||
converted_tools.append(tool) # type: ignore[unreachable]
|
||||
continue
|
||||
|
||||
sanitized_name: str = self.sanitize_tool_name(tool.name)
|
||||
|
||||
@@ -4,7 +4,6 @@ This module contains the OpenAIAgentToolAdapter class that converts CrewAI tools
|
||||
to OpenAI Assistant-compatible format using the agents library.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, cast
|
||||
@@ -114,12 +113,8 @@ class OpenAIAgentToolAdapter(BaseToolAdapter):
|
||||
else:
|
||||
args_dict = {param_name: str(arguments)}
|
||||
|
||||
output: Any | Awaitable[Any] = tool._run(**args_dict)
|
||||
|
||||
if inspect.isawaitable(output):
|
||||
result: Any = await output
|
||||
else:
|
||||
result = output
|
||||
output: Any = tool._run(**args_dict)
|
||||
result: Any = await output if inspect.isawaitable(output) else output
|
||||
|
||||
if isinstance(result, (dict, list, str, int, float, bool, type(None))):
|
||||
return result
|
||||
|
||||
@@ -569,9 +569,6 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
if self.security_config is None:
|
||||
self.security_config = SecurityConfig()
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@@ -621,7 +618,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
) -> str | BaseModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -630,7 +627,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
) -> str | BaseModel:
|
||||
"""Execute a task asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -334,7 +334,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
formatted_answer = None
|
||||
formatted_answer: AgentAction | AgentFinish | None = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||
@@ -385,12 +385,12 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
)
|
||||
formatted_answer = process_llm_response(
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
)
|
||||
else:
|
||||
answer_str = str(answer) if not isinstance(answer, str) else answer
|
||||
formatted_answer = process_llm_response(
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
)
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
fingerprint_context = {}
|
||||
@@ -425,7 +425,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
self._append_message(formatted_answer.text)
|
||||
|
||||
except OutputParserError as e:
|
||||
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||
formatted_answer = handle_output_parser_exception(
|
||||
e=e,
|
||||
messages=self.messages,
|
||||
iterations=self.iterations,
|
||||
@@ -1145,7 +1145,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
formatted_answer = None
|
||||
formatted_answer: AgentAction | AgentFinish | None = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||
@@ -1197,12 +1197,12 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
)
|
||||
formatted_answer = process_llm_response(
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
)
|
||||
else:
|
||||
answer_str = str(answer) if not isinstance(answer, str) else answer
|
||||
formatted_answer = process_llm_response(
|
||||
answer_str, self.use_stop_words
|
||||
) # type: ignore[assignment]
|
||||
)
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
fingerprint_context = {}
|
||||
@@ -1237,7 +1237,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
self._append_message(formatted_answer.text)
|
||||
|
||||
except OutputParserError as e:
|
||||
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||
formatted_answer = handle_output_parser_exception(
|
||||
e=e,
|
||||
messages=self.messages,
|
||||
iterations=self.iterations,
|
||||
|
||||
@@ -308,15 +308,11 @@ class StepExecutor:
|
||||
if isinstance(formatted, AgentFinish):
|
||||
return str(formatted.output)
|
||||
|
||||
if isinstance(formatted, AgentAction):
|
||||
tool_calls_made.append(formatted.tool)
|
||||
tool_result = self._execute_text_tool_with_events(formatted)
|
||||
last_tool_result = tool_result
|
||||
messages.append({"role": "assistant", "content": answer_str})
|
||||
messages.append(self._build_observation_message(tool_result))
|
||||
continue
|
||||
|
||||
return answer_str
|
||||
tool_calls_made.append(formatted.tool)
|
||||
tool_result = self._execute_text_tool_with_events(formatted)
|
||||
last_tool_result = tool_result
|
||||
messages.append({"role": "assistant", "content": answer_str})
|
||||
messages.append(self._build_observation_message(tool_result))
|
||||
|
||||
return last_tool_result
|
||||
|
||||
|
||||
@@ -39,10 +39,7 @@ class ToolsHandler(BaseModel):
|
||||
if self.cache and should_cache and calling.tool_name != CacheTools().name:
|
||||
input_str = ""
|
||||
if calling.arguments:
|
||||
if isinstance(calling.arguments, dict):
|
||||
input_str = json.dumps(calling.arguments)
|
||||
else:
|
||||
input_str = str(calling.arguments)
|
||||
input_str = json.dumps(calling.arguments)
|
||||
|
||||
self.cache.add(
|
||||
tool=calling.tool_name,
|
||||
|
||||
@@ -166,7 +166,7 @@ class CrewAIEventsBus:
|
||||
|
||||
with self._instance_lock:
|
||||
if self._executor_initialized:
|
||||
return
|
||||
return # type: ignore[unreachable]
|
||||
|
||||
self._sync_executor = ThreadPoolExecutor(
|
||||
max_workers=10,
|
||||
@@ -304,7 +304,7 @@ class CrewAIEventsBus:
|
||||
from crewai import RuntimeState
|
||||
|
||||
if RuntimeState is None:
|
||||
logger.warning(
|
||||
logger.warning( # type: ignore[unreachable]
|
||||
"RuntimeState unavailable; skipping entity registration."
|
||||
)
|
||||
return
|
||||
@@ -428,7 +428,7 @@ class CrewAIEventsBus:
|
||||
if cached_plan is None:
|
||||
with self._rwlock.w_locked():
|
||||
if self._shutting_down:
|
||||
return
|
||||
return # type: ignore[unreachable]
|
||||
cached_plan = self._execution_plan_cache.get(event_type)
|
||||
if cached_plan is None:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
|
||||
@@ -291,7 +291,9 @@ class TraceBatchManager:
|
||||
)
|
||||
|
||||
if response is None:
|
||||
logger.warning("Failed to send trace events. Events will be lost.")
|
||||
logger.warning( # type: ignore[unreachable]
|
||||
"Failed to send trace events. Events will be lost."
|
||||
)
|
||||
return 500
|
||||
|
||||
if response.status_code in [200, 201]:
|
||||
|
||||
@@ -2616,7 +2616,9 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
with _llm_stop_words_applied(self.llm, self):
|
||||
self.kickoff()
|
||||
|
||||
formatted_answer = self.state.current_answer
|
||||
formatted_answer: AgentAction | AgentFinish | None = (
|
||||
self.state.current_answer
|
||||
)
|
||||
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
@@ -2717,7 +2719,9 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
with _llm_stop_words_applied(self.llm, self):
|
||||
await self.kickoff_async()
|
||||
|
||||
formatted_answer = self.state.current_answer
|
||||
formatted_answer: AgentAction | AgentFinish | None = (
|
||||
self.state.current_answer
|
||||
)
|
||||
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -962,7 +962,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
ignored_types=(StartMethod, ListenMethod, RouterMethod),
|
||||
ignored_types=(FlowMethod,),
|
||||
revalidate_instances="never",
|
||||
)
|
||||
__hash__ = object.__hash__
|
||||
@@ -3009,8 +3009,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _find_triggered_methods(
|
||||
|
||||
@@ -28,7 +28,7 @@ import inspect
|
||||
import logging
|
||||
import re
|
||||
import textwrap
|
||||
from typing import Any, TypedDict, get_args, get_origin
|
||||
from typing import Any, Literal, TypeAlias, TypedDict, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefined
|
||||
@@ -44,6 +44,8 @@ from crewai.flow.flow_wrappers import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MethodType: TypeAlias = Literal["start", "listen", "router", "start_router"]
|
||||
|
||||
|
||||
class MethodInfo(TypedDict, total=False):
|
||||
"""Information about a single flow method.
|
||||
@@ -59,7 +61,7 @@ class MethodInfo(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
type: MethodType
|
||||
trigger_methods: list[str]
|
||||
condition_type: str | None
|
||||
router_paths: list[str]
|
||||
@@ -132,7 +134,7 @@ def _get_method_type(
|
||||
method: Any,
|
||||
start_methods: list[str],
|
||||
routers: set[str],
|
||||
) -> str:
|
||||
) -> MethodType:
|
||||
"""Determine the type of a flow method.
|
||||
|
||||
Args:
|
||||
@@ -191,7 +193,6 @@ def _detect_crew_reference(method: Any) -> bool:
|
||||
True if crew reference detected, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Get the underlying function from wrapper
|
||||
func = method
|
||||
if hasattr(method, "_meth"):
|
||||
func = method._meth
|
||||
@@ -201,12 +202,11 @@ def _detect_crew_reference(method: Any) -> bool:
|
||||
source = inspect.getsource(func)
|
||||
source = textwrap.dedent(source)
|
||||
|
||||
# Patterns that indicate Crew usage
|
||||
crew_patterns = [
|
||||
r"\.crew\(\)", # .crew() method call
|
||||
r"Crew\s*\(", # Crew( instantiation
|
||||
r":\s*Crew\b", # Type hint with Crew
|
||||
r"->.*Crew", # Return type hint with Crew
|
||||
r"\.crew\(\)",
|
||||
r"Crew\s*\(",
|
||||
r":\s*Crew\b",
|
||||
r"->.*Crew",
|
||||
]
|
||||
|
||||
for pattern in crew_patterns:
|
||||
@@ -215,7 +215,6 @@ def _detect_crew_reference(method: Any) -> bool:
|
||||
|
||||
return False
|
||||
except (OSError, TypeError):
|
||||
# Can't get source code - assume no crew reference
|
||||
return False
|
||||
|
||||
|
||||
@@ -231,11 +230,9 @@ def _extract_trigger_methods(method: Any) -> tuple[list[str], str | None]:
|
||||
trigger_methods: list[str] = []
|
||||
condition_type: str | None = None
|
||||
|
||||
# First try __trigger_methods__ (populated for simple conditions)
|
||||
if hasattr(method, "__trigger_methods__") and method.__trigger_methods__:
|
||||
trigger_methods = [str(m) for m in method.__trigger_methods__]
|
||||
|
||||
# For complex conditions (or_/and_ combinators), extract from __trigger_condition__
|
||||
if (
|
||||
not trigger_methods
|
||||
and hasattr(method, "__trigger_condition__")
|
||||
@@ -264,11 +261,9 @@ def _extract_router_paths(
|
||||
"""
|
||||
method_name = getattr(method, "__name__", "")
|
||||
|
||||
# First check if there are __router_paths__ on the method itself
|
||||
if hasattr(method, "__router_paths__") and method.__router_paths__:
|
||||
return [str(p) for p in method.__router_paths__]
|
||||
|
||||
# Then check the class-level registry
|
||||
if method_name in router_paths_registry:
|
||||
return [str(p) for p in router_paths_registry[method_name]]
|
||||
|
||||
@@ -276,39 +271,15 @@ def _extract_router_paths(
|
||||
|
||||
|
||||
def _extract_all_methods_from_condition(
|
||||
condition: str | FlowCondition | dict[str, Any] | list[Any],
|
||||
condition: str | FlowCondition,
|
||||
) -> list[str]:
|
||||
"""Extract all method names from a condition tree recursively.
|
||||
|
||||
Args:
|
||||
condition: Can be a string, FlowCondition tuple, dict, or list.
|
||||
|
||||
Returns:
|
||||
List of all method names found in the condition.
|
||||
"""
|
||||
"""Extract all method names from a condition tree recursively."""
|
||||
if isinstance(condition, str):
|
||||
return [condition]
|
||||
if isinstance(condition, tuple) and len(condition) == 2:
|
||||
# FlowCondition: (condition_type, methods_list)
|
||||
_, methods = condition
|
||||
if isinstance(methods, list):
|
||||
result: list[str] = []
|
||||
for m in methods:
|
||||
result.extend(_extract_all_methods_from_condition(m))
|
||||
return result
|
||||
return []
|
||||
if isinstance(condition, dict):
|
||||
conditions_list = condition.get("conditions", [])
|
||||
dict_methods: list[str] = []
|
||||
for sub_cond in conditions_list:
|
||||
dict_methods.extend(_extract_all_methods_from_condition(sub_cond))
|
||||
return dict_methods
|
||||
if isinstance(condition, list):
|
||||
list_methods: list[str] = []
|
||||
for item in condition:
|
||||
list_methods.extend(_extract_all_methods_from_condition(item))
|
||||
return list_methods
|
||||
return []
|
||||
methods: list[str] = []
|
||||
for sub_cond in condition.get("conditions", []):
|
||||
methods.extend(_extract_all_methods_from_condition(sub_cond))
|
||||
return methods
|
||||
|
||||
|
||||
def _generate_edges(
|
||||
@@ -330,7 +301,6 @@ def _generate_edges(
|
||||
"""
|
||||
edges: list[EdgeInfo] = []
|
||||
|
||||
# Generate edges from listeners (listen edges)
|
||||
for listener_name, condition_data in listeners.items():
|
||||
trigger_methods: list[str] = []
|
||||
|
||||
@@ -340,7 +310,6 @@ def _generate_edges(
|
||||
elif isinstance(condition_data, dict):
|
||||
trigger_methods = _extract_all_methods_from_condition(condition_data)
|
||||
|
||||
# Create edges from each trigger to the listener
|
||||
edges.extend(
|
||||
EdgeInfo(
|
||||
from_method=trigger,
|
||||
@@ -352,10 +321,8 @@ def _generate_edges(
|
||||
if trigger in all_methods
|
||||
)
|
||||
|
||||
# Generate edges from routers (route edges)
|
||||
for router_name, paths in router_paths.items():
|
||||
for path in paths:
|
||||
# Find listeners that listen to this path
|
||||
for listener_name, condition_data in listeners.items():
|
||||
path_triggers: list[str] = []
|
||||
|
||||
@@ -393,11 +360,9 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
"""
|
||||
state_type: type | None = None
|
||||
|
||||
# Check for _initial_state_t set by __class_getitem__
|
||||
if hasattr(flow_class, "_initial_state_t"):
|
||||
state_type = flow_class._initial_state_t
|
||||
|
||||
# Check initial_state class attribute
|
||||
if state_type is None and hasattr(flow_class, "initial_state"):
|
||||
initial_state = flow_class.initial_state
|
||||
if isinstance(initial_state, type) and issubclass(initial_state, BaseModel):
|
||||
@@ -405,7 +370,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
elif isinstance(initial_state, BaseModel):
|
||||
state_type = type(initial_state)
|
||||
|
||||
# Check __orig_bases__ for generic parameters
|
||||
if state_type is None and hasattr(flow_class, "__orig_bases__"):
|
||||
for base in flow_class.__orig_bases__:
|
||||
origin = get_origin(base)
|
||||
@@ -420,7 +384,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
if state_type is None or not issubclass(state_type, BaseModel):
|
||||
return None
|
||||
|
||||
# Extract fields from the Pydantic model
|
||||
fields: list[StateFieldInfo] = []
|
||||
try:
|
||||
model_fields = state_type.model_fields
|
||||
@@ -428,7 +391,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
field_type_str = "Any"
|
||||
if field_info.annotation is not None:
|
||||
field_type_str = str(field_info.annotation)
|
||||
# Clean up the type string
|
||||
field_type_str = field_type_str.replace("typing.", "")
|
||||
field_type_str = field_type_str.replace("<class '", "").replace(
|
||||
"'>", ""
|
||||
@@ -441,7 +403,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
and not callable(field_info.default)
|
||||
):
|
||||
try:
|
||||
# Try to serialize the default value
|
||||
default_value = field_info.default
|
||||
except Exception:
|
||||
default_value = str(field_info.default)
|
||||
@@ -474,7 +435,6 @@ def _detect_flow_inputs(flow_class: type) -> list[str]:
|
||||
"""
|
||||
inputs: list[str] = []
|
||||
|
||||
# Check for inputs in __init__ signature beyond standard Flow params
|
||||
try:
|
||||
init_method = flow_class.__init__ # type: ignore[misc]
|
||||
init_sig = inspect.signature(init_method)
|
||||
@@ -533,7 +493,6 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
f"Got {type(flow_class).__name__}"
|
||||
)
|
||||
|
||||
# Get class-level metadata set by FlowMeta
|
||||
start_methods: list[str] = getattr(flow_class, "_start_methods", [])
|
||||
listeners: dict[str, Any] = getattr(flow_class, "_listeners", {})
|
||||
routers: set[str] = getattr(flow_class, "_routers", set())
|
||||
@@ -541,7 +500,6 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
flow_class, "_router_paths", {}
|
||||
)
|
||||
|
||||
# Collect all flow methods
|
||||
methods: list[MethodInfo] = []
|
||||
all_method_names: set[str] = set()
|
||||
|
||||
@@ -554,7 +512,6 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
# Check if it's a flow method
|
||||
is_flow_method = (
|
||||
isinstance(attr, (FlowMethod, StartMethod, ListenMethod, RouterMethod))
|
||||
or hasattr(attr, "__is_flow_method__")
|
||||
@@ -568,21 +525,16 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
|
||||
all_method_names.add(attr_name)
|
||||
|
||||
# Get method type
|
||||
method_type = _get_method_type(attr_name, attr, start_methods, routers)
|
||||
|
||||
# Get trigger methods and condition type
|
||||
trigger_methods, condition_type = _extract_trigger_methods(attr)
|
||||
|
||||
# Get router paths if applicable
|
||||
router_paths_list: list[str] = []
|
||||
if method_type in ("router", "start_router"):
|
||||
router_paths_list = _extract_router_paths(attr, router_paths_registry)
|
||||
|
||||
# Check for human feedback
|
||||
has_hf = _has_human_feedback(attr)
|
||||
|
||||
# Check for crew reference
|
||||
has_crew = _detect_crew_reference(attr)
|
||||
|
||||
method_info = MethodInfo(
|
||||
@@ -596,16 +548,10 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
)
|
||||
methods.append(method_info)
|
||||
|
||||
# Generate edges
|
||||
edges = _generate_edges(listeners, routers, router_paths_registry, all_method_names)
|
||||
|
||||
# Extract state schema
|
||||
state_schema = _extract_state_schema(flow_class)
|
||||
|
||||
# Detect inputs
|
||||
inputs = _detect_flow_inputs(flow_class)
|
||||
|
||||
# Get flow description from docstring
|
||||
description: str | None = None
|
||||
if flow_class.__doc__:
|
||||
description = flow_class.__doc__.strip()
|
||||
|
||||
@@ -46,6 +46,8 @@ class FlowMethod(Generic[P, R]):
|
||||
both bound (instance) and unbound (class) method states.
|
||||
"""
|
||||
|
||||
__is_flow_method__: bool = True
|
||||
|
||||
def __init__(self, meth: Callable[P, R], instance: Any = None) -> None:
|
||||
"""Initialize the flow method wrapper.
|
||||
|
||||
@@ -53,9 +55,9 @@ class FlowMethod(Generic[P, R]):
|
||||
meth: The method to wrap.
|
||||
instance: The instance to bind to (None for unbound).
|
||||
"""
|
||||
functools.update_wrapper(self, meth)
|
||||
self._meth = meth
|
||||
self._instance = instance
|
||||
functools.update_wrapper(self, meth, updated=[])
|
||||
self.__name__: FlowMethodName = FlowMethodName(self.__name__)
|
||||
self.__signature__ = inspect.signature(meth)
|
||||
|
||||
@@ -70,16 +72,6 @@ class FlowMethod(Generic[P, R]):
|
||||
|
||||
self._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined]
|
||||
|
||||
# Preserve flow-related attributes from wrapped method (e.g., from @human_feedback)
|
||||
for attr in [
|
||||
"__is_router__",
|
||||
"__router_paths__",
|
||||
"__human_feedback_config__",
|
||||
"_hf_llm", # Live LLM object for HITL resume
|
||||
]:
|
||||
if hasattr(meth, attr):
|
||||
setattr(self, attr, getattr(meth, attr))
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
|
||||
@@ -102,6 +94,19 @@ class FlowMethod(Generic[P, R]):
|
||||
"""
|
||||
return self._meth
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Delegate missing attributes to the wrapped method.
|
||||
|
||||
Lets flow flags like ``__is_start_method__`` defined on the wrapped
|
||||
method's class flow through transparently when this wrapper itself
|
||||
wraps another FlowMethod.
|
||||
"""
|
||||
try:
|
||||
meth = object.__getattribute__(self, "_meth")
|
||||
except AttributeError:
|
||||
raise AttributeError(name) from None
|
||||
return getattr(meth, name)
|
||||
|
||||
def __get__(self, instance: Any, owner: type | None = None) -> Self:
|
||||
"""Support the descriptor protocol for method binding.
|
||||
|
||||
|
||||
@@ -118,13 +118,11 @@ def _deserialize_llm_from_context(
|
||||
if isinstance(llm_data, str):
|
||||
return LLM(model=llm_data)
|
||||
|
||||
if isinstance(llm_data, dict):
|
||||
data = dict(llm_data)
|
||||
model = data.pop("model", None)
|
||||
if not model:
|
||||
return None
|
||||
return LLM(model=model, **data)
|
||||
return None
|
||||
data = dict(llm_data)
|
||||
model = data.pop("model", None)
|
||||
if not model:
|
||||
return None
|
||||
return LLM(model=model, **data)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -706,6 +704,6 @@ def human_feedback(
|
||||
# instead of creating a bare LLM from just the model string.
|
||||
wrapper._hf_llm = llm
|
||||
|
||||
return wrapper # type: ignore[no-any-return]
|
||||
return HumanFeedbackMethod(wrapper) # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -24,15 +24,16 @@ Example:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Final, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Final, ParamSpec, TypeVar, cast
|
||||
|
||||
from crewai_core.printer import PRINTER
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.flow_wrappers import FlowMethod
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
@@ -42,6 +43,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
# Constants for log messages
|
||||
@@ -134,9 +137,71 @@ class PersistenceDecorator:
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
class PersistedFlowMethod(FlowMethod[P, R]):
|
||||
"""FlowMethod variant that persists state after each invocation.
|
||||
|
||||
Wrapping the original method directly (rather than copying its attributes
|
||||
onto a closure) lets ``FlowMethod.__getattr__`` delegate flow flags like
|
||||
``__is_start_method__`` to the wrapped object transparently.
|
||||
|
||||
For async wrapped methods, ``R`` is the ``Coroutine`` returned by calling
|
||||
them, so ``__call__``'s declared return type stays accurate in both cases.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meth: Callable[P, R],
|
||||
instance: Any = None,
|
||||
*,
|
||||
persistence: FlowPersistence | None = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__(meth, instance)
|
||||
self._persistence = persistence
|
||||
self._verbose = verbose
|
||||
|
||||
def _resolve_flow_instance(self, args: tuple[Any, ...]) -> Any:
|
||||
return (
|
||||
self._instance
|
||||
if self._instance is not None
|
||||
else (args[0] if args else None)
|
||||
)
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if inspect.iscoroutinefunction(self._meth):
|
||||
return cast(R, self._call_async(*args, **kwargs))
|
||||
flow_instance = self._resolve_flow_instance(args)
|
||||
result = super().__call__(*args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance,
|
||||
self.__name__,
|
||||
cast(FlowPersistence, self._persistence),
|
||||
self._verbose,
|
||||
)
|
||||
return result
|
||||
|
||||
async def _call_async(self, *args: Any, **kwargs: Any) -> Any:
|
||||
flow_instance = self._resolve_flow_instance(args)
|
||||
meth = cast(Callable[..., Awaitable[Any]], self._meth)
|
||||
if self._instance is not None:
|
||||
result = await meth(self._instance, *args, **kwargs)
|
||||
else:
|
||||
result = await meth(*args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance,
|
||||
self.__name__,
|
||||
cast(FlowPersistence, self._persistence),
|
||||
self._verbose,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def persist(
|
||||
persistence: FlowPersistence | None = None, verbose: bool = False
|
||||
) -> Callable[[type | Callable[..., T]], type | Callable[..., T]]:
|
||||
) -> Callable[
|
||||
[type[Flow[Any]] | Callable[..., T]],
|
||||
type[Flow[Any]] | Callable[..., T],
|
||||
]:
|
||||
"""Decorator to persist flow state.
|
||||
|
||||
This decorator can be applied at either the class level or method level.
|
||||
@@ -164,150 +229,41 @@ def persist(
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||
def decorator(
|
||||
target: type[Flow[Any]] | Callable[..., T],
|
||||
) -> type[Flow[Any]] | Callable[..., T]:
|
||||
"""Decorator that handles both class and method decoration."""
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
|
||||
if isinstance(target, type):
|
||||
# Class decoration
|
||||
original_init = target.__init__ # type: ignore[misc]
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
def new_init(self: Flow[Any], *args: Any, **kwargs: Any) -> None:
|
||||
if "persistence" not in kwargs:
|
||||
kwargs["persistence"] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
target.__init__ = new_init # type: ignore[misc]
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in target.__dict__.items()
|
||||
if callable(method)
|
||||
and (
|
||||
hasattr(method, "__is_start_method__")
|
||||
or hasattr(method, "__trigger_methods__")
|
||||
or hasattr(method, "__condition_type__")
|
||||
or hasattr(method, "__is_flow_method__")
|
||||
or hasattr(method, "__is_router__")
|
||||
for name, method in list(target.__dict__.items()):
|
||||
if not isinstance(method, FlowMethod):
|
||||
continue
|
||||
setattr(
|
||||
target,
|
||||
name,
|
||||
PersistedFlowMethod(
|
||||
method, persistence=actual_persistence, verbose=verbose
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
# Create wrapped versions of the methods that include persistence
|
||||
for name, method in original_methods.items():
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Create a closure to capture the current name and method
|
||||
def create_async_wrapper(
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@functools.wraps(original_method)
|
||||
async def method_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
result = await original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_async_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
else:
|
||||
# Create a closure to capture the current name and method
|
||||
def create_sync_wrapper(
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@functools.wraps(original_method)
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_sync_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
return target
|
||||
# Method decoration
|
||||
method = target
|
||||
method.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(
|
||||
flow_instance: Any, *args: Any, **kwargs: Any
|
||||
) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return cast(T, result)
|
||||
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
return cast(
|
||||
Callable[..., T],
|
||||
PersistedFlowMethod(
|
||||
target, persistence=actual_persistence, verbose=verbose
|
||||
),
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def _extract_direct_or_triggers(
|
||||
condition: str | dict[str, Any] | list[Any] | FlowCondition,
|
||||
condition: str | FlowCondition,
|
||||
) -> list[str]:
|
||||
"""Extract direct OR-level trigger strings from a condition.
|
||||
|
||||
@@ -39,36 +39,22 @@ def _extract_direct_or_triggers(
|
||||
- and_("a", "b") -> [] (neither are direct triggers, both required)
|
||||
- or_(and_("a", "b"), "c") -> ["c"] (only "c" is a direct trigger)
|
||||
|
||||
Args:
|
||||
condition: Can be a string, dict, or list.
|
||||
|
||||
Returns:
|
||||
List of direct OR-level trigger strings.
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return [condition]
|
||||
if isinstance(condition, dict):
|
||||
cond_type = condition.get("type", OR_CONDITION)
|
||||
conditions_list = condition.get("conditions", [])
|
||||
|
||||
if cond_type == OR_CONDITION:
|
||||
strings = []
|
||||
for sub_cond in conditions_list:
|
||||
strings.extend(_extract_direct_or_triggers(sub_cond))
|
||||
return strings
|
||||
cond_type = condition.get("type", OR_CONDITION)
|
||||
if cond_type != OR_CONDITION:
|
||||
return []
|
||||
if isinstance(condition, list):
|
||||
strings = []
|
||||
for item in condition:
|
||||
strings.extend(_extract_direct_or_triggers(item))
|
||||
return strings
|
||||
if callable(condition) and hasattr(condition, "__name__"):
|
||||
return [condition.__name__]
|
||||
return []
|
||||
strings: list[str] = []
|
||||
for sub_cond in condition.get("conditions", []):
|
||||
strings.extend(_extract_direct_or_triggers(sub_cond))
|
||||
return strings
|
||||
|
||||
|
||||
def _extract_all_trigger_names(
|
||||
condition: str | dict[str, Any] | list[Any] | FlowCondition,
|
||||
condition: str | FlowCondition,
|
||||
) -> list[str]:
|
||||
"""Extract ALL trigger names from a condition for display purposes.
|
||||
|
||||
@@ -81,50 +67,26 @@ def _extract_all_trigger_names(
|
||||
- and_("a", "b") -> ["a", "b"]
|
||||
- or_(and_("a", method_6), method_4) -> ["a", "method_6", "method_4"]
|
||||
|
||||
Args:
|
||||
condition: Can be a string, dict, or list.
|
||||
|
||||
Returns:
|
||||
List of all trigger names found in the condition.
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return [condition]
|
||||
if isinstance(condition, dict):
|
||||
conditions_list = condition.get("conditions", [])
|
||||
strings = []
|
||||
for sub_cond in conditions_list:
|
||||
strings.extend(_extract_all_trigger_names(sub_cond))
|
||||
return strings
|
||||
if isinstance(condition, list):
|
||||
strings = []
|
||||
for item in condition:
|
||||
strings.extend(_extract_all_trigger_names(item))
|
||||
return strings
|
||||
if callable(condition) and hasattr(condition, "__name__"):
|
||||
return [condition.__name__]
|
||||
return []
|
||||
strings: list[str] = []
|
||||
for sub_cond in condition.get("conditions", []):
|
||||
strings.extend(_extract_all_trigger_names(sub_cond))
|
||||
return strings
|
||||
|
||||
|
||||
def _create_edges_from_condition(
|
||||
condition: str | dict[str, Any] | list[Any] | FlowCondition,
|
||||
condition: str | FlowCondition,
|
||||
target: str,
|
||||
nodes: dict[str, NodeMetadata],
|
||||
) -> list[StructureEdge]:
|
||||
"""Create edges from a condition tree, preserving AND/OR semantics.
|
||||
|
||||
This function recursively processes the condition tree and creates edges
|
||||
with the appropriate condition_type for each trigger.
|
||||
|
||||
For AND conditions, all triggers get edges with condition_type="AND".
|
||||
For OR conditions, triggers get edges with condition_type="OR".
|
||||
|
||||
Args:
|
||||
condition: The condition tree (string, dict, or list).
|
||||
target: The target node name.
|
||||
nodes: Dictionary of all nodes for validation.
|
||||
|
||||
Returns:
|
||||
List of StructureEdge objects representing the condition.
|
||||
"""
|
||||
edges: list[StructureEdge] = []
|
||||
|
||||
@@ -138,39 +100,24 @@ def _create_edges_from_condition(
|
||||
is_router_path=False,
|
||||
)
|
||||
)
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
method_name = condition.__name__
|
||||
if method_name in nodes:
|
||||
edges.append(
|
||||
StructureEdge(
|
||||
source=method_name,
|
||||
target=target,
|
||||
condition_type=OR_CONDITION,
|
||||
is_router_path=False,
|
||||
)
|
||||
)
|
||||
elif isinstance(condition, dict):
|
||||
cond_type = condition.get("type", OR_CONDITION)
|
||||
conditions_list = condition.get("conditions", [])
|
||||
return edges
|
||||
|
||||
if cond_type == AND_CONDITION:
|
||||
triggers = _extract_all_trigger_names(condition)
|
||||
edges.extend(
|
||||
StructureEdge(
|
||||
source=trigger,
|
||||
target=target,
|
||||
condition_type=AND_CONDITION,
|
||||
is_router_path=False,
|
||||
)
|
||||
for trigger in triggers
|
||||
if trigger in nodes
|
||||
cond_type = condition.get("type", OR_CONDITION)
|
||||
if cond_type == AND_CONDITION:
|
||||
triggers = _extract_all_trigger_names(condition)
|
||||
edges.extend(
|
||||
StructureEdge(
|
||||
source=trigger,
|
||||
target=target,
|
||||
condition_type=AND_CONDITION,
|
||||
is_router_path=False,
|
||||
)
|
||||
else:
|
||||
for sub_cond in conditions_list:
|
||||
edges.extend(_create_edges_from_condition(sub_cond, target, nodes))
|
||||
elif isinstance(condition, list):
|
||||
for item in condition:
|
||||
edges.extend(_create_edges_from_condition(item, target, nodes))
|
||||
for trigger in triggers
|
||||
if trigger in nodes
|
||||
)
|
||||
else:
|
||||
for sub_cond in condition.get("conditions", []):
|
||||
edges.extend(_create_edges_from_condition(sub_cond, target, nodes))
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
@@ -92,8 +92,6 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
raise e
|
||||
|
||||
def add(self) -> None:
|
||||
if self.content is None:
|
||||
return
|
||||
for doc in self.content:
|
||||
new_chunks_iterable = self._chunk_doc(doc)
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
@@ -101,8 +99,6 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add docling content asynchronously."""
|
||||
if self.content is None:
|
||||
return
|
||||
for doc in self.content:
|
||||
new_chunks_iterable = self._chunk_doc(doc)
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
|
||||
@@ -155,11 +155,8 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
# Updated to account for .xlsx workbooks with multiple tabs/sheets
|
||||
content_str = ""
|
||||
for value in self.content.values():
|
||||
if isinstance(value, dict):
|
||||
for sheet_value in value.values():
|
||||
content_str += str(sheet_value) + "\n"
|
||||
else:
|
||||
content_str += str(value) + "\n"
|
||||
for sheet_value in value.values():
|
||||
content_str += str(sheet_value) + "\n"
|
||||
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
@@ -169,11 +166,8 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
"""Add Excel file content asynchronously."""
|
||||
content_str = ""
|
||||
for value in self.content.values():
|
||||
if isinstance(value, dict):
|
||||
for sheet_value in value.values():
|
||||
content_str += str(sheet_value) + "\n"
|
||||
else:
|
||||
content_str += str(value) + "\n"
|
||||
for sheet_value in value.values():
|
||||
content_str += str(sheet_value) + "\n"
|
||||
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
|
||||
@@ -484,10 +484,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
params["tool_choice"] = {"type": "tool", "name": tool_name}
|
||||
|
||||
if self.thinking:
|
||||
if isinstance(self.thinking, AnthropicThinkingConfig):
|
||||
params["thinking"] = self.thinking.model_dump()
|
||||
else:
|
||||
params["thinking"] = self.thinking
|
||||
params["thinking"] = self.thinking.model_dump()
|
||||
|
||||
return params
|
||||
|
||||
|
||||
@@ -582,19 +582,16 @@ class GeminiCompletion(BaseLLM):
|
||||
parts: list[types.Part] = []
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if "text" in item:
|
||||
parts.append(types.Part.from_text(text=str(item["text"])))
|
||||
elif "inlineData" in item:
|
||||
inline = item["inlineData"]
|
||||
parts.append(
|
||||
types.Part.from_bytes(
|
||||
data=base64.b64decode(inline["data"]),
|
||||
mime_type=inline["mimeType"],
|
||||
)
|
||||
if "text" in item:
|
||||
parts.append(types.Part.from_text(text=str(item["text"])))
|
||||
elif "inlineData" in item:
|
||||
inline = item["inlineData"]
|
||||
parts.append(
|
||||
types.Part.from_bytes(
|
||||
data=base64.b64decode(inline["data"]),
|
||||
mime_type=inline["mimeType"],
|
||||
)
|
||||
else:
|
||||
parts.append(types.Part.from_text(text=str(item)))
|
||||
)
|
||||
else:
|
||||
parts.append(types.Part.from_text(text=str(content) if content else ""))
|
||||
|
||||
|
||||
@@ -798,10 +798,7 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if parameters:
|
||||
if isinstance(parameters, dict):
|
||||
responses_tool["parameters"] = parameters
|
||||
else:
|
||||
responses_tool["parameters"] = dict(parameters)
|
||||
responses_tool["parameters"] = parameters
|
||||
|
||||
responses_tools.append(responses_tool)
|
||||
|
||||
|
||||
@@ -383,22 +383,19 @@ class MCPToolResolver:
|
||||
if mcp_config.tool_filter:
|
||||
filtered_tools = []
|
||||
for tool in tools_list:
|
||||
if callable(mcp_config.tool_filter):
|
||||
try:
|
||||
from crewai.mcp.filters import ToolFilterContext
|
||||
try:
|
||||
from crewai.mcp.filters import ToolFilterContext
|
||||
|
||||
context = ToolFilterContext(
|
||||
agent=self._agent,
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
filtered_tools.append(tool)
|
||||
context = ToolFilterContext(
|
||||
agent=self._agent,
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError): # noqa: PERF203
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
tools_list = filtered_tools
|
||||
|
||||
if not tools_list:
|
||||
|
||||
@@ -194,9 +194,6 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if self._use_legacy:
|
||||
return self._call_legacy(input)
|
||||
return self._call_genai(input)
|
||||
|
||||
@@ -54,9 +54,6 @@ class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Install it with: uv add ibm-watsonx-ai"
|
||||
) from e
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embeddings_config: dict[str, Any] = {
|
||||
"model_id": self._config["model_id"],
|
||||
}
|
||||
|
||||
@@ -47,9 +47,6 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
List of embedding vectors.
|
||||
"""
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
result = self._client.embed(
|
||||
texts=input,
|
||||
model=self._config.get("model", "voyage-2"),
|
||||
|
||||
@@ -47,7 +47,7 @@ def _ensure_handlers_registered() -> None:
|
||||
return
|
||||
with _register_lock:
|
||||
if _handlers_registered:
|
||||
return
|
||||
return # type: ignore[unreachable]
|
||||
_register_all_handlers(crewai_event_bus)
|
||||
_handlers_registered = True
|
||||
|
||||
|
||||
@@ -1159,12 +1159,10 @@ Follow these guidelines:
|
||||
return model_output, None
|
||||
if isinstance(model_output, dict):
|
||||
return None, model_output
|
||||
if isinstance(model_output, str):
|
||||
try:
|
||||
return None, json.loads(model_output)
|
||||
except json.JSONDecodeError:
|
||||
return None, None
|
||||
return None, None
|
||||
try:
|
||||
return None, json.loads(model_output)
|
||||
except json.JSONDecodeError:
|
||||
return None, None
|
||||
|
||||
def _get_output_format(self) -> OutputFormat:
|
||||
if self.output_json:
|
||||
|
||||
@@ -97,7 +97,7 @@ class Telemetry:
|
||||
provider: OpenTelemetry tracer provider.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_instance: Self | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> Self:
|
||||
@@ -937,9 +937,6 @@ class Telemetry:
|
||||
value: The attribute value.
|
||||
"""
|
||||
|
||||
if span is None:
|
||||
return
|
||||
|
||||
def _operation() -> None:
|
||||
return span.set_attribute(key, value)
|
||||
|
||||
|
||||
@@ -122,7 +122,8 @@ class BaseAgentTool(BaseTool):
|
||||
logger.debug(
|
||||
f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}"
|
||||
)
|
||||
return selected_agent.execute_task(task_with_assigned_agent, context)
|
||||
result = selected_agent.execute_task(task_with_assigned_agent, context)
|
||||
return result if isinstance(result, str) else result.model_dump_json()
|
||||
except Exception as e:
|
||||
# Handle task creation or execution errors
|
||||
return I18N_DEFAULT.errors("agent_tool_execution_error").format(
|
||||
|
||||
@@ -421,28 +421,10 @@ class BaseTool(BaseModel, ABC):
|
||||
)
|
||||
|
||||
def _set_args_schema(self) -> None:
|
||||
if self.args_schema is None:
|
||||
run_sig = signature(self._run)
|
||||
fields: dict[str, Any] = {}
|
||||
"""No-op retained for backward compatibility.
|
||||
|
||||
for param_name, param in run_sig.parameters.items():
|
||||
if param_name in ("self", "return"):
|
||||
continue
|
||||
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
|
||||
continue
|
||||
|
||||
annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
|
||||
if param.default is param.empty:
|
||||
fields[param_name] = (annotation, ...)
|
||||
else:
|
||||
fields[param_name] = (annotation, param.default)
|
||||
|
||||
self.args_schema = create_model(
|
||||
f"{self.__class__.__name__}Schema", **fields
|
||||
)
|
||||
Schema generation is performed by the ``args_schema`` field validator.
|
||||
"""
|
||||
|
||||
def _generate_description(self) -> None:
|
||||
"""Generate the tool description with a JSON schema for arguments."""
|
||||
|
||||
@@ -274,10 +274,7 @@ class ToolUsage:
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
input_str = ""
|
||||
if calling.arguments:
|
||||
if isinstance(calling.arguments, dict):
|
||||
input_str = json.dumps(calling.arguments)
|
||||
else:
|
||||
input_str = str(calling.arguments)
|
||||
input_str = json.dumps(calling.arguments)
|
||||
|
||||
result = self.tools_handler.cache.read(
|
||||
tool=sanitize_tool_name(calling.tool_name), input=input_str
|
||||
@@ -303,7 +300,7 @@ class ToolUsage:
|
||||
result = self._format_result(result=result)
|
||||
# Don't return early - fall through to finally block
|
||||
elif result is None:
|
||||
try:
|
||||
try: # type: ignore[unreachable]
|
||||
if sanitize_tool_name(calling.tool_name) in [
|
||||
sanitize_tool_name("Delegate work to coworker"),
|
||||
sanitize_tool_name("Ask question to coworker"),
|
||||
@@ -507,10 +504,7 @@ class ToolUsage:
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
input_str = ""
|
||||
if calling.arguments:
|
||||
if isinstance(calling.arguments, dict):
|
||||
input_str = json.dumps(calling.arguments)
|
||||
else:
|
||||
input_str = str(calling.arguments)
|
||||
input_str = json.dumps(calling.arguments)
|
||||
|
||||
result = self.tools_handler.cache.read(
|
||||
tool=sanitize_tool_name(calling.tool_name), input=input_str
|
||||
@@ -536,7 +530,7 @@ class ToolUsage:
|
||||
result = self._format_result(result=result)
|
||||
# Don't return early - fall through to finally block
|
||||
elif result is None:
|
||||
try:
|
||||
try: # type: ignore[unreachable]
|
||||
if sanitize_tool_name(calling.tool_name) in [
|
||||
sanitize_tool_name("Delegate work to coworker"),
|
||||
sanitize_tool_name("Ask question to coworker"),
|
||||
@@ -826,11 +820,6 @@ class ToolUsage:
|
||||
raise
|
||||
return ToolUsageError(f"{I18N_DEFAULT.errors('tool_arguments_error')}")
|
||||
|
||||
if not isinstance(arguments, dict):
|
||||
if raise_error:
|
||||
raise
|
||||
return ToolUsageError(f"{I18N_DEFAULT.errors('tool_arguments_error')}")
|
||||
|
||||
return ToolCalling(
|
||||
tool_name=sanitize_tool_name(tool.name),
|
||||
arguments=arguments,
|
||||
|
||||
@@ -1679,7 +1679,7 @@ def _setup_before_llm_call_hooks(
|
||||
)
|
||||
|
||||
if not isinstance(executor_context.messages, list):
|
||||
if verbose:
|
||||
if verbose: # type: ignore[unreachable]
|
||||
printer.print(
|
||||
content=(
|
||||
"Warning: before_llm_call hook replaced messages with non-list. "
|
||||
@@ -1742,7 +1742,7 @@ def _setup_after_llm_call_hooks(
|
||||
)
|
||||
|
||||
if not isinstance(executor_context.messages, list):
|
||||
if verbose:
|
||||
if verbose: # type: ignore[unreachable]
|
||||
printer.print(
|
||||
content=(
|
||||
"Warning: after_llm_call hook replaced messages with non-list. "
|
||||
|
||||
@@ -161,15 +161,10 @@ def _llm_via_environment_or_fallback() -> LLM | None:
|
||||
# Map environment variable names to recognized parameters
|
||||
param_key = _normalize_key_name(key_name.lower())
|
||||
llm_params[param_key] = env_value
|
||||
elif isinstance(env_var, dict):
|
||||
if env_var.get("default", False):
|
||||
for key, value in env_var.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
llm_params[key.lower()] = value
|
||||
else:
|
||||
logger.debug(
|
||||
f"Expected env_var to be a dictionary, but got {type(env_var)}"
|
||||
)
|
||||
elif env_var.get("default", False):
|
||||
for key, value in env_var.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
llm_params[key.lower()] = value
|
||||
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
|
||||
|
||||
@@ -141,12 +141,12 @@ def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
def add_key_in_dict_recursively(
|
||||
d: dict[str, Any],
|
||||
d: Any,
|
||||
key: str,
|
||||
value: Any,
|
||||
criteria: Callable[[dict[str, Any]], bool],
|
||||
_seen: set[int] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> Any:
|
||||
"""Recursively adds a key/value pair to all nested dicts matching `criteria`.
|
||||
|
||||
Args:
|
||||
@@ -338,9 +338,6 @@ def add_const_to_oneof_variants(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
def _process_oneof(node: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Process a single node that might contain a oneOf with discriminator."""
|
||||
if not isinstance(node, dict):
|
||||
return node
|
||||
|
||||
if "oneOf" in node and "discriminator" in node:
|
||||
discriminator = node["discriminator"]
|
||||
property_name = discriminator.get("propertyName")
|
||||
@@ -606,8 +603,6 @@ def sanitize_tool_params_for_openai_strict(
|
||||
params: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Sanitize a JSON schema for OpenAI strict function calling."""
|
||||
if not isinstance(params, dict):
|
||||
return params
|
||||
return cast(
|
||||
dict[str, Any], strip_unsupported_formats(_common_strict_pipeline(params))
|
||||
)
|
||||
@@ -617,8 +612,6 @@ def sanitize_tool_params_for_anthropic_strict(
|
||||
params: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Sanitize a JSON schema for Anthropic strict tool use."""
|
||||
if not isinstance(params, dict):
|
||||
return params
|
||||
sanitized = lift_top_level_anyof(_common_strict_pipeline(params))
|
||||
sanitized = _strip_keys_recursive(sanitized, _CLAUDE_STRICT_UNSUPPORTED)
|
||||
return cast(dict[str, Any], strip_unsupported_formats(sanitized))
|
||||
|
||||
@@ -124,8 +124,11 @@ disallow_any_unimported = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = true
|
||||
local_partial_types = true
|
||||
extra_checks = true
|
||||
python_version = "3.12"
|
||||
exclude = "(?x)(^lib/crewai/src/crewai/cli/templates/|^lib/cli/src/crewai_cli/templates/|^lib/crewai/tests/|^lib/crewai-tools/tests/|^lib/crewai-files/tests/|^lib/cli/tests/|^lib/devtools/tests/)"
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
Reference in New Issue
Block a user