mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-25 05:08:22 +00:00
Compare commits
3 Commits
lorenze/fi
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a1424534e | ||
|
|
b53c08812d | ||
|
|
ec8d444cfc |
35
.github/workflows/type-checker.yml
vendored
35
.github/workflows/type-checker.yml
vendored
@@ -17,8 +17,6 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for proper diff
|
||||
|
||||
- name: Restore global uv cache
|
||||
id: cache-restore
|
||||
@@ -42,37 +40,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-groups --all-extras
|
||||
|
||||
- name: Get changed Python files
|
||||
id: changed-files
|
||||
run: |
|
||||
# Get the list of changed Python files compared to the base branch
|
||||
echo "Fetching changed files..."
|
||||
git diff --name-only --diff-filter=ACMRT origin/${{ github.base_ref }}...HEAD -- '*.py' > changed_files.txt
|
||||
|
||||
# Filter for files in src/ directory only (excluding tests/)
|
||||
grep -E "^src/" changed_files.txt > filtered_changed_files.txt || true
|
||||
|
||||
# Check if there are any changed files
|
||||
if [ -s filtered_changed_files.txt ]; then
|
||||
echo "Changed Python files in src/:"
|
||||
cat filtered_changed_files.txt
|
||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
||||
# Convert newlines to spaces for mypy command
|
||||
echo "files=$(cat filtered_changed_files.txt | tr '\n' ' ')" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No Python files changed in src/"
|
||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Run type checks on changed files
|
||||
if: steps.changed-files.outputs.has_changes == 'true'
|
||||
run: |
|
||||
echo "Running mypy on changed files with Python ${{ matrix.python-version }}..."
|
||||
uv run mypy ${{ steps.changed-files.outputs.files }}
|
||||
|
||||
- name: No files to check
|
||||
if: steps.changed-files.outputs.has_changes == 'false'
|
||||
run: echo "No Python files in src/ were modified - skipping type checks"
|
||||
- name: Run type checks
|
||||
run: uv run mypy lib/crewai/src/crewai/
|
||||
|
||||
- name: Save uv caches
|
||||
if: steps.cache-restore.outputs.cache-hit != 'true'
|
||||
|
||||
@@ -5,9 +5,12 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
from typing import TYPE_CHECKING, Final, Literal
|
||||
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.pydantic_schema_utils import (
|
||||
ModelDescription,
|
||||
generate_model_description,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -41,7 +44,7 @@ class BaseConverterAdapter(ABC):
|
||||
"""
|
||||
self.agent_adapter = agent_adapter
|
||||
self._output_format: Literal["json", "pydantic"] | None = None
|
||||
self._schema: dict[str, Any] | None = None
|
||||
self._schema: ModelDescription | None = None
|
||||
|
||||
@abstractmethod
|
||||
def configure_structured_output(self, task: Task) -> None:
|
||||
@@ -128,7 +131,7 @@ class BaseConverterAdapter(ABC):
|
||||
@staticmethod
|
||||
def _configure_format_from_task(
|
||||
task: Task,
|
||||
) -> tuple[Literal["json", "pydantic"] | None, dict[str, Any] | None]:
|
||||
) -> tuple[Literal["json", "pydantic"] | None, ModelDescription | None]:
|
||||
"""Determine output format and schema from task requirements.
|
||||
|
||||
This is a helper method that examines the task's output requirements
|
||||
|
||||
@@ -64,7 +64,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
llm: Any = None,
|
||||
max_iterations: int = 10,
|
||||
agent_config: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LangGraph agent adapter.
|
||||
|
||||
|
||||
@@ -948,7 +948,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
error_event_emitted = False
|
||||
|
||||
track_delegation_if_needed(func_name, args_dict, self.task)
|
||||
track_delegation_if_needed(func_name, args_dict or {}, self.task)
|
||||
|
||||
structured_tool: CrewStructuredTool | None = None
|
||||
if original_tool is not None:
|
||||
@@ -965,7 +965,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
hook_blocked = False
|
||||
before_hook_context = ToolCallHookContext(
|
||||
tool_name=func_name,
|
||||
tool_input=args_dict,
|
||||
tool_input=args_dict or {},
|
||||
tool=structured_tool, # type: ignore[arg-type]
|
||||
agent=self.agent,
|
||||
task=self.task,
|
||||
@@ -991,7 +991,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
||||
elif not from_cache and func_name in available_functions:
|
||||
try:
|
||||
raw_result = available_functions[func_name](**args_dict)
|
||||
raw_result = available_functions[func_name](**(args_dict or {}))
|
||||
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
should_cache = True
|
||||
@@ -1001,7 +1001,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
and callable(original_tool.cache_function)
|
||||
):
|
||||
should_cache = original_tool.cache_function(
|
||||
args_dict, raw_result
|
||||
args_dict or {}, raw_result
|
||||
)
|
||||
if should_cache:
|
||||
self.tools_handler.cache.add(
|
||||
@@ -1030,7 +1030,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=func_name,
|
||||
tool_input=args_dict,
|
||||
tool_input=args_dict or {},
|
||||
tool=structured_tool, # type: ignore[arg-type]
|
||||
agent=self.agent,
|
||||
task=self.task,
|
||||
|
||||
@@ -77,7 +77,7 @@ CLI_SETTINGS_KEYS = [
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
DEFAULT_CLI_SETTINGS = {
|
||||
DEFAULT_CLI_SETTINGS: dict[str, Any] = {
|
||||
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
|
||||
@@ -173,13 +173,13 @@ class MemoryTUI(App[None]):
|
||||
info = self._memory.info("/")
|
||||
tree.root.label = f"/ ({info.record_count} records)"
|
||||
tree.root.data = "/"
|
||||
self._add_children(tree.root, "/", depth=0, max_depth=3)
|
||||
self._add_scope_children(tree.root, "/", depth=0, max_depth=3)
|
||||
tree.root.expand()
|
||||
return tree
|
||||
|
||||
def _add_children(
|
||||
def _add_scope_children(
|
||||
self,
|
||||
parent_node: Tree.Node[str],
|
||||
parent_node: Any,
|
||||
path: str,
|
||||
depth: int,
|
||||
max_depth: int,
|
||||
@@ -191,7 +191,7 @@ class MemoryTUI(App[None]):
|
||||
child_info = self._memory.info(child)
|
||||
label = f"{child} ({child_info.record_count})"
|
||||
node = parent_node.add(label, data=child)
|
||||
self._add_children(node, child, depth + 1, max_depth)
|
||||
self._add_scope_children(node, child, depth + 1, max_depth)
|
||||
|
||||
# -- Populating the OptionList -------------------------------------------
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
@@ -6,7 +7,7 @@ from crewai.cli.utils import get_crews, get_flows
|
||||
from crewai.flow import Flow
|
||||
|
||||
|
||||
def _reset_flow_memory(flow: Flow) -> None:
|
||||
def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
"""Reset memory for a single flow instance.
|
||||
|
||||
Handles Memory, MemoryScope (both have .reset()), and MemorySlice
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any
|
||||
|
||||
import tomli_w
|
||||
|
||||
@@ -11,7 +12,7 @@ def update_crew() -> None:
|
||||
migrate_pyproject("pyproject.toml", "pyproject.toml")
|
||||
|
||||
|
||||
def migrate_pyproject(input_file, output_file):
|
||||
def migrate_pyproject(input_file: str, output_file: str) -> None:
|
||||
"""
|
||||
Migrate the pyproject.toml to the new format.
|
||||
|
||||
@@ -23,8 +24,7 @@ def migrate_pyproject(input_file, output_file):
|
||||
# Read the input pyproject.toml
|
||||
pyproject_data = read_toml()
|
||||
|
||||
# Initialize the new project structure
|
||||
new_pyproject = {
|
||||
new_pyproject: dict[str, Any] = {
|
||||
"project": {},
|
||||
"build-system": {"requires": ["hatchling"], "build-backend": "hatchling.build"},
|
||||
}
|
||||
|
||||
@@ -386,7 +386,7 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
|
||||
return crew_instances
|
||||
|
||||
|
||||
def get_flow_instance(module_attr: Any) -> Flow | None:
|
||||
def get_flow_instance(module_attr: Any) -> Flow[Any] | None:
|
||||
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
|
||||
|
||||
Args:
|
||||
@@ -413,7 +413,7 @@ _SKIP_DIRS = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
def get_flows(flow_path: str = "main.py") -> list[Flow[Any]]:
|
||||
"""Get the flow instances from project files.
|
||||
|
||||
Walks the project directory looking for files matching ``flow_path``
|
||||
@@ -427,7 +427,7 @@ def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
Returns:
|
||||
A list of discovered Flow instances.
|
||||
"""
|
||||
flow_instances: list[Flow] = []
|
||||
flow_instances: list[Flow[Any]] = []
|
||||
try:
|
||||
current_dir = os.getcwd()
|
||||
if current_dir not in sys.path:
|
||||
|
||||
@@ -45,14 +45,14 @@ class CrewOutput(BaseModel):
|
||||
output_dict.update(self.pydantic.model_dump())
|
||||
return output_dict
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
if self.pydantic and hasattr(self.pydantic, key):
|
||||
return getattr(self.pydantic, key)
|
||||
if self.json_dict and key in self.json_dict:
|
||||
return self.json_dict[key]
|
||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.pydantic:
|
||||
return str(self.pydantic)
|
||||
if self.json_dict:
|
||||
|
||||
@@ -6,6 +6,7 @@ handlers execute in correct order while maximizing parallelism.
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.types.event_bus_types import ExecutionPlan, Handler
|
||||
@@ -45,7 +46,7 @@ class HandlerGraph:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: dict[Handler, list[Depends]],
|
||||
handlers: dict[Handler, list[Depends[Any]]],
|
||||
) -> None:
|
||||
"""Initialize the dependency graph.
|
||||
|
||||
@@ -103,7 +104,7 @@ class HandlerGraph:
|
||||
|
||||
def build_execution_plan(
|
||||
handlers: Sequence[Handler],
|
||||
dependencies: dict[Handler, list[Depends]],
|
||||
dependencies: dict[Handler, list[Depends[Any]]],
|
||||
) -> ExecutionPlan:
|
||||
"""Build an execution plan from handlers and their dependencies.
|
||||
|
||||
@@ -118,7 +119,7 @@ def build_execution_plan(
|
||||
Raises:
|
||||
CircularDependencyError: If circular dependencies are detected
|
||||
"""
|
||||
handler_dict: dict[Handler, list[Depends]] = {
|
||||
handler_dict: dict[Handler, list[Depends[Any]]] = {
|
||||
h: dependencies.get(h, []) for h in handlers
|
||||
}
|
||||
|
||||
|
||||
@@ -65,9 +65,9 @@ class FirstTimeTraceHandler:
|
||||
self._gracefully_fail(f"Error in trace handling: {e}")
|
||||
mark_first_execution_completed(user_consented=False)
|
||||
|
||||
def _initialize_backend_and_send_events(self):
|
||||
def _initialize_backend_and_send_events(self) -> None:
|
||||
"""Initialize backend batch and send collected events."""
|
||||
if not self.batch_manager:
|
||||
if not self.batch_manager or not self.batch_manager.trace_batch_id:
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -115,12 +115,13 @@ class FirstTimeTraceHandler:
|
||||
except Exception as e:
|
||||
self._gracefully_fail(f"Backend initialization failed: {e}")
|
||||
|
||||
def _display_ephemeral_trace_link(self):
|
||||
def _display_ephemeral_trace_link(self) -> None:
|
||||
"""Display the ephemeral trace link to the user and automatically open browser."""
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
webbrowser.open(self.ephemeral_url)
|
||||
if self.ephemeral_url:
|
||||
webbrowser.open(self.ephemeral_url)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@@ -158,7 +159,7 @@ To disable tracing later, do any one of these:
|
||||
console.print(panel)
|
||||
console.print()
|
||||
|
||||
def _show_tracing_declined_message(self):
|
||||
def _show_tracing_declined_message(self) -> None:
|
||||
"""Show message when user declines tracing."""
|
||||
console = Console()
|
||||
|
||||
@@ -184,15 +185,18 @@ To enable tracing later, do any one of these:
|
||||
console.print(panel)
|
||||
console.print()
|
||||
|
||||
def _gracefully_fail(self, error_message: str):
|
||||
def _gracefully_fail(self, error_message: str) -> None:
|
||||
"""Handle errors gracefully without disrupting user experience."""
|
||||
console = Console()
|
||||
console.print(f"[yellow]Note: {error_message}[/yellow]")
|
||||
|
||||
logger.debug(f"First-time trace error: {error_message}")
|
||||
|
||||
def _show_local_trace_message(self):
|
||||
def _show_local_trace_message(self) -> None:
|
||||
"""Show message when traces were collected locally but couldn't be uploaded."""
|
||||
if self.batch_manager is None:
|
||||
return
|
||||
|
||||
console = Console()
|
||||
|
||||
panel_content = f"""
|
||||
|
||||
@@ -6,6 +6,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_events import BaseEvent
|
||||
@@ -25,16 +26,9 @@ class AgentExecutionStartedEvent(BaseEvent):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
def set_fingerprint_data(self) -> Self:
|
||||
"""Set fingerprint data from the agent if available."""
|
||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
self.source_type = "agent"
|
||||
if (
|
||||
hasattr(self.agent.fingerprint, "metadata")
|
||||
and self.agent.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
||||
_set_agent_fingerprint(self, self.agent)
|
||||
return self
|
||||
|
||||
|
||||
@@ -49,16 +43,9 @@ class AgentExecutionCompletedEvent(BaseEvent):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
def set_fingerprint_data(self) -> Self:
|
||||
"""Set fingerprint data from the agent if available."""
|
||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
self.source_type = "agent"
|
||||
if (
|
||||
hasattr(self.agent.fingerprint, "metadata")
|
||||
and self.agent.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
||||
_set_agent_fingerprint(self, self.agent)
|
||||
return self
|
||||
|
||||
|
||||
@@ -73,16 +60,9 @@ class AgentExecutionErrorEvent(BaseEvent):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
def set_fingerprint_data(self) -> Self:
|
||||
"""Set fingerprint data from the agent if available."""
|
||||
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
self.source_type = "agent"
|
||||
if (
|
||||
hasattr(self.agent.fingerprint, "metadata")
|
||||
and self.agent.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.agent.fingerprint.metadata
|
||||
_set_agent_fingerprint(self, self.agent)
|
||||
return self
|
||||
|
||||
|
||||
@@ -140,3 +120,13 @@ class AgentEvaluationFailedEvent(BaseEvent):
|
||||
iteration: int
|
||||
error: str
|
||||
type: str = "agent_evaluation_failed"
|
||||
|
||||
|
||||
def _set_agent_fingerprint(event: BaseEvent, agent: BaseAgent) -> None:
|
||||
"""Set fingerprint data on an event from an agent object."""
|
||||
fp = agent.security_config.fingerprint
|
||||
if fp is not None:
|
||||
event.source_fingerprint = fp.uuid_str
|
||||
event.source_type = "agent"
|
||||
if fp.metadata:
|
||||
event.fingerprint_metadata = fp.metadata
|
||||
|
||||
@@ -15,21 +15,18 @@ class CrewBaseEvent(BaseEvent):
|
||||
crew_name: str | None
|
||||
crew: Crew | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self.set_crew_fingerprint()
|
||||
self._set_crew_fingerprint()
|
||||
|
||||
def set_crew_fingerprint(self) -> None:
|
||||
if self.crew and hasattr(self.crew, "fingerprint") and self.crew.fingerprint:
|
||||
def _set_crew_fingerprint(self) -> None:
|
||||
if self.crew is not None and self.crew.fingerprint:
|
||||
self.source_fingerprint = self.crew.fingerprint.uuid_str
|
||||
self.source_type = "crew"
|
||||
if (
|
||||
hasattr(self.crew.fingerprint, "metadata")
|
||||
and self.crew.fingerprint.metadata
|
||||
):
|
||||
if self.crew.fingerprint.metadata:
|
||||
self.fingerprint_metadata = self.crew.fingerprint.metadata
|
||||
|
||||
def to_json(self, exclude: set[str] | None = None):
|
||||
def to_json(self, exclude: set[str] | None = None) -> Any:
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
exclude.add("crew")
|
||||
|
||||
@@ -11,7 +11,7 @@ class KnowledgeEventBase(BaseEvent):
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
@@ -13,7 +13,7 @@ class LLMGuardrailBaseEvent(BaseEvent):
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
@@ -28,10 +28,10 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
|
||||
"""
|
||||
|
||||
type: str = "llm_guardrail_started"
|
||||
guardrail: str | Callable
|
||||
guardrail: str | Callable[..., Any]
|
||||
retry_count: int
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
@@ -39,7 +39,7 @@ class LLMGuardrailStartedEvent(LLMGuardrailBaseEvent):
|
||||
|
||||
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
|
||||
self.guardrail = self.guardrail.description.strip()
|
||||
elif isinstance(self.guardrail, Callable):
|
||||
elif callable(self.guardrail):
|
||||
self.guardrail = getsource(self.guardrail).strip()
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class MCPEvent(BaseEvent):
|
||||
from_agent: Any | None = None
|
||||
from_task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
@@ -15,7 +15,7 @@ class ReasoningEvent(BaseEvent):
|
||||
agent_id: str | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
self._set_task_params(data)
|
||||
self._set_agent_params(data)
|
||||
|
||||
@@ -4,6 +4,15 @@ from crewai.events.base_events import BaseEvent
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
def _set_task_fingerprint(event: BaseEvent, task: Any) -> None:
|
||||
"""Set fingerprint data on an event from a task object."""
|
||||
if task is not None and task.fingerprint:
|
||||
event.source_fingerprint = task.fingerprint.uuid_str
|
||||
event.source_type = "task"
|
||||
if task.fingerprint.metadata:
|
||||
event.fingerprint_metadata = task.fingerprint.metadata
|
||||
|
||||
|
||||
class TaskStartedEvent(BaseEvent):
|
||||
"""Event emitted when a task starts"""
|
||||
|
||||
@@ -11,17 +20,9 @@ class TaskStartedEvent(BaseEvent):
|
||||
context: str | None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
|
||||
class TaskCompletedEvent(BaseEvent):
|
||||
@@ -31,17 +32,9 @@ class TaskCompletedEvent(BaseEvent):
|
||||
type: str = "task_completed"
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
|
||||
class TaskFailedEvent(BaseEvent):
|
||||
@@ -51,17 +44,9 @@ class TaskFailedEvent(BaseEvent):
|
||||
type: str = "task_failed"
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
|
||||
class TaskEvaluationEvent(BaseEvent):
|
||||
@@ -71,14 +56,6 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
evaluation_type: str
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the task
|
||||
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
|
||||
self.source_fingerprint = self.task.fingerprint.uuid_str
|
||||
self.source_type = "task"
|
||||
if (
|
||||
hasattr(self.task.fingerprint, "metadata")
|
||||
and self.task.fingerprint.metadata
|
||||
):
|
||||
self.fingerprint_metadata = self.task.fingerprint.metadata
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
import inspect
|
||||
import json
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, GetCoreSchemaHandler
|
||||
@@ -22,7 +22,11 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.core.providers.human_input import get_provider
|
||||
from crewai.core.providers.human_input import (
|
||||
AsyncExecutorContext,
|
||||
ExecutorContext,
|
||||
get_provider,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled_in_context,
|
||||
@@ -89,7 +93,7 @@ from crewai.utilities.planning_types import (
|
||||
TodoList,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.step_execution_context import StepExecutionContext
|
||||
from crewai.utilities.step_execution_context import StepExecutionContext, StepResult
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -105,6 +109,8 @@ if TYPE_CHECKING:
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
|
||||
_RouteT = TypeVar("_RouteT", bound=str)
|
||||
|
||||
|
||||
class AgentExecutorState(BaseModel):
|
||||
"""Structured state for agent executor flow.
|
||||
@@ -446,29 +452,29 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
step failures reliably trigger replanning rather than being
|
||||
silently ignored.
|
||||
"""
|
||||
config = getattr(self.agent, "planning_config", None)
|
||||
if config is not None and hasattr(config, "reasoning_effort"):
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.reasoning_effort
|
||||
return "medium"
|
||||
|
||||
def _get_max_replans(self) -> int:
|
||||
"""Get max replans from planning config or default to 3."""
|
||||
config = getattr(self.agent, "planning_config", None)
|
||||
if config is not None and hasattr(config, "max_replans"):
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.max_replans
|
||||
return 3
|
||||
|
||||
def _get_max_step_iterations(self) -> int:
|
||||
"""Get max step iterations from planning config or default to 15."""
|
||||
config = getattr(self.agent, "planning_config", None)
|
||||
if config is not None and hasattr(config, "max_step_iterations"):
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.max_step_iterations
|
||||
return 15
|
||||
|
||||
def _get_step_timeout(self) -> int | None:
|
||||
"""Get per-step timeout from planning config or default to None."""
|
||||
config = getattr(self.agent, "planning_config", None)
|
||||
if config is not None and hasattr(config, "step_timeout"):
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.step_timeout
|
||||
return None
|
||||
|
||||
@@ -1130,9 +1136,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
# Process results: store on todos and log, then observe each.
|
||||
# asyncio.gather preserves input order, so zip gives us the exact
|
||||
# todo ↔ result (or exception) mapping.
|
||||
step_results: list[tuple[TodoItem, object]] = []
|
||||
step_results: list[tuple[TodoItem, StepResult]] = []
|
||||
for todo, item in zip(ready, gathered, strict=True):
|
||||
if isinstance(item, Exception):
|
||||
if isinstance(item, BaseException):
|
||||
error_msg = f"Error: {item!s}"
|
||||
todo.result = error_msg
|
||||
self.state.todos.mark_failed(todo.step_number, result=error_msg)
|
||||
@@ -1143,31 +1149,34 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
)
|
||||
else:
|
||||
_returned_todo, result = item
|
||||
todo.result = result.result
|
||||
step_result = cast(StepResult, result)
|
||||
todo.result = step_result.result
|
||||
|
||||
self.state.execution_log.append(
|
||||
{
|
||||
"type": "step_execution",
|
||||
"step_number": todo.step_number,
|
||||
"success": result.success,
|
||||
"result_preview": result.result[:200] if result.result else "",
|
||||
"error": result.error,
|
||||
"tool_calls": result.tool_calls_made,
|
||||
"execution_time": result.execution_time,
|
||||
"success": step_result.success,
|
||||
"result_preview": step_result.result[:200]
|
||||
if step_result.result
|
||||
else "",
|
||||
"error": step_result.error,
|
||||
"tool_calls": step_result.tool_calls_made,
|
||||
"execution_time": step_result.execution_time,
|
||||
}
|
||||
)
|
||||
|
||||
if self.agent.verbose:
|
||||
status = "success" if result.success else "failed"
|
||||
status = "success" if step_result.success else "failed"
|
||||
self._printer.print(
|
||||
content=(
|
||||
f"[Execute] Step {todo.step_number} {status} "
|
||||
f"({result.execution_time:.1f}s, "
|
||||
f"{len(result.tool_calls_made)} tool calls)"
|
||||
f"({step_result.execution_time:.1f}s, "
|
||||
f"{len(step_result.tool_calls_made)} tool calls)"
|
||||
),
|
||||
color="green" if result.success else "red",
|
||||
color="green" if step_result.success else "red",
|
||||
)
|
||||
step_results.append((todo, result))
|
||||
step_results.append((todo, step_result))
|
||||
|
||||
# Observe each completed step sequentially (observation updates shared state)
|
||||
effort = self._get_reasoning_effort()
|
||||
@@ -1431,8 +1440,8 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
raise
|
||||
|
||||
def _route_finish_with_todos(
|
||||
self, default_route: str
|
||||
) -> Literal["native_finished", "agent_finished", "todo_satisfied"]:
|
||||
self, default_route: _RouteT
|
||||
) -> _RouteT | Literal["todo_satisfied"]:
|
||||
"""Helper to route finish events, checking for pending todos first.
|
||||
|
||||
If there are pending todos, route to todo_satisfied instead of the
|
||||
@@ -1448,7 +1457,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
current_todo = self.state.todos.current_todo
|
||||
if current_todo:
|
||||
return "todo_satisfied"
|
||||
return default_route # type: ignore[return-value]
|
||||
return default_route
|
||||
|
||||
@router(call_llm_and_parse)
|
||||
def route_by_answer_type(
|
||||
@@ -2063,7 +2072,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
elif not self.state.current_answer and self.state.messages:
|
||||
# For native tools, results are in the message history as 'tool' roles
|
||||
# We take the content of the most recent tool results
|
||||
tool_results = []
|
||||
tool_results: list[str] = []
|
||||
for msg in reversed(self.state.messages):
|
||||
if msg.get("role") == "tool":
|
||||
tool_results.insert(0, str(msg.get("content", "")))
|
||||
@@ -3003,7 +3012,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
Final answer after feedback.
|
||||
"""
|
||||
provider = get_provider()
|
||||
return provider.handle_feedback(formatted_answer, self)
|
||||
return provider.handle_feedback(formatted_answer, cast("ExecutorContext", self))
|
||||
|
||||
async def _ahandle_human_feedback(
|
||||
self, formatted_answer: AgentFinish
|
||||
@@ -3017,7 +3026,9 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
Final answer after feedback.
|
||||
"""
|
||||
provider = get_provider()
|
||||
return await provider.handle_feedback_async(formatted_answer, self)
|
||||
return await provider.handle_feedback_async(
|
||||
formatted_answer, cast("AsyncExecutorContext", self)
|
||||
)
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if training mode is active.
|
||||
|
||||
@@ -37,11 +37,11 @@ class ExecutionState:
|
||||
current_agent_id: str | None = None
|
||||
current_task_id: str | None = None
|
||||
|
||||
def __init__(self):
|
||||
self.traces = {}
|
||||
self.iteration = 1
|
||||
self.iterations_results = {}
|
||||
self.agent_evaluators = {}
|
||||
def __init__(self) -> None:
|
||||
self.traces: dict[str, Any] = {}
|
||||
self.iteration: int = 1
|
||||
self.iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]] = {}
|
||||
self.agent_evaluators: dict[str, Sequence[BaseEvaluator] | None] = {}
|
||||
|
||||
|
||||
class AgentEvaluator:
|
||||
@@ -295,7 +295,7 @@ class AgentEvaluator:
|
||||
|
||||
def emit_evaluation_started_event(
|
||||
self, agent_role: str, agent_id: str, task_id: str | None = None
|
||||
):
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentEvaluationStartedEvent(
|
||||
@@ -313,7 +313,7 @@ class AgentEvaluator:
|
||||
task_id: str | None = None,
|
||||
metric_category: MetricCategory | None = None,
|
||||
score: EvaluationScore | None = None,
|
||||
):
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentEvaluationCompletedEvent(
|
||||
@@ -328,7 +328,7 @@ class AgentEvaluator:
|
||||
|
||||
def emit_evaluation_failed_event(
|
||||
self, agent_role: str, agent_id: str, error: str, task_id: str | None = None
|
||||
):
|
||||
) -> None:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
AgentEvaluationFailedEvent(
|
||||
@@ -341,7 +341,9 @@ class AgentEvaluator:
|
||||
)
|
||||
|
||||
|
||||
def create_default_evaluator(agents: list[Agent] | list[BaseAgent], llm: None = None):
|
||||
def create_default_evaluator(
|
||||
agents: list[Agent] | list[BaseAgent], llm: None = None
|
||||
) -> AgentEvaluator:
|
||||
from crewai.experimental.evaluation import (
|
||||
GoalAlignmentEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
@@ -25,7 +25,7 @@ class MetricCategory(enum.Enum):
|
||||
PARAMETER_EXTRACTION = "parameter_extraction"
|
||||
TOOL_INVOCATION = "tool_invocation"
|
||||
|
||||
def title(self):
|
||||
def title(self) -> str:
|
||||
return self.value.replace("_", " ").title()
|
||||
|
||||
|
||||
|
||||
@@ -18,12 +18,12 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class EvaluationDisplayFormatter:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
|
||||
def display_evaluation_with_feedback(
|
||||
self, iterations_results: dict[int, dict[str, list[Any]]]
|
||||
):
|
||||
) -> None:
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
"[yellow]No evaluation results to display[/yellow]"
|
||||
@@ -103,7 +103,7 @@ class EvaluationDisplayFormatter:
|
||||
def display_summary_results(
|
||||
self,
|
||||
iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]],
|
||||
):
|
||||
) -> None:
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
"[yellow]No evaluation results to display[/yellow]"
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
"""Event listener for collecting execution traces for evaluation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
@@ -30,47 +35,63 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
retrievals, and final output - all for use in agent evaluation.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_instance: EvaluationTraceCallback | None = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> EvaluationTraceCallback:
|
||||
"""Create or return the singleton instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized") or not self._initialized:
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the evaluation trace callback."""
|
||||
if not self._initialized:
|
||||
super().__init__()
|
||||
self.traces = {}
|
||||
self.current_agent_id = None
|
||||
self.current_task_id = None
|
||||
self.traces: dict[str, Any] = {}
|
||||
self.current_agent_id: UUID | str | None = None
|
||||
self.current_task_id: UUID | str | None = None
|
||||
self.current_llm_call: dict[str, Any] = {}
|
||||
self._initialized = True
|
||||
|
||||
def setup_listeners(self, event_bus: CrewAIEventsBus):
|
||||
def setup_listeners(self, event_bus: CrewAIEventsBus) -> None:
|
||||
"""Set up event listeners on the event bus.
|
||||
|
||||
Args:
|
||||
event_bus: The event bus to register listeners on.
|
||||
"""
|
||||
|
||||
@event_bus.on(AgentExecutionStartedEvent)
|
||||
def on_agent_started(source, event: AgentExecutionStartedEvent):
|
||||
def on_agent_started(source: Any, event: AgentExecutionStartedEvent) -> None:
|
||||
self.on_agent_start(event.agent, event.task)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def on_lite_agent_started(source, event: LiteAgentExecutionStartedEvent):
|
||||
def on_lite_agent_started(
|
||||
source: Any, event: LiteAgentExecutionStartedEvent
|
||||
) -> None:
|
||||
self.on_lite_agent_start(event.agent_info)
|
||||
|
||||
@event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_completed(source, event: AgentExecutionCompletedEvent):
|
||||
def on_agent_completed(
|
||||
source: Any, event: AgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
self.on_agent_finish(event.agent, event.task, event.output)
|
||||
|
||||
@event_bus.on(LiteAgentExecutionCompletedEvent)
|
||||
def on_lite_agent_completed(source, event: LiteAgentExecutionCompletedEvent):
|
||||
def on_lite_agent_completed(
|
||||
source: Any, event: LiteAgentExecutionCompletedEvent
|
||||
) -> None:
|
||||
self.on_lite_agent_finish(event.output)
|
||||
|
||||
@event_bus.on(ToolUsageFinishedEvent)
|
||||
def on_tool_completed(source, event: ToolUsageFinishedEvent):
|
||||
def on_tool_completed(source: Any, event: ToolUsageFinishedEvent) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name, event.tool_args, event.output, success=True
|
||||
)
|
||||
|
||||
@event_bus.on(ToolUsageErrorEvent)
|
||||
def on_tool_usage_error(source, event: ToolUsageErrorEvent):
|
||||
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
@@ -80,7 +101,9 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
)
|
||||
|
||||
@event_bus.on(ToolExecutionErrorEvent)
|
||||
def on_tool_execution_error(source, event: ToolExecutionErrorEvent):
|
||||
def on_tool_execution_error(
|
||||
source: Any, event: ToolExecutionErrorEvent
|
||||
) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
@@ -90,7 +113,9 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
)
|
||||
|
||||
@event_bus.on(ToolSelectionErrorEvent)
|
||||
def on_tool_selection_error(source, event: ToolSelectionErrorEvent):
|
||||
def on_tool_selection_error(
|
||||
source: Any, event: ToolSelectionErrorEvent
|
||||
) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
@@ -100,7 +125,9 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
)
|
||||
|
||||
@event_bus.on(ToolValidateInputErrorEvent)
|
||||
def on_tool_validate_input_error(source, event: ToolValidateInputErrorEvent):
|
||||
def on_tool_validate_input_error(
|
||||
source: Any, event: ToolValidateInputErrorEvent
|
||||
) -> None:
|
||||
self.on_tool_use(
|
||||
event.tool_name,
|
||||
event.tool_args,
|
||||
@@ -110,14 +137,19 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
)
|
||||
|
||||
@event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(source, event: LLMCallStartedEvent):
|
||||
def on_llm_call_started(source: Any, event: LLMCallStartedEvent) -> None:
|
||||
self.on_llm_call_start(event.messages, event.tools)
|
||||
|
||||
@event_bus.on(LLMCallCompletedEvent)
|
||||
def on_llm_call_completed(source, event: LLMCallCompletedEvent):
|
||||
def on_llm_call_completed(source: Any, event: LLMCallCompletedEvent) -> None:
|
||||
self.on_llm_call_end(event.messages, event.response)
|
||||
|
||||
def on_lite_agent_start(self, agent_info: dict[str, Any]):
|
||||
def on_lite_agent_start(self, agent_info: dict[str, Any]) -> None:
|
||||
"""Handle a lite agent execution start event.
|
||||
|
||||
Args:
|
||||
agent_info: Dictionary containing agent information.
|
||||
"""
|
||||
self.current_agent_id = agent_info["id"]
|
||||
self.current_task_id = "lite_task"
|
||||
|
||||
@@ -132,10 +164,22 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
final_output=None,
|
||||
)
|
||||
|
||||
def _init_trace(self, trace_key: str, **kwargs: Any):
|
||||
def _init_trace(self, trace_key: str, **kwargs: Any) -> None:
|
||||
"""Initialize a trace entry.
|
||||
|
||||
Args:
|
||||
trace_key: The key to store the trace under.
|
||||
**kwargs: Trace metadata to store.
|
||||
"""
|
||||
self.traces[trace_key] = kwargs
|
||||
|
||||
def on_agent_start(self, agent: BaseAgent, task: Task):
|
||||
def on_agent_start(self, agent: BaseAgent, task: Task) -> None:
|
||||
"""Handle an agent execution start event.
|
||||
|
||||
Args:
|
||||
agent: The agent that started execution.
|
||||
task: The task being executed.
|
||||
"""
|
||||
self.current_agent_id = agent.id
|
||||
self.current_task_id = task.id
|
||||
|
||||
@@ -150,7 +194,14 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
final_output=None,
|
||||
)
|
||||
|
||||
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any):
|
||||
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any) -> None:
|
||||
"""Handle an agent execution completion event.
|
||||
|
||||
Args:
|
||||
agent: The agent that finished execution.
|
||||
task: The task that was executed.
|
||||
output: The agent's output.
|
||||
"""
|
||||
trace_key = f"{agent.id}_{task.id}"
|
||||
if trace_key in self.traces:
|
||||
self.traces[trace_key]["final_output"] = output
|
||||
@@ -158,11 +209,17 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
|
||||
self._reset_current()
|
||||
|
||||
def _reset_current(self):
|
||||
def _reset_current(self) -> None:
|
||||
"""Reset the current agent and task tracking state."""
|
||||
self.current_agent_id = None
|
||||
self.current_task_id = None
|
||||
|
||||
def on_lite_agent_finish(self, output: Any):
|
||||
def on_lite_agent_finish(self, output: Any) -> None:
|
||||
"""Handle a lite agent execution completion event.
|
||||
|
||||
Args:
|
||||
output: The agent's output.
|
||||
"""
|
||||
trace_key = f"{self.current_agent_id}_lite_task"
|
||||
if trace_key in self.traces:
|
||||
self.traces[trace_key]["final_output"] = output
|
||||
@@ -177,13 +234,22 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
result: Any,
|
||||
success: bool = True,
|
||||
error_type: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Record a tool usage event in the current trace.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool used.
|
||||
tool_args: Arguments passed to the tool.
|
||||
result: The tool's output or error message.
|
||||
success: Whether the tool call succeeded.
|
||||
error_type: Type of error if the call failed.
|
||||
"""
|
||||
if not self.current_agent_id or not self.current_task_id:
|
||||
return
|
||||
|
||||
trace_key = f"{self.current_agent_id}_{self.current_task_id}"
|
||||
if trace_key in self.traces:
|
||||
tool_use = {
|
||||
tool_use: dict[str, Any] = {
|
||||
"tool": tool_name,
|
||||
"args": tool_args,
|
||||
"result": result,
|
||||
@@ -191,7 +257,6 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
|
||||
# Add error information if applicable
|
||||
if not success and error_type:
|
||||
tool_use["error"] = True
|
||||
tool_use["error_type"] = error_type
|
||||
@@ -202,7 +267,13 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
self,
|
||||
messages: str | Sequence[dict[str, Any]] | None,
|
||||
tools: Sequence[dict[str, Any]] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Record an LLM call start event.
|
||||
|
||||
Args:
|
||||
messages: The messages sent to the LLM.
|
||||
tools: Tool definitions provided to the LLM.
|
||||
"""
|
||||
if not self.current_agent_id or not self.current_task_id:
|
||||
return
|
||||
|
||||
@@ -220,7 +291,13 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
|
||||
def on_llm_call_end(
|
||||
self, messages: str | list[dict[str, Any]] | None, response: Any
|
||||
):
|
||||
) -> None:
|
||||
"""Record an LLM call completion event.
|
||||
|
||||
Args:
|
||||
messages: The messages from the LLM call.
|
||||
response: The LLM response object.
|
||||
"""
|
||||
if not self.current_agent_id or not self.current_task_id:
|
||||
return
|
||||
|
||||
@@ -229,17 +306,18 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
return
|
||||
|
||||
total_tokens = 0
|
||||
if hasattr(response, "usage") and hasattr(response.usage, "total_tokens"):
|
||||
total_tokens = response.usage.total_tokens
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage is not None:
|
||||
total_tokens = getattr(usage, "total_tokens", 0)
|
||||
|
||||
current_time = datetime.now()
|
||||
start_time = None
|
||||
if hasattr(self, "current_llm_call") and self.current_llm_call:
|
||||
start_time = self.current_llm_call.get("start_time")
|
||||
start_time = (
|
||||
self.current_llm_call.get("start_time") if self.current_llm_call else None
|
||||
)
|
||||
|
||||
if not start_time:
|
||||
start_time = current_time
|
||||
llm_call = {
|
||||
llm_call: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"response": response,
|
||||
"start_time": start_time,
|
||||
@@ -248,16 +326,28 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
}
|
||||
|
||||
self.traces[trace_key]["llm_calls"].append(llm_call)
|
||||
|
||||
if hasattr(self, "current_llm_call"):
|
||||
self.current_llm_call = {}
|
||||
self.current_llm_call = {}
|
||||
|
||||
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
|
||||
"""Retrieve a trace by agent and task ID.
|
||||
|
||||
Args:
|
||||
agent_id: The agent's identifier.
|
||||
task_id: The task's identifier.
|
||||
|
||||
Returns:
|
||||
The trace dictionary, or None if not found.
|
||||
"""
|
||||
trace_key = f"{agent_id}_{task_id}"
|
||||
return self.traces.get(trace_key)
|
||||
|
||||
|
||||
def create_evaluation_callbacks() -> EvaluationTraceCallback:
|
||||
"""Create and register an evaluation trace callback on the event bus.
|
||||
|
||||
Returns:
|
||||
The configured EvaluationTraceCallback instance.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
callback = EvaluationTraceCallback()
|
||||
|
||||
@@ -8,10 +8,10 @@ from crewai.experimental.evaluation.experiment.result import ExperimentResults
|
||||
|
||||
|
||||
class ExperimentResultsDisplay:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.console = Console()
|
||||
|
||||
def summary(self, experiment_results: ExperimentResults):
|
||||
def summary(self, experiment_results: ExperimentResults) -> None:
|
||||
total = len(experiment_results.results)
|
||||
passed = sum(1 for r in experiment_results.results if r.passed)
|
||||
|
||||
@@ -28,7 +28,9 @@ class ExperimentResultsDisplay:
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
|
||||
def comparison_summary(
|
||||
self, comparison: dict[str, Any], baseline_timestamp: str
|
||||
) -> None:
|
||||
self.console.print(
|
||||
Panel(
|
||||
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
|
||||
from crewai.experimental.evaluation.evaluation_display import (
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.result import (
|
||||
|
||||
@@ -7,7 +7,8 @@ from typing import Any
|
||||
|
||||
def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
||||
try:
|
||||
return json.loads(text)
|
||||
result: dict[str, Any] = json.loads(text)
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
@@ -24,7 +25,8 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
||||
matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
|
||||
for match in matches:
|
||||
try:
|
||||
return json.loads(match.strip())
|
||||
parsed: dict[str, Any] = json.loads(match.strip())
|
||||
return parsed
|
||||
except json.JSONDecodeError: # noqa: PERF203
|
||||
continue
|
||||
raise ValueError("No valid JSON found in the response")
|
||||
|
||||
@@ -68,7 +68,7 @@ Evaluate how well the agent's output aligns with the assigned task goal.
|
||||
]
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt) # type: ignore[arg-type]
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
||||
|
||||
@@ -224,7 +224,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
|
||||
def _detect_loops(
|
||||
self, llm_calls: list[dict[str, Any]]
|
||||
) -> tuple[bool, list[dict[str, Any]]]:
|
||||
loop_details = []
|
||||
|
||||
messages = []
|
||||
@@ -272,7 +274,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
|
||||
def _analyze_reasoning_patterns(
|
||||
self, llm_calls: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
call_lengths = []
|
||||
response_times = []
|
||||
|
||||
@@ -345,7 +349,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
max_possible_slope = max(values) - min(values)
|
||||
if max_possible_slope > 0:
|
||||
normalized_slope = slope / max_possible_slope
|
||||
return max(min(normalized_slope, 1.0), -1.0)
|
||||
return float(max(min(normalized_slope, 1.0), -1.0))
|
||||
return 0.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
@@ -384,7 +388,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
|
||||
return float(np.mean(indicators)) if indicators else 0.0
|
||||
|
||||
def _get_call_samples(self, llm_calls: list[dict]) -> str:
|
||||
def _get_call_samples(self, llm_calls: list[dict[str, Any]]) -> str:
|
||||
samples = []
|
||||
|
||||
if len(llm_calls) <= 6:
|
||||
|
||||
@@ -299,15 +299,15 @@ def _extract_all_methods_from_condition(
|
||||
return []
|
||||
if isinstance(condition, dict):
|
||||
conditions_list = condition.get("conditions", [])
|
||||
methods: list[str] = []
|
||||
dict_methods: list[str] = []
|
||||
for sub_cond in conditions_list:
|
||||
methods.extend(_extract_all_methods_from_condition(sub_cond))
|
||||
return methods
|
||||
dict_methods.extend(_extract_all_methods_from_condition(sub_cond))
|
||||
return dict_methods
|
||||
if isinstance(condition, list):
|
||||
methods = []
|
||||
list_methods: list[str] = []
|
||||
for item in condition:
|
||||
methods.extend(_extract_all_methods_from_condition(item))
|
||||
return methods
|
||||
list_methods.extend(_extract_all_methods_from_condition(item))
|
||||
return list_methods
|
||||
return []
|
||||
|
||||
|
||||
@@ -476,7 +476,8 @@ def _detect_flow_inputs(flow_class: type) -> list[str]:
|
||||
|
||||
# Check for inputs in __init__ signature beyond standard Flow params
|
||||
try:
|
||||
init_sig = inspect.signature(flow_class.__init__)
|
||||
init_method = flow_class.__init__ # type: ignore[misc]
|
||||
init_sig = inspect.signature(init_method)
|
||||
standard_params = {
|
||||
"self",
|
||||
"persistence",
|
||||
|
||||
@@ -83,8 +83,11 @@ def _serialize_llm_for_context(llm: Any) -> dict[str, Any] | str | None:
|
||||
subclasses). Falls back to extracting the model string with provider
|
||||
prefix for unknown LLM types.
|
||||
"""
|
||||
if hasattr(llm, "to_config_dict"):
|
||||
return llm.to_config_dict()
|
||||
to_config: Callable[[], dict[str, Any]] | None = getattr(
|
||||
llm, "to_config_dict", None
|
||||
)
|
||||
if to_config is not None:
|
||||
return to_config()
|
||||
|
||||
# Fallback for non-BaseLLM objects: just extract model + provider prefix
|
||||
model = getattr(llm, "model", None)
|
||||
@@ -371,8 +374,11 @@ def human_feedback(
|
||||
) -> Any:
|
||||
"""Recall past HITL lessons and use LLM to pre-review the output."""
|
||||
try:
|
||||
mem = flow_instance.memory
|
||||
if mem is None:
|
||||
return method_output
|
||||
query = f"human feedback lessons for {func.__name__}: {method_output!s}"
|
||||
matches = flow_instance.memory.recall(query, source=learn_source)
|
||||
matches = mem.recall(query, source=learn_source)
|
||||
if not matches:
|
||||
return method_output
|
||||
|
||||
@@ -404,6 +410,9 @@ def human_feedback(
|
||||
) -> None:
|
||||
"""Extract generalizable lessons from output + feedback, store in memory."""
|
||||
try:
|
||||
mem = flow_instance.memory
|
||||
if mem is None:
|
||||
return
|
||||
llm_inst = _resolve_llm_instance()
|
||||
prompt = _get_hitl_prompt("hitl_distill_user").format(
|
||||
method_name=func.__name__,
|
||||
@@ -435,7 +444,7 @@ def human_feedback(
|
||||
]
|
||||
|
||||
if lessons:
|
||||
flow_instance.memory.remember_many(lessons, source=learn_source)
|
||||
mem.remember_many(lessons, source=learn_source) # type: ignore[union-attr]
|
||||
except Exception: # noqa: S110
|
||||
pass # non-critical: don't fail the flow because lesson storage failed
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ def before_llm_call(
|
||||
"""
|
||||
from crewai.hooks.llm_hooks import register_before_llm_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||
hook_type="llm",
|
||||
register_function=register_before_llm_call_hook,
|
||||
marker_attribute="is_before_llm_call_hook",
|
||||
@@ -176,7 +176,7 @@ def after_llm_call(
|
||||
"""
|
||||
from crewai.hooks.llm_hooks import register_after_llm_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||
hook_type="llm",
|
||||
register_function=register_after_llm_call_hook,
|
||||
marker_attribute="is_after_llm_call_hook",
|
||||
@@ -237,7 +237,7 @@ def before_tool_call(
|
||||
"""
|
||||
from crewai.hooks.tool_hooks import register_before_tool_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||
hook_type="tool",
|
||||
register_function=register_before_tool_call_hook,
|
||||
marker_attribute="is_before_tool_call_hook",
|
||||
@@ -293,7 +293,7 @@ def after_tool_call(
|
||||
"""
|
||||
from crewai.hooks.tool_hooks import register_after_tool_call_hook
|
||||
|
||||
return _create_hook_decorator( # type: ignore[return-value]
|
||||
return _create_hook_decorator( # type: ignore[no-any-return]
|
||||
hook_type="tool",
|
||||
register_function=register_after_tool_call_hook,
|
||||
marker_attribute="is_after_tool_call_hook",
|
||||
|
||||
@@ -13,7 +13,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
chunk_size: int = 4000
|
||||
chunk_overlap: int = 200
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
@@ -28,7 +28,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
def add(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them."""
|
||||
|
||||
def get_embeddings(self) -> list[np.ndarray]:
|
||||
def get_embeddings(self) -> list[np.ndarray[Any, np.dtype[Any]]]:
|
||||
"""Return the list of embeddings for the chunks."""
|
||||
return self.chunk_embeddings
|
||||
|
||||
|
||||
@@ -2369,8 +2369,8 @@ class LLM(BaseLLM):
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
litellm.success_callback = success_callbacks # type: ignore[assignment]
|
||||
litellm.failure_callback = failure_callbacks # type: ignore[assignment]
|
||||
|
||||
def __copy__(self) -> LLM:
|
||||
"""Create a shallow copy of the LLM instance."""
|
||||
|
||||
@@ -222,6 +222,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Tool search config
|
||||
self.tool_search: AnthropicToolSearchConfig | None
|
||||
if tool_search is True:
|
||||
self.tool_search = AnthropicToolSearchConfig()
|
||||
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
"""MCP client with session management for CrewAI agents."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, NamedTuple
|
||||
from typing import Any, NamedTuple, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
# BaseExceptionGroup is available in Python 3.11+
|
||||
try:
|
||||
if sys.version_info >= (3, 11):
|
||||
from builtins import BaseExceptionGroup
|
||||
except ImportError:
|
||||
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
||||
BaseExceptionGroup = Exception
|
||||
else:
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.mcp_events import (
|
||||
@@ -47,8 +46,10 @@ MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||
MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers
|
||||
MCP_MAX_RETRIES = 3
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||
_mcp_schema_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
||||
_mcp_schema_cache: dict[str, tuple[list[dict[str, Any]], float]] = {}
|
||||
_cache_ttl = 300 # 5 minutes
|
||||
|
||||
|
||||
@@ -134,11 +135,7 @@ class MCPClient:
|
||||
else:
|
||||
server_name = "Unknown MCP Server"
|
||||
server_url = None
|
||||
transport_type = (
|
||||
self.transport.transport_type.value
|
||||
if hasattr(self.transport, "transport_type")
|
||||
else None
|
||||
)
|
||||
transport_type = self.transport.transport_type.value
|
||||
|
||||
return server_name, server_url, transport_type
|
||||
|
||||
@@ -542,7 +539,7 @@ class MCPClient:
|
||||
Returns:
|
||||
Cleaned arguments ready for MCP server.
|
||||
"""
|
||||
cleaned = {}
|
||||
cleaned: dict[str, Any] = {}
|
||||
|
||||
for key, value in arguments.items():
|
||||
# Skip None values
|
||||
@@ -686,9 +683,9 @@ class MCPClient:
|
||||
|
||||
async def _retry_operation(
|
||||
self,
|
||||
operation: Callable[[], Any],
|
||||
operation: Callable[[], Coroutine[Any, Any, _T]],
|
||||
timeout: int | None = None,
|
||||
) -> Any:
|
||||
) -> _T:
|
||||
"""Retry an operation with exponential backoff.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -23,6 +23,7 @@ from crewai.mcp.config import (
|
||||
MCPServerSSE,
|
||||
MCPServerStdio,
|
||||
)
|
||||
from crewai.mcp.transports.base import BaseTransport
|
||||
from crewai.mcp.transports.http import HTTPTransport
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
from crewai.mcp.transports.stdio import StdioTransport
|
||||
@@ -285,6 +286,7 @@ class MCPToolResolver:
|
||||
independent transport so that parallel tool executions never share
|
||||
state.
|
||||
"""
|
||||
transport: BaseTransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
|
||||
@@ -2,11 +2,17 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
from typing import Any
|
||||
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from mcp.shared.message import SessionMessage
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
MCPReadStream = MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
MCPWriteStream = MemoryObjectSendStream[SessionMessage]
|
||||
|
||||
|
||||
class TransportType(str, Enum):
|
||||
"""MCP transport types."""
|
||||
|
||||
@@ -16,22 +22,6 @@ class TransportType(str, Enum):
|
||||
SSE = "sse"
|
||||
|
||||
|
||||
class ReadStream(Protocol):
|
||||
"""Protocol for read streams."""
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""Read bytes from stream."""
|
||||
...
|
||||
|
||||
|
||||
class WriteStream(Protocol):
|
||||
"""Protocol for write streams."""
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Write bytes to stream."""
|
||||
...
|
||||
|
||||
|
||||
class BaseTransport(ABC):
|
||||
"""Base class for MCP transport implementations.
|
||||
|
||||
@@ -46,8 +36,8 @@ class BaseTransport(ABC):
|
||||
Args:
|
||||
**kwargs: Transport-specific configuration options.
|
||||
"""
|
||||
self._read_stream: ReadStream | None = None
|
||||
self._write_stream: WriteStream | None = None
|
||||
self._read_stream: MCPReadStream | None = None
|
||||
self._write_stream: MCPWriteStream | None = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
@@ -62,14 +52,14 @@ class BaseTransport(ABC):
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def read_stream(self) -> ReadStream:
|
||||
def read_stream(self) -> MCPReadStream:
|
||||
"""Get the read stream."""
|
||||
if self._read_stream is None:
|
||||
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||
return self._read_stream
|
||||
|
||||
@property
|
||||
def write_stream(self) -> WriteStream:
|
||||
def write_stream(self) -> MCPWriteStream:
|
||||
"""Get the write stream."""
|
||||
if self._write_stream is None:
|
||||
raise RuntimeError("Transport not connected. Call connect() first.")
|
||||
@@ -107,7 +97,7 @@ class BaseTransport(ABC):
|
||||
"""Async context manager exit."""
|
||||
...
|
||||
|
||||
def _set_streams(self, read: ReadStream, write: WriteStream) -> None:
|
||||
def _set_streams(self, read: MCPReadStream, write: MCPWriteStream) -> None:
|
||||
"""Set the read and write streams.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
"""HTTP and Streamable HTTP transport for MCP servers."""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
# BaseExceptionGroup is available in Python 3.11+
|
||||
try:
|
||||
if sys.version_info >= (3, 11):
|
||||
from builtins import BaseExceptionGroup
|
||||
except ImportError:
|
||||
# Fallback for Python < 3.11 (shouldn't happen in practice)
|
||||
BaseExceptionGroup = Exception
|
||||
else:
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
from crewai.mcp.transports.base import BaseTransport, TransportType
|
||||
|
||||
|
||||
@@ -122,11 +122,14 @@ class StdioTransport(BaseTransport):
|
||||
if self._process is not None:
|
||||
try:
|
||||
self._process.terminate()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
await asyncio.wait_for(
|
||||
loop.run_in_executor(None, self._process.wait), timeout=5.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self._process.kill()
|
||||
await self._process.wait()
|
||||
await loop.run_in_executor(None, self._process.wait)
|
||||
# except ProcessLookupError:
|
||||
# pass
|
||||
finally:
|
||||
|
||||
@@ -52,7 +52,7 @@ class ChromaDBClient(BaseClient):
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
embedding_function: ChromaEmbeddingFunction, # type: ignore[type-arg]
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
|
||||
@@ -23,7 +23,7 @@ from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSear
|
||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||
|
||||
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction): # type: ignore[type-arg]
|
||||
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
@@ -85,7 +85,7 @@ class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
||||
|
||||
configuration: CollectionConfigurationInterface
|
||||
metadata: CollectionMetadata
|
||||
embedding_function: ChromaEmbeddingFunction
|
||||
embedding_function: ChromaEmbeddingFunction # type: ignore[type-arg]
|
||||
data_loader: DataLoader[Loadable]
|
||||
get_or_create: bool
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
T = TypeVar("T", bound=EmbeddingFunction) # type: ignore[type-arg]
|
||||
|
||||
|
||||
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Core type definitions for RAG systems."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy import floating, integer, number
|
||||
@@ -16,7 +16,7 @@ Embedding = NDArray[np.int32 | np.float32]
|
||||
Embeddings = list[Embedding]
|
||||
|
||||
Documents = list[str]
|
||||
Images = list[np.ndarray]
|
||||
Images = list[np.ndarray[Any, np.dtype[np.generic]]]
|
||||
Embeddable = Documents | Images
|
||||
|
||||
ScalarType = TypeVar("ScalarType", bound=np.generic)
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing_extensions import Required, TypedDict
|
||||
class CustomProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Custom provider."""
|
||||
|
||||
embedding_callable: type[EmbeddingFunction]
|
||||
embedding_callable: type[EmbeddingFunction] # type: ignore[type-arg]
|
||||
|
||||
|
||||
class CustomProviderSpec(TypedDict, total=False):
|
||||
|
||||
@@ -85,7 +85,7 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
||||
"""
|
||||
# Handle deprecated 'region' parameter (only if it has a value)
|
||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
|
||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item,unused-ignore]
|
||||
if region_value is not None:
|
||||
warnings.warn(
|
||||
"The 'region' parameter is deprecated, use 'location' instead. "
|
||||
@@ -94,7 +94,7 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
stacklevel=2,
|
||||
)
|
||||
if "location" not in kwargs or kwargs.get("location") is None:
|
||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
|
||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key,unused-ignore]
|
||||
|
||||
self._config = kwargs
|
||||
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
||||
@@ -123,8 +123,10 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
)
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
import vertexai # type: ignore[import-not-found]
|
||||
from vertexai.language_models import ( # type: ignore[import-not-found]
|
||||
TextEmbeddingModel,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
||||
|
||||
@@ -18,7 +18,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
**kwargs: Configuration parameters for VoyageAI.
|
||||
"""
|
||||
try:
|
||||
import voyageai # type: ignore[import-not-found]
|
||||
import voyageai
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
@@ -26,7 +26,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"Install it with: uv add voyageai"
|
||||
) from e
|
||||
self._config = kwargs
|
||||
self._client = voyageai.Client(
|
||||
self._client = voyageai.Client( # type: ignore[attr-defined]
|
||||
api_key=kwargs["api_key"],
|
||||
max_retries=kwargs.get("max_retries", 0),
|
||||
timeout=kwargs.get("timeout"),
|
||||
|
||||
@@ -311,8 +311,7 @@ class QdrantClient(BaseClient):
|
||||
points = []
|
||||
for doc in batch_docs:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
embedding = await self.embedding_function(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
@@ -412,8 +411,7 @@ class QdrantClient(BaseClient):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
query_embedding = await async_fn(query)
|
||||
query_embedding = await self.embedding_function(query)
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
@@ -7,10 +7,10 @@ import numpy as np
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from qdrant_client import (
|
||||
AsyncQdrantClient, # type: ignore[import-not-found]
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
AsyncQdrantClient,
|
||||
QdrantClient as SyncQdrantClient,
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
HasIdCondition,
|
||||
|
||||
@@ -5,10 +5,10 @@ from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import (
|
||||
AsyncQdrantClient, # type: ignore[import-not-found]
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
AsyncQdrantClient,
|
||||
QdrantClient as SyncQdrantClient,
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
|
||||
@@ -16,7 +16,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
|
||||
@@ -580,7 +580,7 @@ class Task(BaseModel):
|
||||
tools = tools or self.tools or []
|
||||
|
||||
self.processed_by_agents.add(agent.role)
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
@@ -662,12 +662,12 @@ class Task(BaseModel):
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||
TaskCompletedEvent(output=task_output, task=self),
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
finally:
|
||||
clear_task_files(self.id)
|
||||
@@ -694,7 +694,7 @@ class Task(BaseModel):
|
||||
tools = tools or self.tools or []
|
||||
|
||||
self.processed_by_agents.add(agent.role)
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self))
|
||||
result = agent.execute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
@@ -777,12 +777,12 @@ class Task(BaseModel):
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||
TaskCompletedEvent(output=task_output, task=self),
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
finally:
|
||||
clear_task_files(self.id)
|
||||
|
||||
@@ -32,8 +32,8 @@ class ConditionalTask(Task):
|
||||
def __init__(
|
||||
self,
|
||||
condition: Callable[[Any], bool] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.condition = condition
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Task output representation and formatting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
@@ -44,7 +46,7 @@ class TaskOutput(BaseModel):
|
||||
messages: list[LLMMessage] = Field(description="Messages of the task", default=[])
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_summary(self):
|
||||
def set_summary(self) -> TaskOutput:
|
||||
"""Set the summary field based on the description.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -825,65 +825,64 @@ class Telemetry:
|
||||
span, crew, self._add_attribute, include_fingerprint=False
|
||||
)
|
||||
self._add_attribute(span, "crew_inputs", json.dumps(inputs or {}))
|
||||
|
||||
if crew.share_crew:
|
||||
self._add_attribute(
|
||||
span,
|
||||
"crew_agents",
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"key": agent.key,
|
||||
"id": str(agent.id),
|
||||
"role": agent.role,
|
||||
"goal": agent.goal,
|
||||
"backstory": agent.backstory,
|
||||
"verbose?": agent.verbose,
|
||||
"max_iter": agent.max_iter,
|
||||
"max_rpm": agent.max_rpm,
|
||||
"i18n": agent.i18n.prompt_file,
|
||||
"llm": agent.llm.model,
|
||||
"delegation_enabled?": agent.allow_delegation,
|
||||
"tools_names": [
|
||||
sanitize_tool_name(tool.name)
|
||||
for tool in agent.tools or []
|
||||
],
|
||||
}
|
||||
for agent in crew.agents
|
||||
]
|
||||
),
|
||||
)
|
||||
self._add_attribute(
|
||||
span,
|
||||
"crew_tasks",
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"id": str(task.id),
|
||||
"description": task.description,
|
||||
"expected_output": task.expected_output,
|
||||
"async_execution?": task.async_execution,
|
||||
"human_input?": task.human_input,
|
||||
"agent_role": task.agent.role if task.agent else "None",
|
||||
"agent_key": task.agent.key if task.agent else None,
|
||||
"context": (
|
||||
[task.description for task in task.context]
|
||||
if isinstance(task.context, list)
|
||||
else None
|
||||
),
|
||||
"tools_names": [
|
||||
sanitize_tool_name(tool.name)
|
||||
for tool in task.tools or []
|
||||
],
|
||||
}
|
||||
for task in crew.tasks
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
self._add_attribute(
|
||||
span,
|
||||
"crew_agents",
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"key": agent.key,
|
||||
"id": str(agent.id),
|
||||
"role": agent.role,
|
||||
"goal": agent.goal,
|
||||
"backstory": agent.backstory,
|
||||
"verbose?": agent.verbose,
|
||||
"max_iter": agent.max_iter,
|
||||
"max_rpm": agent.max_rpm,
|
||||
"i18n": agent.i18n.prompt_file,
|
||||
"llm": agent.llm.model,
|
||||
"delegation_enabled?": agent.allow_delegation,
|
||||
"tools_names": [
|
||||
sanitize_tool_name(tool.name)
|
||||
for tool in agent.tools or []
|
||||
],
|
||||
}
|
||||
for agent in crew.agents
|
||||
]
|
||||
),
|
||||
)
|
||||
self._add_attribute(
|
||||
span,
|
||||
"crew_tasks",
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"id": str(task.id),
|
||||
"description": task.description,
|
||||
"expected_output": task.expected_output,
|
||||
"async_execution?": task.async_execution,
|
||||
"human_input?": task.human_input,
|
||||
"agent_role": task.agent.role if task.agent else "None",
|
||||
"agent_key": task.agent.key if task.agent else None,
|
||||
"context": (
|
||||
[task.description for task in task.context]
|
||||
if isinstance(task.context, list)
|
||||
else None
|
||||
),
|
||||
"tools_names": [
|
||||
sanitize_tool_name(tool.name)
|
||||
for tool in task.tools or []
|
||||
],
|
||||
}
|
||||
for task in crew.tasks
|
||||
]
|
||||
),
|
||||
)
|
||||
return span
|
||||
|
||||
return self._safe_telemetry_operation(_operation)
|
||||
if crew.share_crew:
|
||||
return self._safe_telemetry_operation(_operation)
|
||||
return None
|
||||
|
||||
def end_crew(self, crew: Any, final_string_output: str) -> None:
|
||||
"""Records the end of crew execution.
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
@@ -27,8 +29,8 @@ class AddImageTool(BaseTool):
|
||||
self,
|
||||
image_url: str,
|
||||
action: str | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
action = action or i18n.tools("add_image")["default_action"] # type: ignore
|
||||
content = [
|
||||
{"type": "text", "text": action},
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
||||
@@ -20,7 +22,7 @@ class AskQuestionTool(BaseAgentTool):
|
||||
question: str,
|
||||
context: str,
|
||||
coworker: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
coworker = self._get_coworker(coworker, **kwargs)
|
||||
return self._execute(coworker, question, context)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
||||
@@ -22,7 +24,7 @@ class DelegateWorkTool(BaseAgentTool):
|
||||
task: str,
|
||||
context: str,
|
||||
coworker: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
coworker = self._get_coworker(coworker, **kwargs)
|
||||
return self._execute(coworker, task, context)
|
||||
|
||||
@@ -70,7 +70,7 @@ class MCPNativeTool(BaseTool):
|
||||
"""Get the server name."""
|
||||
return self._server_name
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Execute tool using the MCP client session.
|
||||
|
||||
Args:
|
||||
@@ -98,7 +98,7 @@ class MCPNativeTool(BaseTool):
|
||||
f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||
) from e
|
||||
|
||||
async def _run_async(self, **kwargs) -> str:
|
||||
async def _run_async(self, **kwargs: Any) -> str:
|
||||
"""Async implementation of tool execution.
|
||||
|
||||
A fresh ``MCPClient`` is created for every invocation so that
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""MCP Tool Wrapper for on-demand MCP server connections."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -16,9 +18,9 @@ class MCPToolWrapper(BaseTool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_server_params: dict,
|
||||
mcp_server_params: dict[str, Any],
|
||||
tool_name: str,
|
||||
tool_schema: dict,
|
||||
tool_schema: dict[str, Any],
|
||||
server_name: str,
|
||||
):
|
||||
"""Initialize the MCP tool wrapper.
|
||||
@@ -54,7 +56,7 @@ class MCPToolWrapper(BaseTool):
|
||||
self._server_name = server_name
|
||||
|
||||
@property
|
||||
def mcp_server_params(self) -> dict:
|
||||
def mcp_server_params(self) -> dict[str, Any]:
|
||||
"""Get the MCP server parameters."""
|
||||
return self._mcp_server_params
|
||||
|
||||
@@ -68,7 +70,7 @@ class MCPToolWrapper(BaseTool):
|
||||
"""Get the server name."""
|
||||
return self._server_name
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
"""Connect to MCP server and execute tool.
|
||||
|
||||
Args:
|
||||
@@ -84,13 +86,15 @@ class MCPToolWrapper(BaseTool):
|
||||
except Exception as e:
|
||||
return f"Error executing MCP tool {self.original_tool_name}: {e!s}"
|
||||
|
||||
async def _run_async(self, **kwargs) -> str:
|
||||
async def _run_async(self, **kwargs: Any) -> str:
|
||||
"""Async implementation of MCP tool execution with timeouts and retry logic."""
|
||||
return await self._retry_with_exponential_backoff(
|
||||
self._execute_tool_with_timeout, **kwargs
|
||||
)
|
||||
|
||||
async def _retry_with_exponential_backoff(self, operation_func, **kwargs) -> str:
|
||||
async def _retry_with_exponential_backoff(
|
||||
self, operation_func: Callable[..., Coroutine[Any, Any, str]], **kwargs: Any
|
||||
) -> str:
|
||||
"""Retry operation with exponential backoff, avoiding try-except in loop for performance."""
|
||||
last_error = None
|
||||
|
||||
@@ -119,7 +123,7 @@ class MCPToolWrapper(BaseTool):
|
||||
)
|
||||
|
||||
async def _execute_single_attempt(
|
||||
self, operation_func, **kwargs
|
||||
self, operation_func: Callable[..., Coroutine[Any, Any, str]], **kwargs: Any
|
||||
) -> tuple[str | None, str, bool]:
|
||||
"""Execute single operation attempt and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
@@ -158,22 +162,23 @@ class MCPToolWrapper(BaseTool):
|
||||
return None, f"Server response parsing error: {e!s}", True
|
||||
return None, f"MCP execution error: {e!s}", False
|
||||
|
||||
async def _execute_tool_with_timeout(self, **kwargs) -> str:
|
||||
async def _execute_tool_with_timeout(self, **kwargs: Any) -> str:
|
||||
"""Execute tool with timeout wrapper."""
|
||||
return await asyncio.wait_for(
|
||||
self._execute_tool(**kwargs), timeout=MCP_TOOL_EXECUTION_TIMEOUT
|
||||
)
|
||||
|
||||
async def _execute_tool(self, **kwargs) -> str:
|
||||
async def _execute_tool(self, **kwargs: Any) -> str:
|
||||
"""Execute the actual MCP tool call."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.types import TextContent
|
||||
|
||||
server_url = self.mcp_server_params["url"]
|
||||
|
||||
try:
|
||||
# Wrap entire operation with single timeout
|
||||
async def _do_mcp_call():
|
||||
|
||||
async def _do_mcp_call() -> str:
|
||||
async with streamablehttp_client(
|
||||
server_url, terminate_on_close=True
|
||||
) as (read, write, _):
|
||||
@@ -183,17 +188,11 @@ class MCPToolWrapper(BaseTool):
|
||||
self.original_tool_name, kwargs
|
||||
)
|
||||
|
||||
# Extract the result content
|
||||
if hasattr(result, "content") and result.content:
|
||||
if (
|
||||
isinstance(result.content, list)
|
||||
and len(result.content) > 0
|
||||
):
|
||||
content_item = result.content[0]
|
||||
if hasattr(content_item, "text"):
|
||||
return str(content_item.text)
|
||||
return str(content_item)
|
||||
return str(result.content)
|
||||
if result.content:
|
||||
content_item = result.content[0]
|
||||
if isinstance(content_item, TextContent):
|
||||
return content_item.text
|
||||
return str(content_item)
|
||||
return str(result)
|
||||
|
||||
return await asyncio.wait_for(
|
||||
@@ -203,7 +202,7 @@ class MCPToolWrapper(BaseTool):
|
||||
except asyncio.CancelledError as e:
|
||||
raise asyncio.TimeoutError("MCP operation was cancelled") from e
|
||||
except Exception as e:
|
||||
if hasattr(e, "__cause__") and e.__cause__:
|
||||
if e.__cause__ is not None:
|
||||
raise asyncio.TimeoutError(
|
||||
f"MCP connection error: {e.__cause__}"
|
||||
) from e.__cause__
|
||||
|
||||
@@ -81,7 +81,7 @@ class TaskEvaluator:
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task), # type: ignore[no-untyped-call]
|
||||
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task),
|
||||
)
|
||||
evaluation_query = (
|
||||
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
||||
@@ -129,7 +129,7 @@ class TaskEvaluator:
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskEvaluationEvent(evaluation_type="training_data_evaluation"), # type: ignore[no-untyped-call]
|
||||
TaskEvaluationEvent(evaluation_type="training_data_evaluation"),
|
||||
)
|
||||
|
||||
output_training_data = training_data[agent_id]
|
||||
|
||||
@@ -12,16 +12,16 @@ from uuid import UUID
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiocache import Cache
|
||||
from aiocache import Cache # type: ignore[import-untyped]
|
||||
from crewai_files import FileInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_file_store: Cache | None = None
|
||||
_file_store: Cache | None = None # type: ignore[no-any-unimported]
|
||||
|
||||
try:
|
||||
from aiocache import Cache
|
||||
from aiocache.serializers import PickleSerializer
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
|
||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||
except ImportError:
|
||||
|
||||
@@ -39,7 +39,7 @@ class GuardrailResult(BaseModel):
|
||||
|
||||
@field_validator("result", "error")
|
||||
@classmethod
|
||||
def validate_result_error_exclusivity(cls, v: Any, info) -> Any:
|
||||
def validate_result_error_exclusivity(cls, v: Any, info: Any) -> Any:
|
||||
"""Ensure that result and error are mutually exclusive based on success.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -259,7 +259,7 @@ class StepObservation(BaseModel):
|
||||
|
||||
@field_validator("suggested_refinements", mode="before")
|
||||
@classmethod
|
||||
def coerce_single_refinement_to_list(cls, v):
|
||||
def coerce_single_refinement_to_list(cls, v: Any) -> Any:
|
||||
"""Coerce a single dict refinement into a list to handle LLM returning a single object."""
|
||||
if isinstance(v, dict):
|
||||
return [v]
|
||||
|
||||
@@ -182,7 +182,7 @@ class AgentReasoning:
|
||||
if self.config.llm is not None:
|
||||
if isinstance(self.config.llm, LLM):
|
||||
return self.config.llm
|
||||
return create_llm(self.config.llm)
|
||||
return cast(LLM, create_llm(self.config.llm))
|
||||
return cast(LLM, self.agent.llm)
|
||||
|
||||
def handle_agent_reasoning(self) -> AgentReasoningOutput:
|
||||
|
||||
@@ -75,7 +75,7 @@ class RPMController(BaseModel):
|
||||
self._current_rpm = 0
|
||||
|
||||
def _reset_request_count(self) -> None:
|
||||
def _reset():
|
||||
def _reset() -> None:
|
||||
self._current_rpm = 0
|
||||
if not self._shutdown_flag:
|
||||
self._timer = threading.Timer(60.0, self._reset_request_count)
|
||||
|
||||
@@ -60,7 +60,9 @@ def _extract_tool_call_info(
|
||||
StreamChunkType.TOOL_CALL,
|
||||
ToolCallChunk(
|
||||
tool_id=event.tool_call.id,
|
||||
tool_name=sanitize_tool_name(event.tool_call.function.name),
|
||||
tool_name=sanitize_tool_name(event.tool_call.function.name)
|
||||
if event.tool_call.function.name
|
||||
else None,
|
||||
arguments=event.tool_call.function.arguments,
|
||||
index=event.tool_call.index,
|
||||
),
|
||||
|
||||
@@ -76,7 +76,7 @@ Please provide ONLY the {field_name} field value as described:
|
||||
|
||||
Respond with ONLY the requested information, nothing else.
|
||||
"""
|
||||
return self.llm.call(
|
||||
result: str = self.llm.call(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
@@ -85,6 +85,7 @@ Respond with ONLY the requested information, nothing else.
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
)
|
||||
return result
|
||||
|
||||
def _process_field_value(self, response: str, field_type: type | None) -> Any:
|
||||
response = response.strip()
|
||||
@@ -104,7 +105,8 @@ Respond with ONLY the requested information, nothing else.
|
||||
def _parse_list(self, response: str) -> list[Any]:
|
||||
try:
|
||||
if response.startswith("["):
|
||||
return json.loads(response)
|
||||
parsed: list[Any] = json.loads(response)
|
||||
return parsed
|
||||
|
||||
items: list[str] = [
|
||||
item.strip() for item in response.split("\n") if item.strip()
|
||||
|
||||
@@ -1571,8 +1571,9 @@ class TestReasoningEffort:
|
||||
executor.agent.planning_config = None
|
||||
assert executor._get_reasoning_effort() == "medium"
|
||||
|
||||
# Case 3: planning_config without reasoning_effort attr → defaults to "medium"
|
||||
executor.agent.planning_config = Mock(spec=[])
|
||||
# Case 3: planning_config with default reasoning_effort
|
||||
executor.agent.planning_config = Mock()
|
||||
executor.agent.planning_config.reasoning_effort = "medium"
|
||||
assert executor._get_reasoning_effort() == "medium"
|
||||
|
||||
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -1241,7 +1241,7 @@ requires-dist = [
|
||||
{ name = "json5", specifier = "~=0.10.0" },
|
||||
{ name = "jsonref", specifier = "~=1.1.0" },
|
||||
{ name = "lancedb", specifier = ">=0.29.2" },
|
||||
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<3" },
|
||||
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<=1.82.6" },
|
||||
{ name = "mcp", specifier = "~=1.26.0" },
|
||||
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = "~=0.1.94" },
|
||||
{ name = "openai", specifier = ">=1.83.0,<3" },
|
||||
|
||||
Reference in New Issue
Block a user