fix: add ConfigDict for Pydantic model_config and ClassVar annotations

This commit is contained in:
Greyson LaLonde
2025-09-19 00:44:33 -04:00
parent eca9077590
commit 82cb72ea41
221 changed files with 2365 additions and 2202 deletions

View File

@@ -1,17 +1,10 @@
import shutil
import subprocess
import time
from collections.abc import Callable, Sequence
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -19,6 +12,24 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
)
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
@@ -38,24 +49,6 @@ from crewai.utilities.agent_utils import (
)
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryRetrievalStartedEvent,
MemoryRetrievalCompletedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -87,36 +80,36 @@ class Agent(BaseAgent):
"""
_times_executed: int = PrivateAttr(default=0)
max_execution_time: Optional[int] = Field(
max_execution_time: int | None = Field(
default=None,
description="Maximum execution time for an agent to execute a task",
)
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
step_callback: Optional[Any] = Field(
step_callback: Any | None = Field(
default=None,
description="Callback to be executed after each step of the agent execution.",
)
use_system_prompt: Optional[bool] = Field(
use_system_prompt: bool | None = Field(
default=True,
description="Use system prompt for the agent.",
)
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
llm: str | InstanceOf[BaseLLM] | Any = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
description="Language model that will run the agent.", default=None
)
system_template: Optional[str] = Field(
system_template: str | None = Field(
default=None, description="System format for the agent."
)
prompt_template: Optional[str] = Field(
prompt_template: str | None = Field(
default=None, description="Prompt format for the agent."
)
response_template: Optional[str] = Field(
response_template: str | None = Field(
default=None, description="Response format for the agent."
)
allow_code_execution: Optional[bool] = Field(
allow_code_execution: bool | None = Field(
default=False, description="Enable code execution for the agent."
)
respect_context_window: bool = Field(
@@ -147,31 +140,31 @@ class Agent(BaseAgent):
default=False,
description="Whether the agent should reflect and create a plan before executing a task.",
)
max_reasoning_attempts: Optional[int] = Field(
max_reasoning_attempts: int | None = Field(
default=None,
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
)
embedder: Optional[Dict[str, Any]] = Field(
embedder: dict[str, Any] | None = Field(
default=None,
description="Embedder configuration for the agent.",
)
agent_knowledge_context: Optional[str] = Field(
agent_knowledge_context: str | None = Field(
default=None,
description="Knowledge context for the agent.",
)
crew_knowledge_context: Optional[str] = Field(
crew_knowledge_context: str | None = Field(
default=None,
description="Knowledge context for the crew.",
)
knowledge_search_query: Optional[str] = Field(
knowledge_search_query: str | None = Field(
default=None,
description="Knowledge search query for the agent dynamically generated by the agent.",
)
from_repository: Optional[str] = Field(
from_repository: str | None = Field(
default=None,
description="The Agent's role to be used from your repository.",
)
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
guardrail: Callable[[Any], tuple[bool, Any]] | str | None = Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
@@ -180,7 +173,7 @@ class Agent(BaseAgent):
)
@model_validator(mode="before")
def validate_from_repository(cls, v):
def validate_from_repository(self, v):
if v is not None and (from_repository := v.get("from_repository")):
return load_agent_from_repository(from_repository) | v
return v
@@ -208,7 +201,7 @@ class Agent(BaseAgent):
self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler)
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
try:
if self.embedder is None and crew_embedder:
self.embedder = crew_embedder
@@ -224,7 +217,7 @@ class Agent(BaseAgent):
)
self.knowledge.add_sources()
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
raise ValueError(f"Invalid Knowledge Configuration: {e!s}") from e
def _is_any_available_memory(self) -> bool:
"""Check if any memory is available."""
@@ -244,8 +237,8 @@ class Agent(BaseAgent):
def execute_task(
self,
task: Task,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task with the agent.
@@ -278,11 +271,9 @@ class Agent(BaseAgent):
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
if hasattr(self, "_logger"):
self._logger.log(
"error", f"Error during reasoning process: {str(e)}"
)
self._logger.log("error", f"Error during reasoning process: {e!s}")
else:
print(f"Error during reasoning process: {str(e)}")
print(f"Error during reasoning process: {e!s}")
self._inject_date_to_task(task)
@@ -525,14 +516,14 @@ class Agent(BaseAgent):
try:
return future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
except concurrent.futures.TimeoutError as e:
future.cancel()
raise TimeoutError(
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
)
) from e
except Exception as e:
future.cancel()
raise RuntimeError(f"Task execution failed: {str(e)}")
raise RuntimeError(f"Task execution failed: {e!s}") from e
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
"""Execute a task without a timeout.
@@ -554,14 +545,14 @@ class Agent(BaseAgent):
)["output"]
def create_agent_executor(
self, tools: Optional[List[BaseTool]] = None, task=None
self, tools: list[BaseTool] | None = None, task=None
) -> None:
"""Create an agent executor for the agent.
Returns:
An instance of the CrewAgentExecutor class.
"""
raw_tools: List[BaseTool] = tools or self.tools or []
raw_tools: list[BaseTool] = tools or self.tools or []
parsed_tools = parse_tools(raw_tools)
prompt = Prompts(
@@ -603,10 +594,9 @@ class Agent(BaseAgent):
callbacks=[TokenCalcHandler(self._token_process)],
)
def get_delegation_tools(self, agents: List[BaseAgent]):
def get_delegation_tools(self, agents: list[BaseAgent]):
agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools()
return tools
return agent_tools.tools()
def get_multimodal_tools(self) -> Sequence[BaseTool]:
from crewai.tools.agent_tools.add_image_tool import AddImageTool
@@ -654,7 +644,7 @@ class Agent(BaseAgent):
)
return task_prompt
def _render_text_description(self, tools: List[Any]) -> str:
def _render_text_description(self, tools: list[Any]) -> str:
"""Render the tool name and description in plain text.
Output will be in the format of:
@@ -664,15 +654,13 @@ class Agent(BaseAgent):
search: This tool is used for search
calculator: This tool is used for math
"""
description = "\n".join(
return "\n".join(
[
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
for tool in tools
]
)
return description
def _inject_date_to_task(self, task):
"""Inject the current date into the task description if inject_date is enabled."""
if self.inject_date:
@@ -700,9 +688,9 @@ class Agent(BaseAgent):
task.description += f"\n\nCurrent Date: {current_date}"
except Exception as e:
if hasattr(self, "_logger"):
self._logger.log("warning", f"Failed to inject date: {str(e)}")
self._logger.log("warning", f"Failed to inject date: {e!s}")
else:
print(f"Warning: Failed to inject date: {str(e)}")
print(f"Warning: Failed to inject date: {e!s}")
def _validate_docker_installation(self) -> None:
"""Check if Docker is installed and running."""
@@ -718,10 +706,10 @@ class Agent(BaseAgent):
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError:
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
)
) from e
def __repr__(self):
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
@@ -796,8 +784,8 @@ class Agent(BaseAgent):
def kickoff(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
) -> LiteAgentOutput:
"""
Execute the agent with the given messages using a LiteAgent instance.
@@ -836,8 +824,8 @@ class Agent(BaseAgent):
async def kickoff_async(
self,
messages: Union[str, List[Dict[str, str]]],
response_format: Optional[Type[Any]] = None,
messages: str | list[dict[str, str]],
response_format: type[Any] | None = None,
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages using a LiteAgent instance.

View File

@@ -1,5 +1,12 @@
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.parser import parse, AgentAction, AgentFinish, OutputParserException
from crewai.agents.parser import AgentAction, AgentFinish, OutputParserException, parse
from crewai.agents.tools_handler import ToolsHandler
__all__ = ["CacheHandler", "parse", "AgentAction", "AgentFinish", "OutputParserException", "ToolsHandler"]
__all__ = [
"AgentAction",
"AgentFinish",
"CacheHandler",
"OutputParserException",
"ToolsHandler",
"parse",
]

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import PrivateAttr
from pydantic import ConfigDict, PrivateAttr
from crewai.agent import BaseAgent
from crewai.tools import BaseTool
@@ -16,22 +16,21 @@ class BaseAgentAdapter(BaseAgent, ABC):
"""
adapted_structured_output: bool = False
_agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_agent_config: dict[str, Any] | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any):
def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any):
super().__init__(adapted_agent=True, **kwargs)
self._agent_config = agent_config
@abstractmethod
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
"""Configure and adapt tools for the specific agent implementation.
Args:
tools: Optional list of BaseTool instances to be configured
"""
pass
def configure_structured_output(self, structured_output: Any) -> None:
"""Configure the structured output for the specific agent implementation.
@@ -39,4 +38,3 @@ class BaseAgentAdapter(BaseAgent, ABC):
Args:
structured_output: The structured output to be configured
"""
pass

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any
from crewai.tools.base_tool import BaseTool
@@ -12,23 +12,22 @@ class BaseToolAdapter(ABC):
different frameworks and platforms.
"""
original_tools: List[BaseTool]
converted_tools: List[Any]
original_tools: list[BaseTool]
converted_tools: list[Any]
def __init__(self, tools: Optional[List[BaseTool]] = None):
def __init__(self, tools: list[BaseTool] | None = None):
self.original_tools = tools or []
self.converted_tools = []
@abstractmethod
def configure_tools(self, tools: List[BaseTool]) -> None:
def configure_tools(self, tools: list[BaseTool]) -> None:
"""Configure and convert tools for the specific implementation.
Args:
tools: List of BaseTool instances to be configured and converted
"""
pass
def tools(self) -> List[Any]:
def tools(self) -> list[Any]:
"""Return all converted tools."""
return self.converted_tools

View File

@@ -1,8 +1,9 @@
import uuid
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, TypeVar
from pydantic import (
UUID4,
@@ -25,7 +26,6 @@ from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.utilities import I18N, Logger, RPMController
from crewai.utilities.config import process_config
from crewai.utilities.converter import Converter
from crewai.utilities.string_utils import interpolate_only
T = TypeVar("T", bound="BaseAgent")
@@ -81,17 +81,17 @@ class BaseAgent(ABC, BaseModel):
__hash__ = object.__hash__ # type: ignore
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
_rpm_controller: RPMController | None = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
_original_role: Optional[str] = PrivateAttr(default=None)
_original_goal: Optional[str] = PrivateAttr(default=None)
_original_backstory: Optional[str] = PrivateAttr(default=None)
_original_role: str | None = PrivateAttr(default=None)
_original_goal: str | None = PrivateAttr(default=None)
_original_backstory: str | None = PrivateAttr(default=None)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
config: Optional[Dict[str, Any]] = Field(
config: dict[str, Any] | None = Field(
description="Configuration for the agent", default=None, exclude=True
)
cache: bool = Field(
@@ -100,7 +100,7 @@ class BaseAgent(ABC, BaseModel):
verbose: bool = Field(
default=False, description="Verbose mode for the Agent Execution"
)
max_rpm: Optional[int] = Field(
max_rpm: int | None = Field(
default=None,
description="Maximum number of requests per minute for the agent execution to be respected.",
)
@@ -108,7 +108,7 @@ class BaseAgent(ABC, BaseModel):
default=False,
description="Enable agent to delegate and ask questions among each other.",
)
tools: Optional[List[BaseTool]] = Field(
tools: list[BaseTool] | None = Field(
default_factory=list, description="Tools at agents' disposal"
)
max_iter: int = Field(
@@ -122,27 +122,27 @@ class BaseAgent(ABC, BaseModel):
)
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
cache_handler: InstanceOf[CacheHandler] | None = Field(
default=None, description="An instance of the CacheHandler class."
)
tools_handler: InstanceOf[ToolsHandler] = Field(
default_factory=ToolsHandler,
description="An instance of the ToolsHandler class.",
)
tools_results: List[Dict[str, Any]] = Field(
tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
)
max_tokens: Optional[int] = Field(
max_tokens: int | None = Field(
default=None, description="Maximum number of tokens for the agent's execution."
)
knowledge: Optional[Knowledge] = Field(
knowledge: Knowledge | None = Field(
default=None, description="Knowledge for the agent."
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
default=None,
description="Knowledge sources for the agent.",
)
knowledge_storage: Optional[Any] = Field(
knowledge_storage: Any | None = Field(
default=None,
description="Custom knowledge storage for the agent.",
)
@@ -150,13 +150,13 @@ class BaseAgent(ABC, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the agent, including fingerprinting.",
)
callbacks: List[Callable] = Field(
callbacks: list[Callable] = Field(
default=[], description="Callbacks to be used for the agent"
)
adapted_agent: bool = Field(
default=False, description="Whether the agent is adapted"
)
knowledge_config: Optional[KnowledgeConfig] = Field(
knowledge_config: KnowledgeConfig | None = Field(
default=None,
description="Knowledge configuration for the agent such as limits and threshold",
)
@@ -168,7 +168,7 @@ class BaseAgent(ABC, BaseModel):
@field_validator("tools")
@classmethod
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
@@ -221,7 +221,7 @@ class BaseAgent(ABC, BaseModel):
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
if v:
raise PydanticCustomError(
"may_not_set_field", "This field is not to be set by the user.", {}
@@ -252,8 +252,8 @@ class BaseAgent(ABC, BaseModel):
def execute_task(
self,
task: Any,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
pass
@@ -262,9 +262,8 @@ class BaseAgent(ABC, BaseModel):
pass
@abstractmethod
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
"""Set the task tools that init BaseAgenTools class."""
pass
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
"""Create a deep copy of the Agent."""
@@ -309,7 +308,7 @@ class BaseAgent(ABC, BaseModel):
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = type(self)(
return type(self)(
**copied_data,
llm=existing_llm,
tools=self.tools,
@@ -318,9 +317,7 @@ class BaseAgent(ABC, BaseModel):
knowledge_storage=copied_knowledge_storage,
)
return copied_agent
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
"""Interpolate inputs into the agent description and backstory."""
if self._original_role is None:
self._original_role = self.role
@@ -362,5 +359,5 @@ class BaseAgent(ABC, BaseModel):
self._rpm_controller = rpm_controller
self.create_agent_executor()
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
pass

View File

@@ -1,13 +1,13 @@
import time
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING
from crewai.events.event_listener import event_listener
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.utilities import I18N
from crewai.utilities.converter import ConverterError
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.printer import Printer
from crewai.events.event_listener import event_listener
if TYPE_CHECKING:
from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -21,7 +21,7 @@ class CrewAgentExecutorMixin:
task: "Task"
iterations: int
max_iter: int
messages: List[Dict[str, str]]
messages: list[dict[str, str]]
_i18n: I18N
_printer: Printer = Printer()
@@ -46,7 +46,6 @@ class CrewAgentExecutorMixin:
)
except Exception as e:
print(f"Failed to add to short term memory: {e}")
pass
def _create_external_memory(self, output) -> None:
"""Create and save a external-term memory item if conditions are met."""
@@ -67,7 +66,6 @@ class CrewAgentExecutorMixin:
)
except Exception as e:
print(f"Failed to add to external memory: {e}")
pass
def _create_long_term_memory(self, output) -> None:
"""Create and save long-term and entity memory items based on evaluation."""
@@ -113,10 +111,8 @@ class CrewAgentExecutorMixin:
self.crew._entity_memory.save(entity_memories)
except AttributeError as e:
print(f"Missing attributes for long term memory: {e}")
pass
except Exception as e:
print(f"Failed to add to long term memory: {e}")
pass
elif (
self.crew
and self.crew._long_term_memory

View File

@@ -251,9 +251,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
i18n=self._i18n,
)
continue
else:
handle_unknown_error(self._printer, e)
raise e
handle_unknown_error(self._printer, e)
raise e
finally:
self.iterations += 1
@@ -324,9 +323,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.agent,
AgentLogsStartedEvent(
agent_role=self.agent.role,
task_description=(
getattr(self.task, "description") if self.task else "Not Found"
),
task_description=(self.task.description if self.task else "Not Found"),
verbose=self.agent.verbose
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
),
@@ -415,8 +412,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"""
prompt = prompt.replace("{input}", inputs["input"])
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
prompt = prompt.replace("{tools}", inputs["tools"])
return prompt
return prompt.replace("{tools}", inputs["tools"])
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
"""Process human feedback.

View File

@@ -10,9 +10,9 @@ from dataclasses import dataclass
from json_repair import repair_json
from crewai.agents.constants import (
ACTION_INPUT_ONLY_REGEX,
ACTION_INPUT_REGEX,
ACTION_REGEX,
ACTION_INPUT_ONLY_REGEX,
FINAL_ANSWER_ACTION,
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
@@ -104,7 +104,7 @@ def parse(text: str) -> AgentAction | AgentFinish:
final_answer = final_answer[:-3].rstrip()
return AgentFinish(thought=thought, output=final_answer, text=text)
elif action_match:
if action_match:
action = action_match.group(1)
clean_action = _clean_action(action)
@@ -121,16 +121,15 @@ def parse(text: str) -> AgentAction | AgentFinish:
raise OutputParserException(
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{_I18N.slice('final_answer_format')}",
)
elif not ACTION_INPUT_ONLY_REGEX.search(text):
if not ACTION_INPUT_ONLY_REGEX.search(text):
raise OutputParserException(
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
)
else:
err_format = _I18N.slice("format_without_tools")
error = f"{err_format}"
raise OutputParserException(
error,
)
err_format = _I18N.slice("format_without_tools")
error = f"{err_format}"
raise OutputParserException(
error,
)
def _extract_thought(text: str) -> str:
@@ -149,8 +148,7 @@ def _extract_thought(text: str) -> str:
return ""
thought = text[:thought_index].strip()
# Remove any triple backticks from the thought string
thought = thought.replace("```", "").strip()
return thought
return thought.replace("```", "").strip()
def _clean_action(text: str) -> str:

View File

@@ -1,8 +1,8 @@
"""Tools handler for managing tool execution and caching."""
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.tools.cache_tools.cache_tools import CacheTools
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
from crewai.agents.cache.cache_handler import CacheHandler
class ToolsHandler:

View File

@@ -1,5 +1,6 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth/device/code"

View File

@@ -1,30 +1,26 @@
from abc import ABC, abstractmethod
from crewai.cli.authentication.main import Oauth2Settings
class BaseProvider(ABC):
def __init__(self, settings: Oauth2Settings):
self.settings = settings
@abstractmethod
def get_authorize_url(self) -> str:
...
def get_authorize_url(self) -> str: ...
@abstractmethod
def get_token_url(self) -> str:
...
def get_token_url(self) -> str: ...
@abstractmethod
def get_jwks_url(self) -> str:
...
def get_jwks_url(self) -> str: ...
@abstractmethod
def get_issuer(self) -> str:
...
def get_issuer(self) -> str: ...
@abstractmethod
def get_audience(self) -> str:
...
def get_audience(self) -> str: ...
@abstractmethod
def get_client_id(self) -> str:
...
def get_client_id(self) -> str: ...

View File

@@ -1,5 +1,6 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"

View File

@@ -1,5 +1,6 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/device_authorization"
@@ -17,9 +18,11 @@ class WorkosProvider(BaseProvider):
return self.settings.audience or ""
def get_client_id(self) -> str:
assert self.settings.client_id is not None, "Client ID is required"
if self.settings.client_id is None:
raise RuntimeError("Client ID is required")
return self.settings.client_id
def _get_domain(self) -> str:
assert self.settings.domain is not None, "Domain is required"
if self.settings.domain is None:
raise RuntimeError("Domain is required")
return self.settings.domain

View File

@@ -17,8 +17,6 @@ def validate_jwt_token(
missing required claims).
"""
decoded_token = None
try:
jwk_client = PyJWKClient(jwks_url)
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
@@ -26,7 +24,7 @@ def validate_jwt_token(
_unverified_decoded_token = jwt.decode(
jwt_token, options={"verify_signature": False}
)
decoded_token = jwt.decode(
return jwt.decode(
jwt_token,
signing_key.key,
algorithms=["RS256"],
@@ -40,7 +38,6 @@ def validate_jwt_token(
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
return decoded_token
except jwt.ExpiredSignatureError:
raise Exception("Token has expired.")
@@ -55,8 +52,8 @@ def validate_jwt_token(
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
)
except jwt.MissingRequiredClaimError as e:
raise Exception(f"Token is missing required claims: {str(e)}")
raise Exception(f"Token is missing required claims: {e!s}")
except jwt.exceptions.PyJWKClientError as e:
raise Exception(f"JWKS or key processing error: {str(e)}")
raise Exception(f"JWKS or key processing error: {e!s}")
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {str(e)}")
raise Exception(f"Invalid token: {e!s}")

View File

@@ -1,13 +1,13 @@
from importlib.metadata import version as get_version
from typing import Optional
import click
from crewai.cli.config import Settings
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.add_crew_to_flow import add_crew_to_flow
from crewai.cli.config import Settings
from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow
from crewai.cli.crew_chat import run_chat
from crewai.cli.settings.main import SettingsCommand
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
@@ -237,13 +237,11 @@ def login():
@crewai.group()
def deploy():
"""Deploy the Crew CLI group."""
pass
@crewai.group()
def tool():
"""Tool Repository related commands."""
pass
@deploy.command(name="create")
@@ -263,7 +261,7 @@ def deploy_list():
@deploy.command(name="push")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_push(uuid: Optional[str]):
def deploy_push(uuid: str | None):
"""Deploy the Crew."""
deploy_cmd = DeployCommand()
deploy_cmd.deploy(uuid=uuid)
@@ -271,7 +269,7 @@ def deploy_push(uuid: Optional[str]):
@deploy.command(name="status")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deply_status(uuid: Optional[str]):
def deply_status(uuid: str | None):
"""Get the status of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_status(uuid=uuid)
@@ -279,7 +277,7 @@ def deply_status(uuid: Optional[str]):
@deploy.command(name="logs")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_logs(uuid: Optional[str]):
def deploy_logs(uuid: str | None):
"""Get the logs of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_logs(uuid=uuid)
@@ -287,7 +285,7 @@ def deploy_logs(uuid: Optional[str]):
@deploy.command(name="remove")
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_remove(uuid: Optional[str]):
def deploy_remove(uuid: str | None):
"""Remove a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.remove_crew(uuid=uuid)
@@ -327,7 +325,6 @@ def tool_publish(is_public: bool, force: bool):
@crewai.group()
def flow():
"""Flow related commands."""
pass
@flow.command(name="kickoff")
@@ -359,7 +356,7 @@ def chat():
and using the Chat LLM to generate responses.
"""
click.secho(
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
)
run_chat()
@@ -368,7 +365,6 @@ def chat():
@crewai.group(invoke_without_command=True)
def org():
"""Organization management commands."""
pass
@org.command("list")
@@ -396,7 +392,6 @@ def current():
@crewai.group()
def enterprise():
"""Enterprise Configuration commands."""
pass
@enterprise.command("configure")
@@ -410,7 +405,6 @@ def enterprise_configure(enterprise_url: str):
@crewai.group()
def config():
"""CLI Configuration commands."""
pass
@config.command("list")

View File

@@ -1,15 +1,14 @@
import json
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
from crewai.cli.constants import (
DEFAULT_CREWAI_ENTERPRISE_URL,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
DEFAULT_CREWAI_ENTERPRISE_URL,
)
from crewai.cli.shared.token_manager import TokenManager
@@ -56,20 +55,20 @@ HIDDEN_SETTINGS_KEYS = [
class Settings(BaseModel):
enterprise_base_url: Optional[str] = Field(
enterprise_base_url: str | None = Field(
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
description="Base URL of the CrewAI Enterprise instance",
)
tool_repository_username: Optional[str] = Field(
tool_repository_username: str | None = Field(
None, description="Username for interacting with the Tool Repository"
)
tool_repository_password: Optional[str] = Field(
tool_repository_password: str | None = Field(
None, description="Password for interacting with the Tool Repository"
)
org_name: Optional[str] = Field(
org_name: str | None = Field(
None, description="Name of the currently active organization"
)
org_uuid: Optional[str] = Field(
org_uuid: str | None = Field(
None, description="UUID of the currently active organization"
)
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
@@ -79,7 +78,7 @@ class Settings(BaseModel):
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
)
oauth2_audience: Optional[str] = Field(
oauth2_audience: str | None = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
)

View File

@@ -16,48 +16,72 @@ from crewai.cli.utils import copy_template, load_env_vars, write_env_file
def create_folder_structure(name, parent_folder=None):
import keyword
import re
name = name.rstrip('/')
name = name.rstrip("/")
if not name.strip():
raise ValueError("Project name cannot be empty or contain only whitespace")
folder_name = name.replace(" ", "_").replace("-", "_").lower()
folder_name = re.sub(r'[^a-zA-Z0-9_]', '', folder_name)
folder_name = re.sub(r"[^a-zA-Z0-9_]", "", folder_name)
# Check if the name starts with invalid characters or is primarily invalid
if re.match(r'^[^a-zA-Z0-9_-]+', name):
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
if re.match(r"^[^a-zA-Z0-9_-]+", name):
raise ValueError(
f"Project name '{name}' contains no valid characters for a Python module name"
)
if not folder_name:
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
raise ValueError(
f"Project name '{name}' contains no valid characters for a Python module name"
)
if folder_name[0].isdigit():
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)")
raise ValueError(
f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)"
)
if keyword.iskeyword(folder_name):
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword")
raise ValueError(
f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword"
)
if not folder_name.isidentifier():
raise ValueError(f"Project name '{name}' would generate invalid Python module name '{folder_name}'")
raise ValueError(
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
)
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
class_name = re.sub(r'[^a-zA-Z0-9_]', '', class_name)
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
if not class_name:
raise ValueError(f"Project name '{name}' contains no valid characters for a Python class name")
raise ValueError(
f"Project name '{name}' contains no valid characters for a Python class name"
)
if class_name[0].isdigit():
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit")
raise ValueError(
f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit"
)
# Check if the original name (before title casing) is a keyword
original_name_clean = re.sub(r'[^a-zA-Z0-9_]', '', name.replace("_", "").replace("-", "").lower())
if keyword.iskeyword(original_name_clean) or keyword.iskeyword(class_name) or class_name in ('True', 'False', 'None'):
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword")
original_name_clean = re.sub(
r"[^a-zA-Z0-9_]", "", name.replace("_", "").replace("-", "").lower()
)
if (
keyword.iskeyword(original_name_clean)
or keyword.iskeyword(class_name)
or class_name in ("True", "False", "None")
):
raise ValueError(
f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword"
)
if not class_name.isidentifier():
raise ValueError(f"Project name '{name}' would generate invalid Python class name '{class_name}'")
raise ValueError(
f"Project name '{name}' would generate invalid Python class name '{class_name}'"
)
if parent_folder:
folder_path = Path(parent_folder) / folder_name
@@ -172,7 +196,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
)
# Check if the selected provider has predefined models
if selected_provider in MODELS and MODELS[selected_provider]:
if MODELS.get(selected_provider):
while True:
selected_model = select_model(selected_provider, provider_models)
if selected_model is None: # User typed 'q'

View File

@@ -5,7 +5,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any
import click
import tomli
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
print()
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
"""Initializes the chat LLM and handles exceptions."""
try:
return create_llm(crew.chat_llm)
@@ -157,7 +157,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
)
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
"""Creates a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs):
@@ -221,9 +221,9 @@ def get_user_input() -> str:
def handle_user_input(
user_input: str,
chat_llm: LLM,
messages: List[Dict[str, str]],
crew_tool_schema: Dict[str, Any],
available_functions: Dict[str, Any],
messages: list[dict[str, str]],
crew_tool_schema: dict[str, Any],
available_functions: dict[str, Any],
) -> None:
if user_input.strip().lower() == "exit":
click.echo("Exiting chat. Goodbye!")
@@ -281,7 +281,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
}
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
"""
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
@@ -304,9 +304,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
crew_output = crew.kickoff(inputs=kwargs)
# Convert CrewOutput to a string to send back to the user
result = str(crew_output)
return str(crew_output)
return result
except Exception as e:
# Exit the chat and show the error message
click.secho("An error occurred while running the crew:", fg="red")
@@ -314,7 +313,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
sys.exit(1)
def load_crew_and_name() -> Tuple[Crew, str]:
def load_crew_and_name() -> tuple[Crew, str]:
"""
Loads the crew by importing the crew class from the user's project.
@@ -395,7 +394,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
)
def fetch_required_inputs(crew: Crew) -> Set[str]:
def fetch_required_inputs(crew: Crew) -> set[str]:
"""
Extracts placeholders from the crew's tasks and agents.
@@ -406,7 +405,7 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
Set[str]: A set of placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)\}")
required_inputs: Set[str] = set()
required_inputs: set[str] = set()
# Scan tasks
for task in crew.tasks:
@@ -479,9 +478,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
f"{context}"
)
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
description = response.strip()
return description
return response.strip()
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
@@ -531,6 +528,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
f"{context}"
)
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
crew_description = response.strip()
return crew_description
return response.strip()

View File

@@ -64,8 +64,7 @@ class Repository:
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
return False
else:
return True
return True
def origin_url(self) -> str | None:
"""Get the Git repository's remote URL."""

View File

@@ -12,7 +12,7 @@ def install_crew(proxy_options: list[str]) -> None:
Install the crew by running the UV command to lock and install.
"""
try:
command = ["uv", "sync"] + proxy_options
command = ["uv", "sync", *proxy_options]
subprocess.run(command, check=True, capture_output=False, text=True)
except subprocess.CalledProcessError as e:

View File

@@ -1,11 +1,10 @@
from typing import List, Optional
from urllib.parse import urljoin
import requests
from crewai.cli.config import Settings
from crewai.cli.version import get_crewai_version
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai.cli.version import get_crewai_version
class PlusAPI:
@@ -56,9 +55,9 @@ class PlusAPI:
handle: str,
is_public: bool,
version: str,
description: Optional[str],
description: str | None,
encoded_file: str,
available_exports: Optional[List[str]] = None,
available_exports: list[str] | None = None,
):
params = {
"handle": handle,

View File

@@ -1,10 +1,10 @@
import os
import certifi
import json
import os
import time
from collections import defaultdict
from pathlib import Path
import certifi
import click
import requests
@@ -25,7 +25,7 @@ def select_choice(prompt_message, choices):
provider_models = get_provider_data()
if not provider_models:
return
return None
click.secho(prompt_message, fg="cyan")
for idx, choice in enumerate(choices, start=1):
click.secho(f"{idx}. {choice}", fg="cyan")
@@ -67,7 +67,7 @@ def select_provider(provider_models):
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
provider = select_choice(
"Select a provider to set up:", predefined_providers + ["other"]
"Select a provider to set up:", [*predefined_providers, "other"]
)
if provider is None: # User typed 'q'
return None
@@ -102,10 +102,9 @@ def select_model(provider, provider_models):
click.secho(f"No models available for provider '{provider}'.", fg="red")
return None
selected_model = select_choice(
return select_choice(
f"Select a model to use for {provider.capitalize()}:", available_models
)
return selected_model
def load_provider_data(cache_file, cache_expiry):
@@ -165,7 +164,7 @@ def fetch_provider_data(cache_file):
Returns:
- dict or None: The fetched provider data or None if the operation fails.
"""
ssl_config = os.environ['SSL_CERT_FILE'] = certifi.where()
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
try:
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)

View File

@@ -1,6 +1,5 @@
import subprocess
from enum import Enum
from typing import List, Optional
import click
from packaging import version

View File

@@ -3,7 +3,7 @@ import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
from cryptography.fernet import Fernet
@@ -49,7 +49,7 @@ class TokenManager:
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]:
def get_token(self) -> str | None:
"""
Get the access token if it is valid and not expired.
@@ -113,7 +113,7 @@ class TokenManager:
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]:
def read_secure_file(self, filename: str) -> bytes | None:
"""
Read the content of a secure file.

View File

@@ -5,7 +5,7 @@ import sys
from functools import reduce
from inspect import getmro, isclass, isfunction, ismethod
from pathlib import Path
from typing import Any, Dict, List, get_type_hints
from typing import Any, get_type_hints
import click
import tomli
@@ -41,8 +41,7 @@ def copy_template(src, dst, name, class_name, folder_name):
def read_toml(file_path: str = "pyproject.toml"):
"""Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f:
toml_dict = tomli.load(f)
return toml_dict
return tomli.load(f)
def parse_toml(content):
@@ -77,7 +76,7 @@ def get_project_description(
def _get_project_attribute(
pyproject_path: str, keys: List[str], require: bool
pyproject_path: str, keys: list[str], require: bool
) -> Any | None:
"""Get an attribute from the pyproject.toml file."""
attribute = None
@@ -96,7 +95,10 @@ def _get_project_attribute(
except FileNotFoundError:
console.print(f"Error: {pyproject_path} not found.", style="bold red")
except KeyError:
console.print(f"Error: {pyproject_path} is not a valid pyproject.toml file.", style="bold red")
console.print(
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
style="bold red",
)
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
console.print(
f"Error: {pyproject_path} is not a valid TOML file."
@@ -117,7 +119,7 @@ def _get_project_attribute(
return attribute
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
@@ -296,7 +298,10 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
try:
crew_instances.extend(fetch_crews(module_attr))
except Exception as e:
console.print(f"Error processing attribute {attr_name}: {e}", style="bold red")
console.print(
f"Error processing attribute {attr_name}: {e}",
style="bold red",
)
continue
# If we found crew instances, break out of the loop
@@ -304,12 +309,15 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
break
except Exception as exec_error:
console.print(f"Error executing module: {exec_error}", style="bold red")
console.print(
f"Error executing module: {exec_error}",
style="bold red",
)
except (ImportError, AttributeError) as e:
if require:
console.print(
f"Error importing crew from {crew_path}: {str(e)}",
f"Error importing crew from {crew_path}: {e!s}",
style="bold red",
)
continue
@@ -325,7 +333,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
except Exception as e:
if require:
console.print(
f"Unexpected error while loading crew: {str(e)}", style="bold red"
f"Unexpected error while loading crew: {e!s}", style="bold red"
)
raise SystemExit
return crew_instances
@@ -348,8 +356,7 @@ def get_crew_instance(module_attr) -> Crew | None:
if isinstance(module_attr, Crew):
return module_attr
else:
return None
return None
def fetch_crews(module_attr) -> list[Crew]:
@@ -402,7 +409,7 @@ def extract_available_exports(dir_path: str = "src"):
return available_exports
except Exception as e:
console.print(f"[red]Error: Could not extract tool classes: {str(e)}[/red]")
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
console.print(
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
)
@@ -440,7 +447,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
]
except Exception as e:
console.print(f"[red]Warning: Could not load {init_file}: {str(e)}[/red]")
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
raise SystemExit(1)
finally:

View File

@@ -1,21 +1,23 @@
import os
import contextvars
from typing import Optional
import os
from contextlib import contextmanager
_platform_integration_token: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"platform_integration_token", default=None
_platform_integration_token: contextvars.ContextVar[str | None] = (
contextvars.ContextVar("platform_integration_token", default=None)
)
def set_platform_integration_token(integration_token: str) -> None:
_platform_integration_token.set(integration_token)
def get_platform_integration_token() -> Optional[str]:
def get_platform_integration_token() -> str | None:
token = _platform_integration_token.get()
if token is None:
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
return token
@contextmanager
def platform_context(integration_token: str):
token = _platform_integration_token.set(integration_token)

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@@ -12,10 +12,10 @@ class CrewOutput(BaseModel):
"""Class that represents the result of a crew."""
raw: str = Field(description="Raw output of crew", default="")
pydantic: Optional[BaseModel] = Field(
pydantic: BaseModel | None = Field(
description="Pydantic output of Crew", default=None
)
json_dict: Optional[Dict[str, Any]] = Field(
json_dict: dict[str, Any] | None = Field(
description="JSON dict output of Crew", default=None
)
tasks_output: list[TaskOutput] = Field(
@@ -24,7 +24,7 @@ class CrewOutput(BaseModel):
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
@property
def json(self) -> Optional[str]:
def json(self) -> str | None:
if self.tasks_output[-1].output_format != OutputFormat.JSON:
raise ValueError(
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
@@ -32,7 +32,7 @@ class CrewOutput(BaseModel):
return json.dumps(self.json_dict)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary."""
output_dict = {}
if self.json_dict:
@@ -44,10 +44,9 @@ class CrewOutput(BaseModel):
def __getitem__(self, key):
if self.pydantic and hasattr(self.pydantic, key):
return getattr(self.pydantic, key)
elif self.json_dict and key in self.json_dict:
if self.json_dict and key in self.json_dict:
return self.json_dict[key]
else:
raise KeyError(f"Key '{key}' not found in CrewOutput.")
raise KeyError(f"Key '{key}' not found in CrewOutput.")
def __str__(self):
if self.pydantic:

View File

@@ -1,5 +1,6 @@
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
from crewai.utilities.serialization import to_serializable
@@ -10,11 +11,11 @@ class BaseEvent(BaseModel):
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
type: str
source_fingerprint: Optional[str] = None # UUID string of the source entity
source_type: Optional[str] = (
source_fingerprint: str | None = None # UUID string of the source entity
source_type: str | None = (
None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"
)
fingerprint_metadata: Optional[Dict[str, Any]] = None # Any relevant metadata
fingerprint_metadata: dict[str, Any] | None = None # Any relevant metadata
def to_json(self, exclude: set[str] | None = None):
"""
@@ -28,13 +29,13 @@ class BaseEvent(BaseModel):
"""
return to_serializable(self, exclude=exclude)
def _set_task_params(self, data: Dict[str, Any]):
def _set_task_params(self, data: dict[str, Any]):
if "from_task" in data and (task := data["from_task"]):
self.task_id = task.id
self.task_name = task.name or task.description
self.from_task = None
def _set_agent_params(self, data: Dict[str, Any]):
def _set_agent_params(self, data: dict[str, Any]):
task = data.get("from_task", None)
agent = task.agent if task else data.get("from_agent", None)

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import threading
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
from typing import Any, TypeVar, cast
from blinker import Signal
@@ -25,17 +26,17 @@ class CrewAIEventsBus:
if cls._instance is None:
with cls._lock:
if cls._instance is None: # prevent race condition
cls._instance = super(CrewAIEventsBus, cls).__new__(cls)
cls._instance = super().__new__(cls)
cls._instance._initialize()
return cls._instance
def _initialize(self) -> None:
"""Initialize the event bus internal state"""
self._signal = Signal("crewai_event_bus")
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
self._handlers: dict[type[BaseEvent], list[Callable]] = {}
def on(
self, event_type: Type[EventT]
self, event_type: type[EventT]
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
"""
Decorator to register an event handler for a specific event type.
@@ -82,7 +83,7 @@ class CrewAIEventsBus:
self._signal.send(source, event=event)
def register_handler(
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
self, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None]
) -> None:
"""Register an event handler for a specific event type"""
if event_type not in self._handlers:

View File

@@ -1,15 +1,30 @@
from __future__ import annotations
from io import StringIO
from typing import Any, Dict
from typing import Any
from pydantic import Field, PrivateAttr
from crewai.llm import LLM
from crewai.task import Task
from crewai.telemetry.telemetry import Telemetry
from crewai.utilities import Logger
from crewai.utilities.constants import EMITTER_COLOR
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestResultEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
@@ -25,34 +40,21 @@ from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
)
from crewai.events.types.llm_guardrail_events import (
LLMGuardrailStartedEvent,
LLMGuardrailCompletedEvent,
)
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.types.logging_events import (
AgentLogsStartedEvent,
AgentLogsExecutionEvent,
AgentLogsStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestResultEvent,
CrewTestStartedEvent,
CrewTrainCompletedEvent,
CrewTrainFailedEvent,
CrewTrainStartedEvent,
)
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.llm import LLM
from crewai.task import Task
from crewai.telemetry.telemetry import Telemetry
from crewai.utilities import Logger
from crewai.utilities.constants import EMITTER_COLOR
from .listeners.memory_listener import MemoryListener
from .types.flow_events import (
FlowCreatedEvent,
FlowFinishedEvent,
@@ -61,26 +63,24 @@ from .types.flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from .types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
from .types.tool_usage_events import (
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from .types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
)
from .listeners.memory_listener import MemoryListener
class EventListener(BaseEventListener):
_instance = None
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
execution_spans: dict[Task, Any] = Field(default_factory=dict)
next_chunk = 0
text_stream = StringIO()
knowledge_retrieval_in_progress = False

View File

@@ -6,6 +6,7 @@ from crewai.events.types.agent_events import (
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
)
from .types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
@@ -24,6 +25,14 @@ from .types.flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from .types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from .types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
@@ -34,6 +43,21 @@ from .types.llm_guardrail_events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from .types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryQueryStartedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from .types.reasoning_events import (
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
AgentReasoningStartedEvent,
)
from .types.task_events import (
TaskCompletedEvent,
TaskFailedEvent,
@@ -44,30 +68,6 @@ from .types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from .types.reasoning_events import (
AgentReasoningStartedEvent,
AgentReasoningCompletedEvent,
AgentReasoningFailedEvent,
)
from .types.knowledge_events import (
KnowledgeRetrievalStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeQueryStartedEvent,
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeSearchQueryFailedEvent,
)
from .types.memory_events import (
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryRetrievalStartedEvent,
MemoryRetrievalCompletedEvent,
)
EventTypes = Union[
CrewKickoffStartedEvent,

View File

@@ -2,4 +2,4 @@
This module contains various event listener implementations
for handling memory, tracing, and other event-driven functionality.
"""
"""

View File

@@ -1,12 +1,12 @@
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemoryRetrievalCompletedEvent,
MemoryRetrievalStartedEvent,
MemoryQueryFailedEvent,
MemoryQueryCompletedEvent,
MemorySaveStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone
from typing import Dict, Any
import uuid
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import Any
@dataclass
@@ -13,7 +13,7 @@ class TraceEvent:
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
type: str = ""
event_data: Dict[str, Any] = field(default_factory=dict)
event_data: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)

View File

@@ -2,4 +2,4 @@
This module contains all event types used throughout the CrewAI system
for monitoring and extending agent, crew, task, and tool execution.
"""
"""

View File

@@ -2,14 +2,15 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Union
from collections.abc import Sequence
from typing import Any
from pydantic import model_validator
from pydantic import ConfigDict, model_validator
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.base_events import BaseEvent
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.events.base_events import BaseEvent
class AgentExecutionStartedEvent(BaseEvent):
@@ -17,11 +18,11 @@ class AgentExecutionStartedEvent(BaseEvent):
agent: BaseAgent
task: Any
tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]]
tools: Sequence[BaseTool | CrewStructuredTool] | None
task_prompt: str
type: str = "agent_execution_started"
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="after")
def set_fingerprint_data(self):
@@ -45,7 +46,7 @@ class AgentExecutionCompletedEvent(BaseEvent):
output: str
type: str = "agent_execution_completed"
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="after")
def set_fingerprint_data(self):
@@ -69,7 +70,7 @@ class AgentExecutionErrorEvent(BaseEvent):
error: str
type: str = "agent_execution_error"
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(arbitrary_types_allowed=True)
@model_validator(mode="after")
def set_fingerprint_data(self):
@@ -89,18 +90,18 @@ class AgentExecutionErrorEvent(BaseEvent):
class LiteAgentExecutionStartedEvent(BaseEvent):
"""Event emitted when a LiteAgent starts executing"""
agent_info: Dict[str, Any]
tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]]
messages: Union[str, List[Dict[str, str]]]
agent_info: dict[str, Any]
tools: Sequence[BaseTool | CrewStructuredTool] | None
messages: str | list[dict[str, str]]
type: str = "lite_agent_execution_started"
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(arbitrary_types_allowed=True)
class LiteAgentExecutionCompletedEvent(BaseEvent):
"""Event emitted when a LiteAgent completes execution"""
agent_info: Dict[str, Any]
agent_info: dict[str, Any]
output: str
type: str = "lite_agent_execution_completed"
@@ -108,7 +109,7 @@ class LiteAgentExecutionCompletedEvent(BaseEvent):
class LiteAgentExecutionErrorEvent(BaseEvent):
"""Event emitted when a LiteAgent encounters an error during execution"""
agent_info: Dict[str, Any]
agent_info: dict[str, Any]
error: str
type: str = "lite_agent_execution_error"

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any
from crewai.events.base_events import BaseEvent
@@ -11,8 +11,8 @@ else:
class CrewBaseEvent(BaseEvent):
"""Base class for crew events with fingerprint handling"""
crew_name: Optional[str]
crew: Optional[Crew] = None
crew_name: str | None
crew: Crew | None = None
def __init__(self, **data):
super().__init__(**data)
@@ -38,7 +38,7 @@ class CrewBaseEvent(BaseEvent):
class CrewKickoffStartedEvent(CrewBaseEvent):
"""Event emitted when a crew starts execution"""
inputs: Optional[Dict[str, Any]]
inputs: dict[str, Any] | None
type: str = "crew_kickoff_started"
@@ -62,7 +62,7 @@ class CrewTrainStartedEvent(CrewBaseEvent):
n_iterations: int
filename: str
inputs: Optional[Dict[str, Any]]
inputs: dict[str, Any] | None
type: str = "crew_train_started"
@@ -85,8 +85,8 @@ class CrewTestStartedEvent(CrewBaseEvent):
"""Event emitted when a crew starts testing"""
n_iterations: int
eval_llm: Optional[Union[str, Any]]
inputs: Optional[Dict[str, Any]]
eval_llm: str | Any | None
inputs: dict[str, Any] | None
type: str = "crew_test_started"

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union
from typing import Any
from pydantic import BaseModel, ConfigDict
@@ -16,7 +16,7 @@ class FlowStartedEvent(FlowEvent):
"""Event emitted when a flow starts execution"""
flow_name: str
inputs: Optional[Dict[str, Any]] = None
inputs: dict[str, Any] | None = None
type: str = "flow_started"
@@ -32,8 +32,8 @@ class MethodExecutionStartedEvent(FlowEvent):
flow_name: str
method_name: str
state: Union[Dict[str, Any], BaseModel]
params: Optional[Dict[str, Any]] = None
state: dict[str, Any] | BaseModel
params: dict[str, Any] | None = None
type: str = "method_execution_started"
@@ -43,7 +43,7 @@ class MethodExecutionFinishedEvent(FlowEvent):
flow_name: str
method_name: str
result: Any = None
state: Union[Dict[str, Any], BaseModel]
state: dict[str, Any] | BaseModel
type: str = "method_execution_finished"
@@ -62,7 +62,7 @@ class FlowFinishedEvent(FlowEvent):
"""Event emitted when a flow completes execution"""
flow_name: str
result: Optional[Any] = None
result: Any | None = None
type: str = "flow_finished"

View File

@@ -1,6 +1,5 @@
from crewai.events.base_events import BaseEvent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.base_events import BaseEvent
class KnowledgeRetrievalStartedEvent(BaseEvent):

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any
from pydantic import BaseModel
@@ -7,14 +7,14 @@ from crewai.events.base_events import BaseEvent
class LLMEventBase(BaseEvent):
task_name: Optional[str] = None
task_id: Optional[str] = None
task_name: str | None = None
task_id: str | None = None
agent_id: Optional[str] = None
agent_role: Optional[str] = None
agent_id: str | None = None
agent_role: str | None = None
from_task: Optional[Any] = None
from_agent: Optional[Any] = None
from_task: Any | None = None
from_agent: Any | None = None
def __init__(self, **data):
super().__init__(**data)
@@ -38,11 +38,11 @@ class LLMCallStartedEvent(LLMEventBase):
"""
type: str = "llm_call_started"
model: Optional[str] = None
messages: Optional[Union[str, List[Dict[str, Any]]]] = None
tools: Optional[List[dict[str, Any]]] = None
callbacks: Optional[List[Any]] = None
available_functions: Optional[Dict[str, Any]] = None
model: str | None = None
messages: str | list[dict[str, Any]] | None = None
tools: list[dict[str, Any]] | None = None
callbacks: list[Any] | None = None
available_functions: dict[str, Any] | None = None
class LLMCallCompletedEvent(LLMEventBase):
@@ -52,7 +52,7 @@ class LLMCallCompletedEvent(LLMEventBase):
messages: str | list[dict[str, Any]] | None = None
response: Any
call_type: LLMCallType
model: Optional[str] = None
model: str | None = None
class LLMCallFailedEvent(LLMEventBase):
@@ -64,13 +64,13 @@ class LLMCallFailedEvent(LLMEventBase):
class FunctionCall(BaseModel):
arguments: str
name: Optional[str] = None
name: str | None = None
class ToolCall(BaseModel):
id: Optional[str] = None
id: str | None = None
function: FunctionCall
type: Optional[str] = None
type: str | None = None
index: int
@@ -79,4 +79,4 @@ class LLMStreamChunkEvent(LLMEventBase):
type: str = "llm_stream_chunk"
chunk: str
tool_call: Optional[ToolCall] = None
tool_call: ToolCall | None = None

View File

@@ -1,5 +1,6 @@
from collections.abc import Callable
from inspect import getsource
from typing import Any, Callable, Optional, Union
from typing import Any
from crewai.events.base_events import BaseEvent
@@ -13,12 +14,12 @@ class LLMGuardrailStartedEvent(BaseEvent):
"""
type: str = "llm_guardrail_started"
guardrail: Union[str, Callable]
guardrail: str | Callable
retry_count: int
def __init__(self, **data):
from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
from crewai.tasks.llm_guardrail import LLMGuardrail
super().__init__(**data)
@@ -41,5 +42,5 @@ class LLMGuardrailCompletedEvent(BaseEvent):
type: str = "llm_guardrail_completed"
success: bool
result: Any
error: Optional[str] = None
error: str | None = None
retry_count: int

View File

@@ -1,6 +1,8 @@
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
from typing import Any, Optional
from typing import Any
from pydantic import ConfigDict
from crewai.events.base_events import BaseEvent
@@ -9,7 +11,7 @@ class AgentLogsStartedEvent(BaseEvent):
"""Event emitted when agent logs should be shown at start"""
agent_role: str
task_description: Optional[str] = None
task_description: str | None = None
verbose: bool = False
type: str = "agent_logs_started"
@@ -22,4 +24,4 @@ class AgentLogsExecutionEvent(BaseEvent):
verbose: bool = False
type: str = "agent_logs_execution"
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any
from crewai.events.base_events import BaseEvent
@@ -7,12 +7,12 @@ class MemoryBaseEvent(BaseEvent):
"""Base event for memory operations"""
type: str
task_id: Optional[str] = None
task_name: Optional[str] = None
from_task: Optional[Any] = None
from_agent: Optional[Any] = None
agent_role: Optional[str] = None
agent_id: Optional[str] = None
task_id: str | None = None
task_name: str | None = None
from_task: Any | None = None
from_agent: Any | None = None
agent_role: str | None = None
agent_id: str | None = None
def __init__(self, **data):
super().__init__(**data)
@@ -26,7 +26,7 @@ class MemoryQueryStartedEvent(MemoryBaseEvent):
type: str = "memory_query_started"
query: str
limit: int
score_threshold: Optional[float] = None
score_threshold: float | None = None
class MemoryQueryCompletedEvent(MemoryBaseEvent):
@@ -36,7 +36,7 @@ class MemoryQueryCompletedEvent(MemoryBaseEvent):
query: str
results: Any
limit: int
score_threshold: Optional[float] = None
score_threshold: float | None = None
query_time_ms: float
@@ -46,7 +46,7 @@ class MemoryQueryFailedEvent(MemoryBaseEvent):
type: str = "memory_query_failed"
query: str
limit: int
score_threshold: Optional[float] = None
score_threshold: float | None = None
error: str
@@ -54,9 +54,9 @@ class MemorySaveStartedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation is started"""
type: str = "memory_save_started"
value: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
agent_role: Optional[str] = None
value: str | None = None
metadata: dict[str, Any] | None = None
agent_role: str | None = None
class MemorySaveCompletedEvent(MemoryBaseEvent):
@@ -64,8 +64,8 @@ class MemorySaveCompletedEvent(MemoryBaseEvent):
type: str = "memory_save_completed"
value: str
metadata: Optional[Dict[str, Any]] = None
agent_role: Optional[str] = None
metadata: dict[str, Any] | None = None
agent_role: str | None = None
save_time_ms: float
@@ -73,9 +73,9 @@ class MemorySaveFailedEvent(MemoryBaseEvent):
"""Event emitted when a memory save operation fails"""
type: str = "memory_save_failed"
value: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
agent_role: Optional[str] = None
value: str | None = None
metadata: dict[str, Any] | None = None
agent_role: str | None = None
error: str
@@ -83,13 +83,13 @@ class MemoryRetrievalStartedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt starts"""
type: str = "memory_retrieval_started"
task_id: Optional[str] = None
task_id: str | None = None
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
"""Event emitted when memory retrieval for a task prompt completes successfully"""
type: str = "memory_retrieval_completed"
task_id: Optional[str] = None
task_id: str | None = None
memory_content: str
retrieval_time_ms: float

View File

@@ -1,5 +1,6 @@
from typing import Any
from crewai.events.base_events import BaseEvent
from typing import Any, Optional
class ReasoningEvent(BaseEvent):
@@ -9,10 +10,10 @@ class ReasoningEvent(BaseEvent):
attempt: int = 1
agent_role: str
task_id: str
task_name: Optional[str] = None
from_task: Optional[Any] = None
agent_id: Optional[str] = None
from_agent: Optional[Any] = None
task_name: str | None = None
from_task: Any | None = None
agent_id: str | None = None
from_agent: Any | None = None
def __init__(self, **data):
super().__init__(**data)

View File

@@ -1,15 +1,15 @@
from typing import Any, Optional
from typing import Any
from crewai.tasks.task_output import TaskOutput
from crewai.events.base_events import BaseEvent
from crewai.tasks.task_output import TaskOutput
class TaskStartedEvent(BaseEvent):
"""Event emitted when a task starts"""
type: str = "task_started"
context: Optional[str]
task: Optional[Any] = None
context: str | None
task: Any | None = None
def __init__(self, **data):
super().__init__(**data)
@@ -29,7 +29,7 @@ class TaskCompletedEvent(BaseEvent):
output: TaskOutput
type: str = "task_completed"
task: Optional[Any] = None
task: Any | None = None
def __init__(self, **data):
super().__init__(**data)
@@ -49,7 +49,7 @@ class TaskFailedEvent(BaseEvent):
error: str
type: str = "task_failed"
task: Optional[Any] = None
task: Any | None = None
def __init__(self, **data):
super().__init__(**data)
@@ -69,7 +69,7 @@ class TaskEvaluationEvent(BaseEvent):
type: str = "task_evaluation"
evaluation_type: str
task: Optional[Any] = None
task: Any | None = None
def __init__(self, **data):
super().__init__(**data)

View File

@@ -1,5 +1,8 @@
from collections.abc import Callable
from datetime import datetime
from typing import Any, Callable, Dict, Optional
from typing import Any
from pydantic import ConfigDict
from crewai.events.base_events import BaseEvent
@@ -7,21 +10,21 @@ from crewai.events.base_events import BaseEvent
class ToolUsageEvent(BaseEvent):
"""Base event for tool usage tracking"""
agent_key: Optional[str] = None
agent_role: Optional[str] = None
agent_id: Optional[str] = None
agent_key: str | None = None
agent_role: str | None = None
agent_id: str | None = None
tool_name: str
tool_args: Dict[str, Any] | str
tool_class: Optional[str] = None
tool_args: dict[str, Any] | str
tool_class: str | None = None
run_attempts: int | None = None
delegations: int | None = None
agent: Optional[Any] = None
task_name: Optional[str] = None
task_id: Optional[str] = None
from_task: Optional[Any] = None
from_agent: Optional[Any] = None
agent: Any | None = None
task_name: str | None = None
task_id: str | None = None
from_task: Any | None = None
from_agent: Any | None = None
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
super().__init__(**data)
@@ -81,9 +84,9 @@ class ToolExecutionErrorEvent(BaseEvent):
error: Any
type: str = "tool_execution_error"
tool_name: str
tool_args: Dict[str, Any]
tool_args: dict[str, Any]
tool_class: Callable
agent: Optional[Any] = None
agent: Any | None = None
def __init__(self, **data):
super().__init__(**data)

View File

@@ -1,25 +1,25 @@
from typing import Any, Dict, Optional
from typing import Any, ClassVar
from rich.console import Console
from rich.live import Live
from rich.panel import Panel
from rich.syntax import Syntax
from rich.text import Text
from rich.tree import Tree
from rich.live import Live
from rich.syntax import Syntax
class ConsoleFormatter:
current_crew_tree: Optional[Tree] = None
current_task_branch: Optional[Tree] = None
current_agent_branch: Optional[Tree] = None
current_tool_branch: Optional[Tree] = None
current_flow_tree: Optional[Tree] = None
current_method_branch: Optional[Tree] = None
current_lite_agent_branch: Optional[Tree] = None
tool_usage_counts: Dict[str, int] = {}
current_reasoning_branch: Optional[Tree] = None # Track reasoning status
_live_paused: bool = False
current_llm_tool_tree: Optional[Tree] = None
current_crew_tree: ClassVar[Tree | None] = None
current_task_branch: ClassVar[Tree | None] = None
current_agent_branch: ClassVar[Tree | None] = None
current_tool_branch: ClassVar[Tree | None] = None
current_flow_tree: ClassVar[Tree | None] = None
current_method_branch: ClassVar[Tree | None] = None
current_lite_agent_branch: ClassVar[Tree | None] = None
tool_usage_counts: ClassVar[dict[str, int]] = {}
current_reasoning_branch: ClassVar[Tree | None] = None # Track reasoning status
_live_paused: ClassVar[bool] = False
current_llm_tool_tree: ClassVar[Tree | None] = None
def __init__(self, verbose: bool = False):
self.console = Console(width=None)
@@ -29,7 +29,7 @@ class ConsoleFormatter:
# instance so the previous render is replaced instead of writing a new one.
# Once any non-Tree renderable is printed we stop the Live session so the
# final Tree persists on the terminal.
self._live: Optional[Live] = None
self._live: Live | None = None
def create_panel(self, content: Text, title: str, style: str = "blue") -> Panel:
"""Create a standardized panel with consistent styling."""
@@ -45,7 +45,7 @@ class ConsoleFormatter:
title: str,
name: str,
status_style: str = "blue",
tool_args: Dict[str, Any] | str = "",
tool_args: dict[str, Any] | str = "",
**fields,
) -> Text:
"""Create standardized status content with consistent formatting."""
@@ -70,7 +70,7 @@ class ConsoleFormatter:
prefix: str,
name: str,
style: str = "blue",
status: Optional[str] = None,
status: str | None = None,
) -> None:
"""Update tree label with consistent formatting."""
label = Text()
@@ -156,7 +156,7 @@ class ConsoleFormatter:
def update_crew_tree(
self,
tree: Optional[Tree],
tree: Tree | None,
crew_name: str,
source_id: str,
status: str = "completed",
@@ -196,7 +196,7 @@ class ConsoleFormatter:
self.print_panel(content, title, style)
def create_crew_tree(self, crew_name: str, source_id: str) -> Optional[Tree]:
def create_crew_tree(self, crew_name: str, source_id: str) -> Tree | None:
"""Create and initialize a new crew tree with initial status."""
if not self.verbose:
return None
@@ -220,8 +220,8 @@ class ConsoleFormatter:
return tree
def create_task_branch(
self, crew_tree: Optional[Tree], task_id: str, task_name: Optional[str] = None
) -> Optional[Tree]:
self, crew_tree: Tree | None, task_id: str, task_name: str | None = None
) -> Tree | None:
"""Create and initialize a task branch."""
if not self.verbose:
return None
@@ -255,11 +255,11 @@ class ConsoleFormatter:
def update_task_status(
self,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
task_id: str,
agent_role: str,
status: str = "completed",
task_name: Optional[str] = None,
task_name: str | None = None,
) -> None:
"""Update task status in the tree."""
if not self.verbose or crew_tree is None:
@@ -306,8 +306,8 @@ class ConsoleFormatter:
self.print_panel(content, panel_title, style)
def create_agent_branch(
self, task_branch: Optional[Tree], agent_role: str, crew_tree: Optional[Tree]
) -> Optional[Tree]:
self, task_branch: Tree | None, agent_role: str, crew_tree: Tree | None
) -> Tree | None:
"""Create and initialize an agent branch."""
if not self.verbose or not task_branch or not crew_tree:
return None
@@ -325,9 +325,9 @@ class ConsoleFormatter:
def update_agent_status(
self,
agent_branch: Optional[Tree],
agent_branch: Tree | None,
agent_role: str,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
status: str = "completed",
) -> None:
"""Update agent status in the tree."""
@@ -336,7 +336,7 @@ class ConsoleFormatter:
# altering the tree. Keeping it a no-op avoids duplicate status lines.
return
def create_flow_tree(self, flow_name: str, flow_id: str) -> Optional[Tree]:
def create_flow_tree(self, flow_name: str, flow_id: str) -> Tree | None:
"""Create and initialize a flow tree."""
content = self.create_status_content(
"Starting Flow Execution", flow_name, "blue", ID=flow_id
@@ -356,7 +356,7 @@ class ConsoleFormatter:
return flow_tree
def start_flow(self, flow_name: str, flow_id: str) -> Optional[Tree]:
def start_flow(self, flow_name: str, flow_id: str) -> Tree | None:
"""Initialize a flow execution tree."""
flow_tree = Tree("")
flow_label = Text()
@@ -376,7 +376,7 @@ class ConsoleFormatter:
def update_flow_status(
self,
flow_tree: Optional[Tree],
flow_tree: Tree | None,
flow_name: str,
flow_id: str,
status: str = "completed",
@@ -423,11 +423,11 @@ class ConsoleFormatter:
def update_method_status(
self,
method_branch: Optional[Tree],
flow_tree: Optional[Tree],
method_branch: Tree | None,
flow_tree: Tree | None,
method_name: str,
status: str = "running",
) -> Optional[Tree]:
) -> Tree | None:
"""Update method status in the flow tree."""
if not flow_tree:
return None
@@ -480,7 +480,7 @@ class ConsoleFormatter:
def handle_llm_tool_usage_started(
self,
tool_name: str,
tool_args: Dict[str, Any] | str,
tool_args: dict[str, Any] | str,
):
# Create status content for the tool usage
content = self.create_status_content(
@@ -520,11 +520,11 @@ class ConsoleFormatter:
def handle_tool_usage_started(
self,
agent_branch: Optional[Tree],
agent_branch: Tree | None,
tool_name: str,
crew_tree: Optional[Tree],
tool_args: Dict[str, Any] | str = "",
) -> Optional[Tree]:
crew_tree: Tree | None,
tool_args: dict[str, Any] | str = "",
) -> Tree | None:
"""Handle tool usage started event."""
if not self.verbose:
return None
@@ -569,9 +569,9 @@ class ConsoleFormatter:
def handle_tool_usage_finished(
self,
tool_branch: Optional[Tree],
tool_branch: Tree | None,
tool_name: str,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
) -> None:
"""Handle tool usage finished event."""
if not self.verbose or tool_branch is None:
@@ -600,10 +600,10 @@ class ConsoleFormatter:
def handle_tool_usage_error(
self,
tool_branch: Optional[Tree],
tool_branch: Tree | None,
tool_name: str,
error: str,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
) -> None:
"""Handle tool usage error event."""
if not self.verbose:
@@ -631,9 +631,9 @@ class ConsoleFormatter:
def handle_llm_call_started(
self,
agent_branch: Optional[Tree],
crew_tree: Optional[Tree],
) -> Optional[Tree]:
agent_branch: Tree | None,
crew_tree: Tree | None,
) -> Tree | None:
"""Handle LLM call started event."""
if not self.verbose:
return None
@@ -672,9 +672,9 @@ class ConsoleFormatter:
def handle_llm_call_completed(
self,
tool_branch: Optional[Tree],
agent_branch: Optional[Tree],
crew_tree: Optional[Tree],
tool_branch: Tree | None,
agent_branch: Tree | None,
crew_tree: Tree | None,
) -> None:
"""Handle LLM call completed event."""
if not self.verbose:
@@ -736,7 +736,7 @@ class ConsoleFormatter:
self.print()
def handle_llm_call_failed(
self, tool_branch: Optional[Tree], error: str, crew_tree: Optional[Tree]
self, tool_branch: Tree | None, error: str, crew_tree: Tree | None
) -> None:
"""Handle LLM call failed event."""
if not self.verbose:
@@ -789,7 +789,7 @@ class ConsoleFormatter:
def handle_crew_test_started(
self, crew_name: str, source_id: str, n_iterations: int
) -> Optional[Tree]:
) -> Tree | None:
"""Handle crew test started event."""
if not self.verbose:
return None
@@ -823,7 +823,7 @@ class ConsoleFormatter:
return test_tree
def handle_crew_test_completed(
self, flow_tree: Optional[Tree], crew_name: str
self, flow_tree: Tree | None, crew_name: str
) -> None:
"""Handle crew test completed event."""
if not self.verbose:
@@ -913,7 +913,7 @@ class ConsoleFormatter:
self.print_panel(failure_content, "Test Failure", "red")
self.print()
def create_lite_agent_branch(self, lite_agent_role: str) -> Optional[Tree]:
def create_lite_agent_branch(self, lite_agent_role: str) -> Tree | None:
"""Create and initialize a lite agent branch."""
if not self.verbose:
return None
@@ -935,10 +935,10 @@ class ConsoleFormatter:
def update_lite_agent_status(
self,
lite_agent_branch: Optional[Tree],
lite_agent_branch: Tree | None,
lite_agent_role: str,
status: str = "completed",
**fields: Dict[str, Any],
**fields: dict[str, Any],
) -> None:
"""Update lite agent status in the tree."""
if not self.verbose or lite_agent_branch is None:
@@ -981,7 +981,7 @@ class ConsoleFormatter:
lite_agent_role: str,
status: str = "started",
error: Any = None,
**fields: Dict[str, Any],
**fields: dict[str, Any],
) -> None:
"""Handle lite agent execution events with consistent formatting."""
if not self.verbose:
@@ -1006,9 +1006,9 @@ class ConsoleFormatter:
def handle_knowledge_retrieval_started(
self,
agent_branch: Optional[Tree],
crew_tree: Optional[Tree],
) -> Optional[Tree]:
agent_branch: Tree | None,
crew_tree: Tree | None,
) -> Tree | None:
"""Handle knowledge retrieval started event."""
if not self.verbose:
return None
@@ -1034,13 +1034,13 @@ class ConsoleFormatter:
def handle_knowledge_retrieval_completed(
self,
agent_branch: Optional[Tree],
crew_tree: Optional[Tree],
agent_branch: Tree | None,
crew_tree: Tree | None,
retrieved_knowledge: Any,
) -> None:
"""Handle knowledge retrieval completed event."""
if not self.verbose:
return None
return
branch_to_use = self.current_lite_agent_branch or agent_branch
tree_to_use = branch_to_use or crew_tree
@@ -1062,7 +1062,7 @@ class ConsoleFormatter:
)
self.print(knowledge_panel)
self.print()
return None
return
knowledge_branch_found = False
for child in branch_to_use.children:
@@ -1111,18 +1111,18 @@ class ConsoleFormatter:
def handle_knowledge_query_started(
self,
agent_branch: Optional[Tree],
agent_branch: Tree | None,
task_prompt: str,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
) -> None:
"""Handle knowledge query generated event."""
if not self.verbose:
return None
return
branch_to_use = self.current_lite_agent_branch or agent_branch
tree_to_use = branch_to_use or crew_tree
if branch_to_use is None or tree_to_use is None:
return None
return
query_branch = branch_to_use.add("")
self.update_tree_label(
@@ -1134,9 +1134,9 @@ class ConsoleFormatter:
def handle_knowledge_query_failed(
self,
agent_branch: Optional[Tree],
agent_branch: Tree | None,
error: str,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
) -> None:
"""Handle knowledge query failed event."""
if not self.verbose:
@@ -1159,18 +1159,18 @@ class ConsoleFormatter:
def handle_knowledge_query_completed(
self,
agent_branch: Optional[Tree],
crew_tree: Optional[Tree],
agent_branch: Tree | None,
crew_tree: Tree | None,
) -> None:
"""Handle knowledge query completed event."""
if not self.verbose:
return None
return
branch_to_use = self.current_lite_agent_branch or agent_branch
tree_to_use = branch_to_use or crew_tree
if branch_to_use is None or tree_to_use is None:
return None
return
query_branch = branch_to_use.add("")
self.update_tree_label(query_branch, "", "Knowledge Query Completed", "green")
@@ -1180,9 +1180,9 @@ class ConsoleFormatter:
def handle_knowledge_search_query_failed(
self,
agent_branch: Optional[Tree],
agent_branch: Tree | None,
error: str,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
) -> None:
"""Handle knowledge search query failed event."""
if not self.verbose:
@@ -1207,10 +1207,10 @@ class ConsoleFormatter:
def handle_reasoning_started(
self,
agent_branch: Optional[Tree],
agent_branch: Tree | None,
attempt: int,
crew_tree: Optional[Tree],
) -> Optional[Tree]:
crew_tree: Tree | None,
) -> Tree | None:
"""Handle agent reasoning started (or refinement) event."""
if not self.verbose:
return None
@@ -1249,7 +1249,7 @@ class ConsoleFormatter:
self,
plan: str,
ready: bool,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
) -> None:
"""Handle agent reasoning completed event."""
if not self.verbose:
@@ -1292,7 +1292,7 @@ class ConsoleFormatter:
def handle_reasoning_failed(
self,
error: str,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
) -> None:
"""Handle agent reasoning failure event."""
if not self.verbose:
@@ -1329,7 +1329,7 @@ class ConsoleFormatter:
def handle_agent_logs_started(
self,
agent_role: str,
task_description: Optional[str] = None,
task_description: str | None = None,
verbose: bool = False,
) -> None:
"""Handle agent logs started event."""
@@ -1367,10 +1367,11 @@ class ConsoleFormatter:
if not verbose:
return
from crewai.agents.parser import AgentAction, AgentFinish
import json
import re
from crewai.agents.parser import AgentAction, AgentFinish
agent_role = agent_role.partition("\n")[0]
if isinstance(formatted_answer, AgentAction):
@@ -1473,9 +1474,9 @@ class ConsoleFormatter:
def handle_memory_retrieval_started(
self,
agent_branch: Optional[Tree],
crew_tree: Optional[Tree],
) -> Optional[Tree]:
agent_branch: Tree | None,
crew_tree: Tree | None,
) -> Tree | None:
if not self.verbose:
return None
@@ -1497,13 +1498,13 @@ class ConsoleFormatter:
def handle_memory_retrieval_completed(
self,
agent_branch: Optional[Tree],
crew_tree: Optional[Tree],
agent_branch: Tree | None,
crew_tree: Tree | None,
memory_content: str,
retrieval_time_ms: float,
) -> None:
if not self.verbose:
return None
return
branch_to_use = self.current_lite_agent_branch or agent_branch
tree_to_use = branch_to_use or crew_tree
@@ -1528,7 +1529,7 @@ class ConsoleFormatter:
if branch_to_use is None or tree_to_use is None:
add_panel()
return None
return
memory_branch_found = False
for child in branch_to_use.children:
@@ -1565,13 +1566,13 @@ class ConsoleFormatter:
def handle_memory_query_completed(
self,
agent_branch: Optional[Tree],
agent_branch: Tree | None,
source_type: str,
query_time_ms: float,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
) -> None:
if not self.verbose:
return None
return
branch_to_use = self.current_lite_agent_branch or agent_branch
tree_to_use = branch_to_use or crew_tree
@@ -1580,7 +1581,7 @@ class ConsoleFormatter:
branch_to_use = tree_to_use
if branch_to_use is None:
return None
return
memory_type = source_type.replace("_", " ").title()
@@ -1598,13 +1599,13 @@ class ConsoleFormatter:
def handle_memory_query_failed(
self,
agent_branch: Optional[Tree],
crew_tree: Optional[Tree],
agent_branch: Tree | None,
crew_tree: Tree | None,
error: str,
source_type: str,
) -> None:
if not self.verbose:
return None
return
branch_to_use = self.current_lite_agent_branch or agent_branch
tree_to_use = branch_to_use or crew_tree
@@ -1613,7 +1614,7 @@ class ConsoleFormatter:
branch_to_use = tree_to_use
if branch_to_use is None:
return None
return
memory_type = source_type.replace("_", " ").title()
@@ -1630,16 +1631,16 @@ class ConsoleFormatter:
break
def handle_memory_save_started(
self, agent_branch: Optional[Tree], crew_tree: Optional[Tree]
self, agent_branch: Tree | None, crew_tree: Tree | None
) -> None:
if not self.verbose:
return None
return
branch_to_use = agent_branch or self.current_lite_agent_branch
tree_to_use = branch_to_use or crew_tree
if tree_to_use is None:
return None
return
for child in tree_to_use.children:
if "Memory Update" in str(child.label):
@@ -1655,19 +1656,19 @@ class ConsoleFormatter:
def handle_memory_save_completed(
self,
agent_branch: Optional[Tree],
crew_tree: Optional[Tree],
agent_branch: Tree | None,
crew_tree: Tree | None,
save_time_ms: float,
source_type: str,
) -> None:
if not self.verbose:
return None
return
branch_to_use = agent_branch or self.current_lite_agent_branch
tree_to_use = branch_to_use or crew_tree
if tree_to_use is None:
return None
return
memory_type = source_type.replace("_", " ").title()
content = f"{memory_type} Memory Saved ({save_time_ms:.2f}ms)"
@@ -1685,19 +1686,19 @@ class ConsoleFormatter:
def handle_memory_save_failed(
self,
agent_branch: Optional[Tree],
agent_branch: Tree | None,
error: str,
source_type: str,
crew_tree: Optional[Tree],
crew_tree: Tree | None,
) -> None:
if not self.verbose:
return None
return
branch_to_use = agent_branch or self.current_lite_agent_branch
tree_to_use = branch_to_use or crew_tree
if branch_to_use is None or tree_to_use is None:
return None
return
memory_type = source_type.replace("_", " ").title()
content = f"{memory_type} Memory Save Failed"
@@ -1738,7 +1739,7 @@ class ConsoleFormatter:
def handle_guardrail_completed(
self,
success: bool,
error: Optional[str],
error: str | None,
retry_count: int,
) -> None:
"""Display guardrail evaluation result.

View File

@@ -1,40 +1,39 @@
from crewai.experimental.evaluation import (
AgentEvaluationResult,
AgentEvaluator,
BaseEvaluator,
EvaluationScore,
MetricCategory,
AgentEvaluationResult,
SemanticQualityEvaluator,
GoalAlignmentEvaluator,
ReasoningEfficiencyEvaluator,
ToolSelectionEvaluator,
ParameterExtractionEvaluator,
ToolInvocationEvaluator,
EvaluationTraceCallback,
create_evaluation_callbacks,
AgentEvaluator,
create_default_evaluator,
ExperimentRunner,
ExperimentResults,
ExperimentResult,
ExperimentResults,
ExperimentRunner,
GoalAlignmentEvaluator,
MetricCategory,
ParameterExtractionEvaluator,
ReasoningEfficiencyEvaluator,
SemanticQualityEvaluator,
ToolInvocationEvaluator,
ToolSelectionEvaluator,
create_default_evaluator,
create_evaluation_callbacks,
)
__all__ = [
"AgentEvaluationResult",
"AgentEvaluator",
"BaseEvaluator",
"EvaluationScore",
"MetricCategory",
"AgentEvaluationResult",
"SemanticQualityEvaluator",
"GoalAlignmentEvaluator",
"ReasoningEfficiencyEvaluator",
"ToolSelectionEvaluator",
"ParameterExtractionEvaluator",
"ToolInvocationEvaluator",
"EvaluationTraceCallback",
"create_evaluation_callbacks",
"AgentEvaluator",
"create_default_evaluator",
"ExperimentRunner",
"ExperimentResult",
"ExperimentResults",
"ExperimentResult"
]
"ExperimentRunner",
"GoalAlignmentEvaluator",
"MetricCategory",
"ParameterExtractionEvaluator",
"ReasoningEfficiencyEvaluator",
"SemanticQualityEvaluator",
"ToolInvocationEvaluator",
"ToolSelectionEvaluator",
"create_default_evaluator",
"create_evaluation_callbacks",
]

View File

@@ -1,51 +1,47 @@
from crewai.experimental.evaluation.agent_evaluator import (
AgentEvaluator,
create_default_evaluator,
)
from crewai.experimental.evaluation.base_evaluator import (
AgentEvaluationResult,
BaseEvaluator,
EvaluationScore,
MetricCategory,
AgentEvaluationResult
)
from crewai.experimental.evaluation.metrics import (
SemanticQualityEvaluator,
GoalAlignmentEvaluator,
ReasoningEfficiencyEvaluator,
ToolSelectionEvaluator,
ParameterExtractionEvaluator,
ToolInvocationEvaluator
)
from crewai.experimental.evaluation.evaluation_listener import (
EvaluationTraceCallback,
create_evaluation_callbacks
create_evaluation_callbacks,
)
from crewai.experimental.evaluation.agent_evaluator import (
AgentEvaluator,
create_default_evaluator
)
from crewai.experimental.evaluation.experiment import (
ExperimentRunner,
ExperimentResult,
ExperimentResults,
ExperimentResult
ExperimentRunner,
)
from crewai.experimental.evaluation.metrics import (
GoalAlignmentEvaluator,
ParameterExtractionEvaluator,
ReasoningEfficiencyEvaluator,
SemanticQualityEvaluator,
ToolInvocationEvaluator,
ToolSelectionEvaluator,
)
__all__ = [
"AgentEvaluationResult",
"AgentEvaluator",
"BaseEvaluator",
"EvaluationScore",
"MetricCategory",
"AgentEvaluationResult",
"SemanticQualityEvaluator",
"GoalAlignmentEvaluator",
"ReasoningEfficiencyEvaluator",
"ToolSelectionEvaluator",
"ParameterExtractionEvaluator",
"ToolInvocationEvaluator",
"EvaluationTraceCallback",
"create_evaluation_callbacks",
"AgentEvaluator",
"create_default_evaluator",
"ExperimentRunner",
"ExperimentResult",
"ExperimentResults",
"ExperimentResult"
"ExperimentRunner",
"GoalAlignmentEvaluator",
"MetricCategory",
"ParameterExtractionEvaluator",
"ReasoningEfficiencyEvaluator",
"SemanticQualityEvaluator",
"ToolInvocationEvaluator",
"ToolSelectionEvaluator",
"create_default_evaluator",
"create_evaluation_callbacks",
]

View File

@@ -1,34 +1,32 @@
import threading
from typing import Any, Optional
from collections.abc import Sequence
from typing import Any
from crewai.experimental.evaluation.base_evaluator import (
AgentEvaluationResult,
AggregationStrategy,
)
from crewai.agent import Agent
from crewai.task import Task
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentEvaluationStartedEvent,
AgentEvaluationCompletedEvent,
AgentEvaluationFailedEvent,
AgentEvaluationStartedEvent,
LiteAgentExecutionCompletedEvent,
)
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
from collections.abc import Sequence
from crewai.events.event_bus import crewai_event_bus
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.types.task_events import TaskCompletedEvent
from crewai.events.types.agent_events import LiteAgentExecutionCompletedEvent
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
from crewai.experimental.evaluation.base_evaluator import (
AgentAggregatedEvaluationResult,
AgentEvaluationResult,
AggregationStrategy,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
from crewai.task import Task
class ExecutionState:
current_agent_id: Optional[str] = None
current_task_id: Optional[str] = None
current_agent_id: str | None = None
current_task_id: str | None = None
def __init__(self):
self.traces = {}
@@ -284,7 +282,7 @@ class AgentEvaluator:
error=str(e),
)
self.console_formatter.print(
f"Error in {evaluator.metric_category.value} evaluator: {str(e)}"
f"Error in {evaluator.metric_category.value} evaluator: {e!s}"
)
return result
@@ -340,11 +338,11 @@ class AgentEvaluator:
def create_default_evaluator(agents: list[Agent], llm: None = None):
from crewai.experimental.evaluation import (
GoalAlignmentEvaluator,
SemanticQualityEvaluator,
ToolSelectionEvaluator,
ParameterExtractionEvaluator,
ToolInvocationEvaluator,
ReasoningEfficiencyEvaluator,
SemanticQualityEvaluator,
ToolInvocationEvaluator,
ToolSelectionEvaluator,
)
evaluators = [

View File

@@ -1,15 +1,16 @@
import abc
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, Field
from crewai.agent import Agent
from crewai.task import Task
from crewai.llm import BaseLLM
from crewai.task import Task
from crewai.utilities.llm_utils import create_llm
class MetricCategory(enum.Enum):
GOAL_ALIGNMENT = "goal_alignment"
SEMANTIC_QUALITY = "semantic_quality"
@@ -19,7 +20,7 @@ class MetricCategory(enum.Enum):
TOOL_INVOCATION = "tool_invocation"
def title(self):
return self.value.replace('_', ' ').title()
return self.value.replace("_", " ").title()
class EvaluationScore(BaseModel):
@@ -27,15 +28,13 @@ class EvaluationScore(BaseModel):
default=5.0,
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
ge=0.0,
le=10.0
le=10.0,
)
feedback: str = Field(
default="",
description="Detailed feedback explaining the evaluation score"
default="", description="Detailed feedback explaining the evaluation score"
)
raw_response: str | None = Field(
default=None,
description="Raw response from the evaluator (e.g., LLM)"
default=None, description="Raw response from the evaluator (e.g., LLM)"
)
def __str__(self) -> str:
@@ -57,7 +56,7 @@ class BaseEvaluator(abc.ABC):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
execution_trace: dict[str, Any],
final_output: Any,
task: Task | None = None,
) -> EvaluationScore:
@@ -67,9 +66,8 @@ class BaseEvaluator(abc.ABC):
class AgentEvaluationResult(BaseModel):
agent_id: str = Field(description="ID of the evaluated agent")
task_id: str = Field(description="ID of the task that was executed")
metrics: Dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict,
description="Evaluation scores for each metric category"
metrics: dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict, description="Evaluation scores for each metric category"
)
@@ -81,33 +79,23 @@ class AggregationStrategy(Enum):
class AgentAggregatedEvaluationResult(BaseModel):
agent_id: str = Field(
default="",
description="ID of the agent"
)
agent_role: str = Field(
default="",
description="Role of the agent"
)
agent_id: str = Field(default="", description="ID of the agent")
agent_role: str = Field(default="", description="Role of the agent")
task_count: int = Field(
default=0,
description="Number of tasks included in this aggregation"
default=0, description="Number of tasks included in this aggregation"
)
aggregation_strategy: AggregationStrategy = Field(
default=AggregationStrategy.SIMPLE_AVERAGE,
description="Strategy used for aggregation"
description="Strategy used for aggregation",
)
metrics: Dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict,
description="Aggregated metrics across all tasks"
metrics: dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict, description="Aggregated metrics across all tasks"
)
task_results: List[str] = Field(
default_factory=list,
description="IDs of tasks included in this aggregation"
task_results: list[str] = Field(
default_factory=list, description="IDs of tasks included in this aggregation"
)
overall_score: Optional[float] = Field(
default=None,
description="Overall score for this agent"
overall_score: float | None = Field(
default=None, description="Overall score for this agent"
)
def __str__(self) -> str:
@@ -119,7 +107,7 @@ class AgentAggregatedEvaluationResult(BaseModel):
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
if score.feedback:
detailed_feedback = "\n ".join(score.feedback.split('\n'))
detailed_feedback = "\n ".join(score.feedback.split("\n"))
result += f" {detailed_feedback}\n"
return result
return result

View File

@@ -1,16 +1,18 @@
from collections import defaultdict
from typing import Dict, Any, List
from rich.table import Table
from rich.box import HEAVY_EDGE, ROUNDED
from collections.abc import Sequence
from typing import Any
from rich.box import HEAVY_EDGE, ROUNDED
from rich.table import Table
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.experimental.evaluation import EvaluationScore
from crewai.experimental.evaluation.base_evaluator import (
AgentAggregatedEvaluationResult,
AggregationStrategy,
AgentEvaluationResult,
AggregationStrategy,
MetricCategory,
)
from crewai.experimental.evaluation import EvaluationScore
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.utilities.llm_utils import create_llm
@@ -19,7 +21,7 @@ class EvaluationDisplayFormatter:
self.console_formatter = ConsoleFormatter()
def display_evaluation_with_feedback(
self, iterations_results: Dict[int, Dict[str, List[Any]]]
self, iterations_results: dict[int, dict[str, list[Any]]]
):
if not iterations_results:
self.console_formatter.print(
@@ -99,7 +101,7 @@ class EvaluationDisplayFormatter:
def display_summary_results(
self,
iterations_results: Dict[int, Dict[str, List[AgentAggregatedEvaluationResult]]],
iterations_results: dict[int, dict[str, list[AgentAggregatedEvaluationResult]]],
):
if not iterations_results:
self.console_formatter.print(
@@ -304,25 +306,25 @@ class EvaluationDisplayFormatter:
self,
agent_role: str,
metric: str,
feedbacks: List[str],
scores: List[float | None],
feedbacks: list[str],
scores: list[float | None],
strategy: AggregationStrategy,
) -> str:
if len(feedbacks) <= 2 and all(len(fb) < 200 for fb in feedbacks):
return "\n\n".join(
[f"Feedback {i+1}: {fb}" for i, fb in enumerate(feedbacks)]
[f"Feedback {i + 1}: {fb}" for i, fb in enumerate(feedbacks)]
)
try:
llm = create_llm()
formatted_feedbacks = []
for i, (feedback, score) in enumerate(zip(feedbacks, scores)):
for i, (feedback, score) in enumerate(zip(feedbacks, scores, strict=False)):
if len(feedback) > 500:
feedback = feedback[:500] + "..."
score_text = f"{score:.1f}" if score is not None else "N/A"
formatted_feedbacks.append(
f"Feedback #{i+1} (Score: {score_text}):\n{feedback}"
f"Feedback #{i + 1} (Score: {score_text}):\n{feedback}"
)
all_feedbacks = "\n\n" + "\n\n---\n\n".join(formatted_feedbacks)
@@ -366,9 +368,7 @@ class EvaluationDisplayFormatter:
},
]
assert llm is not None
response = llm.call(prompt)
return response
return llm.call(prompt)
except Exception:
return "Synthesized from multiple tasks: " + "\n\n".join(

View File

@@ -1,26 +1,25 @@
from datetime import datetime
from typing import Any, Dict, Optional
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from crewai.agent import Agent
from crewai.task import Task
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.types.agent_events import (
AgentExecutionStartedEvent,
AgentExecutionCompletedEvent,
LiteAgentExecutionStartedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallStartedEvent
from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
ToolExecutionErrorEvent,
ToolSelectionErrorEvent,
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolValidateInputErrorEvent,
)
from crewai.events.types.llm_events import LLMCallStartedEvent, LLMCallCompletedEvent
from crewai.task import Task
class EvaluationTraceCallback(BaseEventListener):
@@ -253,7 +252,7 @@ class EvaluationTraceCallback(BaseEventListener):
if hasattr(self, "current_llm_call"):
self.current_llm_call = {}
def get_trace(self, agent_id: str, task_id: str) -> Optional[Dict[str, Any]]:
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
trace_key = f"{agent_id}_{task_id}"
return self.traces.get(trace_key)

View File

@@ -1,8 +1,7 @@
from crewai.experimental.evaluation.experiment.result import (
ExperimentResult,
ExperimentResults,
)
from crewai.experimental.evaluation.experiment.runner import ExperimentRunner
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
__all__ = [
"ExperimentRunner",
"ExperimentResults",
"ExperimentResult"
]
__all__ = ["ExperimentResult", "ExperimentResults", "ExperimentRunner"]

View File

@@ -2,8 +2,10 @@ import json
import os
from datetime import datetime, timezone
from typing import Any
from pydantic import BaseModel
class ExperimentResult(BaseModel):
identifier: str
inputs: dict[str, Any]
@@ -12,35 +14,48 @@ class ExperimentResult(BaseModel):
passed: bool
agent_evaluations: dict[str, Any] | None = None
class ExperimentResults:
def __init__(self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None):
def __init__(
self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None
):
self.results = results
self.metadata = metadata or {}
self.timestamp = datetime.now(timezone.utc)
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
from crewai.experimental.evaluation.experiment.result_display import (
ExperimentResultsDisplay,
)
self.display = ExperimentResultsDisplay()
def to_json(self, filepath: str | None = None) -> dict[str, Any]:
data = {
"timestamp": self.timestamp.isoformat(),
"metadata": self.metadata,
"results": [r.model_dump(exclude={"agent_evaluations"}) for r in self.results]
"results": [
r.model_dump(exclude={"agent_evaluations"}) for r in self.results
],
}
if filepath:
with open(filepath, 'w') as f:
with open(filepath, "w") as f:
json.dump(data, f, indent=2)
self.display.console.print(f"[green]Results saved to {filepath}[/green]")
return data
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True, print_summary: bool = False) -> dict[str, Any]:
def compare_with_baseline(
self,
baseline_filepath: str,
save_current: bool = True,
print_summary: bool = False,
) -> dict[str, Any]:
baseline_runs = []
if os.path.exists(baseline_filepath) and os.path.getsize(baseline_filepath) > 0:
try:
with open(baseline_filepath, 'r') as f:
with open(baseline_filepath, "r") as f:
baseline_data = json.load(f)
if isinstance(baseline_data, dict) and "timestamp" in baseline_data:
@@ -48,14 +63,18 @@ class ExperimentResults:
elif isinstance(baseline_data, list):
baseline_runs = baseline_data
except (json.JSONDecodeError, FileNotFoundError) as e:
self.display.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
self.display.console.print(
f"[yellow]Warning: Could not load baseline file: {e!s}[/yellow]"
)
if not baseline_runs:
if save_current:
current_data = self.to_json()
with open(baseline_filepath, 'w') as f:
with open(baseline_filepath, "w") as f:
json.dump([current_data], f, indent=2)
self.display.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
self.display.console.print(
f"[green]Saved current results as new baseline to {baseline_filepath}[/green]"
)
return {"is_baseline": True, "changes": {}}
baseline_runs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
@@ -69,9 +88,11 @@ class ExperimentResults:
if save_current:
current_data = self.to_json()
baseline_runs.append(current_data)
with open(baseline_filepath, 'w') as f:
with open(baseline_filepath, "w") as f:
json.dump(baseline_runs, f, indent=2)
self.display.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
self.display.console.print(
f"[green]Added current results to baseline file {baseline_filepath}[/green]"
)
return comparison
@@ -118,5 +139,5 @@ class ExperimentResults:
"new_tests": new_tests,
"missing_tests": missing_tests,
"total_compared": len(improved) + len(regressed) + len(unchanged),
"baseline_timestamp": baseline_run.get("timestamp", "unknown")
"baseline_timestamp": baseline_run.get("timestamp", "unknown"),
}

View File

@@ -1,9 +1,12 @@
from typing import Dict, Any
from typing import Any
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.table import Table
from crewai.experimental.evaluation.experiment.result import ExperimentResults
class ExperimentResultsDisplay:
def __init__(self):
self.console = Console()
@@ -19,13 +22,19 @@ class ExperimentResultsDisplay:
table.add_row("Total Test Cases", str(total))
table.add_row("Passed", str(passed))
table.add_row("Failed", str(total - passed))
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
table.add_row(
"Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
)
self.console.print(table)
def comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
expand=False))
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
self.console.print(
Panel(
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
expand=False,
)
)
table = Table(title="Results Comparison")
table.add_column("Metric", style="cyan")
@@ -34,7 +43,9 @@ class ExperimentResultsDisplay:
improved = comparison.get("improved", [])
if improved:
details = ", ".join([f"{test_identifier}" for test_identifier in improved[:3]])
details = ", ".join(
[f"{test_identifier}" for test_identifier in improved[:3]]
)
if len(improved) > 3:
details += f" and {len(improved) - 3} more"
table.add_row("✅ Improved", str(len(improved)), details)
@@ -43,7 +54,9 @@ class ExperimentResultsDisplay:
regressed = comparison.get("regressed", [])
if regressed:
details = ", ".join([f"{test_identifier}" for test_identifier in regressed[:3]])
details = ", ".join(
[f"{test_identifier}" for test_identifier in regressed[:3]]
)
if len(regressed) > 3:
details += f" and {len(regressed) - 3} more"
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")

View File

@@ -2,11 +2,19 @@ from collections import defaultdict
from hashlib import md5
from typing import Any
from crewai import Crew, Agent
from crewai import Agent, Crew
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
from crewai.experimental.evaluation.evaluation_display import AgentAggregatedEvaluationResult
from crewai.experimental.evaluation.evaluation_display import (
AgentAggregatedEvaluationResult,
)
from crewai.experimental.evaluation.experiment.result import (
ExperimentResult,
ExperimentResults,
)
from crewai.experimental.evaluation.experiment.result_display import (
ExperimentResultsDisplay,
)
class ExperimentRunner:
def __init__(self, dataset: list[dict[str, Any]]):
@@ -14,7 +22,12 @@ class ExperimentRunner:
self.evaluator: AgentEvaluator | None = None
self.display = ExperimentResultsDisplay()
def run(self, crew: Crew | None = None, agents: list[Agent] | None = None, print_summary: bool = False) -> ExperimentResults:
def run(
self,
crew: Crew | None = None,
agents: list[Agent] | None = None,
print_summary: bool = False,
) -> ExperimentResults:
if crew and not agents:
agents = crew.agents
@@ -35,13 +48,20 @@ class ExperimentRunner:
return experiment_results
def _run_test_case(self, test_case: dict[str, Any], agents: list[Agent], crew: Crew | None = None) -> ExperimentResult:
def _run_test_case(
self, test_case: dict[str, Any], agents: list[Agent], crew: Crew | None = None
) -> ExperimentResult:
inputs = test_case["inputs"]
expected_score = test_case["expected_score"]
identifier = test_case.get("identifier") or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
identifier = (
test_case.get("identifier")
or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
)
try:
self.display.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
self.display.console.print(
f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]"
)
self.display.console.print("\n")
if crew:
crew.kickoff(inputs=inputs)
@@ -61,35 +81,38 @@ class ExperimentRunner:
score=actual_score,
expected_score=expected_score,
passed=passed,
agent_evaluations=agent_evaluations
agent_evaluations=agent_evaluations,
)
except Exception as e:
self.display.console.print(f"[red]Error running test case: {str(e)}[/red]")
self.display.console.print(f"[red]Error running test case: {e!s}[/red]")
return ExperimentResult(
identifier=identifier,
inputs=inputs,
score=0,
expected_score=expected_score,
passed=False
passed=False,
)
def _extract_scores(self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]) -> float | dict[str, float]:
def _extract_scores(
self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]
) -> float | dict[str, float]:
all_scores: dict[str, list[float]] = defaultdict(list)
for evaluation in agent_evaluations.values():
for metric_name, score in evaluation.metrics.items():
if score.score is not None:
all_scores[metric_name.value].append(score.score)
avg_scores = {m: sum(s)/len(s) for m, s in all_scores.items()}
avg_scores = {m: sum(s) / len(s) for m, s in all_scores.items()}
if len(avg_scores) == 1:
return list(avg_scores.values())[0]
return next(iter(avg_scores.values()))
return avg_scores
def _assert_scores(self, expected: float | dict[str, float],
actual: float | dict[str, float]) -> bool:
def _assert_scores(
self, expected: float | dict[str, float], actual: float | dict[str, float]
) -> bool:
"""
Compare expected and actual scores, and return whether the test case passed.
@@ -122,4 +145,4 @@ class ExperimentRunner:
# All matching keys must have actual >= expected
return all(actual[key] >= expected[key] for key in matching_keys)
return False
return False

View File

@@ -1,26 +1,21 @@
from crewai.experimental.evaluation.metrics.goal_metrics import GoalAlignmentEvaluator
from crewai.experimental.evaluation.metrics.reasoning_metrics import (
ReasoningEfficiencyEvaluator
ReasoningEfficiencyEvaluator,
)
from crewai.experimental.evaluation.metrics.tools_metrics import (
ToolSelectionEvaluator,
ParameterExtractionEvaluator,
ToolInvocationEvaluator
)
from crewai.experimental.evaluation.metrics.goal_metrics import (
GoalAlignmentEvaluator
)
from crewai.experimental.evaluation.metrics.semantic_quality_metrics import (
SemanticQualityEvaluator
SemanticQualityEvaluator,
)
from crewai.experimental.evaluation.metrics.tools_metrics import (
ParameterExtractionEvaluator,
ToolInvocationEvaluator,
ToolSelectionEvaluator,
)
__all__ = [
"ReasoningEfficiencyEvaluator",
"ToolSelectionEvaluator",
"ParameterExtractionEvaluator",
"ToolInvocationEvaluator",
"GoalAlignmentEvaluator",
"SemanticQualityEvaluator"
]
"ParameterExtractionEvaluator",
"ReasoningEfficiencyEvaluator",
"SemanticQualityEvaluator",
"ToolInvocationEvaluator",
"ToolSelectionEvaluator",
]

View File

@@ -1,10 +1,14 @@
from typing import Any, Dict
from typing import Any
from crewai.agent import Agent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.task import Task
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
class GoalAlignmentEvaluator(BaseEvaluator):
@property
@@ -14,7 +18,7 @@ class GoalAlignmentEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
execution_trace: dict[str, Any],
final_output: Any,
task: Task | None = None,
) -> EvaluationScore:
@@ -23,7 +27,9 @@ class GoalAlignmentEvaluator(BaseEvaluator):
task_context = f"Task description: {task.description}\nExpected output: {task.expected_output}\n"
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
{
"role": "system",
"content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
Score the agent's goal alignment on a scale from 0-10 where:
- 0: Complete misalignment, agent did not understand or attempt the task goal
@@ -37,8 +43,11 @@ Consider:
4. Did the agent provide all requested information or deliverables?
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
Agent goal: {agent.goal}
{task_context}
@@ -47,7 +56,8 @@ Agent's final output:
{final_output}
Evaluate how well the agent's output aligns with the assigned task goal.
"""}
""",
},
]
assert self.llm is not None
response = self.llm.call(prompt)
@@ -59,11 +69,11 @@ Evaluate how well the agent's output aligns with the assigned task goal.
return EvaluationScore(
score=evaluation_data.get("score", 0),
feedback=evaluation_data.get("feedback", response),
raw_response=response
raw_response=response,
)
except Exception:
return EvaluationScore(
score=None,
feedback=f"Failed to parse evaluation. Raw response: {response}",
raw_response=response
raw_response=response,
)

View File

@@ -8,18 +8,23 @@ This module provides evaluator implementations for:
import logging
import re
from enum import Enum
from typing import Any, Dict, List, Tuple
import numpy as np
from collections.abc import Sequence
from enum import Enum
from typing import Any
import numpy as np
from crewai.agent import Agent
from crewai.task import Task
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
class ReasoningPatternType(Enum):
EFFICIENT = "efficient" # Good reasoning flow
LOOP = "loop" # Agent is stuck in a loop
@@ -36,7 +41,7 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
execution_trace: dict[str, Any],
final_output: TaskOutput | str,
task: Task | None = None,
) -> EvaluationScore:
@@ -49,7 +54,7 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
if not llm_calls or len(llm_calls) < 2:
return EvaluationScore(
score=None,
feedback="Insufficient LLM calls to evaluate reasoning efficiency."
feedback="Insufficient LLM calls to evaluate reasoning efficiency.",
)
total_calls = len(llm_calls)
@@ -58,12 +63,16 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
time_intervals = []
has_reliable_timing = True
for i in range(1, len(llm_calls)):
start_time = llm_calls[i-1].get("end_time")
start_time = llm_calls[i - 1].get("end_time")
end_time = llm_calls[i].get("start_time")
if start_time and end_time and start_time != end_time:
try:
interval = end_time - start_time
time_intervals.append(interval.total_seconds() if hasattr(interval, 'total_seconds') else 0)
time_intervals.append(
interval.total_seconds()
if hasattr(interval, "total_seconds")
else 0
)
except Exception:
has_reliable_timing = False
else:
@@ -83,14 +92,22 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
if has_reliable_timing and time_intervals:
efficiency_metrics["avg_time_between_calls"] = np.mean(time_intervals)
loop_info = f"Detected {len(loop_details)} potential reasoning loops." if loop_detected else "No significant reasoning loops detected."
loop_info = (
f"Detected {len(loop_details)} potential reasoning loops."
if loop_detected
else "No significant reasoning loops detected."
)
call_samples = self._get_call_samples(llm_calls)
final_output = final_output.raw if isinstance(final_output, TaskOutput) else final_output
final_output = (
final_output.raw if isinstance(final_output, TaskOutput) else final_output
)
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
{
"role": "system",
"content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
Evaluate the agent's reasoning efficiency across these five key subcategories:
@@ -120,8 +137,11 @@ Return your evaluation as JSON with the following structure:
"feedback": string (general feedback about overall reasoning efficiency),
"optimization_suggestions": string (concrete suggestions for improving reasoning efficiency),
"detected_patterns": string (describe any inefficient reasoning patterns you observe)
}"""},
{"role": "user", "content": f"""
}""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -140,7 +160,8 @@ Agent's final output:
Evaluate the reasoning efficiency of this agent based on these interaction patterns.
Identify any inefficient reasoning patterns and provide specific suggestions for optimization.
"""}
""",
},
]
assert self.llm is not None
@@ -156,34 +177,46 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
conciseness = scores.get("conciseness", 5.0)
loop_avoidance = scores.get("loop_avoidance", 5.0)
overall_score = evaluation_data.get("overall_score", evaluation_data.get("score", 5.0))
overall_score = evaluation_data.get(
"overall_score", evaluation_data.get("score", 5.0)
)
feedback = evaluation_data.get("feedback", "No detailed feedback provided.")
optimization_suggestions = evaluation_data.get("optimization_suggestions", "No specific suggestions provided.")
optimization_suggestions = evaluation_data.get(
"optimization_suggestions", "No specific suggestions provided."
)
detailed_feedback = "Reasoning Efficiency Evaluation:\n"
detailed_feedback += f"• Focus: {focus}/10 - Staying on topic without tangents\n"
detailed_feedback += f"• Progression: {progression}/10 - Building on previous thinking\n"
detailed_feedback += (
f"• Focus: {focus}/10 - Staying on topic without tangents\n"
)
detailed_feedback += (
f"• Progression: {progression}/10 - Building on previous thinking\n"
)
detailed_feedback += f"• Decision Quality: {decision_quality}/10 - Making appropriate decisions\n"
detailed_feedback += f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
detailed_feedback += (
f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
)
detailed_feedback += f"• Loop Avoidance: {loop_avoidance}/10 - Avoiding repetitive patterns\n\n"
detailed_feedback += f"Feedback:\n{feedback}\n\n"
detailed_feedback += f"Optimization Suggestions:\n{optimization_suggestions}"
detailed_feedback += (
f"Optimization Suggestions:\n{optimization_suggestions}"
)
return EvaluationScore(
score=float(overall_score),
feedback=detailed_feedback,
raw_response=response
raw_response=response,
)
except Exception as e:
logging.warning(f"Failed to parse reasoning efficiency evaluation: {e}")
return EvaluationScore(
score=None,
feedback=f"Failed to parse reasoning efficiency evaluation. Raw response: {response[:200]}...",
raw_response=response
raw_response=response,
)
def _detect_loops(self, llm_calls: List[Dict]) -> Tuple[bool, List[Dict]]:
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
loop_details = []
messages = []
@@ -205,18 +238,20 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
# A more sophisticated approach would use semantic similarity
similarity = self._calculate_text_similarity(messages[i], messages[j])
if similarity > 0.7: # Arbitrary threshold
loop_details.append({
"first_occurrence": i,
"second_occurrence": j,
"similarity": similarity,
"snippet": messages[i][:100] + "..."
})
loop_details.append(
{
"first_occurrence": i,
"second_occurrence": j,
"similarity": similarity,
"snippet": messages[i][:100] + "...",
}
)
return len(loop_details) > 0, loop_details
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
text1 = re.sub(r'\s+', ' ', text1.lower()).strip()
text2 = re.sub(r'\s+', ' ', text2.lower()).strip()
text1 = re.sub(r"\s+", " ", text1.lower()).strip()
text2 = re.sub(r"\s+", " ", text2.lower()).strip()
# Simple Jaccard similarity on word sets
words1 = set(text1.split())
@@ -227,7 +262,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
return intersection / union if union > 0 else 0.0
def _analyze_reasoning_patterns(self, llm_calls: List[Dict]) -> Dict[str, Any]:
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
call_lengths = []
response_times = []
@@ -267,7 +302,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
details = "Agent is consistently verbose across interactions."
elif len(llm_calls) > 10 and length_trend > 0.5:
primary_pattern = ReasoningPatternType.INDECISIVE
details = "Agent shows signs of indecisiveness with increasing message lengths."
details = (
"Agent shows signs of indecisiveness with increasing message lengths."
)
elif std_length / avg_length > 0.8:
primary_pattern = ReasoningPatternType.SCATTERED
details = "Agent shows inconsistent reasoning flow with highly variable responses."
@@ -279,8 +316,8 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
"avg_length": avg_length,
"std_length": std_length,
"length_trend": length_trend,
"loop_score": loop_score
}
"loop_score": loop_score,
},
}
def _calculate_trend(self, values: Sequence[float | int]) -> float:
@@ -303,7 +340,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
except Exception:
return 0.0
def _calculate_loop_likelihood(self, call_lengths: Sequence[float], response_times: Sequence[float]) -> float:
def _calculate_loop_likelihood(
self, call_lengths: Sequence[float], response_times: Sequence[float]
) -> float:
if not call_lengths or len(call_lengths) < 3:
return 0.0
@@ -312,7 +351,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
if len(call_lengths) >= 4:
repeated_lengths = 0
for i in range(len(call_lengths) - 2):
ratio = call_lengths[i] / call_lengths[i + 2] if call_lengths[i + 2] > 0 else 0
ratio = (
call_lengths[i] / call_lengths[i + 2]
if call_lengths[i + 2] > 0
else 0
)
if 0.85 <= ratio <= 1.15:
repeated_lengths += 1
@@ -331,14 +374,20 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
return np.mean(indicators) if indicators else 0.0
def _get_call_samples(self, llm_calls: List[Dict]) -> str:
def _get_call_samples(self, llm_calls: list[dict]) -> str:
samples = []
if len(llm_calls) <= 6:
sample_indices = list(range(len(llm_calls)))
else:
sample_indices = [0, 1, len(llm_calls) // 2 - 1, len(llm_calls) // 2,
len(llm_calls) - 2, len(llm_calls) - 1]
sample_indices = [
0,
1,
len(llm_calls) // 2 - 1,
len(llm_calls) // 2,
len(llm_calls) - 2,
len(llm_calls) - 1,
]
for idx in sample_indices:
call = llm_calls[idx]

View File

@@ -1,10 +1,14 @@
from typing import Any, Dict
from typing import Any
from crewai.agent import Agent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.task import Task
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
class SemanticQualityEvaluator(BaseEvaluator):
@property
@@ -14,7 +18,7 @@ class SemanticQualityEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
execution_trace: dict[str, Any],
final_output: Any,
task: Task | None = None,
) -> EvaluationScore:
@@ -22,7 +26,9 @@ class SemanticQualityEvaluator(BaseEvaluator):
if task is not None:
task_context = f"Task description: {task.description}"
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
{
"role": "system",
"content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
Score the semantic quality on a scale from 0-10 where:
- 0: Completely incoherent, confusing, or logically flawed output
@@ -37,8 +43,11 @@ Consider:
5. Is the output free from contradictions and logical fallacies?
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -46,7 +55,8 @@ Agent's final output:
{final_output}
Evaluate the semantic quality and reasoning of this output.
"""}
""",
},
]
assert self.llm is not None
@@ -56,13 +66,15 @@ Evaluate the semantic quality and reasoning of this output.
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
assert evaluation_data is not None
return EvaluationScore(
score=float(evaluation_data["score"]) if evaluation_data.get("score") is not None else None,
score=float(evaluation_data["score"])
if evaluation_data.get("score") is not None
else None,
feedback=evaluation_data.get("feedback", response),
raw_response=response
raw_response=response,
)
except Exception:
return EvaluationScore(
score=None,
feedback=f"Failed to parse evaluation. Raw response: {response}",
raw_response=response
)
raw_response=response,
)

View File

@@ -1,14 +1,17 @@
import json
from typing import Dict, Any
from typing import Any
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.agent import Agent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.task import Task
class ToolSelectionEvaluator(BaseEvaluator):
@property
def metric_category(self) -> MetricCategory:
return MetricCategory.TOOL_SELECTION
@@ -16,7 +19,7 @@ class ToolSelectionEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
execution_trace: dict[str, Any],
final_output: str,
task: Task | None = None,
) -> EvaluationScore:
@@ -26,19 +29,18 @@ class ToolSelectionEvaluator(BaseEvaluator):
tool_uses = execution_trace.get("tool_uses", [])
tool_count = len(tool_uses)
unique_tool_types = set([tool.get("tool", "Unknown tool") for tool in tool_uses])
unique_tool_types = set(
[tool.get("tool", "Unknown tool") for tool in tool_uses]
)
if tool_count == 0:
if not agent.tools:
return EvaluationScore(
score=None,
feedback="Agent had no tools available to use."
)
else:
return EvaluationScore(
score=None,
feedback="Agent had tools available but didn't use any."
score=None, feedback="Agent had no tools available to use."
)
return EvaluationScore(
score=None, feedback="Agent had tools available but didn't use any."
)
available_tools_info = ""
if agent.tools:
@@ -52,7 +54,9 @@ class ToolSelectionEvaluator(BaseEvaluator):
tool_types_summary += f"- {tool_type}\n"
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
{
"role": "system",
"content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
You must evaluate based on these 2 criteria:
1. Relevance (0-10): Were the tools chosen directly aligned with the task's goals?
@@ -73,8 +77,11 @@ Return your evaluation as JSON with these fields:
- overall_score: number (average of all scores, 0-10)
- feedback: string (focused ONLY on tool selection decisions from available tools)
- improvement_suggestions: string (ONLY suggest better selection from the AVAILABLE tools list, NOT new tools)
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -89,7 +96,8 @@ IMPORTANT:
- ONLY evaluate selection from tools listed as available
- DO NOT suggest new tools that aren't in the available tools list
- DO NOT evaluate tool usage or results
"""}
""",
},
]
assert self.llm is not None
response = self.llm.call(prompt)
@@ -105,22 +113,24 @@ IMPORTANT:
feedback = "Tool Selection Evaluation:\n"
feedback += f"• Relevance: {relevance}/10 - Selection of appropriate tool types for the task\n"
feedback += f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
feedback += (
f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
)
if "improvement_suggestions" in evaluation_data:
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
else:
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
feedback += evaluation_data.get(
"feedback", "No detailed feedback available."
)
return EvaluationScore(
score=overall_score,
feedback=feedback,
raw_response=response
score=overall_score, feedback=feedback, raw_response=response
)
except Exception as e:
return EvaluationScore(
score=None,
feedback=f"Error evaluating tool selection: {e}",
raw_response=response
raw_response=response,
)
@@ -132,7 +142,7 @@ class ParameterExtractionEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
execution_trace: dict[str, Any],
final_output: str,
task: Task | None = None,
) -> EvaluationScore:
@@ -145,19 +155,26 @@ class ParameterExtractionEvaluator(BaseEvaluator):
if tool_count == 0:
return EvaluationScore(
score=None,
feedback="No tool usage detected. Cannot evaluate parameter extraction."
feedback="No tool usage detected. Cannot evaluate parameter extraction.",
)
validation_errors = []
for tool_use in tool_uses:
if not tool_use.get("success", True) and tool_use.get("error_type") == "validation_error":
validation_errors.append({
"tool": tool_use.get("tool", "Unknown tool"),
"error": tool_use.get("result"),
"args": tool_use.get("args", {})
})
if (
not tool_use.get("success", True)
and tool_use.get("error_type") == "validation_error"
):
validation_errors.append(
{
"tool": tool_use.get("tool", "Unknown tool"),
"error": tool_use.get("result"),
"args": tool_use.get("args", {}),
}
)
validation_error_rate = len(validation_errors) / tool_count if tool_count > 0 else 0
validation_error_rate = (
len(validation_errors) / tool_count if tool_count > 0 else 0
)
param_samples = []
for i, tool_use in enumerate(tool_uses[:5]):
@@ -168,7 +185,7 @@ class ParameterExtractionEvaluator(BaseEvaluator):
is_validation_error = error_type == "validation_error"
sample = f"Tool use #{i+1} - {tool_name}:\n"
sample = f"Tool use #{i + 1} - {tool_name}:\n"
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
sample += f"- Success: {'No' if not success else 'Yes'}"
@@ -187,13 +204,17 @@ class ParameterExtractionEvaluator(BaseEvaluator):
tool_name = err.get("tool", "Unknown tool")
error_msg = err.get("error", "Unknown error")
args = err.get("args", {})
validation_errors_info += f"\nValidation Error #{i+1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
validation_errors_info += f"\nValidation Error #{i + 1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
if len(validation_errors) > 3:
validation_errors_info += f"\n...and {len(validation_errors) - 3} more validation errors."
validation_errors_info += (
f"\n...and {len(validation_errors) - 3} more validation errors."
)
param_samples_text = "\n\n".join(param_samples)
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
{
"role": "system",
"content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
Your job is to evaluate ONLY whether the agent used the correct parameter VALUES, not whether the right tools were selected or how the tools were invoked.
@@ -216,8 +237,11 @@ Return your evaluation as JSON with these fields:
- overall_score: number (average of all scores, 0-10)
- feedback: string (focused ONLY on parameter value extraction quality)
- improvement_suggestions: string (concrete suggestions for better parameter VALUE extraction)
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -226,7 +250,8 @@ Parameter extraction examples:
{validation_errors_info}
Evaluate the quality of the agent's parameter extraction for this task.
"""}
""",
},
]
assert self.llm is not None
@@ -251,18 +276,18 @@ Evaluate the quality of the agent's parameter extraction for this task.
if "improvement_suggestions" in evaluation_data:
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
else:
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
feedback += evaluation_data.get(
"feedback", "No detailed feedback available."
)
return EvaluationScore(
score=overall_score,
feedback=feedback,
raw_response=response
score=overall_score, feedback=feedback, raw_response=response
)
except Exception as e:
return EvaluationScore(
score=None,
feedback=f"Error evaluating parameter extraction: {e}",
raw_response=response
raw_response=response,
)
@@ -274,7 +299,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
execution_trace: dict[str, Any],
final_output: str,
task: Task | None = None,
) -> EvaluationScore:
@@ -288,7 +313,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
if tool_count == 0:
return EvaluationScore(
score=None,
feedback="No tool usage detected. Cannot evaluate tool invocation."
feedback="No tool usage detected. Cannot evaluate tool invocation.",
)
for tool_use in tool_uses:
@@ -296,7 +321,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
error_info = {
"tool": tool_use.get("tool", "Unknown tool"),
"error": tool_use.get("result"),
"error_type": tool_use.get("error_type", "unknown_error")
"error_type": tool_use.get("error_type", "unknown_error"),
}
tool_errors.append(error_info)
@@ -315,9 +340,11 @@ class ToolInvocationEvaluator(BaseEvaluator):
tool_args = tool_use.get("args", {})
success = tool_use.get("success", True) and not tool_use.get("error", False)
error_type = tool_use.get("error_type", "") if not success else ""
error_msg = tool_use.get("result", "No error") if not success else "No error"
error_msg = (
tool_use.get("result", "No error") if not success else "No error"
)
sample = f"Tool invocation #{i+1}:\n"
sample = f"Tool invocation #{i + 1}:\n"
sample += f"- Tool: {tool_name}\n"
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
sample += f"- Success: {'No' if not success else 'Yes'}\n"
@@ -330,11 +357,13 @@ class ToolInvocationEvaluator(BaseEvaluator):
if error_types:
error_type_summary = "Error type breakdown:\n"
for error_type, count in error_types.items():
error_type_summary += f"- {error_type}: {count} occurrences ({(count/tool_count):.1%})\n"
error_type_summary += f"- {error_type}: {count} occurrences ({(count / tool_count):.1%})\n"
invocation_samples_text = "\n\n".join(invocation_samples)
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
{
"role": "system",
"content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
Your job is to evaluate ONLY the structural and syntactical aspects of how the agent called tools, NOT which tools were selected or what parameter values were used.
@@ -359,8 +388,11 @@ Return your evaluation as JSON with these fields:
- overall_score: number (average of all scores, 0-10)
- feedback: string (focused ONLY on structural aspects of tool invocation)
- improvement_suggestions: string (concrete suggestions for better structuring of tool calls)
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -371,7 +403,8 @@ Tool error rate: {error_rate:.2%} ({len(tool_errors)} errors out of {tool_count}
{error_type_summary}
Evaluate the quality of the agent's tool invocation structure during this task.
"""}
""",
},
]
assert self.llm is not None
@@ -388,23 +421,25 @@ Evaluate the quality of the agent's tool invocation structure during this task.
overall_score = float(evaluation_data.get("overall_score", 5.0))
feedback = "Tool Invocation Evaluation:\n"
feedback += f"• Structure: {structure}/10 - Following proper syntax and format\n"
feedback += (
f"• Structure: {structure}/10 - Following proper syntax and format\n"
)
feedback += f"• Error Handling: {error_handling}/10 - Appropriately handling tool errors\n"
feedback += f"• Invocation Patterns: {invocation_patterns}/10 - Proper sequencing and management of calls\n\n"
if "improvement_suggestions" in evaluation_data:
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
else:
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
feedback += evaluation_data.get(
"feedback", "No detailed feedback available."
)
return EvaluationScore(
score=overall_score,
feedback=feedback,
raw_response=response
score=overall_score, feedback=feedback, raw_response=response
)
except Exception as e:
return EvaluationScore(
score=None,
feedback=f"Error evaluating tool invocation: {e}",
raw_response=response
raw_response=response,
)

View File

@@ -1,12 +1,21 @@
import inspect
import warnings
from typing_extensions import Any
import warnings
from crewai.experimental.evaluation.experiment import ExperimentResults, ExperimentRunner
from crewai import Crew, Agent
def assert_experiment_successfully(experiment_results: ExperimentResults, baseline_filepath: str | None = None) -> None:
failed_tests = [result for result in experiment_results.results if not result.passed]
from crewai import Agent, Crew
from crewai.experimental.evaluation.experiment import (
ExperimentResults,
ExperimentRunner,
)
def assert_experiment_successfully(
experiment_results: ExperimentResults, baseline_filepath: str | None = None
) -> None:
failed_tests = [
result for result in experiment_results.results if not result.passed
]
if failed_tests:
detailed_failures: list[str] = []
@@ -14,39 +23,54 @@ def assert_experiment_successfully(experiment_results: ExperimentResults, baseli
for result in failed_tests:
expected = result.expected_score
actual = result.score
detailed_failures.append(f"- {result.identifier}: expected {expected}, got {actual}")
detailed_failures.append(
f"- {result.identifier}: expected {expected}, got {actual}"
)
failure_details = "\n".join(detailed_failures)
raise AssertionError(f"The following test cases failed:\n{failure_details}")
baseline_filepath = baseline_filepath or _get_baseline_filepath_fallback()
comparison = experiment_results.compare_with_baseline(baseline_filepath=baseline_filepath)
comparison = experiment_results.compare_with_baseline(
baseline_filepath=baseline_filepath
)
assert_experiment_no_regression(comparison)
def assert_experiment_no_regression(comparison_result: dict[str, list[str]]) -> None:
regressed = comparison_result.get("regressed", [])
if regressed:
raise AssertionError(f"Regression detected! The following tests that previously passed now fail: {regressed}")
raise AssertionError(
f"Regression detected! The following tests that previously passed now fail: {regressed}"
)
missing_tests = comparison_result.get("missing_tests", [])
if missing_tests:
warnings.warn(
f"Warning: {len(missing_tests)} tests from the baseline are missing in the current run: {missing_tests}",
UserWarning
UserWarning,
stacklevel=2,
)
def run_experiment(dataset: list[dict[str, Any]], crew: Crew | None = None, agents: list[Agent] | None = None, verbose: bool = False) -> ExperimentResults:
def run_experiment(
dataset: list[dict[str, Any]],
crew: Crew | None = None,
agents: list[Agent] | None = None,
verbose: bool = False,
) -> ExperimentResults:
runner = ExperimentRunner(dataset=dataset)
return runner.run(agents=agents, crew=crew, print_summary=verbose)
def _get_baseline_filepath_fallback() -> str:
test_func_name = "experiment_fallback"
try:
current_frame = inspect.currentframe()
if current_frame is not None:
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
except Exception:
...
return f"{test_func_name}_results.json"
return f"{test_func_name}_results.json"

View File

@@ -1,5 +1,4 @@
from crewai.flow.flow import Flow, start, listen, or_, and_, router
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.flow.persistence import persist
__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"]
__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]

View File

@@ -1086,7 +1086,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
for method_name in self._start_methods:
# Check if this start method is triggered by the current trigger
if method_name in self._listeners:
condition_type, trigger_methods = self._listeners[
_condition_type, trigger_methods = self._listeners[
method_name
]
if current_trigger in trigger_methods:

View File

@@ -1,5 +1,4 @@
import inspect
from typing import Optional
from pydantic import BaseModel, Field, InstanceOf, model_validator
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
inspecting the call stack.
"""
parent_flow: Optional[InstanceOf[Flow]] = Field(
parent_flow: InstanceOf[Flow] | None = Field(
default=None,
description="The parent flow of the instance, if it was created inside a flow.",
)

View File

@@ -1,14 +1,13 @@
# flow_visualizer.py
import os
from pathlib import Path
from pyvis.network import Network
from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
from crewai.flow.path_utils import safe_path_join, validate_path_exists
from crewai.flow.path_utils import safe_path_join
from crewai.flow.utils import calculate_node_levels
from crewai.flow.visualization_utils import (
add_edges,
@@ -34,13 +33,13 @@ class FlowPlot:
ValueError
If flow object is invalid or missing required attributes.
"""
if not hasattr(flow, '_methods'):
if not hasattr(flow, "_methods"):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not hasattr(flow, '_listeners'):
if not hasattr(flow, "_listeners"):
raise ValueError("Invalid flow object: missing '_listeners' attribute")
if not hasattr(flow, '_start_methods'):
if not hasattr(flow, "_start_methods"):
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
self.flow = flow
self.colors = COLORS
self.node_styles = NODE_STYLES
@@ -65,7 +64,7 @@ class FlowPlot:
"""
if not filename or not isinstance(filename, str):
raise ValueError("Filename must be a non-empty string")
try:
# Initialize network
net = Network(
@@ -96,32 +95,32 @@ class FlowPlot:
try:
node_levels = calculate_node_levels(self.flow)
except Exception as e:
raise ValueError(f"Failed to calculate node levels: {str(e)}")
raise ValueError(f"Failed to calculate node levels: {e!s}")
# Compute positions
try:
node_positions = compute_positions(self.flow, node_levels)
except Exception as e:
raise ValueError(f"Failed to compute node positions: {str(e)}")
raise ValueError(f"Failed to compute node positions: {e!s}")
# Add nodes to the network
try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
except Exception as e:
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
raise RuntimeError(f"Failed to add nodes to network: {e!s}")
# Add edges to the network
try:
add_edges(net, self.flow, node_positions, self.colors)
except Exception as e:
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
raise RuntimeError(f"Failed to add edges to network: {e!s}")
# Generate HTML
try:
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
except Exception as e:
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
raise RuntimeError(f"Failed to generate network visualization: {e!s}")
# Save the final HTML content to the file
try:
@@ -129,12 +128,14 @@ class FlowPlot:
f.write(final_html_content)
print(f"Plot saved as {filename}.html")
except IOError as e:
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
raise IOError(
f"Failed to save flow visualization to {filename}.html: {e!s}"
)
except (ValueError, RuntimeError, IOError) as e:
raise e
except Exception as e:
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
raise RuntimeError(f"Unexpected error during flow visualization: {e!s}")
finally:
self._cleanup_pyvis_lib()
@@ -165,7 +166,9 @@ class FlowPlot:
try:
# Extract just the body content from the generated HTML
current_dir = os.path.dirname(__file__)
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
template_path = safe_path_join(
"assets", "crewai_flow_visual_template.html", root=current_dir
)
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
if not os.path.exists(template_path):
@@ -179,12 +182,9 @@ class FlowPlot:
# Generate the legend items HTML
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html(
network_body, legend_items_html
)
return final_html_content
return html_handler.generate_final_html(network_body, legend_items_html)
except Exception as e:
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
raise IOError(f"Failed to generate visualization HTML: {e!s}")
def _cleanup_pyvis_lib(self):
"""
@@ -197,6 +197,7 @@ class FlowPlot:
lib_folder = safe_path_join("lib", root=os.getcwd())
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
shutil.rmtree(lib_folder)
except ValueError as e:
print(f"Error validating lib folder path: {e}")

View File

@@ -1,8 +1,7 @@
import base64
import re
from pathlib import Path
from crewai.flow.path_utils import safe_path_join, validate_path_exists
from crewai.flow.path_utils import validate_path_exists
class HTMLTemplateHandler:
@@ -53,23 +52,23 @@ class HTMLTemplateHandler:
if "border" in item:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
<div>{item['label']}</div>
<div class="legend-color-box" style="background-color: {item["color"]}; border: 2px dashed {item["border"]};"></div>
<div>{item["label"]}</div>
</div>
"""
elif item.get("dashed") is not None:
style = "dashed" if item["dashed"] else "solid"
legend_items_html += f"""
<div class="legend-item">
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
<div>{item['label']}</div>
<div class="legend-{style}" style="border-bottom: 2px {style} {item["color"]};"></div>
<div>{item["label"]}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']};"></div>
<div>{item['label']}</div>
<div class="legend-color-box" style="background-color: {item["color"]};"></div>
<div>{item["label"]}</div>
</div>
"""
return legend_items_html
@@ -86,8 +85,6 @@ class HTMLTemplateHandler:
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
return final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
return final_html_content

View File

@@ -5,12 +5,10 @@ This module provides utilities for secure path handling to prevent directory
traversal attacks and ensure paths remain within allowed boundaries.
"""
import os
from pathlib import Path
from typing import List, Union
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
"""
Safely join path components and ensure the result is within allowed boundaries.
@@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
# Establish root directory
root_path = Path(root).resolve() if root else Path.cwd()
# Join and resolve the full path
full_path = Path(root_path, *clean_parts).resolve()
# Check if the resolved path is within root
if not str(full_path).startswith(str(root_path)):
raise ValueError(
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
)
return str(full_path)
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid path components: {str(e)}")
raise ValueError(f"Invalid path components: {e!s}")
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
"""
Validate that a path exists and is of the expected type.
@@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
"""
try:
path_obj = Path(path).resolve()
if not path_obj.exists():
raise ValueError(f"Path does not exist: {path}")
if file_type == "file" and not path_obj.is_file():
raise ValueError(f"Path is not a file: {path}")
elif file_type == "directory" and not path_obj.is_dir():
if file_type == "directory" and not path_obj.is_dir():
raise ValueError(f"Path is not a directory: {path}")
return str(path_obj)
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid path: {str(e)}")
raise ValueError(f"Invalid path: {e!s}")
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
"""
Safely list files in a directory matching a pattern.
@@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
dir_path = Path(directory).resolve()
if not dir_path.is_dir():
raise ValueError(f"Not a directory: {directory}")
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Error listing files: {str(e)}")
raise ValueError(f"Error listing files: {e!s}")

View File

@@ -12,7 +12,7 @@ from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.decorators import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]
__all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"]
StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
DictStateType = Dict[str, Any]
StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel)
DictStateType = dict[str, Any]

View File

@@ -1,53 +1,47 @@
"""Base class for flow state persistence."""
import abc
from typing import Any, Dict, Optional, Union
from typing import Any
from pydantic import BaseModel
class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence.
This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
"""
@abc.abstractmethod
def init_db(self) -> None:
"""Initialize the persistence backend.
This method should handle any necessary setup, such as:
- Creating tables
- Establishing connections
- Setting up indexes
"""
pass
@abc.abstractmethod
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel]
self, flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel
) -> None:
"""Persist the flow state after method completion.
Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
pass
@abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
Args:
flow_uuid: Unique identifier for the flow instance
Returns:
The most recent state as a dictionary, or None if no state exists
"""
pass

View File

@@ -24,13 +24,10 @@ Example:
import asyncio
import functools
import logging
from collections.abc import Callable
from typing import (
Any,
Callable,
Optional,
Type,
TypeVar,
Union,
cast,
)
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
"save_state": "Saving flow state to memory for ID: {}",
"save_error": "Failed to persist state for method {}: {}",
"state_missing": "Flow instance has no state",
"id_missing": "Flow state must have an 'id' field for persistence"
"id_missing": "Flow state must have an 'id' field for persistence",
}
@@ -58,7 +55,13 @@ class PersistenceDecorator:
_printer = Printer() # Class-level printer instance
@classmethod
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence, verbose: bool = False) -> None:
def persist_state(
cls,
flow_instance: Any,
method_name: str,
persistence_instance: FlowPersistence,
verbose: bool = False,
) -> None:
"""Persist flow state with proper error handling and logging.
This method handles the persistence of flow state data, including proper
@@ -76,22 +79,24 @@ class PersistenceDecorator:
AttributeError: If flow instance lacks required state attributes
"""
try:
state = getattr(flow_instance, 'state', None)
state = getattr(flow_instance, "state", None)
if state is None:
raise ValueError("Flow instance has no state")
flow_uuid: Optional[str] = None
flow_uuid: str | None = None
if isinstance(state, dict):
flow_uuid = state.get('id')
flow_uuid = state.get("id")
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None)
flow_uuid = getattr(state, "id", None)
if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence")
# Log state saving only if verbose is True
if verbose:
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
cls._printer.print(
LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan"
)
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
try:
@@ -104,7 +109,7 @@ class PersistenceDecorator:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise RuntimeError(f"State persistence failed: {str(e)}") from e
raise RuntimeError(f"State persistence failed: {e!s}") from e
except AttributeError:
error_msg = LOG_MESSAGES["state_missing"]
cls._printer.print(error_msg, color="red")
@@ -117,7 +122,7 @@ class PersistenceDecorator:
raise ValueError(error_msg) from e
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
"""Decorator to persist flow state.
This decorator can be applied at either the class level or method level.
@@ -144,32 +149,33 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
def begin(self):
pass
"""
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()
if isinstance(target, type):
# Class decoration
original_init = getattr(target, "__init__")
original_init = target.__init__
@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
if 'persistence' not in kwargs:
kwargs['persistence'] = actual_persistence
if "persistence" not in kwargs:
kwargs["persistence"] = actual_persistence
original_init(self, *args, **kwargs)
setattr(target, "__init__", new_init)
target.__init__ = new_init
# Store original methods to preserve their decorators
original_methods = {}
for name, method in target.__dict__.items():
if callable(method) and (
hasattr(method, "__is_start_method__") or
hasattr(method, "__trigger_methods__") or
hasattr(method, "__condition_type__") or
hasattr(method, "__is_flow_method__") or
hasattr(method, "__is_router__")
hasattr(method, "__is_start_method__")
or hasattr(method, "__trigger_methods__")
or hasattr(method, "__condition_type__")
or hasattr(method, "__is_flow_method__")
or hasattr(method, "__is_router__")
):
original_methods[name] = method
@@ -177,78 +183,116 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
for name, method in original_methods.items():
if asyncio.iscoroutinefunction(method):
# Create a closure to capture the current name and method
def create_async_wrapper(method_name: str, original_method: Callable):
def create_async_wrapper(
method_name: str, original_method: Callable
):
@functools.wraps(original_method)
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
async def method_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> Any:
result = await original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
PersistenceDecorator.persist_state(
self, method_name, actual_persistence, verbose
)
return result
return method_wrapper
wrapped = create_async_wrapper(name, method)
# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)
wrapped.__is_flow_method__ = True
# Update the class with the wrapped method
setattr(target, name, wrapped)
else:
# Create a closure to capture the current name and method
def create_sync_wrapper(method_name: str, original_method: Callable):
def create_sync_wrapper(
method_name: str, original_method: Callable
):
@functools.wraps(original_method)
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
PersistenceDecorator.persist_state(
self, method_name, actual_persistence, verbose
)
return result
return method_wrapper
wrapped = create_sync_wrapper(name, method)
# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)
wrapped.__is_flow_method__ = True
# Update the class with the wrapped method
setattr(target, name, wrapped)
return target
else:
# Method decoration
method = target
setattr(method, "__is_flow_method__", True)
# Method decoration
method = target
method.__is_flow_method__ = True
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
if asyncio.iscoroutinefunction(method):
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
setattr(method_async_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_async_wrapper)
else:
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
@functools.wraps(method)
async def method_async_wrapper(
flow_instance: Any, *args: Any, **kwargs: Any
) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(
flow_instance, method.__name__, actual_persistence, verbose
)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
setattr(method_sync_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_sync_wrapper)
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
method_async_wrapper.__is_flow_method__ = True
return cast(Callable[..., T], method_async_wrapper)
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(
flow_instance, method.__name__, actual_persistence, verbose
)
return result
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
method_sync_wrapper.__is_flow_method__ = True
return cast(Callable[..., T], method_sync_wrapper)
return decorator

View File

@@ -6,7 +6,7 @@ import json
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any
from pydantic import BaseModel
@@ -23,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence):
db_path: str
def __init__(self, db_path: Optional[str] = None):
def __init__(self, db_path: str | None = None):
"""Initialize SQLite persistence.
Args:
@@ -70,7 +70,7 @@ class SQLiteFlowPersistence(FlowPersistence):
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel],
state_data: dict[str, Any] | BaseModel,
) -> None:
"""Save the current flow state to SQLite.
@@ -107,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence):
),
)
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
Args:

View File

@@ -5,6 +5,7 @@ the Flow system.
"""
from typing import Any, TypedDict
from typing_extensions import NotRequired, Required

View File

@@ -17,10 +17,10 @@ import ast
import inspect
import textwrap
from collections import defaultdict, deque
from typing import Any, Deque, Dict, List, Optional, Set, Union
from typing import Any
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
def get_possible_return_constants(function: Any) -> list[str] | None:
try:
source = inspect.getsource(function)
except OSError:
@@ -94,7 +94,7 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
return list(return_values) if return_values else None
def calculate_node_levels(flow: Any) -> Dict[str, int]:
def calculate_node_levels(flow: Any) -> dict[str, int]:
"""
Calculate the hierarchical level of each node in the flow.
@@ -118,10 +118,10 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
- Handles both OR and AND conditions for listeners
- Processes router paths separately
"""
levels: Dict[str, int] = {}
queue: Deque[str] = deque()
visited: Set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {}
levels: dict[str, int] = {}
queue: deque[str] = deque()
visited: set[str] = set()
pending_and_listeners: dict[str, set[str]] = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
@@ -172,7 +172,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
return levels
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
def count_outgoing_edges(flow: Any) -> dict[str, int]:
"""
Count the number of outgoing edges for each method in the flow.
@@ -197,7 +197,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
return counts
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
"""
Build a dictionary mapping each node to its ancestor nodes.
@@ -211,8 +211,8 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
Dict[str, Set[str]]
Dictionary mapping each node to a set of its ancestor nodes.
"""
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
visited: Set[str] = set()
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
visited: set[str] = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
@@ -220,7 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
def dfs_ancestors(
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any
) -> None:
"""
Perform depth-first search to build ancestor relationships.
@@ -265,7 +265,7 @@ def dfs_ancestors(
def is_ancestor(
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]]
) -> bool:
"""
Check if one node is an ancestor of another.
@@ -287,7 +287,7 @@ def is_ancestor(
return ancestor_candidate in ancestors.get(node, set())
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
"""
Build a dictionary mapping parent nodes to their children.
@@ -307,7 +307,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
- Maps router methods to their paths and listeners
- Children lists are sorted for consistent ordering
"""
parent_children: Dict[str, List[str]] = {}
parent_children: dict[str, list[str]] = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
@@ -332,7 +332,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
def get_child_index(
parent: str, child: str, parent_children: Dict[str, List[str]]
parent: str, child: str, parent_children: dict[str, list[str]]
) -> int:
"""
Get the index of a child node in its parent's sorted children list.
@@ -364,7 +364,7 @@ def process_router_paths(flow, current, current_level, levels, queue):
paths = flow._router_paths.get(current, [])
for path in paths:
for listener_name, (
condition_type,
_condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:

View File

@@ -17,7 +17,7 @@ Example
import ast
import inspect
from typing import Any, Dict, List, Tuple, Union
from typing import Any
from .utils import (
build_ancestor_dict,
@@ -56,6 +56,7 @@ def method_calls_crew(method: Any) -> bool:
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls."""
def __init__(self):
self.found = False
@@ -73,8 +74,8 @@ def method_calls_crew(method: Any) -> bool:
def add_nodes_to_network(
net: Any,
flow: Any,
node_positions: Dict[str, Tuple[float, float]],
node_styles: Dict[str, Dict[str, Any]]
node_positions: dict[str, tuple[float, float]],
node_styles: dict[str, dict[str, Any]],
) -> None:
"""
Add nodes to the network visualization with appropriate styling.
@@ -98,6 +99,7 @@ def add_nodes_to_network(
- Crew methods
- Regular methods
"""
def human_friendly_label(method_name):
return method_name.replace("_", " ").title()
@@ -138,10 +140,10 @@ def add_nodes_to_network(
def compute_positions(
flow: Any,
node_levels: Dict[str, int],
node_levels: dict[str, int],
y_spacing: float = 150,
x_spacing: float = 300
) -> Dict[str, Tuple[float, float]]:
x_spacing: float = 300,
) -> dict[str, tuple[float, float]]:
"""
Compute the (x, y) positions for each node in the flow graph.
@@ -161,8 +163,8 @@ def compute_positions(
Dict[str, Tuple[float, float]]
Dictionary mapping node names to their (x, y) coordinates.
"""
level_nodes: Dict[int, List[str]] = {}
node_positions: Dict[str, Tuple[float, float]] = {}
level_nodes: dict[int, list[str]] = {}
node_positions: dict[str, tuple[float, float]] = {}
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
@@ -180,10 +182,10 @@ def compute_positions(
def add_edges(
net: Any,
flow: Any,
node_positions: Dict[str, Tuple[float, float]],
colors: Dict[str, str]
node_positions: dict[str, tuple[float, float]],
colors: dict[str, str],
) -> None:
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
"""
Add edges to the network visualization with appropriate styling.
@@ -269,7 +271,7 @@ def add_edges(
for router_method_name, paths in flow._router_paths.items():
for path in paths:
for listener_name, (
condition_type,
_condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union
from pydantic import Field, field_validator
@@ -14,19 +13,19 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files."""
_logger: Logger = Logger(verbose=True)
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
file_path: Path | list[Path] | str | list[str] | None = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
file_paths: Path | list[Path] | str | list[str] | None = Field(
default_factory=list, description="The path to the file"
)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: Optional[KnowledgeStorage] = Field(default=None)
safe_file_paths: List[Path] = Field(default_factory=list)
content: dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage | None = Field(default=None)
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info):
def validate_file_path(self, v, info):
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
@@ -46,9 +45,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
self.content = self.load_content()
@abstractmethod
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
pass
def validate_content(self):
"""Validate the paths."""
@@ -74,11 +72,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
else:
raise ValueError("No storage found to save documents.")
def convert_to_path(self, path: Union[Path, str]) -> Path:
def convert_to_path(self, path: Path | str) -> Path:
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
def _process_file_paths(self) -> List[Path]:
def _process_file_paths(self) -> list[Path]:
"""Convert file_path to a list of Path objects."""
if hasattr(self, "file_path") and self.file_path is not None:
@@ -93,7 +91,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
raise ValueError("Your source must be provided with a file_paths: []")
# Convert single path to list
path_list: List[Union[Path, str]] = (
path_list: list[Path | str] = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))
else list(self.file_paths)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict, Field
@@ -12,29 +12,27 @@ class BaseKnowledgeSource(BaseModel, ABC):
chunk_size: int = 4000
chunk_overlap: int = 200
chunks: List[str] = Field(default_factory=list)
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
chunks: list[str] = Field(default_factory=list)
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: Optional[str] = Field(default=None)
storage: KnowledgeStorage | None = Field(default=None)
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: str | None = Field(default=None)
@abstractmethod
def validate_content(self) -> Any:
"""Load and preprocess content from the source."""
pass
@abstractmethod
def add(self) -> None:
"""Process content, chunk it, compute embeddings, and save them."""
pass
def get_embeddings(self) -> List[np.ndarray]:
def get_embeddings(self) -> list[np.ndarray]:
"""Return the list of embeddings for the chunks."""
return self.chunk_embeddings
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,5 +1,5 @@
from collections.abc import Iterator
from pathlib import Path
from typing import Iterator, List, Optional, Union
from urllib.parse import urlparse
try:
@@ -35,11 +35,11 @@ class CrewDoclingSource(BaseKnowledgeSource):
_logger: Logger = Logger(verbose=True)
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
file_paths: List[Union[Path, str]] = Field(default_factory=list)
chunks: List[str] = Field(default_factory=list)
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
content: List["DoclingDocument"] = Field(default_factory=list)
file_path: list[Path | str] | None = Field(default=None)
file_paths: list[Path | str] = Field(default_factory=list)
chunks: list[str] = Field(default_factory=list)
safe_file_paths: list[Path | str] = Field(default_factory=list)
content: list["DoclingDocument"] = Field(default_factory=list)
document_converter: "DocumentConverter" = Field(
default_factory=lambda: DocumentConverter(
allowed_formats=[
@@ -66,7 +66,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.safe_file_paths = self.validate_content()
self.content = self._load_content()
def _load_content(self) -> List["DoclingDocument"]:
def _load_content(self) -> list["DoclingDocument"]:
try:
return self._convert_source_to_docling_documents()
except ConversionError as e:
@@ -88,7 +88,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.chunks.extend(list(new_chunks_iterable))
self._save_documents()
def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]:
def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter]
@@ -97,8 +97,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
for chunk in chunker.chunk(doc):
yield chunk.text
def validate_content(self) -> List[Union[Path, str]]:
processed_paths: List[Union[Path, str]] = []
def validate_content(self) -> list[Path | str]:
processed_paths: list[Path | str] = []
for path in self.file_paths:
if isinstance(path, str):
if path.startswith(("http://", "https://")):
@@ -108,7 +108,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
else:
raise ValueError(f"Invalid URL format: {path}")
except Exception as e:
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
raise ValueError(f"Invalid URL: {path}. Error: {e!s}") from e
else:
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
if local_path.exists():

View File

@@ -1,6 +1,5 @@
import csv
from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -8,7 +7,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class CSVKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries CSV file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess CSV file content."""
content_dict = {}
for file_path in self.safe_file_paths:
@@ -32,7 +31,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,6 +1,4 @@
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Union
from urllib.parse import urlparse
from pydantic import Field, field_validator
@@ -16,19 +14,19 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
_logger: Logger = Logger(verbose=True)
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
file_path: Path | list[Path] | str | list[str] | None = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
file_paths: Path | list[Path] | str | list[str] | None = Field(
default_factory=list, description="The path to the file"
)
chunks: List[str] = Field(default_factory=list)
content: Dict[Path, Dict[str, str]] = Field(default_factory=dict)
safe_file_paths: List[Path] = Field(default_factory=list)
chunks: list[str] = Field(default_factory=list)
content: dict[Path, dict[str, str]] = Field(default_factory=dict)
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info):
def validate_file_path(self, v, info):
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
@@ -41,7 +39,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
raise ValueError("Either file_path or file_paths must be provided")
return v
def _process_file_paths(self) -> List[Path]:
def _process_file_paths(self) -> list[Path]:
"""Convert file_path to a list of Path objects."""
if hasattr(self, "file_path") and self.file_path is not None:
@@ -56,7 +54,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
raise ValueError("Your source must be provided with a file_paths: []")
# Convert single path to list
path_list: List[Union[Path, str]] = (
path_list: list[Path | str] = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))
else list(self.file_paths)
@@ -100,7 +98,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.validate_content()
self.content = self._load_content()
def _load_content(self) -> Dict[Path, Dict[str, str]]:
def _load_content(self) -> dict[Path, dict[str, str]]:
"""Load and preprocess Excel file content from multiple sheets.
Each sheet's content is converted to CSV format and stored.
@@ -126,7 +124,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
content_dict[file_path] = sheet_dict
return content_dict
def convert_to_path(self, path: Union[Path, str]) -> Path:
def convert_to_path(self, path: Path | str) -> Path:
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
@@ -161,7 +159,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Any, Dict, List
from typing import Any
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -8,9 +8,9 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class JSONKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries JSON file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess JSON file content."""
content: Dict[Path, str] = {}
content: dict[Path, str] = {}
for path in self.safe_file_paths:
path = self.convert_to_path(path)
with open(path, "r", encoding="utf-8") as json_file:
@@ -29,7 +29,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
for item in data:
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
else:
text += f"{str(data)}"
text += f"{data!s}"
return text
def add(self) -> None:
@@ -44,7 +44,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class PDFKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries PDF file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess PDF file content."""
pdfplumber = self._import_pdfplumber()
@@ -40,12 +39,12 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
Add PDF file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
for _, text in self.content.items():
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,5 +1,3 @@
from typing import List, Optional
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
@@ -9,7 +7,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
"""A knowledge source that stores and queries plain text content using embeddings."""
content: str = Field(...)
collection_name: Optional[str] = Field(default=None)
collection_name: str | None = Field(default=None)
def model_post_init(self, _):
"""Post-initialization method to validate content."""
@@ -26,7 +24,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Dict, List
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries text file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
def load_content(self) -> dict[Path, str]:
"""Load and preprocess text file content."""
content = {}
for path in self.safe_file_paths:
@@ -21,12 +20,12 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
Add text file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
for _, text in self.content.items():
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
self._save_documents()
def _chunk_text(self, text: str) -> List[str]:
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]

View File

@@ -1,21 +1,14 @@
import asyncio
import inspect
import uuid
from collections.abc import Callable
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
cast,
get_args,
get_origin,
)
try:
from typing import Self
except ImportError:
@@ -24,11 +17,12 @@ except ImportError:
from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
model_validator,
field_validator,
model_validator,
)
from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -39,12 +33,18 @@ from crewai.agents.parser import (
AgentFinish,
OutputParserException,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.flow.flow_trackable import FlowTrackable
from crewai.llm import LLM, BaseLLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities import I18N
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.agent_utils import (
enforce_rpm_limit,
format_message_for_llm,
@@ -62,14 +62,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
)
from crewai.utilities.converter import generate_model_description
from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.events.types.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer
from crewai.utilities.token_counter_callback import TokenCalcHandler
@@ -79,18 +72,18 @@ from crewai.utilities.tool_utils import execute_tool_and_check_finality
class LiteAgentOutput(BaseModel):
"""Class that represents the result of a LiteAgent execution."""
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(arbitrary_types_allowed=True)
raw: str = Field(description="Raw output of the agent", default="")
pydantic: Optional[BaseModel] = Field(
pydantic: BaseModel | None = Field(
description="Pydantic output of the agent", default=None
)
agent_role: str = Field(description="Role of the agent that produced this output")
usage_metrics: Optional[Dict[str, Any]] = Field(
usage_metrics: dict[str, Any] | None = Field(
description="Token usage metrics for this execution", default=None
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert pydantic_output to a dictionary."""
if self.pydantic:
return self.pydantic.model_dump()
@@ -123,17 +116,17 @@ class LiteAgent(FlowTrackable, BaseModel):
response_format: Optional Pydantic model for structured output.
"""
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(arbitrary_types_allowed=True)
# Core Agent Properties
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
role: str = Field(description="Role of the agent")
goal: str = Field(description="Goal of the agent")
backstory: str = Field(description="Backstory of the agent")
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
default=None, description="Language model that will run the agent"
)
tools: List[BaseTool] = Field(
tools: list[BaseTool] = Field(
default_factory=list, description="Tools at agent's disposal"
)
@@ -141,7 +134,7 @@ class LiteAgent(FlowTrackable, BaseModel):
max_iterations: int = Field(
default=15, description="Maximum number of iterations for tool usage"
)
max_execution_time: Optional[int] = Field(
max_execution_time: int | None = Field(
default=None, description=". Maximum execution time in seconds"
)
respect_context_window: bool = Field(
@@ -152,52 +145,50 @@ class LiteAgent(FlowTrackable, BaseModel):
default=True,
description="Whether to use stop words to prevent the LLM from using tools",
)
request_within_rpm_limit: Optional[Callable[[], bool]] = Field(
request_within_rpm_limit: Callable[[], bool] | None = Field(
default=None,
description="Callback to check if the request is within the RPM limit",
)
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
# Output and Formatting Properties
response_format: Optional[Type[BaseModel]] = Field(
response_format: type[BaseModel] | None = Field(
default=None, description="Pydantic model for structured output"
)
verbose: bool = Field(
default=False, description="Whether to print execution details"
)
callbacks: List[Callable] = Field(
callbacks: list[Callable] = Field(
default=[], description="Callbacks to be used for the agent"
)
# Guardrail Properties
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
guardrail: Callable[[LiteAgentOutput], tuple[bool, Any]] | str | None = Field(
default=None,
description="Function or string description of a guardrail to validate agent output",
)
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
# State and Results
tools_results: List[Dict[str, Any]] = Field(
tools_results: list[dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
)
# Reference of Agent
original_agent: Optional[BaseAgent] = Field(
original_agent: BaseAgent | None = Field(
default=None, description="Reference to the agent that created this LiteAgent"
)
# Private Attributes
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
_iterations: int = PrivateAttr(default=0)
_printer: Printer = PrivateAttr(default_factory=Printer)
_guardrail: Optional[Callable] = PrivateAttr(default=None)
_guardrail: Callable | None = PrivateAttr(default=None)
_guardrail_retry_count: int = PrivateAttr(default=0)
@model_validator(mode="after")
@@ -241,8 +232,8 @@ class LiteAgent(FlowTrackable, BaseModel):
@field_validator("guardrail", mode="before")
@classmethod
def validate_guardrail_function(
cls, v: Optional[Union[Callable, str]]
) -> Optional[Union[Callable, str]]:
cls, v: Callable | str | None
) -> Callable | str | None:
"""Validate that the guardrail function has the correct signature.
If v is a callable, validate that it has the correct signature.
@@ -267,7 +258,7 @@ class LiteAgent(FlowTrackable, BaseModel):
# Check return annotation if present
if sig.return_annotation is not sig.empty:
if sig.return_annotation == Tuple[bool, Any]:
if sig.return_annotation == tuple[bool, Any]:
return v
origin = get_origin(sig.return_annotation)
@@ -290,7 +281,7 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Return the original role for compatibility with tool interfaces."""
return self.role
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
"""
Execute the agent with the given messages.
@@ -338,7 +329,7 @@ class LiteAgent(FlowTrackable, BaseModel):
)
raise e
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
# Emit event for agent execution start
crewai_event_bus.emit(
self,
@@ -351,7 +342,7 @@ class LiteAgent(FlowTrackable, BaseModel):
# Execute the agent using invoke loop
agent_finish = self._invoke_loop()
formatted_result: Optional[BaseModel] = None
formatted_result: BaseModel | None = None
if self.response_format:
try:
# Cast to BaseModel to ensure type safety
@@ -360,7 +351,7 @@ class LiteAgent(FlowTrackable, BaseModel):
formatted_result = result
except Exception as e:
self._printer.print(
content=f"Failed to parse output into response format: {str(e)}",
content=f"Failed to parse output into response format: {e!s}",
color="yellow",
)
@@ -428,7 +419,7 @@ class LiteAgent(FlowTrackable, BaseModel):
return output
async def kickoff_async(
self, messages: Union[str, List[Dict[str, str]]]
self, messages: str | list[dict[str, str]]
) -> LiteAgentOutput:
"""
Execute the agent asynchronously with the given messages.
@@ -475,8 +466,8 @@ class LiteAgent(FlowTrackable, BaseModel):
return base_prompt
def _format_messages(
self, messages: Union[str, List[Dict[str, str]]]
) -> List[Dict[str, str]]:
self, messages: str | list[dict[str, str]]
) -> list[dict[str, str]]:
"""Format messages for the LLM."""
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
@@ -571,9 +562,8 @@ class LiteAgent(FlowTrackable, BaseModel):
i18n=self.i18n,
)
continue
else:
handle_unknown_error(self._printer, e)
raise e
handle_unknown_error(self._printer, e)
raise e
finally:
self._iterations += 1
@@ -582,7 +572,7 @@ class LiteAgent(FlowTrackable, BaseModel):
self._show_logs(formatted_answer)
return formatted_answer
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
def _show_logs(self, formatted_answer: AgentAction | AgentFinish):
"""Show logs for the agent's execution."""
crewai_event_bus.emit(
self,

View File

@@ -6,19 +6,14 @@ import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from typing import (
Any,
DefaultDict,
Dict,
List,
Literal,
Optional,
Type,
TypedDict,
Union,
cast,
)
from datetime import datetime
from dotenv import load_dotenv
from litellm.types.utils import ChatCompletionDeltaToolCall
from pydantic import BaseModel, Field
@@ -31,9 +26,9 @@ from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
)
from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
with warnings.catch_warnings():
@@ -51,8 +46,8 @@ with warnings.catch_warnings():
import io
from typing import TextIO
from crewai.llms.base_llm import BaseLLM
from crewai.events.event_bus import crewai_event_bus
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
@@ -268,14 +263,14 @@ def suppress_warnings():
class Delta(TypedDict):
content: Optional[str]
role: Optional[str]
content: str | None
role: str | None
class StreamingChoices(TypedDict):
delta: Delta
index: int
finish_reason: Optional[str]
finish_reason: str | None
class FunctionArgs(BaseModel):
@@ -288,31 +283,31 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: Optional[float] = None
completion_cost: float | None = None
def __init__(
self,
model: str,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] | None = None,
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
timeout: float | int | None = None,
temperature: float | None = None,
top_p: float | None = None,
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
logprobs: int | None = None,
top_logprobs: int | None = None,
base_url: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
**kwargs,
):
@@ -345,7 +340,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str]
if stop is None:
self.stop: List[str] = []
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
@@ -368,9 +363,9 @@ class LLM(BaseLLM):
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for the completion call.
Args:
@@ -419,11 +414,11 @@ class LLM(BaseLLM):
def _handle_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle a streaming response from the LLM.
@@ -447,7 +442,7 @@ class LLM(BaseLLM):
usage_info = None
tool_calls = None
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs
)
@@ -472,16 +467,16 @@ class LLM(BaseLLM):
choices = chunk["choices"]
elif hasattr(chunk, "choices"):
# Check if choices is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "choices"), type):
choices = getattr(chunk, "choices")
if not isinstance(chunk.choices, type):
choices = chunk.choices
# Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"]
elif hasattr(chunk, "usage"):
# Check if usage is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "usage"), type):
usage_info = getattr(chunk, "usage")
if not isinstance(chunk.usage, type):
usage_info = chunk.usage
if choices and len(choices) > 0:
choice = choices[0]
@@ -491,7 +486,7 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "delta" in choice:
delta = choice["delta"]
elif hasattr(choice, "delta"):
delta = getattr(choice, "delta")
delta = choice.delta
# Extract content from delta
if delta:
@@ -501,7 +496,7 @@ class LLM(BaseLLM):
chunk_content = delta["content"]
# Handle object format
elif hasattr(delta, "content"):
chunk_content = getattr(delta, "content")
chunk_content = delta.content
# Handle case where content might be None or empty
if chunk_content is None and isinstance(delta, dict):
@@ -572,8 +567,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -583,14 +578,14 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
content = None
if isinstance(message, dict) and "content" in message:
content = message["content"]
elif hasattr(message, "content"):
content = getattr(message, "content")
content = message.content
if content:
full_response = content
@@ -617,8 +612,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -627,13 +622,13 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"):
tool_calls = getattr(message, "tool_calls")
tool_calls = message.tool_calls
except Exception as e:
logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly
@@ -675,9 +670,9 @@ class LLM(BaseLLM):
# decide whether to summarize the content or abort based on the respect_context_window flag.
raise LLMContextLengthExceededException(str(e))
except Exception as e:
logging.error(f"Error in streaming response: {str(e)}")
logging.error(f"Error in streaming response: {e!s}")
if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}")
logging.warning(f"Returning partial response despite error: {e!s}")
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
@@ -695,15 +690,15 @@ class LLM(BaseLLM):
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise Exception(f"Failed to get streaming response: {str(e)}")
raise Exception(f"Failed to get streaming response: {e!s}")
def _handle_streaming_tool_calls(
self,
tool_calls: List[ChatCompletionDeltaToolCall],
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
tool_calls: list[ChatCompletionDeltaToolCall],
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> None | str:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -744,9 +739,9 @@ class LLM(BaseLLM):
def _handle_streaming_callbacks(
self,
callbacks: Optional[List[Any]],
usage_info: Optional[Dict[str, Any]],
last_chunk: Optional[Any],
callbacks: list[Any] | None,
usage_info: dict[str, Any] | None,
last_chunk: Any | None,
) -> None:
"""Handle callbacks with usage info for streaming responses.
@@ -769,10 +764,8 @@ class LLM(BaseLLM):
):
usage_info = last_chunk["usage"]
elif hasattr(last_chunk, "usage"):
if not isinstance(
getattr(last_chunk, "usage"), type
):
usage_info = getattr(last_chunk, "usage")
if not isinstance(last_chunk.usage, type):
usage_info = last_chunk.usage
except Exception as e:
logging.debug(f"Error extracting usage info: {e}")
@@ -786,11 +779,11 @@ class LLM(BaseLLM):
def _handle_non_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Handle a non-streaming response from the LLM.
@@ -847,7 +840,7 @@ class LLM(BaseLLM):
)
return text_response
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
elif tool_calls and not available_functions and not text_response:
if tool_calls and not available_functions and not text_response:
return tool_calls
# --- 7) Handle tool calls if present
@@ -868,11 +861,11 @@ class LLM(BaseLLM):
def _handle_tool_call(
self,
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Optional[str]:
tool_calls: list[Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | None:
"""Handle a tool call from the LLM.
Args:
@@ -942,14 +935,14 @@ class LLM(BaseLLM):
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
)
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=function_name,
tool_args=function_args,
error=f"Tool execution error: {str(e)}",
error=f"Tool execution error: {e!s}",
from_task=from_task,
from_agent=from_agent,
),
@@ -958,13 +951,13 @@ class LLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""High-level LLM call method.
Args:
@@ -1028,10 +1021,9 @@ class LLM(BaseLLM):
return self._handle_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
else:
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
except LLMContextLengthExceededException:
# Re-raise LLMContextLengthExceededException as it should be handled
@@ -1078,8 +1070,8 @@ class LLM(BaseLLM):
self,
response: Any,
call_type: LLMCallType,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
from_task: Any | None = None,
from_agent: Any | None = None,
messages: str | list[dict[str, Any]] | None = None,
):
"""Handle the events for the LLM call.
@@ -1105,8 +1097,8 @@ class LLM(BaseLLM):
)
def _format_messages_for_provider(
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
self, messages: list[dict[str, str]]
) -> list[dict[str, str]]:
"""Format messages according to provider requirements.
Args:
@@ -1147,7 +1139,7 @@ class LLM(BaseLLM):
if "mistral" in self.model.lower():
# Check if the last message has a role of 'assistant'
if messages and messages[-1]["role"] == "assistant":
return messages + [{"role": "user", "content": "Please continue."}]
return [*messages, {"role": "user", "content": "Please continue."}]
return messages
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
@@ -1157,7 +1149,7 @@ class LLM(BaseLLM):
and messages
and messages[-1]["role"] == "assistant"
):
return messages + [{"role": "user", "content": ""}]
return [*messages, {"role": "user", "content": ""}]
# Handle Anthropic models
if not self.is_anthropic:
@@ -1170,7 +1162,7 @@ class LLM(BaseLLM):
return messages
def _get_custom_llm_provider(self) -> Optional[str]:
def _get_custom_llm_provider(self) -> str | None:
"""
Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
@@ -1207,7 +1199,7 @@ class LLM(BaseLLM):
self.model, custom_llm_provider=provider
)
except Exception as e:
logging.error(f"Failed to check function calling support: {str(e)}")
logging.error(f"Failed to check function calling support: {e!s}")
return False
def supports_stop_words(self) -> bool:
@@ -1215,7 +1207,7 @@ class LLM(BaseLLM):
params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
logging.error(f"Failed to get supported params: {e!s}")
return False
def get_context_window_size(self) -> int:
@@ -1247,7 +1239,7 @@ class LLM(BaseLLM):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
def set_callbacks(self, callbacks: list[Any]):
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.

View File

@@ -1,11 +1,11 @@
from .entity.entity_memory import EntityMemory
from .external.external_memory import ExternalMemory
from .long_term.long_term_memory import LongTermMemory
from .short_term.short_term_memory import ShortTermMemory
from .external.external_memory import ExternalMemory
__all__ = [
"EntityMemory",
"ExternalMemory",
"LongTermMemory",
"ShortTermMemory",
"ExternalMemory",
]

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, Optional
from typing import Any
class ExternalMemoryItem:
def __init__(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
metadata: dict[str, Any] | None = None,
agent: str | None = None,
):
self.value = value
self.metadata = metadata

View File

@@ -1,17 +1,17 @@
from typing import Any, Dict, List
import time
from typing import Any
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryStartedEvent,
MemoryQueryCompletedEvent,
MemoryQueryFailedEvent,
MemorySaveStartedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveFailedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
self,
task: str,
latest_n: int = 3,
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union
from typing import Any
class LongTermMemoryItem:
@@ -8,8 +8,8 @@ class LongTermMemoryItem:
task: str,
expected_output: str,
datetime: str,
quality: Optional[Union[int, float]] = None,
metadata: Optional[Dict[str, Any]] = None,
quality: int | float | None = None,
metadata: dict[str, Any] | None = None,
):
self.task = task
self.agent = agent

View File

@@ -1,12 +1,12 @@
from typing import Any, Dict, Optional
from typing import Any
class ShortTermMemoryItem:
def __init__(
self,
data: Any,
agent: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
agent: str | None = None,
metadata: dict[str, Any] | None = None,
):
self.data = data
self.agent = agent

View File

@@ -1,15 +1,15 @@
from typing import Any, Dict, List
from typing import Any
class Storage:
"""Abstract base class defining the storage interface"""
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: dict[str, Any]) -> None:
pass
def search(
self, query: str, limit: int, score_threshold: float
) -> Dict[str, Any] | List[Any]:
) -> dict[str, Any] | list[Any]:
return {}
def reset(self) -> None:

View File

@@ -2,7 +2,7 @@ import json
import logging
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
from crewai.task import Task
from crewai.utilities import Printer
@@ -18,7 +18,7 @@ class KickoffTaskOutputsSQLiteStorage:
An updated SQLite storage class for kickoff task outputs storage.
"""
def __init__(self, db_path: Optional[str] = None) -> None:
def __init__(self, db_path: str | None = None) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
@@ -62,10 +62,10 @@ class KickoffTaskOutputsSQLiteStorage:
def add(
self,
task: Task,
output: Dict[str, Any],
output: dict[str, Any],
task_index: int,
was_replayed: bool = False,
inputs: Dict[str, Any] | None = None,
inputs: dict[str, Any] | None = None,
) -> None:
"""Add a new task output record to the database.
@@ -153,7 +153,7 @@ class KickoffTaskOutputsSQLiteStorage:
logger.error(error_msg)
raise DatabaseOperationError(error_msg, e)
def load(self) -> List[Dict[str, Any]]:
def load(self) -> list[dict[str, Any]]:
"""Load all task output records from the database.
Returns:

View File

@@ -1,7 +1,7 @@
import json
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any
from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
@@ -12,9 +12,7 @@ class LTMSQLiteStorage:
An updated SQLite storage class for LTM data storage.
"""
def __init__(
self, db_path: Optional[str] = None
) -> None:
def __init__(self, db_path: str | None = None) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
@@ -53,9 +51,9 @@ class LTMSQLiteStorage:
def save(
self,
task_description: str,
metadata: Dict[str, Any],
metadata: dict[str, Any],
datetime: str,
score: Union[int, float],
score: int | float,
) -> None:
"""Saves data to the LTM table with error handling."""
try:
@@ -75,9 +73,7 @@ class LTMSQLiteStorage:
color="red",
)
def load(
self, task_description: str, latest_n: int
) -> Optional[List[Dict[str, Any]]]:
def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -125,4 +121,4 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return None
return

View File

@@ -14,16 +14,16 @@ from .annotations import (
from .crew_base import CrewBase
__all__ = [
"CrewBase",
"after_kickoff",
"agent",
"before_kickoff",
"cache_handler",
"callback",
"crew",
"task",
"llm",
"output_json",
"output_pydantic",
"task",
"tool",
"callback",
"CrewBase",
"llm",
"cache_handler",
"before_kickoff",
"after_kickoff",
]

View File

@@ -1,5 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import Callable
from crewai import Crew
from crewai.project.utils import memoize
@@ -36,15 +36,13 @@ def task(func):
def agent(func):
"""Marks a method as a crew agent."""
func.is_agent = True
func = memoize(func)
return func
return memoize(func)
def llm(func):
"""Marks a method as an LLM provider."""
func.is_llm = True
func = memoize(func)
return func
return memoize(func)
def output_json(cls):
@@ -91,7 +89,7 @@ def crew(func) -> Callable[..., Crew]:
agents = self._original_agents.items()
# Instantiate tasks in order
for task_name, task_method in tasks:
for _task_name, task_method in tasks:
task_instance = task_method(self)
instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None)
@@ -100,7 +98,7 @@ def crew(func) -> Callable[..., Crew]:
agent_roles.add(agent_instance.role)
# Instantiate agents not included by tasks
for agent_name, agent_method in agents:
for _agent_name, agent_method in agents:
agent_instance = agent_method(self)
if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
@@ -117,9 +115,9 @@ def crew(func) -> Callable[..., Crew]:
return wrapper
for _, callback in self._before_kickoff.items():
for callback in self._before_kickoff.values():
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
for _, callback in self._after_kickoff.items():
for callback in self._after_kickoff.values():
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
return crew

View File

@@ -1,17 +1,17 @@
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
import sys
import importlib
import sys
from types import ModuleType
from typing import Any
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import set_rag_config
_module_path = __path__
_module_file = __file__
class _RagModule(ModuleType):
"""Module wrapper to intercept attribute setting for config."""
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
"""
try:
return importlib.import_module(f"{self.__name__}.{name}")
except ImportError:
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
except ImportError as e:
raise AttributeError(
f"module '{self.__name__}' has no attribute '{name}'"
) from e
sys.modules[__name__] = _RagModule(__name__)

View File

@@ -1 +1 @@
"""Optional imports for RAG configuration providers."""
"""Optional imports for RAG configuration providers."""

View File

@@ -1,7 +1,7 @@
"""Base classes for missing provider configurations."""
from typing import Literal
from dataclasses import field
from typing import Literal
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass

Some files were not shown because too many files have changed in this diff Show More