mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: resolve all mypy errors across crewai package
This commit is contained in:
@@ -5,9 +5,12 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
from typing import TYPE_CHECKING, Final, Literal
|
||||||
|
|
||||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
from crewai.utilities.pydantic_schema_utils import (
|
||||||
|
ModelDescription,
|
||||||
|
generate_model_description,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -41,7 +44,7 @@ class BaseConverterAdapter(ABC):
|
|||||||
"""
|
"""
|
||||||
self.agent_adapter = agent_adapter
|
self.agent_adapter = agent_adapter
|
||||||
self._output_format: Literal["json", "pydantic"] | None = None
|
self._output_format: Literal["json", "pydantic"] | None = None
|
||||||
self._schema: dict[str, Any] | None = None
|
self._schema: ModelDescription | None = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def configure_structured_output(self, task: Task) -> None:
|
def configure_structured_output(self, task: Task) -> None:
|
||||||
@@ -128,7 +131,7 @@ class BaseConverterAdapter(ABC):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_format_from_task(
|
def _configure_format_from_task(
|
||||||
task: Task,
|
task: Task,
|
||||||
) -> tuple[Literal["json", "pydantic"] | None, dict[str, Any] | None]:
|
) -> tuple[Literal["json", "pydantic"] | None, ModelDescription | None]:
|
||||||
"""Determine output format and schema from task requirements.
|
"""Determine output format and schema from task requirements.
|
||||||
|
|
||||||
This is a helper method that examines the task's output requirements
|
This is a helper method that examines the task's output requirements
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
|||||||
llm: Any = None,
|
llm: Any = None,
|
||||||
max_iterations: int = 10,
|
max_iterations: int = 10,
|
||||||
agent_config: dict[str, Any] | None = None,
|
agent_config: dict[str, Any] | None = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the LangGraph agent adapter.
|
"""Initialize the LangGraph agent adapter.
|
||||||
|
|
||||||
|
|||||||
@@ -948,7 +948,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
)
|
)
|
||||||
error_event_emitted = False
|
error_event_emitted = False
|
||||||
|
|
||||||
track_delegation_if_needed(func_name, args_dict, self.task)
|
track_delegation_if_needed(func_name, args_dict or {}, self.task)
|
||||||
|
|
||||||
structured_tool: CrewStructuredTool | None = None
|
structured_tool: CrewStructuredTool | None = None
|
||||||
if original_tool is not None:
|
if original_tool is not None:
|
||||||
@@ -965,7 +965,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
hook_blocked = False
|
hook_blocked = False
|
||||||
before_hook_context = ToolCallHookContext(
|
before_hook_context = ToolCallHookContext(
|
||||||
tool_name=func_name,
|
tool_name=func_name,
|
||||||
tool_input=args_dict,
|
tool_input=args_dict or {},
|
||||||
tool=structured_tool, # type: ignore[arg-type]
|
tool=structured_tool, # type: ignore[arg-type]
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
@@ -991,7 +991,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
||||||
elif not from_cache and func_name in available_functions:
|
elif not from_cache and func_name in available_functions:
|
||||||
try:
|
try:
|
||||||
raw_result = available_functions[func_name](**args_dict)
|
raw_result = available_functions[func_name](**(args_dict or {}))
|
||||||
|
|
||||||
if self.tools_handler and self.tools_handler.cache:
|
if self.tools_handler and self.tools_handler.cache:
|
||||||
should_cache = True
|
should_cache = True
|
||||||
@@ -1001,7 +1001,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
and callable(original_tool.cache_function)
|
and callable(original_tool.cache_function)
|
||||||
):
|
):
|
||||||
should_cache = original_tool.cache_function(
|
should_cache = original_tool.cache_function(
|
||||||
args_dict, raw_result
|
args_dict or {}, raw_result
|
||||||
)
|
)
|
||||||
if should_cache:
|
if should_cache:
|
||||||
self.tools_handler.cache.add(
|
self.tools_handler.cache.add(
|
||||||
@@ -1030,7 +1030,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
after_hook_context = ToolCallHookContext(
|
after_hook_context = ToolCallHookContext(
|
||||||
tool_name=func_name,
|
tool_name=func_name,
|
||||||
tool_input=args_dict,
|
tool_input=args_dict or {},
|
||||||
tool=structured_tool, # type: ignore[arg-type]
|
tool=structured_tool, # type: ignore[arg-type]
|
||||||
agent=self.agent,
|
agent=self.agent,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ CLI_SETTINGS_KEYS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Default values for CLI settings
|
# Default values for CLI settings
|
||||||
DEFAULT_CLI_SETTINGS = {
|
DEFAULT_CLI_SETTINGS: dict[str, Any] = {
|
||||||
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
|
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||||
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||||
|
|||||||
@@ -173,13 +173,13 @@ class MemoryTUI(App[None]):
|
|||||||
info = self._memory.info("/")
|
info = self._memory.info("/")
|
||||||
tree.root.label = f"/ ({info.record_count} records)"
|
tree.root.label = f"/ ({info.record_count} records)"
|
||||||
tree.root.data = "/"
|
tree.root.data = "/"
|
||||||
self._add_children(tree.root, "/", depth=0, max_depth=3)
|
self._add_scope_children(tree.root, "/", depth=0, max_depth=3)
|
||||||
tree.root.expand()
|
tree.root.expand()
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
def _add_children(
|
def _add_scope_children(
|
||||||
self,
|
self,
|
||||||
parent_node: Tree.Node[str],
|
parent_node: Any,
|
||||||
path: str,
|
path: str,
|
||||||
depth: int,
|
depth: int,
|
||||||
max_depth: int,
|
max_depth: int,
|
||||||
@@ -191,7 +191,7 @@ class MemoryTUI(App[None]):
|
|||||||
child_info = self._memory.info(child)
|
child_info = self._memory.info(child)
|
||||||
label = f"{child} ({child_info.record_count})"
|
label = f"{child} ({child_info.record_count})"
|
||||||
node = parent_node.add(label, data=child)
|
node = parent_node.add(label, data=child)
|
||||||
self._add_children(node, child, depth + 1, max_depth)
|
self._add_scope_children(node, child, depth + 1, max_depth)
|
||||||
|
|
||||||
# -- Populating the OptionList -------------------------------------------
|
# -- Populating the OptionList -------------------------------------------
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@@ -6,7 +7,7 @@ from crewai.cli.utils import get_crews, get_flows
|
|||||||
from crewai.flow import Flow
|
from crewai.flow import Flow
|
||||||
|
|
||||||
|
|
||||||
def _reset_flow_memory(flow: Flow) -> None:
|
def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||||
"""Reset memory for a single flow instance.
|
"""Reset memory for a single flow instance.
|
||||||
|
|
||||||
Handles Memory, MemoryScope (both have .reset()), and MemorySlice
|
Handles Memory, MemoryScope (both have .reset()), and MemorySlice
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import tomli_w
|
import tomli_w
|
||||||
|
|
||||||
@@ -11,7 +12,7 @@ def update_crew() -> None:
|
|||||||
migrate_pyproject("pyproject.toml", "pyproject.toml")
|
migrate_pyproject("pyproject.toml", "pyproject.toml")
|
||||||
|
|
||||||
|
|
||||||
def migrate_pyproject(input_file, output_file):
|
def migrate_pyproject(input_file: str, output_file: str) -> None:
|
||||||
"""
|
"""
|
||||||
Migrate the pyproject.toml to the new format.
|
Migrate the pyproject.toml to the new format.
|
||||||
|
|
||||||
@@ -23,8 +24,7 @@ def migrate_pyproject(input_file, output_file):
|
|||||||
# Read the input pyproject.toml
|
# Read the input pyproject.toml
|
||||||
pyproject_data = read_toml()
|
pyproject_data = read_toml()
|
||||||
|
|
||||||
# Initialize the new project structure
|
new_pyproject: dict[str, Any] = {
|
||||||
new_pyproject = {
|
|
||||||
"project": {},
|
"project": {},
|
||||||
"build-system": {"requires": ["hatchling"], "build-backend": "hatchling.build"},
|
"build-system": {"requires": ["hatchling"], "build-backend": "hatchling.build"},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -386,7 +386,7 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
|
|||||||
return crew_instances
|
return crew_instances
|
||||||
|
|
||||||
|
|
||||||
def get_flow_instance(module_attr: Any) -> Flow | None:
|
def get_flow_instance(module_attr: Any) -> Flow[Any] | None:
|
||||||
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
|
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -413,7 +413,7 @@ _SKIP_DIRS = frozenset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
def get_flows(flow_path: str = "main.py") -> list[Flow[Any]]:
|
||||||
"""Get the flow instances from project files.
|
"""Get the flow instances from project files.
|
||||||
|
|
||||||
Walks the project directory looking for files matching ``flow_path``
|
Walks the project directory looking for files matching ``flow_path``
|
||||||
@@ -427,7 +427,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
|||||||
Returns:
|
Returns:
|
||||||
A list of discovered Flow instances.
|
A list of discovered Flow instances.
|
||||||
"""
|
"""
|
||||||
flow_instances: list[Flow] = []
|
flow_instances: list[Flow[Any]] = []
|
||||||
try:
|
try:
|
||||||
current_dir = os.getcwd()
|
current_dir = os.getcwd()
|
||||||
if current_dir not in sys.path:
|
if current_dir not in sys.path:
|
||||||
|
|||||||
@@ -45,14 +45,14 @@ class CrewOutput(BaseModel):
|
|||||||
output_dict.update(self.pydantic.model_dump())
|
output_dict.update(self.pydantic.model_dump())
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: str) -> Any:
|
||||||
if self.pydantic and hasattr(self.pydantic, key):
|
if self.pydantic and hasattr(self.pydantic, key):
|
||||||
return getattr(self.pydantic, key)
|
return getattr(self.pydantic, key)
|
||||||
if self.json_dict and key in self.json_dict:
|
if self.json_dict and key in self.json_dict:
|
||||||
return self.json_dict[key]
|
return self.json_dict[key]
|
||||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
if self.pydantic:
|
if self.pydantic:
|
||||||
return str(self.pydantic)
|
return str(self.pydantic)
|
||||||
if self.json_dict:
|
if self.json_dict:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ handlers execute in correct order while maximizing parallelism.
|
|||||||
|
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.events.depends import Depends
|
from crewai.events.depends import Depends
|
||||||
from crewai.events.types.event_bus_types import ExecutionPlan, Handler
|
from crewai.events.types.event_bus_types import ExecutionPlan, Handler
|
||||||
@@ -45,7 +46,7 @@ class HandlerGraph:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
handlers: dict[Handler, list[Depends]],
|
handlers: dict[Handler, list[Depends[Any]]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the dependency graph.
|
"""Initialize the dependency graph.
|
||||||
|
|
||||||
@@ -103,7 +104,7 @@ class HandlerGraph:
|
|||||||
|
|
||||||
def build_execution_plan(
|
def build_execution_plan(
|
||||||
handlers: Sequence[Handler],
|
handlers: Sequence[Handler],
|
||||||
dependencies: dict[Handler, list[Depends]],
|
dependencies: dict[Handler, list[Depends[Any]]],
|
||||||
) -> ExecutionPlan:
|
) -> ExecutionPlan:
|
||||||
"""Build an execution plan from handlers and their dependencies.
|
"""Build an execution plan from handlers and their dependencies.
|
||||||
|
|
||||||
@@ -118,7 +119,7 @@ def build_execution_plan(
|
|||||||
Raises:
|
Raises:
|
||||||
CircularDependencyError: If circular dependencies are detected
|
CircularDependencyError: If circular dependencies are detected
|
||||||
"""
|
"""
|
||||||
handler_dict: dict[Handler, list[Depends]] = {
|
handler_dict: dict[Handler, list[Depends[Any]]] = {
|
||||||
h: dependencies.get(h, []) for h in handlers
|
h: dependencies.get(h, []) for h in handlers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -65,9 +65,9 @@ class FirstTimeTraceHandler:
|
|||||||
self._gracefully_fail(f"Error in trace handling: {e}")
|
self._gracefully_fail(f"Error in trace handling: {e}")
|
||||||
mark_first_execution_completed(user_consented=False)
|
mark_first_execution_completed(user_consented=False)
|
||||||
|
|
||||||
def _initialize_backend_and_send_events(self):
|
def _initialize_backend_and_send_events(self) -> None:
|
||||||
"""Initialize backend batch and send collected events."""
|
"""Initialize backend batch and send collected events."""
|
||||||
if not self.batch_manager:
|
if not self.batch_manager or not self.batch_manager.trace_batch_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -115,12 +115,13 @@ class FirstTimeTraceHandler:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._gracefully_fail(f"Backend initialization failed: {e}")
|
self._gracefully_fail(f"Backend initialization failed: {e}")
|
||||||
|
|
||||||
def _display_ephemeral_trace_link(self):
|
def _display_ephemeral_trace_link(self) -> None:
|
||||||
"""Display the ephemeral trace link to the user and automatically open browser."""
|
"""Display the ephemeral trace link to the user and automatically open browser."""
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webbrowser.open(self.ephemeral_url)
|
if self.ephemeral_url:
|
||||||
|
webbrowser.open(self.ephemeral_url)
|
||||||
except Exception: # noqa: S110
|
except Exception: # noqa: S110
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -158,7 +159,7 @@ To disable tracing later, do any one of these:
|
|||||||
console.print(panel)
|
console.print(panel)
|
||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
def _show_tracing_declined_message(self):
|
def _show_tracing_declined_message(self) -> None:
|
||||||
"""Show message when user declines tracing."""
|
"""Show message when user declines tracing."""
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
@@ -184,15 +185,18 @@ To enable tracing later, do any one of these:
|
|||||||
console.print(panel)
|
console.print(panel)
|
||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
def _gracefully_fail(self, error_message: str):
|
def _gracefully_fail(self, error_message: str) -> None:
|
||||||
"""Handle errors gracefully without disrupting user experience."""
|
"""Handle errors gracefully without disrupting user experience."""
|
||||||
console = Console()
|
console = Console()
|
||||||
console.print(f"[yellow]Note: {error_message}[/yellow]")
|
console.print(f"[yellow]Note: {error_message}[/yellow]")
|
||||||
|
|
||||||
logger.debug(f"First-time trace error: {error_message}")
|
logger.debug(f"First-time trace error: {error_message}")
|
||||||
|
|
||||||
def _show_local_trace_message(self):
|
def _show_local_trace_message(self) -> None:
|
||||||
"""Show message when traces were collected locally but couldn't be uploaded."""
|
"""Show message when traces were collected locally but couldn't be uploaded."""
|
||||||
|
if self.batch_manager is None:
|
||||||
|
return
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
panel_content = f"""
|
panel_content = f"""
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from collections.abc import Sequence
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import ConfigDict, model_validator
|
from pydantic import ConfigDict, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.events.base_events import BaseEvent
|
from crewai.events.base_events import BaseEvent
|
||||||
@@ -25,16 +26,9 @@ class AgentExecutionStartedEvent(BaseEvent):
|
|||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_fingerprint_data(self):
|
def set_fingerprint_data(self) -> Self:
|
||||||
"""Set fingerprint data from the agent if available."""
|
"""Set fingerprint data from the agent if available."""
|
||||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
_set_agent_fingerprint(self, self.agent)
|
||||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
|
||||||
self.source_type = "agent"
|
|
||||||
if (
|
|
||||||
hasattr(self.agent.fingerprint, "metadata")
|
|
||||||
and self.agent.fingerprint.metadata
|
|
||||||
):
|
|
||||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@@ -49,16 +43,9 @@ class AgentExecutionCompletedEvent(BaseEvent):
|
|||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_fingerprint_data(self):
|
def set_fingerprint_data(self) -> Self:
|
||||||
"""Set fingerprint data from the agent if available."""
|
"""Set fingerprint data from the agent if available."""
|
||||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
_set_agent_fingerprint(self, self.agent)
|
||||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
|
||||||
self.source_type = "agent"
|
|
||||||
if (
|
|
||||||
hasattr(self.agent.fingerprint, "metadata")
|
|
||||||
and self.agent.fingerprint.metadata
|
|
||||||
):
|
|
||||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@@ -73,16 +60,9 @@ class AgentExecutionErrorEvent(BaseEvent):
|
|||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_fingerprint_data(self):
|
def set_fingerprint_data(self) -> Self:
|
||||||
"""Set fingerprint data from the agent if available."""
|
"""Set fingerprint data from the agent if available."""
|
||||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
_set_agent_fingerprint(self, self.agent)
|
||||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
|
||||||
self.source_type = "agent"
|
|
||||||
if (
|
|
||||||
hasattr(self.agent.fingerprint, "metadata")
|
|
||||||
and self.agent.fingerprint.metadata
|
|
||||||
):
|
|
||||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@@ -140,3 +120,13 @@ class AgentEvaluationFailedEvent(BaseEvent):
|
|||||||
iteration: int
|
iteration: int
|
||||||
error: str
|
error: str
|
||||||
type: str = "agent_evaluation_failed"
|
type: str = "agent_evaluation_failed"
|
||||||
|
|
||||||
|
|
||||||
|
def _set_agent_fingerprint(event: BaseEvent, agent: BaseAgent) -> None:
|
||||||
|
"""Set fingerprint data on an event from an agent object."""
|
||||||
|
fp = agent.security_config.fingerprint
|
||||||
|
if fp is not None:
|
||||||
|
event.source_fingerprint = fp.uuid_str
|
||||||
|
event.source_type = "agent"
|
||||||
|
if fp.metadata:
|
||||||
|
event.fingerprint_metadata = fp.metadata
|
||||||
|
|||||||
@@ -15,21 +15,18 @@ class CrewBaseEvent(BaseEvent):
|
|||||||
crew_name: str | None
|
crew_name: str | None
|
||||||
crew: Crew | None = None
|
crew: Crew | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
self.set_crew_fingerprint()
|
self._set_crew_fingerprint()
|
||||||
|
|
||||||
def set_crew_fingerprint(self) -> None:
|
def _set_crew_fingerprint(self) -> None:
|
||||||
if self.crew and hasattr(self.crew, "fingerprint") and self.crew.fingerprint:
|
if self.crew is not None and self.crew.fingerprint:
|
||||||
self.source_fingerprint = self.crew.fingerprint.uuid_str
|
self.source_fingerprint = self.crew.fingerprint.uuid_str
|
||||||
self.source_type = "crew"
|
self.source_type = "crew"
|
||||||
if (
|
if self.crew.fingerprint.metadata:
|
||||||
hasattr(self.crew.fingerprint, "metadata")
|
|
||||||
and self.crew.fingerprint.metadata
|
|
||||||
):
|
|
||||||
self.fingerprint_metadata = self.crew.fingerprint.metadata
|
self.fingerprint_metadata = self.crew.fingerprint.metadata
|
||||||
|
|
||||||
def to_json(self, exclude: set[str] | None = None):
|
def to_json(self, exclude: set[str] | None = None) -> Any:
|
||||||
if exclude is None:
|
if exclude is None:
|
||||||
exclude = set()
|
exclude = set()
|
||||||
exclude.add("crew")
|
exclude.add("crew")
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class KnowledgeEventBase(BaseEvent):
|
|||||||
agent_role: str | None = None
|
agent_role: str | None = None
|
||||||
agent_id: str | None = None
|
agent_id: str | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
self._set_agent_params(data)
|
self._set_agent_params(data)
|
||||||
self._set_task_params(data)
|
self._set_task_params(data)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class LLMGuardrailBaseEvent(BaseEvent):
|
|||||||
agent_role: str | None = None
|
agent_role: str | None = None
|
||||||
agent_id: str | None = None
|
agent_id: str | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
self._set_agent_params(data)
|
self._set_agent_params(data)
|
||||||
self._set_task_params(data)
|
self._set_task_params(data)
|
||||||
@@ -28,10 +28,10 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "llm_guardrail_started"
|
type: str = "llm_guardrail_started"
|
||||||
guardrail: str | Callable
|
guardrail: str | Callable[..., Any]
|
||||||
retry_count: int
|
retry_count: int
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
||||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
|
|||||||
|
|
||||||
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
|
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
|
||||||
self.guardrail = self.guardrail.description.strip()
|
self.guardrail = self.guardrail.description.strip()
|
||||||
elif isinstance(self.guardrail, Callable):
|
elif callable(self.guardrail):
|
||||||
self.guardrail = getsource(self.guardrail).strip()
|
self.guardrail = getsource(self.guardrail).strip()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class MCPEvent(BaseEvent):
|
|||||||
from_agent: Any | None = None
|
from_agent: Any | None = None
|
||||||
from_task: Any | None = None
|
from_task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
self._set_agent_params(data)
|
self._set_agent_params(data)
|
||||||
self._set_task_params(data)
|
self._set_task_params(data)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class ReasoningEvent(BaseEvent):
|
|||||||
agent_id: str | None = None
|
agent_id: str | None = None
|
||||||
from_agent: Any | None = None
|
from_agent: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
self._set_task_params(data)
|
self._set_task_params(data)
|
||||||
self._set_agent_params(data)
|
self._set_agent_params(data)
|
||||||
|
|||||||
@@ -4,6 +4,15 @@ from crewai.events.base_events import BaseEvent
|
|||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
|
||||||
|
|
||||||
|
def _set_task_fingerprint(event: BaseEvent, task: Any) -> None:
|
||||||
|
"""Set fingerprint data on an event from a task object."""
|
||||||
|
if task is not None and task.fingerprint:
|
||||||
|
event.source_fingerprint = task.fingerprint.uuid_str
|
||||||
|
event.source_type = "task"
|
||||||
|
if task.fingerprint.metadata:
|
||||||
|
event.fingerprint_metadata = task.fingerprint.metadata
|
||||||
|
|
||||||
|
|
||||||
class TaskStartedEvent(BaseEvent):
|
class TaskStartedEvent(BaseEvent):
|
||||||
"""Event emitted when a task starts"""
|
"""Event emitted when a task starts"""
|
||||||
|
|
||||||
@@ -11,17 +20,9 @@ class TaskStartedEvent(BaseEvent):
|
|||||||
context: str | None
|
context: str | None
|
||||||
task: Any | None = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
# Set fingerprint data from the task
|
_set_task_fingerprint(self, self.task)
|
||||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
|
||||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
|
||||||
self.source_type = "task"
|
|
||||||
if (
|
|
||||||
hasattr(self.task.fingerprint, "metadata")
|
|
||||||
and self.task.fingerprint.metadata
|
|
||||||
):
|
|
||||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
|
||||||
|
|
||||||
|
|
||||||
class TaskCompletedEvent(BaseEvent):
|
class TaskCompletedEvent(BaseEvent):
|
||||||
@@ -31,17 +32,9 @@ class TaskCompletedEvent(BaseEvent):
|
|||||||
type: str = "task_completed"
|
type: str = "task_completed"
|
||||||
task: Any | None = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
# Set fingerprint data from the task
|
_set_task_fingerprint(self, self.task)
|
||||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
|
||||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
|
||||||
self.source_type = "task"
|
|
||||||
if (
|
|
||||||
hasattr(self.task.fingerprint, "metadata")
|
|
||||||
and self.task.fingerprint.metadata
|
|
||||||
):
|
|
||||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
|
||||||
|
|
||||||
|
|
||||||
class TaskFailedEvent(BaseEvent):
|
class TaskFailedEvent(BaseEvent):
|
||||||
@@ -51,17 +44,9 @@ class TaskFailedEvent(BaseEvent):
|
|||||||
type: str = "task_failed"
|
type: str = "task_failed"
|
||||||
task: Any | None = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
# Set fingerprint data from the task
|
_set_task_fingerprint(self, self.task)
|
||||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
|
||||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
|
||||||
self.source_type = "task"
|
|
||||||
if (
|
|
||||||
hasattr(self.task.fingerprint, "metadata")
|
|
||||||
and self.task.fingerprint.metadata
|
|
||||||
):
|
|
||||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
|
||||||
|
|
||||||
|
|
||||||
class TaskEvaluationEvent(BaseEvent):
|
class TaskEvaluationEvent(BaseEvent):
|
||||||
@@ -71,14 +56,6 @@ class TaskEvaluationEvent(BaseEvent):
|
|||||||
evaluation_type: str
|
evaluation_type: str
|
||||||
task: Any | None = None
|
task: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data: Any) -> None:
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
# Set fingerprint data from the task
|
_set_task_fingerprint(self, self.task)
|
||||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
|
||||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
|
||||||
self.source_type = "task"
|
|
||||||
if (
|
|
||||||
hasattr(self.task.fingerprint, "metadata")
|
|
||||||
and self.task.fingerprint.metadata
|
|
||||||
):
|
|
||||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, GetCoreSchemaHandler
|
from pydantic import BaseModel, Field, GetCoreSchemaHandler
|
||||||
@@ -22,7 +22,11 @@ from crewai.agents.parser import (
|
|||||||
AgentFinish,
|
AgentFinish,
|
||||||
OutputParserError,
|
OutputParserError,
|
||||||
)
|
)
|
||||||
from crewai.core.providers.human_input import get_provider
|
from crewai.core.providers.human_input import (
|
||||||
|
AsyncExecutorContext,
|
||||||
|
ExecutorContext,
|
||||||
|
get_provider,
|
||||||
|
)
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.listeners.tracing.utils import (
|
from crewai.events.listeners.tracing.utils import (
|
||||||
is_tracing_enabled_in_context,
|
is_tracing_enabled_in_context,
|
||||||
@@ -89,7 +93,7 @@ from crewai.utilities.planning_types import (
|
|||||||
TodoList,
|
TodoList,
|
||||||
)
|
)
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
from crewai.utilities.step_execution_context import StepExecutionContext
|
from crewai.utilities.step_execution_context import StepExecutionContext, StepResult
|
||||||
from crewai.utilities.string_utils import sanitize_tool_name
|
from crewai.utilities.string_utils import sanitize_tool_name
|
||||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
@@ -105,6 +109,8 @@ if TYPE_CHECKING:
|
|||||||
from crewai.tools.tool_types import ToolResult
|
from crewai.tools.tool_types import ToolResult
|
||||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||||
|
|
||||||
|
_RouteT = TypeVar("_RouteT", bound=str)
|
||||||
|
|
||||||
|
|
||||||
class AgentExecutorState(BaseModel):
|
class AgentExecutorState(BaseModel):
|
||||||
"""Structured state for agent executor flow.
|
"""Structured state for agent executor flow.
|
||||||
@@ -446,29 +452,29 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
step failures reliably trigger replanning rather than being
|
step failures reliably trigger replanning rather than being
|
||||||
silently ignored.
|
silently ignored.
|
||||||
"""
|
"""
|
||||||
config = getattr(self.agent, "planning_config", None)
|
config = self.agent.planning_config
|
||||||
if config is not None and hasattr(config, "reasoning_effort"):
|
if config is not None:
|
||||||
return config.reasoning_effort
|
return config.reasoning_effort
|
||||||
return "medium"
|
return "medium"
|
||||||
|
|
||||||
def _get_max_replans(self) -> int:
|
def _get_max_replans(self) -> int:
|
||||||
"""Get max replans from planning config or default to 3."""
|
"""Get max replans from planning config or default to 3."""
|
||||||
config = getattr(self.agent, "planning_config", None)
|
config = self.agent.planning_config
|
||||||
if config is not None and hasattr(config, "max_replans"):
|
if config is not None:
|
||||||
return config.max_replans
|
return config.max_replans
|
||||||
return 3
|
return 3
|
||||||
|
|
||||||
def _get_max_step_iterations(self) -> int:
|
def _get_max_step_iterations(self) -> int:
|
||||||
"""Get max step iterations from planning config or default to 15."""
|
"""Get max step iterations from planning config or default to 15."""
|
||||||
config = getattr(self.agent, "planning_config", None)
|
config = self.agent.planning_config
|
||||||
if config is not None and hasattr(config, "max_step_iterations"):
|
if config is not None:
|
||||||
return config.max_step_iterations
|
return config.max_step_iterations
|
||||||
return 15
|
return 15
|
||||||
|
|
||||||
def _get_step_timeout(self) -> int | None:
|
def _get_step_timeout(self) -> int | None:
|
||||||
"""Get per-step timeout from planning config or default to None."""
|
"""Get per-step timeout from planning config or default to None."""
|
||||||
config = getattr(self.agent, "planning_config", None)
|
config = self.agent.planning_config
|
||||||
if config is not None and hasattr(config, "step_timeout"):
|
if config is not None:
|
||||||
return config.step_timeout
|
return config.step_timeout
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -1130,9 +1136,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
# Process results: store on todos and log, then observe each.
|
# Process results: store on todos and log, then observe each.
|
||||||
# asyncio.gather preserves input order, so zip gives us the exact
|
# asyncio.gather preserves input order, so zip gives us the exact
|
||||||
# todo ↔ result (or exception) mapping.
|
# todo ↔ result (or exception) mapping.
|
||||||
step_results: list[tuple[TodoItem, object]] = []
|
step_results: list[tuple[TodoItem, StepResult]] = []
|
||||||
for todo, item in zip(ready, gathered, strict=True):
|
for todo, item in zip(ready, gathered, strict=True):
|
||||||
if isinstance(item, Exception):
|
if isinstance(item, BaseException):
|
||||||
error_msg = f"Error: {item!s}"
|
error_msg = f"Error: {item!s}"
|
||||||
todo.result = error_msg
|
todo.result = error_msg
|
||||||
self.state.todos.mark_failed(todo.step_number, result=error_msg)
|
self.state.todos.mark_failed(todo.step_number, result=error_msg)
|
||||||
@@ -1143,31 +1149,34 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_returned_todo, result = item
|
_returned_todo, result = item
|
||||||
todo.result = result.result
|
step_result = cast(StepResult, result)
|
||||||
|
todo.result = step_result.result
|
||||||
|
|
||||||
self.state.execution_log.append(
|
self.state.execution_log.append(
|
||||||
{
|
{
|
||||||
"type": "step_execution",
|
"type": "step_execution",
|
||||||
"step_number": todo.step_number,
|
"step_number": todo.step_number,
|
||||||
"success": result.success,
|
"success": step_result.success,
|
||||||
"result_preview": result.result[:200] if result.result else "",
|
"result_preview": step_result.result[:200]
|
||||||
"error": result.error,
|
if step_result.result
|
||||||
"tool_calls": result.tool_calls_made,
|
else "",
|
||||||
"execution_time": result.execution_time,
|
"error": step_result.error,
|
||||||
|
"tool_calls": step_result.tool_calls_made,
|
||||||
|
"execution_time": step_result.execution_time,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.agent.verbose:
|
if self.agent.verbose:
|
||||||
status = "success" if result.success else "failed"
|
status = "success" if step_result.success else "failed"
|
||||||
self._printer.print(
|
self._printer.print(
|
||||||
content=(
|
content=(
|
||||||
f"[Execute] Step {todo.step_number} {status} "
|
f"[Execute] Step {todo.step_number} {status} "
|
||||||
f"({result.execution_time:.1f}s, "
|
f"({step_result.execution_time:.1f}s, "
|
||||||
f"{len(result.tool_calls_made)} tool calls)"
|
f"{len(step_result.tool_calls_made)} tool calls)"
|
||||||
),
|
),
|
||||||
color="green" if result.success else "red",
|
color="green" if step_result.success else "red",
|
||||||
)
|
)
|
||||||
step_results.append((todo, result))
|
step_results.append((todo, step_result))
|
||||||
|
|
||||||
# Observe each completed step sequentially (observation updates shared state)
|
# Observe each completed step sequentially (observation updates shared state)
|
||||||
effort = self._get_reasoning_effort()
|
effort = self._get_reasoning_effort()
|
||||||
@@ -1431,8 +1440,8 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _route_finish_with_todos(
|
def _route_finish_with_todos(
|
||||||
self, default_route: str
|
self, default_route: _RouteT
|
||||||
) -> Literal["native_finished", "agent_finished", "todo_satisfied"]:
|
) -> _RouteT | Literal["todo_satisfied"]:
|
||||||
"""Helper to route finish events, checking for pending todos first.
|
"""Helper to route finish events, checking for pending todos first.
|
||||||
|
|
||||||
If there are pending todos, route to todo_satisfied instead of the
|
If there are pending todos, route to todo_satisfied instead of the
|
||||||
@@ -1448,7 +1457,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
current_todo = self.state.todos.current_todo
|
current_todo = self.state.todos.current_todo
|
||||||
if current_todo:
|
if current_todo:
|
||||||
return "todo_satisfied"
|
return "todo_satisfied"
|
||||||
return default_route # type: ignore[return-value]
|
return default_route
|
||||||
|
|
||||||
@router(call_llm_and_parse)
|
@router(call_llm_and_parse)
|
||||||
def route_by_answer_type(
|
def route_by_answer_type(
|
||||||
@@ -2063,7 +2072,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
elif not self.state.current_answer and self.state.messages:
|
elif not self.state.current_answer and self.state.messages:
|
||||||
# For native tools, results are in the message history as 'tool' roles
|
# For native tools, results are in the message history as 'tool' roles
|
||||||
# We take the content of the most recent tool results
|
# We take the content of the most recent tool results
|
||||||
tool_results = []
|
tool_results: list[str] = []
|
||||||
for msg in reversed(self.state.messages):
|
for msg in reversed(self.state.messages):
|
||||||
if msg.get("role") == "tool":
|
if msg.get("role") == "tool":
|
||||||
tool_results.insert(0, str(msg.get("content", "")))
|
tool_results.insert(0, str(msg.get("content", "")))
|
||||||
@@ -3003,7 +3012,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
Final answer after feedback.
|
Final answer after feedback.
|
||||||
"""
|
"""
|
||||||
provider = get_provider()
|
provider = get_provider()
|
||||||
return provider.handle_feedback(formatted_answer, self)
|
return provider.handle_feedback(formatted_answer, cast("ExecutorContext", self))
|
||||||
|
|
||||||
async def _ahandle_human_feedback(
|
async def _ahandle_human_feedback(
|
||||||
self, formatted_answer: AgentFinish
|
self, formatted_answer: AgentFinish
|
||||||
@@ -3017,7 +3026,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
|||||||
Final answer after feedback.
|
Final answer after feedback.
|
||||||
"""
|
"""
|
||||||
provider = get_provider()
|
provider = get_provider()
|
||||||
return await provider.handle_feedback_async(formatted_answer, self)
|
return await provider.handle_feedback_async(
|
||||||
|
formatted_answer, cast("AsyncExecutorContext", self)
|
||||||
|
)
|
||||||
|
|
||||||
def _is_training_mode(self) -> bool:
|
def _is_training_mode(self) -> bool:
|
||||||
"""Check if training mode is active.
|
"""Check if training mode is active.
|
||||||
|
|||||||
@@ -37,11 +37,11 @@ class ExecutionState:
|
|||||||
current_agent_id: str | None = None
|
current_agent_id: str | None = None
|
||||||
current_task_id: str | None = None
|
current_task_id: str | None = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.traces = {}
|
self.traces: dict[str, Any] = {}
|
||||||
self.iteration = 1
|
self.iteration: int = 1
|
||||||
self.iterations_results = {}
|
self.iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]] = {}
|
||||||
self.agent_evaluators = {}
|
self.agent_evaluators: dict[str, Sequence[BaseEvaluator] | None] = {}
|
||||||
|
|
||||||
|
|
||||||
class AgentEvaluator:
|
class AgentEvaluator:
|
||||||
@@ -295,7 +295,7 @@ class AgentEvaluator:
|
|||||||
|
|
||||||
def emit_evaluation_started_event(
|
def emit_evaluation_started_event(
|
||||||
self, agent_role: str, agent_id: str, task_id: str | None = None
|
self, agent_role: str, agent_id: str, task_id: str | None = None
|
||||||
):
|
) -> None:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
AgentEvaluationStartedEvent(
|
AgentEvaluationStartedEvent(
|
||||||
@@ -313,7 +313,7 @@ class AgentEvaluator:
|
|||||||
task_id: str | None = None,
|
task_id: str | None = None,
|
||||||
metric_category: MetricCategory | None = None,
|
metric_category: MetricCategory | None = None,
|
||||||
score: EvaluationScore | None = None,
|
score: EvaluationScore | None = None,
|
||||||
):
|
) -> None:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
AgentEvaluationCompletedEvent(
|
AgentEvaluationCompletedEvent(
|
||||||
@@ -328,7 +328,7 @@ class AgentEvaluator:
|
|||||||
|
|
||||||
def emit_evaluation_failed_event(
|
def emit_evaluation_failed_event(
|
||||||
self, agent_role: str, agent_id: str, error: str, task_id: str | None = None
|
self, agent_role: str, agent_id: str, error: str, task_id: str | None = None
|
||||||
):
|
) -> None:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
AgentEvaluationFailedEvent(
|
AgentEvaluationFailedEvent(
|
||||||
@@ -341,7 +341,9 @@ class AgentEvaluator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_default_evaluator(agents: list[Agent] | list[BaseAgent], llm: None = None):
|
def create_default_evaluator(
|
||||||
|
agents: list[Agent] | list[BaseAgent], llm: None = None
|
||||||
|
) -> AgentEvaluator:
|
||||||
from crewai.experimental.evaluation import (
|
from crewai.experimental.evaluation import (
|
||||||
GoalAlignmentEvaluator,
|
GoalAlignmentEvaluator,
|
||||||
ParameterExtractionEvaluator,
|
ParameterExtractionEvaluator,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ class MetricCategory(enum.Enum):
|
|||||||
PARAMETER_EXTRACTION = "parameter_extraction"
|
PARAMETER_EXTRACTION = "parameter_extraction"
|
||||||
TOOL_INVOCATION = "tool_invocation"
|
TOOL_INVOCATION = "tool_invocation"
|
||||||
|
|
||||||
def title(self):
|
def title(self) -> str:
|
||||||
return self.value.replace("_", " ").title()
|
return self.value.replace("_", " ").title()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,12 +18,12 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationDisplayFormatter:
|
class EvaluationDisplayFormatter:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.console_formatter = ConsoleFormatter()
|
self.console_formatter = ConsoleFormatter()
|
||||||
|
|
||||||
def display_evaluation_with_feedback(
|
def display_evaluation_with_feedback(
|
||||||
self, iterations_results: dict[int, dict[str, list[Any]]]
|
self, iterations_results: dict[int, dict[str, list[Any]]]
|
||||||
):
|
) -> None:
|
||||||
if not iterations_results:
|
if not iterations_results:
|
||||||
self.console_formatter.print(
|
self.console_formatter.print(
|
||||||
"[yellow]No evaluation results to display[/yellow]"
|
"[yellow]No evaluation results to display[/yellow]"
|
||||||
@@ -103,7 +103,7 @@ class EvaluationDisplayFormatter:
|
|||||||
def display_summary_results(
|
def display_summary_results(
|
||||||
self,
|
self,
|
||||||
iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]],
|
iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]],
|
||||||
):
|
) -> None:
|
||||||
if not iterations_results:
|
if not iterations_results:
|
||||||
self.console_formatter.print(
|
self.console_formatter.print(
|
||||||
"[yellow]No evaluation results to display[/yellow]"
|
"[yellow]No evaluation results to display[/yellow]"
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
|
"""Event listener for collecting execution traces for evaluation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.events.base_event_listener import BaseEventListener
|
from crewai.events.base_event_listener import BaseEventListener
|
||||||
@@ -30,47 +35,63 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
retrievals, and final output - all for use in agent evaluation.
|
retrievals, and final output - all for use in agent evaluation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance = None
|
_instance: EvaluationTraceCallback | None = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls) -> EvaluationTraceCallback:
|
||||||
|
"""Create or return the singleton instance."""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
cls._instance._initialized = False
|
cls._instance._initialized = False
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
if not hasattr(self, "_initialized") or not self._initialized:
|
"""Initialize the evaluation trace callback."""
|
||||||
|
if not self._initialized:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.traces = {}
|
self.traces: dict[str, Any] = {}
|
||||||
self.current_agent_id = None
|
self.current_agent_id: UUID | str | None = None
|
||||||
self.current_task_id = None
|
self.current_task_id: UUID | str | None = None
|
||||||
|
self.current_llm_call: dict[str, Any] = {}
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def setup_listeners(self, event_bus: CrewAIEventsBus):
|
def setup_listeners(self, event_bus: CrewAIEventsBus) -> None:
|
||||||
|
"""Set up event listeners on the event bus.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_bus: The event bus to register listeners on.
|
||||||
|
"""
|
||||||
|
|
||||||
@event_bus.on(AgentExecutionStartedEvent)
|
@event_bus.on(AgentExecutionStartedEvent)
|
||||||
def on_agent_started(source, event: AgentExecutionStartedEvent):
|
def on_agent_started(source: Any, event: AgentExecutionStartedEvent) -> None:
|
||||||
self.on_agent_start(event.agent, event.task)
|
self.on_agent_start(event.agent, event.task)
|
||||||
|
|
||||||
@event_bus.on(LiteAgentExecutionStartedEvent)
|
@event_bus.on(LiteAgentExecutionStartedEvent)
|
||||||
def on_lite_agent_started(source, event: LiteAgentExecutionStartedEvent):
|
def on_lite_agent_started(
|
||||||
|
source: Any, event: LiteAgentExecutionStartedEvent
|
||||||
|
) -> None:
|
||||||
self.on_lite_agent_start(event.agent_info)
|
self.on_lite_agent_start(event.agent_info)
|
||||||
|
|
||||||
@event_bus.on(AgentExecutionCompletedEvent)
|
@event_bus.on(AgentExecutionCompletedEvent)
|
||||||
def on_agent_completed(source, event: AgentExecutionCompletedEvent):
|
def on_agent_completed(
|
||||||
|
source: Any, event: AgentExecutionCompletedEvent
|
||||||
|
) -> None:
|
||||||
self.on_agent_finish(event.agent, event.task, event.output)
|
self.on_agent_finish(event.agent, event.task, event.output)
|
||||||
|
|
||||||
@event_bus.on(LiteAgentExecutionCompletedEvent)
|
@event_bus.on(LiteAgentExecutionCompletedEvent)
|
||||||
def on_lite_agent_completed(source, event: LiteAgentExecutionCompletedEvent):
|
def on_lite_agent_completed(
|
||||||
|
source: Any, event: LiteAgentExecutionCompletedEvent
|
||||||
|
) -> None:
|
||||||
self.on_lite_agent_finish(event.output)
|
self.on_lite_agent_finish(event.output)
|
||||||
|
|
||||||
@event_bus.on(ToolUsageFinishedEvent)
|
@event_bus.on(ToolUsageFinishedEvent)
|
||||||
def on_tool_completed(source, event: ToolUsageFinishedEvent):
|
def on_tool_completed(source: Any, event: ToolUsageFinishedEvent) -> None:
|
||||||
self.on_tool_use(
|
self.on_tool_use(
|
||||||
event.tool_name, event.tool_args, event.output, success=True
|
event.tool_name, event.tool_args, event.output, success=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@event_bus.on(ToolUsageErrorEvent)
|
@event_bus.on(ToolUsageErrorEvent)
|
||||||
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
|
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
|
||||||
self.on_tool_use(
|
self.on_tool_use(
|
||||||
event.tool_name,
|
event.tool_name,
|
||||||
event.tool_args,
|
event.tool_args,
|
||||||
@@ -80,7 +101,9 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@event_bus.on(ToolExecutionErrorEvent)
|
@event_bus.on(ToolExecutionErrorEvent)
|
||||||
def on_tool_execution_error(source, event: ToolExecutionErrorEvent):
|
def on_tool_execution_error(
|
||||||
|
source: Any, event: ToolExecutionErrorEvent
|
||||||
|
) -> None:
|
||||||
self.on_tool_use(
|
self.on_tool_use(
|
||||||
event.tool_name,
|
event.tool_name,
|
||||||
event.tool_args,
|
event.tool_args,
|
||||||
@@ -90,7 +113,9 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@event_bus.on(ToolSelectionErrorEvent)
|
@event_bus.on(ToolSelectionErrorEvent)
|
||||||
def on_tool_selection_error(source, event: ToolSelectionErrorEvent):
|
def on_tool_selection_error(
|
||||||
|
source: Any, event: ToolSelectionErrorEvent
|
||||||
|
) -> None:
|
||||||
self.on_tool_use(
|
self.on_tool_use(
|
||||||
event.tool_name,
|
event.tool_name,
|
||||||
event.tool_args,
|
event.tool_args,
|
||||||
@@ -100,7 +125,9 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@event_bus.on(ToolValidateInputErrorEvent)
|
@event_bus.on(ToolValidateInputErrorEvent)
|
||||||
def on_tool_validate_input_error(source, event: ToolValidateInputErrorEvent):
|
def on_tool_validate_input_error(
|
||||||
|
source: Any, event: ToolValidateInputErrorEvent
|
||||||
|
) -> None:
|
||||||
self.on_tool_use(
|
self.on_tool_use(
|
||||||
event.tool_name,
|
event.tool_name,
|
||||||
event.tool_args,
|
event.tool_args,
|
||||||
@@ -110,14 +137,19 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@event_bus.on(LLMCallStartedEvent)
|
@event_bus.on(LLMCallStartedEvent)
|
||||||
def on_llm_call_started(source, event: LLMCallStartedEvent):
|
def on_llm_call_started(source: Any, event: LLMCallStartedEvent) -> None:
|
||||||
self.on_llm_call_start(event.messages, event.tools)
|
self.on_llm_call_start(event.messages, event.tools)
|
||||||
|
|
||||||
@event_bus.on(LLMCallCompletedEvent)
|
@event_bus.on(LLMCallCompletedEvent)
|
||||||
def on_llm_call_completed(source, event: LLMCallCompletedEvent):
|
def on_llm_call_completed(source: Any, event: LLMCallCompletedEvent) -> None:
|
||||||
self.on_llm_call_end(event.messages, event.response)
|
self.on_llm_call_end(event.messages, event.response)
|
||||||
|
|
||||||
def on_lite_agent_start(self, agent_info: dict[str, Any]):
|
def on_lite_agent_start(self, agent_info: dict[str, Any]) -> None:
|
||||||
|
"""Handle a lite agent execution start event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_info: Dictionary containing agent information.
|
||||||
|
"""
|
||||||
self.current_agent_id = agent_info["id"]
|
self.current_agent_id = agent_info["id"]
|
||||||
self.current_task_id = "lite_task"
|
self.current_task_id = "lite_task"
|
||||||
|
|
||||||
@@ -132,10 +164,22 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
final_output=None,
|
final_output=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_trace(self, trace_key: str, **kwargs: Any):
|
def _init_trace(self, trace_key: str, **kwargs: Any) -> None:
|
||||||
|
"""Initialize a trace entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_key: The key to store the trace under.
|
||||||
|
**kwargs: Trace metadata to store.
|
||||||
|
"""
|
||||||
self.traces[trace_key] = kwargs
|
self.traces[trace_key] = kwargs
|
||||||
|
|
||||||
def on_agent_start(self, agent: BaseAgent, task: Task):
|
def on_agent_start(self, agent: BaseAgent, task: Task) -> None:
|
||||||
|
"""Handle an agent execution start event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent that started execution.
|
||||||
|
task: The task being executed.
|
||||||
|
"""
|
||||||
self.current_agent_id = agent.id
|
self.current_agent_id = agent.id
|
||||||
self.current_task_id = task.id
|
self.current_task_id = task.id
|
||||||
|
|
||||||
@@ -150,7 +194,14 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
final_output=None,
|
final_output=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any):
|
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any) -> None:
|
||||||
|
"""Handle an agent execution completion event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent that finished execution.
|
||||||
|
task: The task that was executed.
|
||||||
|
output: The agent's output.
|
||||||
|
"""
|
||||||
trace_key = f"{agent.id}_{task.id}"
|
trace_key = f"{agent.id}_{task.id}"
|
||||||
if trace_key in self.traces:
|
if trace_key in self.traces:
|
||||||
self.traces[trace_key]["final_output"] = output
|
self.traces[trace_key]["final_output"] = output
|
||||||
@@ -158,11 +209,17 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
|
|
||||||
self._reset_current()
|
self._reset_current()
|
||||||
|
|
||||||
def _reset_current(self):
|
def _reset_current(self) -> None:
|
||||||
|
"""Reset the current agent and task tracking state."""
|
||||||
self.current_agent_id = None
|
self.current_agent_id = None
|
||||||
self.current_task_id = None
|
self.current_task_id = None
|
||||||
|
|
||||||
def on_lite_agent_finish(self, output: Any):
|
def on_lite_agent_finish(self, output: Any) -> None:
|
||||||
|
"""Handle a lite agent execution completion event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: The agent's output.
|
||||||
|
"""
|
||||||
trace_key = f"{self.current_agent_id}_lite_task"
|
trace_key = f"{self.current_agent_id}_lite_task"
|
||||||
if trace_key in self.traces:
|
if trace_key in self.traces:
|
||||||
self.traces[trace_key]["final_output"] = output
|
self.traces[trace_key]["final_output"] = output
|
||||||
@@ -177,13 +234,22 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
result: Any,
|
result: Any,
|
||||||
success: bool = True,
|
success: bool = True,
|
||||||
error_type: str | None = None,
|
error_type: str | None = None,
|
||||||
):
|
) -> None:
|
||||||
|
"""Record a tool usage event in the current trace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool used.
|
||||||
|
tool_args: Arguments passed to the tool.
|
||||||
|
result: The tool's output or error message.
|
||||||
|
success: Whether the tool call succeeded.
|
||||||
|
error_type: Type of error if the call failed.
|
||||||
|
"""
|
||||||
if not self.current_agent_id or not self.current_task_id:
|
if not self.current_agent_id or not self.current_task_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
trace_key = f"{self.current_agent_id}_{self.current_task_id}"
|
trace_key = f"{self.current_agent_id}_{self.current_task_id}"
|
||||||
if trace_key in self.traces:
|
if trace_key in self.traces:
|
||||||
tool_use = {
|
tool_use: dict[str, Any] = {
|
||||||
"tool": tool_name,
|
"tool": tool_name,
|
||||||
"args": tool_args,
|
"args": tool_args,
|
||||||
"result": result,
|
"result": result,
|
||||||
@@ -191,7 +257,6 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
"timestamp": datetime.now(),
|
"timestamp": datetime.now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add error information if applicable
|
|
||||||
if not success and error_type:
|
if not success and error_type:
|
||||||
tool_use["error"] = True
|
tool_use["error"] = True
|
||||||
tool_use["error_type"] = error_type
|
tool_use["error_type"] = error_type
|
||||||
@@ -202,7 +267,13 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
self,
|
self,
|
||||||
messages: str | Sequence[dict[str, Any]] | None,
|
messages: str | Sequence[dict[str, Any]] | None,
|
||||||
tools: Sequence[dict[str, Any]] | None = None,
|
tools: Sequence[dict[str, Any]] | None = None,
|
||||||
):
|
) -> None:
|
||||||
|
"""Record an LLM call start event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The messages sent to the LLM.
|
||||||
|
tools: Tool definitions provided to the LLM.
|
||||||
|
"""
|
||||||
if not self.current_agent_id or not self.current_task_id:
|
if not self.current_agent_id or not self.current_task_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -220,7 +291,13 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
|
|
||||||
def on_llm_call_end(
|
def on_llm_call_end(
|
||||||
self, messages: str | list[dict[str, Any]] | None, response: Any
|
self, messages: str | list[dict[str, Any]] | None, response: Any
|
||||||
):
|
) -> None:
|
||||||
|
"""Record an LLM call completion event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The messages from the LLM call.
|
||||||
|
response: The LLM response object.
|
||||||
|
"""
|
||||||
if not self.current_agent_id or not self.current_task_id:
|
if not self.current_agent_id or not self.current_task_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -229,17 +306,18 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
return
|
return
|
||||||
|
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
if hasattr(response, "usage") and hasattr(response.usage, "total_tokens"):
|
usage = getattr(response, "usage", None)
|
||||||
total_tokens = response.usage.total_tokens
|
if usage is not None:
|
||||||
|
total_tokens = getattr(usage, "total_tokens", 0)
|
||||||
|
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
start_time = None
|
start_time = (
|
||||||
if hasattr(self, "current_llm_call") and self.current_llm_call:
|
self.current_llm_call.get("start_time") if self.current_llm_call else None
|
||||||
start_time = self.current_llm_call.get("start_time")
|
)
|
||||||
|
|
||||||
if not start_time:
|
if not start_time:
|
||||||
start_time = current_time
|
start_time = current_time
|
||||||
llm_call = {
|
llm_call: dict[str, Any] = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"response": response,
|
"response": response,
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
@@ -248,16 +326,28 @@ class EvaluationTraceCallback(BaseEventListener):
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.traces[trace_key]["llm_calls"].append(llm_call)
|
self.traces[trace_key]["llm_calls"].append(llm_call)
|
||||||
|
self.current_llm_call = {}
|
||||||
if hasattr(self, "current_llm_call"):
|
|
||||||
self.current_llm_call = {}
|
|
||||||
|
|
||||||
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
|
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Retrieve a trace by agent and task ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent's identifier.
|
||||||
|
task_id: The task's identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The trace dictionary, or None if not found.
|
||||||
|
"""
|
||||||
trace_key = f"{agent_id}_{task_id}"
|
trace_key = f"{agent_id}_{task_id}"
|
||||||
return self.traces.get(trace_key)
|
return self.traces.get(trace_key)
|
||||||
|
|
||||||
|
|
||||||
def create_evaluation_callbacks() -> EvaluationTraceCallback:
|
def create_evaluation_callbacks() -> EvaluationTraceCallback:
|
||||||
|
"""Create and register an evaluation trace callback on the event bus.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured EvaluationTraceCallback instance.
|
||||||
|
"""
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
|
||||||
callback = EvaluationTraceCallback()
|
callback = EvaluationTraceCallback()
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ from crewai.experimental.evaluation.experiment.result import ExperimentResults
|
|||||||
|
|
||||||
|
|
||||||
class ExperimentResultsDisplay:
|
class ExperimentResultsDisplay:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.console = Console()
|
self.console = Console()
|
||||||
|
|
||||||
def summary(self, experiment_results: ExperimentResults):
|
def summary(self, experiment_results: ExperimentResults) -> None:
|
||||||
total = len(experiment_results.results)
|
total = len(experiment_results.results)
|
||||||
passed = sum(1 for r in experiment_results.results if r.passed)
|
passed = sum(1 for r in experiment_results.results if r.passed)
|
||||||
|
|
||||||
@@ -28,7 +28,9 @@ class ExperimentResultsDisplay:
|
|||||||
|
|
||||||
self.console.print(table)
|
self.console.print(table)
|
||||||
|
|
||||||
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
|
def comparison_summary(
|
||||||
|
self, comparison: dict[str, Any], baseline_timestamp: str
|
||||||
|
) -> None:
|
||||||
self.console.print(
|
self.console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
|
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
|
||||||
from crewai.experimental.evaluation.evaluation_display import (
|
from crewai.experimental.evaluation.base_evaluator import (
|
||||||
AgentAggregatedEvaluationResult,
|
AgentAggregatedEvaluationResult,
|
||||||
)
|
)
|
||||||
from crewai.experimental.evaluation.experiment.result import (
|
from crewai.experimental.evaluation.experiment.result import (
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ from typing import Any
|
|||||||
|
|
||||||
def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
return json.loads(text)
|
result: dict[str, Any] = json.loads(text)
|
||||||
|
return result
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -24,7 +25,8 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
|||||||
matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
|
matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
|
||||||
for match in matches:
|
for match in matches:
|
||||||
try:
|
try:
|
||||||
return json.loads(match.strip())
|
parsed: dict[str, Any] = json.loads(match.strip())
|
||||||
|
return parsed
|
||||||
except json.JSONDecodeError: # noqa: PERF203
|
except json.JSONDecodeError: # noqa: PERF203
|
||||||
continue
|
continue
|
||||||
raise ValueError("No valid JSON found in the response")
|
raise ValueError("No valid JSON found in the response")
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ Evaluate how well the agent's output aligns with the assigned task goal.
|
|||||||
]
|
]
|
||||||
if self.llm is None:
|
if self.llm is None:
|
||||||
raise ValueError("LLM must be initialized")
|
raise ValueError("LLM must be initialized")
|
||||||
response = self.llm.call(prompt) # type: ignore[arg-type]
|
response = self.llm.call(prompt)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
||||||
|
|||||||
@@ -224,7 +224,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
raw_response=response,
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
|
def _detect_loops(
|
||||||
|
self, llm_calls: list[dict[str, Any]]
|
||||||
|
) -> tuple[bool, list[dict[str, Any]]]:
|
||||||
loop_details = []
|
loop_details = []
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
@@ -272,7 +274,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
|
|
||||||
return intersection / union if union > 0 else 0.0
|
return intersection / union if union > 0 else 0.0
|
||||||
|
|
||||||
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
|
def _analyze_reasoning_patterns(
|
||||||
|
self, llm_calls: list[dict[str, Any]]
|
||||||
|
) -> dict[str, Any]:
|
||||||
call_lengths = []
|
call_lengths = []
|
||||||
response_times = []
|
response_times = []
|
||||||
|
|
||||||
@@ -345,7 +349,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
max_possible_slope = max(values) - min(values)
|
max_possible_slope = max(values) - min(values)
|
||||||
if max_possible_slope > 0:
|
if max_possible_slope > 0:
|
||||||
normalized_slope = slope / max_possible_slope
|
normalized_slope = slope / max_possible_slope
|
||||||
return max(min(normalized_slope, 1.0), -1.0)
|
return float(max(min(normalized_slope, 1.0), -1.0))
|
||||||
return 0.0
|
return 0.0
|
||||||
except Exception:
|
except Exception:
|
||||||
return 0.0
|
return 0.0
|
||||||
@@ -384,7 +388,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
|||||||
|
|
||||||
return float(np.mean(indicators)) if indicators else 0.0
|
return float(np.mean(indicators)) if indicators else 0.0
|
||||||
|
|
||||||
def _get_call_samples(self, llm_calls: list[dict]) -> str:
|
def _get_call_samples(self, llm_calls: list[dict[str, Any]]) -> str:
|
||||||
samples = []
|
samples = []
|
||||||
|
|
||||||
if len(llm_calls) <= 6:
|
if len(llm_calls) <= 6:
|
||||||
|
|||||||
@@ -299,15 +299,15 @@ def _extract_all_methods_from_condition(
|
|||||||
return []
|
return []
|
||||||
if isinstance(condition, dict):
|
if isinstance(condition, dict):
|
||||||
conditions_list = condition.get("conditions", [])
|
conditions_list = condition.get("conditions", [])
|
||||||
methods: list[str] = []
|
dict_methods: list[str] = []
|
||||||
for sub_cond in conditions_list:
|
for sub_cond in conditions_list:
|
||||||
methods.extend(_extract_all_methods_from_condition(sub_cond))
|
dict_methods.extend(_extract_all_methods_from_condition(sub_cond))
|
||||||
return methods
|
return dict_methods
|
||||||
if isinstance(condition, list):
|
if isinstance(condition, list):
|
||||||
methods = []
|
list_methods: list[str] = []
|
||||||
for item in condition:
|
for item in condition:
|
||||||
methods.extend(_extract_all_methods_from_condition(item))
|
list_methods.extend(_extract_all_methods_from_condition(item))
|
||||||
return methods
|
return list_methods
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@@ -476,7 +476,8 @@ def _detect_flow_inputs(flow_class: type) -> list[str]:
|
|||||||
|
|
||||||
# Check for inputs in __init__ signature beyond standard Flow params
|
# Check for inputs in __init__ signature beyond standard Flow params
|
||||||
try:
|
try:
|
||||||
init_sig = inspect.signature(flow_class.__init__)
|
init_method = flow_class.__init__ # type: ignore[misc]
|
||||||
|
init_sig = inspect.signature(init_method)
|
||||||
standard_params = {
|
standard_params = {
|
||||||
"self",
|
"self",
|
||||||
"persistence",
|
"persistence",
|
||||||
|
|||||||
@@ -83,8 +83,11 @@ def _serialize_llm_for_context(llm: Any) -> dict[str, Any] | str | None:
|
|||||||
subclasses). Falls back to extracting the model string with provider
|
subclasses). Falls back to extracting the model string with provider
|
||||||
prefix for unknown LLM types.
|
prefix for unknown LLM types.
|
||||||
"""
|
"""
|
||||||
if hasattr(llm, "to_config_dict"):
|
to_config: Callable[[], dict[str, Any]] | None = getattr(
|
||||||
return llm.to_config_dict()
|
llm, "to_config_dict", None
|
||||||
|
)
|
||||||
|
if to_config is not None:
|
||||||
|
return to_config()
|
||||||
|
|
||||||
# Fallback for non-BaseLLM objects: just extract model + provider prefix
|
# Fallback for non-BaseLLM objects: just extract model + provider prefix
|
||||||
model = getattr(llm, "model", None)
|
model = getattr(llm, "model", None)
|
||||||
@@ -371,8 +374,13 @@ def human_feedback(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Recall past HITL lessons and use LLM to pre-review the output."""
|
"""Recall past HITL lessons and use LLM to pre-review the output."""
|
||||||
try:
|
try:
|
||||||
|
from crewai.memory.unified_memory import Memory
|
||||||
|
|
||||||
|
mem = flow_instance.memory
|
||||||
|
if not isinstance(mem, Memory):
|
||||||
|
return method_output
|
||||||
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
||||||
matches = flow_instance.memory.recall(query, source=learn_source)
|
matches = mem.recall(query, source=learn_source)
|
||||||
if not matches:
|
if not matches:
|
||||||
return method_output
|
return method_output
|
||||||
|
|
||||||
@@ -404,6 +412,11 @@ def human_feedback(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Extract generalizable lessons from output + feedback, store in memory."""
|
"""Extract generalizable lessons from output + feedback, store in memory."""
|
||||||
try:
|
try:
|
||||||
|
from crewai.memory.unified_memory import Memory
|
||||||
|
|
||||||
|
mem = flow_instance.memory
|
||||||
|
if not isinstance(mem, Memory):
|
||||||
|
return
|
||||||
llm_inst = _resolve_llm_instance()
|
llm_inst = _resolve_llm_instance()
|
||||||
prompt = _get_hitl_prompt("hitl_distill_user").format(
|
prompt = _get_hitl_prompt("hitl_distill_user").format(
|
||||||
method_name=func.__name__,
|
method_name=func.__name__,
|
||||||
@@ -435,7 +448,7 @@ def human_feedback(
|
|||||||
]
|
]
|
||||||
|
|
||||||
if lessons:
|
if lessons:
|
||||||
flow_instance.memory.remember_many(lessons, source=learn_source)
|
mem.remember_many(lessons, source=learn_source)
|
||||||
except Exception: # noqa: S110
|
except Exception: # noqa: S110
|
||||||
pass # non-critical: don't fail the flow because lesson storage failed
|
pass # non-critical: don't fail the flow because lesson storage failed
|
||||||
|
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ def before_llm_call(
|
|||||||
"""
|
"""
|
||||||
from crewai.hooks.llm_hooks import register_before_llm_call_hook
|
from crewai.hooks.llm_hooks import register_before_llm_call_hook
|
||||||
|
|
||||||
return _create_hook_decorator( # type: ignore[return-value]
|
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||||
hook_type="llm",
|
hook_type="llm",
|
||||||
register_function=register_before_llm_call_hook,
|
register_function=register_before_llm_call_hook,
|
||||||
marker_attribute="is_before_llm_call_hook",
|
marker_attribute="is_before_llm_call_hook",
|
||||||
@@ -176,7 +176,7 @@ def after_llm_call(
|
|||||||
"""
|
"""
|
||||||
from crewai.hooks.llm_hooks import register_after_llm_call_hook
|
from crewai.hooks.llm_hooks import register_after_llm_call_hook
|
||||||
|
|
||||||
return _create_hook_decorator( # type: ignore[return-value]
|
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||||
hook_type="llm",
|
hook_type="llm",
|
||||||
register_function=register_after_llm_call_hook,
|
register_function=register_after_llm_call_hook,
|
||||||
marker_attribute="is_after_llm_call_hook",
|
marker_attribute="is_after_llm_call_hook",
|
||||||
@@ -237,7 +237,7 @@ def before_tool_call(
|
|||||||
"""
|
"""
|
||||||
from crewai.hooks.tool_hooks import register_before_tool_call_hook
|
from crewai.hooks.tool_hooks import register_before_tool_call_hook
|
||||||
|
|
||||||
return _create_hook_decorator( # type: ignore[return-value]
|
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||||
hook_type="tool",
|
hook_type="tool",
|
||||||
register_function=register_before_tool_call_hook,
|
register_function=register_before_tool_call_hook,
|
||||||
marker_attribute="is_before_tool_call_hook",
|
marker_attribute="is_before_tool_call_hook",
|
||||||
@@ -293,7 +293,7 @@ def after_tool_call(
|
|||||||
"""
|
"""
|
||||||
from crewai.hooks.tool_hooks import register_after_tool_call_hook
|
from crewai.hooks.tool_hooks import register_after_tool_call_hook
|
||||||
|
|
||||||
return _create_hook_decorator( # type: ignore[return-value]
|
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||||
hook_type="tool",
|
hook_type="tool",
|
||||||
register_function=register_after_tool_call_hook,
|
register_function=register_after_tool_call_hook,
|
||||||
marker_attribute="is_after_tool_call_hook",
|
marker_attribute="is_after_tool_call_hook",
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
|||||||
chunk_size: int = 4000
|
chunk_size: int = 4000
|
||||||
chunk_overlap: int = 200
|
chunk_overlap: int = 200
|
||||||
chunks: list[str] = Field(default_factory=list)
|
chunks: list[str] = Field(default_factory=list)
|
||||||
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
|
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(default_factory=list)
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
storage: KnowledgeStorage | None = Field(default=None)
|
storage: KnowledgeStorage | None = Field(default=None)
|
||||||
@@ -28,7 +28,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
|||||||
def add(self) -> None:
|
def add(self) -> None:
|
||||||
"""Process content, chunk it, compute embeddings, and save them."""
|
"""Process content, chunk it, compute embeddings, and save them."""
|
||||||
|
|
||||||
def get_embeddings(self) -> list[np.ndarray]:
|
def get_embeddings(self) -> list[np.ndarray[Any, np.dtype[Any]]]:
|
||||||
"""Return the list of embeddings for the chunks."""
|
"""Return the list of embeddings for the chunks."""
|
||||||
return self.chunk_embeddings
|
return self.chunk_embeddings
|
||||||
|
|
||||||
|
|||||||
@@ -2369,8 +2369,8 @@ class LLM(BaseLLM):
|
|||||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
litellm.success_callback = success_callbacks
|
litellm.success_callback = success_callbacks # type: ignore[assignment]
|
||||||
litellm.failure_callback = failure_callbacks
|
litellm.failure_callback = failure_callbacks # type: ignore[assignment]
|
||||||
|
|
||||||
def __copy__(self) -> LLM:
|
def __copy__(self) -> LLM:
|
||||||
"""Create a shallow copy of the LLM instance."""
|
"""Create a shallow copy of the LLM instance."""
|
||||||
|
|||||||
@@ -222,6 +222,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||||
self.response_format = response_format
|
self.response_format = response_format
|
||||||
# Tool search config
|
# Tool search config
|
||||||
|
self.tool_search: AnthropicToolSearchConfig | None
|
||||||
if tool_search is True:
|
if tool_search is True:
|
||||||
self.tool_search = AnthropicToolSearchConfig()
|
self.tool_search = AnthropicToolSearchConfig()
|
||||||
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
||||||
|
|||||||
@@ -1,22 +1,21 @@
|
|||||||
"""MCP client with session management for CrewAI agents."""
|
"""MCP client with session management for CrewAI agents."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Coroutine
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple, TypeVar
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
# BaseExceptionGroup is available in Python 3.11+
|
if sys.version_info >= (3, 11):
|
||||||
try:
|
|
||||||
from builtins import BaseExceptionGroup
|
from builtins import BaseExceptionGroup
|
||||||
except ImportError:
|
else:
|
||||||
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
from exceptiongroup import BaseExceptionGroup
|
||||||
BaseExceptionGroup = Exception
|
|
||||||
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.mcp_events import (
|
from crewai.events.types.mcp_events import (
|
||||||
@@ -47,8 +46,10 @@ MCP_TOOL_EXECUTION_TIMEOUT = 30
|
|||||||
MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers
|
MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers
|
||||||
MCP_MAX_RETRIES = 3
|
MCP_MAX_RETRIES = 3
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||||
_mcp_schema_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
_mcp_schema_cache: dict[str, tuple[list[dict[str, Any]], float]] = {}
|
||||||
_cache_ttl = 300 # 5 minutes
|
_cache_ttl = 300 # 5 minutes
|
||||||
|
|
||||||
|
|
||||||
@@ -134,11 +135,7 @@ class MCPClient:
|
|||||||
else:
|
else:
|
||||||
server_name = "Unknown MCP Server"
|
server_name = "Unknown MCP Server"
|
||||||
server_url = None
|
server_url = None
|
||||||
transport_type = (
|
transport_type = self.transport.transport_type.value
|
||||||
self.transport.transport_type.value
|
|
||||||
if hasattr(self.transport, "transport_type")
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return server_name, server_url, transport_type
|
return server_name, server_url, transport_type
|
||||||
|
|
||||||
@@ -542,7 +539,7 @@ class MCPClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Cleaned arguments ready for MCP server.
|
Cleaned arguments ready for MCP server.
|
||||||
"""
|
"""
|
||||||
cleaned = {}
|
cleaned: dict[str, Any] = {}
|
||||||
|
|
||||||
for key, value in arguments.items():
|
for key, value in arguments.items():
|
||||||
# Skip None values
|
# Skip None values
|
||||||
@@ -686,9 +683,9 @@ class MCPClient:
|
|||||||
|
|
||||||
async def _retry_operation(
|
async def _retry_operation(
|
||||||
self,
|
self,
|
||||||
operation: Callable[[], Any],
|
operation: Callable[[], Coroutine[Any, Any, _T]],
|
||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
) -> Any:
|
) -> _T:
|
||||||
"""Retry an operation with exponential backoff.
|
"""Retry an operation with exponential backoff.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from crewai.mcp.config import (
|
|||||||
MCPServerSSE,
|
MCPServerSSE,
|
||||||
MCPServerStdio,
|
MCPServerStdio,
|
||||||
)
|
)
|
||||||
|
from crewai.mcp.transports.base import BaseTransport
|
||||||
from crewai.mcp.transports.http import HTTPTransport
|
from crewai.mcp.transports.http import HTTPTransport
|
||||||
from crewai.mcp.transports.sse import SSETransport
|
from crewai.mcp.transports.sse import SSETransport
|
||||||
from crewai.mcp.transports.stdio import StdioTransport
|
from crewai.mcp.transports.stdio import StdioTransport
|
||||||
@@ -285,6 +286,7 @@ class MCPToolResolver:
|
|||||||
independent transport so that parallel tool executions never share
|
independent transport so that parallel tool executions never share
|
||||||
state.
|
state.
|
||||||
"""
|
"""
|
||||||
|
transport: BaseTransport
|
||||||
if isinstance(mcp_config, MCPServerStdio):
|
if isinstance(mcp_config, MCPServerStdio):
|
||||||
transport = StdioTransport(
|
transport = StdioTransport(
|
||||||
command=mcp_config.command,
|
command=mcp_config.command,
|
||||||
|
|||||||
@@ -2,11 +2,17 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Protocol
|
from typing import Any
|
||||||
|
|
||||||
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||||
|
from mcp.shared.message import SessionMessage
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
MCPReadStream = MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||||
|
MCPWriteStream = MemoryObjectSendStream[SessionMessage]
|
||||||
|
|
||||||
|
|
||||||
class TransportType(str, Enum):
|
class TransportType(str, Enum):
|
||||||
"""MCP transport types."""
|
"""MCP transport types."""
|
||||||
|
|
||||||
@@ -16,22 +22,6 @@ class TransportType(str, Enum):
|
|||||||
SSE = "sse"
|
SSE = "sse"
|
||||||
|
|
||||||
|
|
||||||
class ReadStream(Protocol):
|
|
||||||
"""Protocol for read streams."""
|
|
||||||
|
|
||||||
async def read(self, n: int = -1) -> bytes:
|
|
||||||
"""Read bytes from stream."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class WriteStream(Protocol):
|
|
||||||
"""Protocol for write streams."""
|
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
|
||||||
"""Write bytes to stream."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTransport(ABC):
|
class BaseTransport(ABC):
|
||||||
"""Base class for MCP transport implementations.
|
"""Base class for MCP transport implementations.
|
||||||
|
|
||||||
@@ -46,8 +36,8 @@ class BaseTransport(ABC):
|
|||||||
Args:
|
Args:
|
||||||
**kwargs: Transport-specific configuration options.
|
**kwargs: Transport-specific configuration options.
|
||||||
"""
|
"""
|
||||||
self._read_stream: ReadStream | None = None
|
self._read_stream: MCPReadStream | None = None
|
||||||
self._write_stream: WriteStream | None = None
|
self._write_stream: MCPWriteStream | None = None
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -62,14 +52,14 @@ class BaseTransport(ABC):
|
|||||||
return self._connected
|
return self._connected
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def read_stream(self) -> ReadStream:
|
def read_stream(self) -> MCPReadStream:
|
||||||
"""Get the read stream."""
|
"""Get the read stream."""
|
||||||
if self._read_stream is None:
|
if self._read_stream is None:
|
||||||
raise RuntimeError("Transport not connected. Call connect() first.")
|
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||||
return self._read_stream
|
return self._read_stream
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def write_stream(self) -> WriteStream:
|
def write_stream(self) -> MCPWriteStream:
|
||||||
"""Get the write stream."""
|
"""Get the write stream."""
|
||||||
if self._write_stream is None:
|
if self._write_stream is None:
|
||||||
raise RuntimeError("Transport not connected. Call connect() first.")
|
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||||
@@ -107,7 +97,7 @@ class BaseTransport(ABC):
|
|||||||
"""Async context manager exit."""
|
"""Async context manager exit."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def _set_streams(self, read: ReadStream, write: WriteStream) -> None:
|
def _set_streams(self, read: MCPReadStream, write: MCPWriteStream) -> None:
|
||||||
"""Set the read and write streams.
|
"""Set the read and write streams.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
"""HTTP and Streamable HTTP transport for MCP servers."""
|
"""HTTP and Streamable HTTP transport for MCP servers."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import sys
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
# BaseExceptionGroup is available in Python 3.11+
|
if sys.version_info >= (3, 11):
|
||||||
try:
|
|
||||||
from builtins import BaseExceptionGroup
|
from builtins import BaseExceptionGroup
|
||||||
except ImportError:
|
else:
|
||||||
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
from exceptiongroup import BaseExceptionGroup
|
||||||
BaseExceptionGroup = Exception
|
|
||||||
|
|
||||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||||
|
|
||||||
|
|||||||
@@ -122,11 +122,14 @@ class StdioTransport(BaseTransport):
|
|||||||
if self._process is not None:
|
if self._process is not None:
|
||||||
try:
|
try:
|
||||||
self._process.terminate()
|
self._process.terminate()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
await asyncio.wait_for(
|
||||||
|
loop.run_in_executor(None, self._process.wait), timeout=5.0
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self._process.kill()
|
self._process.kill()
|
||||||
await self._process.wait()
|
await loop.run_in_executor(None, self._process.wait)
|
||||||
# except ProcessLookupError:
|
# except ProcessLookupError:
|
||||||
# pass
|
# pass
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: ChromaDBClientType,
|
client: ChromaDBClientType,
|
||||||
embedding_function: ChromaEmbeddingFunction,
|
embedding_function: ChromaEmbeddingFunction, # type: ignore[type-arg]
|
||||||
default_limit: int = 5,
|
default_limit: int = 5,
|
||||||
default_score_threshold: float = 0.6,
|
default_score_threshold: float = 0.6,
|
||||||
default_batch_size: int = 100,
|
default_batch_size: int = 100,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSear
|
|||||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||||
|
|
||||||
|
|
||||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
|
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction): # type: ignore[type-arg]
|
||||||
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -85,7 +85,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
|||||||
|
|
||||||
configuration: CollectionConfigurationInterface
|
configuration: CollectionConfigurationInterface
|
||||||
metadata: CollectionMetadata
|
metadata: CollectionMetadata
|
||||||
embedding_function: ChromaEmbeddingFunction
|
embedding_function: ChromaEmbeddingFunction # type: ignore[type-arg]
|
||||||
data_loader: DataLoader[Loadable]
|
data_loader: DataLoader[Loadable]
|
||||||
get_or_create: bool
|
get_or_create: bool
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=EmbeddingFunction)
|
T = TypeVar("T", bound=EmbeddingFunction) # type: ignore[type-arg]
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
|
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Core type definitions for RAG systems."""
|
"""Core type definitions for RAG systems."""
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import floating, integer, number
|
from numpy import floating, integer, number
|
||||||
@@ -16,7 +16,7 @@ Embedding = NDArray[np.int32 | np.float32]
|
|||||||
Embeddings = list[Embedding]
|
Embeddings = list[Embedding]
|
||||||
|
|
||||||
Documents = list[str]
|
Documents = list[str]
|
||||||
Images = list[np.ndarray]
|
Images = list[np.ndarray[Any, np.dtype[np.generic]]]
|
||||||
Embeddable = Documents | Images
|
Embeddable = Documents | Images
|
||||||
|
|
||||||
ScalarType = TypeVar("ScalarType", bound=np.generic)
|
ScalarType = TypeVar("ScalarType", bound=np.generic)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing_extensions import Required, TypedDict
|
|||||||
class CustomProviderConfig(TypedDict, total=False):
|
class CustomProviderConfig(TypedDict, total=False):
|
||||||
"""Configuration for Custom provider."""
|
"""Configuration for Custom provider."""
|
||||||
|
|
||||||
embedding_callable: type[EmbeddingFunction]
|
embedding_callable: type[EmbeddingFunction] # type: ignore[type-arg]
|
||||||
|
|
||||||
|
|
||||||
class CustomProviderSpec(TypedDict, total=False):
|
class CustomProviderSpec(TypedDict, total=False):
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
|||||||
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
||||||
"""
|
"""
|
||||||
# Handle deprecated 'region' parameter (only if it has a value)
|
# Handle deprecated 'region' parameter (only if it has a value)
|
||||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
|
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item,unused-ignore]
|
||||||
if region_value is not None:
|
if region_value is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The 'region' parameter is deprecated, use 'location' instead. "
|
"The 'region' parameter is deprecated, use 'location' instead. "
|
||||||
@@ -94,7 +94,7 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
|||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
if "location" not in kwargs or kwargs.get("location") is None:
|
if "location" not in kwargs or kwargs.get("location") is None:
|
||||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
|
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key,unused-ignore]
|
||||||
|
|
||||||
self._config = kwargs
|
self._config = kwargs
|
||||||
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
||||||
@@ -123,8 +123,10 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vertexai
|
import vertexai # type: ignore[import-not-found]
|
||||||
from vertexai.language_models import TextEmbeddingModel
|
from vertexai.language_models import ( # type: ignore[import-not-found]
|
||||||
|
TextEmbeddingModel,
|
||||||
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
|||||||
**kwargs: Configuration parameters for VoyageAI.
|
**kwargs: Configuration parameters for VoyageAI.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import voyageai # type: ignore[import-not-found]
|
import voyageai
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -26,7 +26,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
|||||||
"Install it with: uv add voyageai"
|
"Install it with: uv add voyageai"
|
||||||
) from e
|
) from e
|
||||||
self._config = kwargs
|
self._config = kwargs
|
||||||
self._client = voyageai.Client(
|
self._client = voyageai.Client( # type: ignore[attr-defined]
|
||||||
api_key=kwargs["api_key"],
|
api_key=kwargs["api_key"],
|
||||||
max_retries=kwargs.get("max_retries", 0),
|
max_retries=kwargs.get("max_retries", 0),
|
||||||
timeout=kwargs.get("timeout"),
|
timeout=kwargs.get("timeout"),
|
||||||
|
|||||||
@@ -311,8 +311,7 @@ class QdrantClient(BaseClient):
|
|||||||
points = []
|
points = []
|
||||||
for doc in batch_docs:
|
for doc in batch_docs:
|
||||||
if _is_async_embedding_function(self.embedding_function):
|
if _is_async_embedding_function(self.embedding_function):
|
||||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
embedding = await self.embedding_function(doc["content"])
|
||||||
embedding = await async_fn(doc["content"])
|
|
||||||
else:
|
else:
|
||||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||||
embedding = sync_fn(doc["content"])
|
embedding = sync_fn(doc["content"])
|
||||||
@@ -412,8 +411,7 @@ class QdrantClient(BaseClient):
|
|||||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||||
|
|
||||||
if _is_async_embedding_function(self.embedding_function):
|
if _is_async_embedding_function(self.embedding_function):
|
||||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
query_embedding = await self.embedding_function(query)
|
||||||
query_embedding = await async_fn(query)
|
|
||||||
else:
|
else:
|
||||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||||
query_embedding = sync_fn(query)
|
query_embedding = sync_fn(query)
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ import numpy as np
|
|||||||
from pydantic import GetCoreSchemaHandler
|
from pydantic import GetCoreSchemaHandler
|
||||||
from pydantic_core import CoreSchema, core_schema
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from qdrant_client import (
|
from qdrant_client import (
|
||||||
AsyncQdrantClient, # type: ignore[import-not-found]
|
AsyncQdrantClient,
|
||||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
QdrantClient as SyncQdrantClient,
|
||||||
)
|
)
|
||||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
from qdrant_client.models import (
|
||||||
FieldCondition,
|
FieldCondition,
|
||||||
Filter,
|
Filter,
|
||||||
HasIdCondition,
|
HasIdCondition,
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ from typing import TypeGuard
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from qdrant_client import (
|
from qdrant_client import (
|
||||||
AsyncQdrantClient, # type: ignore[import-not-found]
|
AsyncQdrantClient,
|
||||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
QdrantClient as SyncQdrantClient,
|
||||||
)
|
)
|
||||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
from qdrant_client.models import (
|
||||||
FieldCondition,
|
FieldCondition,
|
||||||
Filter,
|
Filter,
|
||||||
MatchValue,
|
MatchValue,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class BaseRAGStorage(ABC):
|
|||||||
self,
|
self,
|
||||||
type: str,
|
type: str,
|
||||||
allow_reset: bool = True,
|
allow_reset: bool = True,
|
||||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
|
||||||
crew: Any = None,
|
crew: Any = None,
|
||||||
):
|
):
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|||||||
@@ -580,7 +580,7 @@ class Task(BaseModel):
|
|||||||
tools = tools or self.tools or []
|
tools = tools or self.tools or []
|
||||||
|
|
||||||
self.processed_by_agents.add(agent.role)
|
self.processed_by_agents.add(agent.role)
|
||||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
|
||||||
result = await agent.aexecute_task(
|
result = await agent.aexecute_task(
|
||||||
task=self,
|
task=self,
|
||||||
context=context,
|
context=context,
|
||||||
@@ -662,12 +662,12 @@ class Task(BaseModel):
|
|||||||
self._save_file(content)
|
self._save_file(content)
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
TaskCompletedEvent(output=task_output, task=self),
|
||||||
)
|
)
|
||||||
return task_output
|
return task_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.end_time = datetime.datetime.now()
|
self.end_time = datetime.datetime.now()
|
||||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||||
raise e # Re-raise the exception after emitting the event
|
raise e # Re-raise the exception after emitting the event
|
||||||
finally:
|
finally:
|
||||||
clear_task_files(self.id)
|
clear_task_files(self.id)
|
||||||
@@ -694,7 +694,7 @@ class Task(BaseModel):
|
|||||||
tools = tools or self.tools or []
|
tools = tools or self.tools or []
|
||||||
|
|
||||||
self.processed_by_agents.add(agent.role)
|
self.processed_by_agents.add(agent.role)
|
||||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
|
||||||
result = agent.execute_task(
|
result = agent.execute_task(
|
||||||
task=self,
|
task=self,
|
||||||
context=context,
|
context=context,
|
||||||
@@ -777,12 +777,12 @@ class Task(BaseModel):
|
|||||||
self._save_file(content)
|
self._save_file(content)
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
TaskCompletedEvent(output=task_output, task=self),
|
||||||
)
|
)
|
||||||
return task_output
|
return task_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.end_time = datetime.datetime.now()
|
self.end_time = datetime.datetime.now()
|
||||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||||
raise e # Re-raise the exception after emitting the event
|
raise e # Re-raise the exception after emitting the event
|
||||||
finally:
|
finally:
|
||||||
clear_task_files(self.id)
|
clear_task_files(self.id)
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ class ConditionalTask(Task):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
condition: Callable[[Any], bool] | None = None,
|
condition: Callable[[Any], bool] | None = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.condition = condition
|
self.condition = condition
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Task output representation and formatting."""
|
"""Task output representation and formatting."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -44,7 +46,7 @@ class TaskOutput(BaseModel):
|
|||||||
messages: list[LLMMessage] = Field(description="Messages of the task", default=[])
|
messages: list[LLMMessage] = Field(description="Messages of the task", default=[])
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_summary(self):
|
def set_summary(self) -> TaskOutput:
|
||||||
"""Set the summary field based on the description.
|
"""Set the summary field based on the description.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
@@ -27,8 +29,8 @@ class AddImageTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
image_url: str,
|
image_url: str,
|
||||||
action: str | None = None,
|
action: str | None = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict[str, Any]:
|
||||||
action = action or i18n.tools("add_image")["default_action"] # type: ignore
|
action = action or i18n.tools("add_image")["default_action"] # type: ignore
|
||||||
content = [
|
content = [
|
||||||
{"type": "text", "text": action},
|
{"type": "text", "text": action},
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
||||||
@@ -20,7 +22,7 @@ class AskQuestionTool(BaseAgentTool):
|
|||||||
question: str,
|
question: str,
|
||||||
context: str,
|
context: str,
|
||||||
coworker: str | None = None,
|
coworker: str | None = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
coworker = self._get_coworker(coworker, **kwargs)
|
coworker = self._get_coworker(coworker, **kwargs)
|
||||||
return self._execute(coworker, question, context)
|
return self._execute(coworker, question, context)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
||||||
@@ -22,7 +24,7 @@ class DelegateWorkTool(BaseAgentTool):
|
|||||||
task: str,
|
task: str,
|
||||||
context: str,
|
context: str,
|
||||||
coworker: str | None = None,
|
coworker: str | None = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
coworker = self._get_coworker(coworker, **kwargs)
|
coworker = self._get_coworker(coworker, **kwargs)
|
||||||
return self._execute(coworker, task, context)
|
return self._execute(coworker, task, context)
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class MCPNativeTool(BaseTool):
|
|||||||
"""Get the server name."""
|
"""Get the server name."""
|
||||||
return self._server_name
|
return self._server_name
|
||||||
|
|
||||||
def _run(self, **kwargs) -> str:
|
def _run(self, **kwargs: Any) -> str:
|
||||||
"""Execute tool using the MCP client session.
|
"""Execute tool using the MCP client session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -98,7 +98,7 @@ class MCPNativeTool(BaseTool):
|
|||||||
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
async def _run_async(self, **kwargs) -> str:
|
async def _run_async(self, **kwargs: Any) -> str:
|
||||||
"""Async implementation of tool execution.
|
"""Async implementation of tool execution.
|
||||||
|
|
||||||
A fresh ``MCPClient`` is created for every invocation so that
|
A fresh ``MCPClient`` is created for every invocation so that
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""MCP Tool Wrapper for on-demand MCP server connections."""
|
"""MCP Tool Wrapper for on-demand MCP server connections."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
@@ -16,9 +18,9 @@ class MCPToolWrapper(BaseTool):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mcp_server_params: dict,
|
mcp_server_params: dict[str, Any],
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
tool_schema: dict,
|
tool_schema: dict[str, Any],
|
||||||
server_name: str,
|
server_name: str,
|
||||||
):
|
):
|
||||||
"""Initialize the MCP tool wrapper.
|
"""Initialize the MCP tool wrapper.
|
||||||
@@ -54,7 +56,7 @@ class MCPToolWrapper(BaseTool):
|
|||||||
self._server_name = server_name
|
self._server_name = server_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mcp_server_params(self) -> dict:
|
def mcp_server_params(self) -> dict[str, Any]:
|
||||||
"""Get the MCP server parameters."""
|
"""Get the MCP server parameters."""
|
||||||
return self._mcp_server_params
|
return self._mcp_server_params
|
||||||
|
|
||||||
@@ -68,7 +70,7 @@ class MCPToolWrapper(BaseTool):
|
|||||||
"""Get the server name."""
|
"""Get the server name."""
|
||||||
return self._server_name
|
return self._server_name
|
||||||
|
|
||||||
def _run(self, **kwargs) -> str:
|
def _run(self, **kwargs: Any) -> str:
|
||||||
"""Connect to MCP server and execute tool.
|
"""Connect to MCP server and execute tool.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -84,13 +86,15 @@ class MCPToolWrapper(BaseTool):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
return f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||||
|
|
||||||
async def _run_async(self, **kwargs) -> str:
|
async def _run_async(self, **kwargs: Any) -> str:
|
||||||
"""Async implementation of MCP tool execution with timeouts and retry logic."""
|
"""Async implementation of MCP tool execution with timeouts and retry logic."""
|
||||||
return await self._retry_with_exponential_backoff(
|
return await self._retry_with_exponential_backoff(
|
||||||
self._execute_tool_with_timeout, **kwargs
|
self._execute_tool_with_timeout, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _retry_with_exponential_backoff(self, operation_func, **kwargs) -> str:
|
async def _retry_with_exponential_backoff(
|
||||||
|
self, operation_func: Callable[..., Coroutine[Any, Any, str]], **kwargs: Any
|
||||||
|
) -> str:
|
||||||
"""Retry operation with exponential backoff, avoiding try-except in loop for performance."""
|
"""Retry operation with exponential backoff, avoiding try-except in loop for performance."""
|
||||||
last_error = None
|
last_error = None
|
||||||
|
|
||||||
@@ -119,7 +123,7 @@ class MCPToolWrapper(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_single_attempt(
|
async def _execute_single_attempt(
|
||||||
self, operation_func, **kwargs
|
self, operation_func: Callable[..., Coroutine[Any, Any, str]], **kwargs: Any
|
||||||
) -> tuple[str | None, str, bool]:
|
) -> tuple[str | None, str, bool]:
|
||||||
"""Execute single operation attempt and return (result, error_message, should_retry)."""
|
"""Execute single operation attempt and return (result, error_message, should_retry)."""
|
||||||
try:
|
try:
|
||||||
@@ -158,22 +162,23 @@ class MCPToolWrapper(BaseTool):
|
|||||||
return None, f"Server response parsing error: {e!s}", True
|
return None, f"Server response parsing error: {e!s}", True
|
||||||
return None, f"MCP execution error: {e!s}", False
|
return None, f"MCP execution error: {e!s}", False
|
||||||
|
|
||||||
async def _execute_tool_with_timeout(self, **kwargs) -> str:
|
async def _execute_tool_with_timeout(self, **kwargs: Any) -> str:
|
||||||
"""Execute tool with timeout wrapper."""
|
"""Execute tool with timeout wrapper."""
|
||||||
return await asyncio.wait_for(
|
return await asyncio.wait_for(
|
||||||
self._execute_tool(**kwargs), timeout=MCP_TOOL_EXECUTION_TIMEOUT
|
self._execute_tool(**kwargs), timeout=MCP_TOOL_EXECUTION_TIMEOUT
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_tool(self, **kwargs) -> str:
|
async def _execute_tool(self, **kwargs: Any) -> str:
|
||||||
"""Execute the actual MCP tool call."""
|
"""Execute the actual MCP tool call."""
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
from mcp.types import TextContent
|
||||||
|
|
||||||
server_url = self.mcp_server_params["url"]
|
server_url = self.mcp_server_params["url"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Wrap entire operation with single timeout
|
|
||||||
async def _do_mcp_call():
|
async def _do_mcp_call() -> str:
|
||||||
async with streamablehttp_client(
|
async with streamablehttp_client(
|
||||||
server_url, terminate_on_close=True
|
server_url, terminate_on_close=True
|
||||||
) as (read, write, _):
|
) as (read, write, _):
|
||||||
@@ -183,17 +188,11 @@ class MCPToolWrapper(BaseTool):
|
|||||||
self.original_tool_name, kwargs
|
self.original_tool_name, kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract the result content
|
if result.content:
|
||||||
if hasattr(result, "content") and result.content:
|
content_item = result.content[0]
|
||||||
if (
|
if isinstance(content_item, TextContent):
|
||||||
isinstance(result.content, list)
|
return content_item.text
|
||||||
and len(result.content) > 0
|
return str(content_item)
|
||||||
):
|
|
||||||
content_item = result.content[0]
|
|
||||||
if hasattr(content_item, "text"):
|
|
||||||
return str(content_item.text)
|
|
||||||
return str(content_item)
|
|
||||||
return str(result.content)
|
|
||||||
return str(result)
|
return str(result)
|
||||||
|
|
||||||
return await asyncio.wait_for(
|
return await asyncio.wait_for(
|
||||||
@@ -203,7 +202,7 @@ class MCPToolWrapper(BaseTool):
|
|||||||
except asyncio.CancelledError as e:
|
except asyncio.CancelledError as e:
|
||||||
raise asyncio.TimeoutError("MCP operation was cancelled") from e
|
raise asyncio.TimeoutError("MCP operation was cancelled") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "__cause__") and e.__cause__:
|
if e.__cause__ is not None:
|
||||||
raise asyncio.TimeoutError(
|
raise asyncio.TimeoutError(
|
||||||
f"MCP connection error: {e.__cause__}"
|
f"MCP connection error: {e.__cause__}"
|
||||||
) from e.__cause__
|
) from e.__cause__
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class TaskEvaluator:
|
|||||||
"""
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task), # type: ignore[no-untyped-call]
|
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task),
|
||||||
)
|
)
|
||||||
evaluation_query = (
|
evaluation_query = (
|
||||||
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
||||||
@@ -129,7 +129,7 @@ class TaskEvaluator:
|
|||||||
"""
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
TaskEvaluationEvent(evaluation_type="training_data_evaluation"), # type: ignore[no-untyped-call]
|
TaskEvaluationEvent(evaluation_type="training_data_evaluation"),
|
||||||
)
|
)
|
||||||
|
|
||||||
output_training_data = training_data[agent_id]
|
output_training_data = training_data[agent_id]
|
||||||
|
|||||||
@@ -12,16 +12,16 @@ from uuid import UUID
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aiocache import Cache
|
from aiocache import Cache # type: ignore[import-untyped]
|
||||||
from crewai_files import FileInput
|
from crewai_files import FileInput
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_file_store: Cache | None = None
|
_file_store: Cache | None = None # type: ignore[no-any-unimported]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from aiocache import Cache
|
from aiocache import Cache
|
||||||
from aiocache.serializers import PickleSerializer
|
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||||
|
|
||||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class GuardrailResult(BaseModel):
|
|||||||
|
|
||||||
@field_validator("result", "error")
|
@field_validator("result", "error")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_result_error_exclusivity(cls, v: Any, info) -> Any:
|
def validate_result_error_exclusivity(cls, v: Any, info: Any) -> Any:
|
||||||
"""Ensure that result and error are mutually exclusive based on success.
|
"""Ensure that result and error are mutually exclusive based on success.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
@@ -259,7 +259,7 @@ class StepObservation(BaseModel):
|
|||||||
|
|
||||||
@field_validator("suggested_refinements", mode="before")
|
@field_validator("suggested_refinements", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def coerce_single_refinement_to_list(cls, v):
|
def coerce_single_refinement_to_list(cls, v: Any) -> Any:
|
||||||
"""Coerce a single dict refinement into a list to handle LLM returning a single object."""
|
"""Coerce a single dict refinement into a list to handle LLM returning a single object."""
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
return [v]
|
return [v]
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class AgentReasoning:
|
|||||||
if self.config.llm is not None:
|
if self.config.llm is not None:
|
||||||
if isinstance(self.config.llm, LLM):
|
if isinstance(self.config.llm, LLM):
|
||||||
return self.config.llm
|
return self.config.llm
|
||||||
return create_llm(self.config.llm)
|
return cast(LLM, create_llm(self.config.llm))
|
||||||
return cast(LLM, self.agent.llm)
|
return cast(LLM, self.agent.llm)
|
||||||
|
|
||||||
def handle_agent_reasoning(self) -> AgentReasoningOutput:
|
def handle_agent_reasoning(self) -> AgentReasoningOutput:
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class RPMController(BaseModel):
|
|||||||
self._current_rpm = 0
|
self._current_rpm = 0
|
||||||
|
|
||||||
def _reset_request_count(self) -> None:
|
def _reset_request_count(self) -> None:
|
||||||
def _reset():
|
def _reset() -> None:
|
||||||
self._current_rpm = 0
|
self._current_rpm = 0
|
||||||
if not self._shutdown_flag:
|
if not self._shutdown_flag:
|
||||||
self._timer = threading.Timer(60.0, self._reset_request_count)
|
self._timer = threading.Timer(60.0, self._reset_request_count)
|
||||||
|
|||||||
@@ -60,7 +60,9 @@ def _extract_tool_call_info(
|
|||||||
StreamChunkType.TOOL_CALL,
|
StreamChunkType.TOOL_CALL,
|
||||||
ToolCallChunk(
|
ToolCallChunk(
|
||||||
tool_id=event.tool_call.id,
|
tool_id=event.tool_call.id,
|
||||||
tool_name=sanitize_tool_name(event.tool_call.function.name),
|
tool_name=sanitize_tool_name(event.tool_call.function.name)
|
||||||
|
if event.tool_call.function.name
|
||||||
|
else None,
|
||||||
arguments=event.tool_call.function.arguments,
|
arguments=event.tool_call.function.arguments,
|
||||||
index=event.tool_call.index,
|
index=event.tool_call.index,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ Please provide ONLY the {field_name} field value as described:
|
|||||||
|
|
||||||
Respond with ONLY the requested information, nothing else.
|
Respond with ONLY the requested information, nothing else.
|
||||||
"""
|
"""
|
||||||
return self.llm.call(
|
result: str = self.llm.call(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -85,6 +85,7 @@ Respond with ONLY the requested information, nothing else.
|
|||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def _process_field_value(self, response: str, field_type: type | None) -> Any:
|
def _process_field_value(self, response: str, field_type: type | None) -> Any:
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
@@ -104,7 +105,8 @@ Respond with ONLY the requested information, nothing else.
|
|||||||
def _parse_list(self, response: str) -> list[Any]:
|
def _parse_list(self, response: str) -> list[Any]:
|
||||||
try:
|
try:
|
||||||
if response.startswith("["):
|
if response.startswith("["):
|
||||||
return json.loads(response)
|
parsed: list[Any] = json.loads(response)
|
||||||
|
return parsed
|
||||||
|
|
||||||
items: list[str] = [
|
items: list[str] = [
|
||||||
item.strip() for item in response.split("\n") if item.strip()
|
item.strip() for item in response.split("\n") if item.strip()
|
||||||
|
|||||||
@@ -1571,8 +1571,9 @@ class TestReasoningEffort:
|
|||||||
executor.agent.planning_config = None
|
executor.agent.planning_config = None
|
||||||
assert executor._get_reasoning_effort() == "medium"
|
assert executor._get_reasoning_effort() == "medium"
|
||||||
|
|
||||||
# Case 3: planning_config without reasoning_effort attr → defaults to "medium"
|
# Case 3: planning_config with default reasoning_effort
|
||||||
executor.agent.planning_config = Mock(spec=[])
|
executor.agent.planning_config = Mock()
|
||||||
|
executor.agent.planning_config.reasoning_effort = "medium"
|
||||||
assert executor._get_reasoning_effort() == "medium"
|
assert executor._get_reasoning_effort() == "medium"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
2
uv.lock
generated
2
uv.lock
generated
@@ -1241,7 +1241,7 @@ requires-dist = [
|
|||||||
{ name = "json5", specifier = "~=0.10.0" },
|
{ name = "json5", specifier = "~=0.10.0" },
|
||||||
{ name = "jsonref", specifier = "~=1.1.0" },
|
{ name = "jsonref", specifier = "~=1.1.0" },
|
||||||
{ name = "lancedb", specifier = ">=0.29.2" },
|
{ name = "lancedb", specifier = ">=0.29.2" },
|
||||||
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<3" },
|
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<=1.82.6" },
|
||||||
{ name = "mcp", specifier = "~=1.26.0" },
|
{ name = "mcp", specifier = "~=1.26.0" },
|
||||||
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = "~=0.1.94" },
|
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = "~=0.1.94" },
|
||||||
{ name = "openai", specifier = ">=1.83.0,<3" },
|
{ name = "openai", specifier = ">=1.83.0,<3" },
|
||||||
|
|||||||
Reference in New Issue
Block a user