mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
fix: add ConfigDict for Pydantic model_config and ClassVar annotations
This commit is contained in:
@@ -1,17 +1,10 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
@@ -19,6 +12,24 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
@@ -38,24 +49,6 @@ from crewai.utilities.agent_utils import (
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -87,36 +80,36 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
max_execution_time: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
)
|
||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
step_callback: Optional[Any] = Field(
|
||||
step_callback: Any | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
use_system_prompt: Optional[bool] = Field(
|
||||
use_system_prompt: bool | None = Field(
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
system_template: str | None = Field(
|
||||
default=None, description="System format for the agent."
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
prompt_template: str | None = Field(
|
||||
default=None, description="Prompt format for the agent."
|
||||
)
|
||||
response_template: Optional[str] = Field(
|
||||
response_template: str | None = Field(
|
||||
default=None, description="Response format for the agent."
|
||||
)
|
||||
allow_code_execution: Optional[bool] = Field(
|
||||
allow_code_execution: bool | None = Field(
|
||||
default=False, description="Enable code execution for the agent."
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
@@ -147,31 +140,31 @@ class Agent(BaseAgent):
|
||||
default=False,
|
||||
description="Whether the agent should reflect and create a plan before executing a task.",
|
||||
)
|
||||
max_reasoning_attempts: Optional[int] = Field(
|
||||
max_reasoning_attempts: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||
)
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
embedder: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
agent_knowledge_context: Optional[str] = Field(
|
||||
agent_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the agent.",
|
||||
)
|
||||
crew_knowledge_context: Optional[str] = Field(
|
||||
crew_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the crew.",
|
||||
)
|
||||
knowledge_search_query: Optional[str] = Field(
|
||||
knowledge_search_query: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||
)
|
||||
from_repository: Optional[str] = Field(
|
||||
from_repository: str | None = Field(
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
guardrail: Callable[[Any], tuple[bool, Any]] | str | None = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
@@ -180,7 +173,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v):
|
||||
def validate_from_repository(self, v):
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
@@ -208,7 +201,7 @@ class Agent(BaseAgent):
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
@@ -224,7 +217,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
self.knowledge.add_sources()
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {e!s}") from e
|
||||
|
||||
def _is_any_available_memory(self) -> bool:
|
||||
"""Check if any memory is available."""
|
||||
@@ -244,8 +237,8 @@ class Agent(BaseAgent):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
@@ -278,11 +271,9 @@ class Agent(BaseAgent):
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
if hasattr(self, "_logger"):
|
||||
self._logger.log(
|
||||
"error", f"Error during reasoning process: {str(e)}"
|
||||
)
|
||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
else:
|
||||
print(f"Error during reasoning process: {str(e)}")
|
||||
print(f"Error during reasoning process: {e!s}")
|
||||
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
@@ -525,14 +516,14 @@ class Agent(BaseAgent):
|
||||
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
except concurrent.futures.TimeoutError as e:
|
||||
future.cancel()
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
||||
)
|
||||
) from e
|
||||
except Exception as e:
|
||||
future.cancel()
|
||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
||||
raise RuntimeError(f"Task execution failed: {e!s}") from e
|
||||
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||
"""Execute a task without a timeout.
|
||||
@@ -554,14 +545,14 @@ class Agent(BaseAgent):
|
||||
)["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
self, tools: list[BaseTool] | None = None, task=None
|
||||
) -> None:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the CrewAgentExecutor class.
|
||||
"""
|
||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
prompt = Prompts(
|
||||
@@ -603,10 +594,9 @@ class Agent(BaseAgent):
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]):
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
return agent_tools.tools()
|
||||
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
@@ -654,7 +644,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
@@ -664,15 +654,13 @@ class Agent(BaseAgent):
|
||||
search: This tool is used for search
|
||||
calculator: This tool is used for math
|
||||
"""
|
||||
description = "\n".join(
|
||||
return "\n".join(
|
||||
[
|
||||
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
|
||||
return description
|
||||
|
||||
def _inject_date_to_task(self, task):
|
||||
"""Inject the current date into the task description if inject_date is enabled."""
|
||||
if self.inject_date:
|
||||
@@ -700,9 +688,9 @@ class Agent(BaseAgent):
|
||||
task.description += f"\n\nCurrent Date: {current_date}"
|
||||
except Exception as e:
|
||||
if hasattr(self, "_logger"):
|
||||
self._logger.log("warning", f"Failed to inject date: {str(e)}")
|
||||
self._logger.log("warning", f"Failed to inject date: {e!s}")
|
||||
else:
|
||||
print(f"Warning: Failed to inject date: {str(e)}")
|
||||
print(f"Warning: Failed to inject date: {e!s}")
|
||||
|
||||
def _validate_docker_installation(self) -> None:
|
||||
"""Check if Docker is installed and running."""
|
||||
@@ -718,10 +706,10 @@ class Agent(BaseAgent):
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
)
|
||||
) from e
|
||||
|
||||
def __repr__(self):
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
@@ -796,8 +784,8 @@ class Agent(BaseAgent):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
@@ -836,8 +824,8 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.parser import parse, AgentAction, AgentFinish, OutputParserException
|
||||
from crewai.agents.parser import AgentAction, AgentFinish, OutputParserException, parse
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
|
||||
__all__ = ["CacheHandler", "parse", "AgentAction", "AgentFinish", "OutputParserException", "ToolsHandler"]
|
||||
__all__ = [
|
||||
"AgentAction",
|
||||
"AgentFinish",
|
||||
"CacheHandler",
|
||||
"OutputParserException",
|
||||
"ToolsHandler",
|
||||
"parse",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
from pydantic import ConfigDict, PrivateAttr
|
||||
|
||||
from crewai.agent import BaseAgent
|
||||
from crewai.tools import BaseTool
|
||||
@@ -16,22 +16,21 @@ class BaseAgentAdapter(BaseAgent, ABC):
|
||||
"""
|
||||
|
||||
adapted_structured_output: bool = False
|
||||
_agent_config: Optional[Dict[str, Any]] = PrivateAttr(default=None)
|
||||
_agent_config: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, agent_config: Optional[Dict[str, Any]] = None, **kwargs: Any):
|
||||
def __init__(self, agent_config: dict[str, Any] | None = None, **kwargs: Any):
|
||||
super().__init__(adapted_agent=True, **kwargs)
|
||||
self._agent_config = agent_config
|
||||
|
||||
@abstractmethod
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool] | None = None) -> None:
|
||||
"""Configure and adapt tools for the specific agent implementation.
|
||||
|
||||
Args:
|
||||
tools: Optional list of BaseTool instances to be configured
|
||||
"""
|
||||
pass
|
||||
|
||||
def configure_structured_output(self, structured_output: Any) -> None:
|
||||
"""Configure the structured output for the specific agent implementation.
|
||||
@@ -39,4 +38,3 @@ class BaseAgentAdapter(BaseAgent, ABC):
|
||||
Args:
|
||||
structured_output: The structured output to be configured
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
@@ -12,23 +12,22 @@ class BaseToolAdapter(ABC):
|
||||
different frameworks and platforms.
|
||||
"""
|
||||
|
||||
original_tools: List[BaseTool]
|
||||
converted_tools: List[Any]
|
||||
original_tools: list[BaseTool]
|
||||
converted_tools: list[Any]
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
def __init__(self, tools: list[BaseTool] | None = None):
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
|
||||
@abstractmethod
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
def configure_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Configure and convert tools for the specific implementation.
|
||||
|
||||
Args:
|
||||
tools: List of BaseTool instances to be configured and converted
|
||||
"""
|
||||
pass
|
||||
|
||||
def tools(self) -> List[Any]:
|
||||
def tools(self) -> list[Any]:
|
||||
"""Return all converted tools."""
|
||||
return self.converted_tools
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -25,7 +26,6 @@ from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
T = TypeVar("T", bound="BaseAgent")
|
||||
@@ -81,17 +81,17 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||
_rpm_controller: RPMController | None = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||
_original_role: str | None = PrivateAttr(default=None)
|
||||
_original_goal: str | None = PrivateAttr(default=None)
|
||||
_original_backstory: str | None = PrivateAttr(default=None)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Objective of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
config: dict[str, Any] | None = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True
|
||||
)
|
||||
cache: bool = Field(
|
||||
@@ -100,7 +100,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
verbose: bool = Field(
|
||||
default=False, description="Verbose mode for the Agent Execution"
|
||||
)
|
||||
max_rpm: Optional[int] = Field(
|
||||
max_rpm: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of requests per minute for the agent execution to be respected.",
|
||||
)
|
||||
@@ -108,7 +108,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
default=False,
|
||||
description="Enable agent to delegate and ask questions among each other.",
|
||||
)
|
||||
tools: Optional[List[BaseTool]] = Field(
|
||||
tools: list[BaseTool] | None = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
)
|
||||
max_iter: int = Field(
|
||||
@@ -122,27 +122,27 @@ class BaseAgent(ABC, BaseModel):
|
||||
)
|
||||
crew: Any = Field(default=None, description="Crew to which the agent belongs.")
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
|
||||
cache_handler: InstanceOf[CacheHandler] | None = Field(
|
||||
default=None, description="An instance of the CacheHandler class."
|
||||
)
|
||||
tools_handler: InstanceOf[ToolsHandler] = Field(
|
||||
default_factory=ToolsHandler,
|
||||
description="An instance of the ToolsHandler class.",
|
||||
)
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution."
|
||||
)
|
||||
knowledge: Optional[Knowledge] = Field(
|
||||
knowledge: Knowledge | None = Field(
|
||||
default=None, description="Knowledge for the agent."
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: Optional[Any] = Field(
|
||||
knowledge_storage: Any | None = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
@@ -150,13 +150,13 @@ class BaseAgent(ABC, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the agent, including fingerprinting.",
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
callbacks: list[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
adapted_agent: bool = Field(
|
||||
default=False, description="Whether the agent is adapted"
|
||||
)
|
||||
knowledge_config: Optional[KnowledgeConfig] = Field(
|
||||
knowledge_config: KnowledgeConfig | None = Field(
|
||||
default=None,
|
||||
description="Knowledge configuration for the agent such as limits and threshold",
|
||||
)
|
||||
@@ -168,7 +168,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@field_validator("tools")
|
||||
@classmethod
|
||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
||||
def validate_tools(cls, tools: list[Any]) -> list[BaseTool]:
|
||||
"""Validate and process the tools provided to the agent.
|
||||
|
||||
This method ensures that each tool is either an instance of BaseTool
|
||||
@@ -221,7 +221,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
if v:
|
||||
raise PydanticCustomError(
|
||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||
@@ -252,8 +252,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@@ -262,9 +262,8 @@ class BaseAgent(ABC, BaseModel):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
||||
def get_delegation_tools(self, agents: list["BaseAgent"]) -> list[BaseTool]:
|
||||
"""Set the task tools that init BaseAgenTools class."""
|
||||
pass
|
||||
|
||||
def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with supertype "BaseModel"
|
||||
"""Create a deep copy of the Agent."""
|
||||
@@ -309,7 +308,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
copied_agent = type(self)(
|
||||
return type(self)(
|
||||
**copied_data,
|
||||
llm=existing_llm,
|
||||
tools=self.tools,
|
||||
@@ -318,9 +317,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
knowledge_storage=copied_knowledge_storage,
|
||||
)
|
||||
|
||||
return copied_agent
|
||||
|
||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
def interpolate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Interpolate inputs into the agent description and backstory."""
|
||||
if self._original_role is None:
|
||||
self._original_role = self.role
|
||||
@@ -362,5 +359,5 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
|
||||
pass
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.events.event_listener import event_listener
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -21,7 +21,7 @@ class CrewAgentExecutorMixin:
|
||||
task: "Task"
|
||||
iterations: int
|
||||
max_iter: int
|
||||
messages: List[Dict[str, str]]
|
||||
messages: list[dict[str, str]]
|
||||
_i18n: I18N
|
||||
_printer: Printer = Printer()
|
||||
|
||||
@@ -46,7 +46,6 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to short term memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_external_memory(self, output) -> None:
|
||||
"""Create and save a external-term memory item if conditions are met."""
|
||||
@@ -67,7 +66,6 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to add to external memory: {e}")
|
||||
pass
|
||||
|
||||
def _create_long_term_memory(self, output) -> None:
|
||||
"""Create and save long-term and entity memory items based on evaluation."""
|
||||
@@ -113,10 +111,8 @@ class CrewAgentExecutorMixin:
|
||||
self.crew._entity_memory.save(entity_memories)
|
||||
except AttributeError as e:
|
||||
print(f"Missing attributes for long term memory: {e}")
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Failed to add to long term memory: {e}")
|
||||
pass
|
||||
elif (
|
||||
self.crew
|
||||
and self.crew._long_term_memory
|
||||
|
||||
@@ -251,9 +251,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
@@ -324,9 +323,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.agent,
|
||||
AgentLogsStartedEvent(
|
||||
agent_role=self.agent.role,
|
||||
task_description=(
|
||||
getattr(self.task, "description") if self.task else "Not Found"
|
||||
),
|
||||
task_description=(self.task.description if self.task else "Not Found"),
|
||||
verbose=self.agent.verbose
|
||||
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
|
||||
),
|
||||
@@ -415,8 +412,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""
|
||||
prompt = prompt.replace("{input}", inputs["input"])
|
||||
prompt = prompt.replace("{tool_names}", inputs["tool_names"])
|
||||
prompt = prompt.replace("{tools}", inputs["tools"])
|
||||
return prompt
|
||||
return prompt.replace("{tools}", inputs["tools"])
|
||||
|
||||
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
||||
"""Process human feedback.
|
||||
|
||||
@@ -10,9 +10,9 @@ from dataclasses import dataclass
|
||||
from json_repair import repair_json
|
||||
|
||||
from crewai.agents.constants import (
|
||||
ACTION_INPUT_ONLY_REGEX,
|
||||
ACTION_INPUT_REGEX,
|
||||
ACTION_REGEX,
|
||||
ACTION_INPUT_ONLY_REGEX,
|
||||
FINAL_ANSWER_ACTION,
|
||||
MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE,
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
@@ -104,7 +104,7 @@ def parse(text: str) -> AgentAction | AgentFinish:
|
||||
final_answer = final_answer[:-3].rstrip()
|
||||
return AgentFinish(thought=thought, output=final_answer, text=text)
|
||||
|
||||
elif action_match:
|
||||
if action_match:
|
||||
action = action_match.group(1)
|
||||
clean_action = _clean_action(action)
|
||||
|
||||
@@ -121,16 +121,15 @@ def parse(text: str) -> AgentAction | AgentFinish:
|
||||
raise OutputParserException(
|
||||
f"{MISSING_ACTION_AFTER_THOUGHT_ERROR_MESSAGE}\n{_I18N.slice('final_answer_format')}",
|
||||
)
|
||||
elif not ACTION_INPUT_ONLY_REGEX.search(text):
|
||||
if not ACTION_INPUT_ONLY_REGEX.search(text):
|
||||
raise OutputParserException(
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
)
|
||||
else:
|
||||
err_format = _I18N.slice("format_without_tools")
|
||||
error = f"{err_format}"
|
||||
raise OutputParserException(
|
||||
error,
|
||||
)
|
||||
err_format = _I18N.slice("format_without_tools")
|
||||
error = f"{err_format}"
|
||||
raise OutputParserException(
|
||||
error,
|
||||
)
|
||||
|
||||
|
||||
def _extract_thought(text: str) -> str:
|
||||
@@ -149,8 +148,7 @@ def _extract_thought(text: str) -> str:
|
||||
return ""
|
||||
thought = text[:thought_index].strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
thought = thought.replace("```", "").strip()
|
||||
return thought
|
||||
return thought.replace("```", "").strip()
|
||||
|
||||
|
||||
def _clean_action(text: str) -> str:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Tools handler for managing tool execution and caching."""
|
||||
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.tools.cache_tools.cache_tools import CacheTools
|
||||
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
|
||||
|
||||
class ToolsHandler:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class Auth0Provider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth/device/code"
|
||||
|
||||
@@ -1,30 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
def __init__(self, settings: Oauth2Settings):
|
||||
self.settings = settings
|
||||
|
||||
@abstractmethod
|
||||
def get_authorize_url(self) -> str:
|
||||
...
|
||||
def get_authorize_url(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_token_url(self) -> str:
|
||||
...
|
||||
def get_token_url(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_jwks_url(self) -> str:
|
||||
...
|
||||
def get_jwks_url(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_issuer(self) -> str:
|
||||
...
|
||||
def get_issuer(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_audience(self) -> str:
|
||||
...
|
||||
def get_audience(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_client_id(self) -> str:
|
||||
...
|
||||
def get_client_id(self) -> str: ...
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class OktaProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class WorkosProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"https://{self._get_domain()}/oauth2/device_authorization"
|
||||
@@ -17,9 +18,11 @@ class WorkosProvider(BaseProvider):
|
||||
return self.settings.audience or ""
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
assert self.settings.client_id is not None, "Client ID is required"
|
||||
if self.settings.client_id is None:
|
||||
raise RuntimeError("Client ID is required")
|
||||
return self.settings.client_id
|
||||
|
||||
def _get_domain(self) -> str:
|
||||
assert self.settings.domain is not None, "Domain is required"
|
||||
if self.settings.domain is None:
|
||||
raise RuntimeError("Domain is required")
|
||||
return self.settings.domain
|
||||
|
||||
@@ -17,8 +17,6 @@ def validate_jwt_token(
|
||||
missing required claims).
|
||||
"""
|
||||
|
||||
decoded_token = None
|
||||
|
||||
try:
|
||||
jwk_client = PyJWKClient(jwks_url)
|
||||
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
|
||||
@@ -26,7 +24,7 @@ def validate_jwt_token(
|
||||
_unverified_decoded_token = jwt.decode(
|
||||
jwt_token, options={"verify_signature": False}
|
||||
)
|
||||
decoded_token = jwt.decode(
|
||||
return jwt.decode(
|
||||
jwt_token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256"],
|
||||
@@ -40,7 +38,6 @@ def validate_jwt_token(
|
||||
"require": ["exp", "iat", "iss", "aud", "sub"],
|
||||
},
|
||||
)
|
||||
return decoded_token
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise Exception("Token has expired.")
|
||||
@@ -55,8 +52,8 @@ def validate_jwt_token(
|
||||
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
|
||||
)
|
||||
except jwt.MissingRequiredClaimError as e:
|
||||
raise Exception(f"Token is missing required claims: {str(e)}")
|
||||
raise Exception(f"Token is missing required claims: {e!s}")
|
||||
except jwt.exceptions.PyJWKClientError as e:
|
||||
raise Exception(f"JWKS or key processing error: {str(e)}")
|
||||
raise Exception(f"JWKS or key processing error: {e!s}")
|
||||
except jwt.InvalidTokenError as e:
|
||||
raise Exception(f"Invalid token: {str(e)}")
|
||||
raise Exception(f"Invalid token: {e!s}")
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from importlib.metadata import version as get_version
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
|
||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.create_crew import create_crew
|
||||
from crewai.cli.create_flow import create_flow
|
||||
from crewai.cli.crew_chat import run_chat
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
@@ -237,13 +237,11 @@ def login():
|
||||
@crewai.group()
|
||||
def deploy():
|
||||
"""Deploy the Crew CLI group."""
|
||||
pass
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
"""Tool Repository related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@@ -263,7 +261,7 @@ def deploy_list():
|
||||
|
||||
@deploy.command(name="push")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_push(uuid: Optional[str]):
|
||||
def deploy_push(uuid: str | None):
|
||||
"""Deploy the Crew."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid)
|
||||
@@ -271,7 +269,7 @@ def deploy_push(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="status")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deply_status(uuid: Optional[str]):
|
||||
def deply_status(uuid: str | None):
|
||||
"""Get the status of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_status(uuid=uuid)
|
||||
@@ -279,7 +277,7 @@ def deply_status(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="logs")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_logs(uuid: Optional[str]):
|
||||
def deploy_logs(uuid: str | None):
|
||||
"""Get the logs of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||
@@ -287,7 +285,7 @@ def deploy_logs(uuid: Optional[str]):
|
||||
|
||||
@deploy.command(name="remove")
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_remove(uuid: Optional[str]):
|
||||
def deploy_remove(uuid: str | None):
|
||||
"""Remove a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
@@ -327,7 +325,6 @@ def tool_publish(is_public: bool, force: bool):
|
||||
@crewai.group()
|
||||
def flow():
|
||||
"""Flow related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@flow.command(name="kickoff")
|
||||
@@ -359,7 +356,7 @@ def chat():
|
||||
and using the Chat LLM to generate responses.
|
||||
"""
|
||||
click.secho(
|
||||
"\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n",
|
||||
"\nStarting a conversation with the Crew\nType 'exit' or Ctrl+C to quit.\n",
|
||||
)
|
||||
|
||||
run_chat()
|
||||
@@ -368,7 +365,6 @@ def chat():
|
||||
@crewai.group(invoke_without_command=True)
|
||||
def org():
|
||||
"""Organization management commands."""
|
||||
pass
|
||||
|
||||
|
||||
@org.command("list")
|
||||
@@ -396,7 +392,6 @@ def current():
|
||||
@crewai.group()
|
||||
def enterprise():
|
||||
"""Enterprise Configuration commands."""
|
||||
pass
|
||||
|
||||
|
||||
@enterprise.command("configure")
|
||||
@@ -410,7 +405,6 @@ def enterprise_configure(enterprise_url: str):
|
||||
@crewai.group()
|
||||
def config():
|
||||
"""CLI Configuration commands."""
|
||||
pass
|
||||
|
||||
|
||||
@config.command("list")
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.cli.constants import (
|
||||
DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
|
||||
DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
@@ -56,20 +55,20 @@ HIDDEN_SETTINGS_KEYS = [
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
enterprise_base_url: Optional[str] = Field(
|
||||
enterprise_base_url: str | None = Field(
|
||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||
description="Base URL of the CrewAI Enterprise instance",
|
||||
)
|
||||
tool_repository_username: Optional[str] = Field(
|
||||
tool_repository_username: str | None = Field(
|
||||
None, description="Username for interacting with the Tool Repository"
|
||||
)
|
||||
tool_repository_password: Optional[str] = Field(
|
||||
tool_repository_password: str | None = Field(
|
||||
None, description="Password for interacting with the Tool Repository"
|
||||
)
|
||||
org_name: Optional[str] = Field(
|
||||
org_name: str | None = Field(
|
||||
None, description="Name of the currently active organization"
|
||||
)
|
||||
org_uuid: Optional[str] = Field(
|
||||
org_uuid: str | None = Field(
|
||||
None, description="UUID of the currently active organization"
|
||||
)
|
||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
|
||||
@@ -79,7 +78,7 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||
)
|
||||
|
||||
oauth2_audience: Optional[str] = Field(
|
||||
oauth2_audience: str | None = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||
)
|
||||
|
||||
@@ -16,48 +16,72 @@ from crewai.cli.utils import copy_template, load_env_vars, write_env_file
|
||||
def create_folder_structure(name, parent_folder=None):
|
||||
import keyword
|
||||
import re
|
||||
|
||||
name = name.rstrip('/')
|
||||
|
||||
|
||||
name = name.rstrip("/")
|
||||
|
||||
if not name.strip():
|
||||
raise ValueError("Project name cannot be empty or contain only whitespace")
|
||||
|
||||
|
||||
folder_name = name.replace(" ", "_").replace("-", "_").lower()
|
||||
folder_name = re.sub(r'[^a-zA-Z0-9_]', '', folder_name)
|
||||
|
||||
folder_name = re.sub(r"[^a-zA-Z0-9_]", "", folder_name)
|
||||
|
||||
# Check if the name starts with invalid characters or is primarily invalid
|
||||
if re.match(r'^[^a-zA-Z0-9_-]+', name):
|
||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
|
||||
|
||||
if re.match(r"^[^a-zA-Z0-9_-]+", name):
|
||||
raise ValueError(
|
||||
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||
)
|
||||
|
||||
if not folder_name:
|
||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python module name")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' contains no valid characters for a Python module name"
|
||||
)
|
||||
|
||||
if folder_name[0].isdigit():
|
||||
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate folder name '{folder_name}' which cannot start with a digit (invalid Python module name)"
|
||||
)
|
||||
|
||||
if keyword.iskeyword(folder_name):
|
||||
raise ValueError(f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate folder name '{folder_name}' which is a reserved Python keyword"
|
||||
)
|
||||
|
||||
if not folder_name.isidentifier():
|
||||
raise ValueError(f"Project name '{name}' would generate invalid Python module name '{folder_name}'")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
|
||||
)
|
||||
|
||||
class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")
|
||||
|
||||
class_name = re.sub(r'[^a-zA-Z0-9_]', '', class_name)
|
||||
|
||||
|
||||
class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
|
||||
|
||||
if not class_name:
|
||||
raise ValueError(f"Project name '{name}' contains no valid characters for a Python class name")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' contains no valid characters for a Python class name"
|
||||
)
|
||||
|
||||
if class_name[0].isdigit():
|
||||
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit")
|
||||
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate class name '{class_name}' which cannot start with a digit"
|
||||
)
|
||||
|
||||
# Check if the original name (before title casing) is a keyword
|
||||
original_name_clean = re.sub(r'[^a-zA-Z0-9_]', '', name.replace("_", "").replace("-", "").lower())
|
||||
if keyword.iskeyword(original_name_clean) or keyword.iskeyword(class_name) or class_name in ('True', 'False', 'None'):
|
||||
raise ValueError(f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword")
|
||||
|
||||
original_name_clean = re.sub(
|
||||
r"[^a-zA-Z0-9_]", "", name.replace("_", "").replace("-", "").lower()
|
||||
)
|
||||
if (
|
||||
keyword.iskeyword(original_name_clean)
|
||||
or keyword.iskeyword(class_name)
|
||||
or class_name in ("True", "False", "None")
|
||||
):
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate class name '{class_name}' which is a reserved Python keyword"
|
||||
)
|
||||
|
||||
if not class_name.isidentifier():
|
||||
raise ValueError(f"Project name '{name}' would generate invalid Python class name '{class_name}'")
|
||||
raise ValueError(
|
||||
f"Project name '{name}' would generate invalid Python class name '{class_name}'"
|
||||
)
|
||||
|
||||
if parent_folder:
|
||||
folder_path = Path(parent_folder) / folder_name
|
||||
@@ -172,7 +196,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
)
|
||||
|
||||
# Check if the selected provider has predefined models
|
||||
if selected_provider in MODELS and MODELS[selected_provider]:
|
||||
if MODELS.get(selected_provider):
|
||||
while True:
|
||||
selected_model = select_model(selected_provider, provider_models)
|
||||
if selected_model is None: # User typed 'q'
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
|
||||
print()
|
||||
|
||||
|
||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
|
||||
def initialize_chat_llm(crew: Crew) -> LLM | BaseLLM | None:
|
||||
"""Initializes the chat LLM and handles exceptions."""
|
||||
try:
|
||||
return create_llm(crew.chat_llm)
|
||||
@@ -157,7 +157,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
)
|
||||
|
||||
|
||||
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
||||
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
|
||||
"""Creates a wrapper function for running the crew tool with messages."""
|
||||
|
||||
def run_crew_tool_with_messages(**kwargs):
|
||||
@@ -221,9 +221,9 @@ def get_user_input() -> str:
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: LLM,
|
||||
messages: List[Dict[str, str]],
|
||||
crew_tool_schema: Dict[str, Any],
|
||||
available_functions: Dict[str, Any],
|
||||
messages: list[dict[str, str]],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
) -> None:
|
||||
if user_input.strip().lower() == "exit":
|
||||
click.echo("Exiting chat. Goodbye!")
|
||||
@@ -281,7 +281,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
|
||||
"""
|
||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
|
||||
@@ -304,9 +304,8 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
crew_output = crew.kickoff(inputs=kwargs)
|
||||
|
||||
# Convert CrewOutput to a string to send back to the user
|
||||
result = str(crew_output)
|
||||
return str(crew_output)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# Exit the chat and show the error message
|
||||
click.secho("An error occurred while running the crew:", fg="red")
|
||||
@@ -314,7 +313,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
def load_crew_and_name() -> tuple[Crew, str]:
|
||||
"""
|
||||
Loads the crew by importing the crew class from the user's project.
|
||||
|
||||
@@ -395,7 +394,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
||||
)
|
||||
|
||||
|
||||
def fetch_required_inputs(crew: Crew) -> Set[str]:
|
||||
def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||
"""
|
||||
Extracts placeholders from the crew's tasks and agents.
|
||||
|
||||
@@ -406,7 +405,7 @@ def fetch_required_inputs(crew: Crew) -> Set[str]:
|
||||
Set[str]: A set of placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: Set[str] = set()
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks
|
||||
for task in crew.tasks:
|
||||
@@ -479,9 +478,7 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
description = response.strip()
|
||||
|
||||
return description
|
||||
return response.strip()
|
||||
|
||||
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
@@ -531,6 +528,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
crew_description = response.strip()
|
||||
|
||||
return crew_description
|
||||
return response.strip()
|
||||
|
||||
@@ -64,8 +64,7 @@ class Repository:
|
||||
"""Return True if the Git repository is fully synced with the remote, False otherwise."""
|
||||
if self.has_uncommitted_changes() or self.is_ahead_or_behind():
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return True
|
||||
|
||||
def origin_url(self) -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
|
||||
@@ -12,7 +12,7 @@ def install_crew(proxy_options: list[str]) -> None:
|
||||
Install the crew by running the UV command to lock and install.
|
||||
"""
|
||||
try:
|
||||
command = ["uv", "sync"] + proxy_options
|
||||
command = ["uv", "sync", *proxy_options]
|
||||
subprocess.run(command, check=True, capture_output=False, text=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
class PlusAPI:
|
||||
@@ -56,9 +55,9 @@ class PlusAPI:
|
||||
handle: str,
|
||||
is_public: bool,
|
||||
version: str,
|
||||
description: Optional[str],
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
available_exports: Optional[List[str]] = None,
|
||||
available_exports: list[str] | None = None,
|
||||
):
|
||||
params = {
|
||||
"handle": handle,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import certifi
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import certifi
|
||||
import click
|
||||
import requests
|
||||
|
||||
@@ -25,7 +25,7 @@ def select_choice(prompt_message, choices):
|
||||
|
||||
provider_models = get_provider_data()
|
||||
if not provider_models:
|
||||
return
|
||||
return None
|
||||
click.secho(prompt_message, fg="cyan")
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
click.secho(f"{idx}. {choice}", fg="cyan")
|
||||
@@ -67,7 +67,7 @@ def select_provider(provider_models):
|
||||
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
|
||||
|
||||
provider = select_choice(
|
||||
"Select a provider to set up:", predefined_providers + ["other"]
|
||||
"Select a provider to set up:", [*predefined_providers, "other"]
|
||||
)
|
||||
if provider is None: # User typed 'q'
|
||||
return None
|
||||
@@ -102,10 +102,9 @@ def select_model(provider, provider_models):
|
||||
click.secho(f"No models available for provider '{provider}'.", fg="red")
|
||||
return None
|
||||
|
||||
selected_model = select_choice(
|
||||
return select_choice(
|
||||
f"Select a model to use for {provider.capitalize()}:", available_models
|
||||
)
|
||||
return selected_model
|
||||
|
||||
|
||||
def load_provider_data(cache_file, cache_expiry):
|
||||
@@ -165,7 +164,7 @@ def fetch_provider_data(cache_file):
|
||||
Returns:
|
||||
- dict or None: The fetched provider data or None if the operation fails.
|
||||
"""
|
||||
ssl_config = os.environ['SSL_CERT_FILE'] = certifi.where()
|
||||
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
|
||||
try:
|
||||
response = requests.get(JSON_URL, stream=True, timeout=60, verify=ssl_config)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class TokenManager:
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
def get_token(self) -> str | None:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
|
||||
@@ -113,7 +113,7 @@ class TokenManager:
|
||||
# Set appropriate permissions (read/write for owner only)
|
||||
os.chmod(file_path, 0o600)
|
||||
|
||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||
def read_secure_file(self, filename: str) -> bytes | None:
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
from functools import reduce
|
||||
from inspect import getmro, isclass, isfunction, ismethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, get_type_hints
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
import click
|
||||
import tomli
|
||||
@@ -41,8 +41,7 @@ def copy_template(src, dst, name, class_name, folder_name):
|
||||
def read_toml(file_path: str = "pyproject.toml"):
|
||||
"""Read the content of a TOML file and return it as a dictionary."""
|
||||
with open(file_path, "rb") as f:
|
||||
toml_dict = tomli.load(f)
|
||||
return toml_dict
|
||||
return tomli.load(f)
|
||||
|
||||
|
||||
def parse_toml(content):
|
||||
@@ -77,7 +76,7 @@ def get_project_description(
|
||||
|
||||
|
||||
def _get_project_attribute(
|
||||
pyproject_path: str, keys: List[str], require: bool
|
||||
pyproject_path: str, keys: list[str], require: bool
|
||||
) -> Any | None:
|
||||
"""Get an attribute from the pyproject.toml file."""
|
||||
attribute = None
|
||||
@@ -96,7 +95,10 @@ def _get_project_attribute(
|
||||
except FileNotFoundError:
|
||||
console.print(f"Error: {pyproject_path} not found.", style="bold red")
|
||||
except KeyError:
|
||||
console.print(f"Error: {pyproject_path} is not a valid pyproject.toml file.", style="bold red")
|
||||
console.print(
|
||||
f"Error: {pyproject_path} is not a valid pyproject.toml file.",
|
||||
style="bold red",
|
||||
)
|
||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
||||
console.print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file."
|
||||
@@ -117,7 +119,7 @@ def _get_project_attribute(
|
||||
return attribute
|
||||
|
||||
|
||||
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
||||
def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
@@ -296,7 +298,10 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
try:
|
||||
crew_instances.extend(fetch_crews(module_attr))
|
||||
except Exception as e:
|
||||
console.print(f"Error processing attribute {attr_name}: {e}", style="bold red")
|
||||
console.print(
|
||||
f"Error processing attribute {attr_name}: {e}",
|
||||
style="bold red",
|
||||
)
|
||||
continue
|
||||
|
||||
# If we found crew instances, break out of the loop
|
||||
@@ -304,12 +309,15 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
break
|
||||
|
||||
except Exception as exec_error:
|
||||
console.print(f"Error executing module: {exec_error}", style="bold red")
|
||||
console.print(
|
||||
f"Error executing module: {exec_error}",
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Error importing crew from {crew_path}: {str(e)}",
|
||||
f"Error importing crew from {crew_path}: {e!s}",
|
||||
style="bold red",
|
||||
)
|
||||
continue
|
||||
@@ -325,7 +333,7 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
||||
except Exception as e:
|
||||
if require:
|
||||
console.print(
|
||||
f"Unexpected error while loading crew: {str(e)}", style="bold red"
|
||||
f"Unexpected error while loading crew: {e!s}", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
return crew_instances
|
||||
@@ -348,8 +356,7 @@ def get_crew_instance(module_attr) -> Crew | None:
|
||||
|
||||
if isinstance(module_attr, Crew):
|
||||
return module_attr
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def fetch_crews(module_attr) -> list[Crew]:
|
||||
@@ -402,7 +409,7 @@ def extract_available_exports(dir_path: str = "src"):
|
||||
return available_exports
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error: Could not extract tool classes: {str(e)}[/red]")
|
||||
console.print(f"[red]Error: Could not extract tool classes: {e!s}[/red]")
|
||||
console.print(
|
||||
"Please ensure your project contains valid tools (classes inheriting from BaseTool or functions with @tool decorator)."
|
||||
)
|
||||
@@ -440,7 +447,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]Warning: Could not load {init_file}: {str(e)}[/red]")
|
||||
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
|
||||
raise SystemExit(1)
|
||||
|
||||
finally:
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
import os
|
||||
import contextvars
|
||||
from typing import Optional
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
_platform_integration_token: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
|
||||
"platform_integration_token", default=None
|
||||
_platform_integration_token: contextvars.ContextVar[str | None] = (
|
||||
contextvars.ContextVar("platform_integration_token", default=None)
|
||||
)
|
||||
|
||||
|
||||
def set_platform_integration_token(integration_token: str) -> None:
|
||||
_platform_integration_token.set(integration_token)
|
||||
|
||||
def get_platform_integration_token() -> Optional[str]:
|
||||
|
||||
def get_platform_integration_token() -> str | None:
|
||||
token = _platform_integration_token.get()
|
||||
if token is None:
|
||||
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
|
||||
return token
|
||||
|
||||
|
||||
@contextmanager
|
||||
def platform_context(integration_token: str):
|
||||
token = _platform_integration_token.set(integration_token)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -12,10 +12,10 @@ class CrewOutput(BaseModel):
|
||||
"""Class that represents the result of a crew."""
|
||||
|
||||
raw: str = Field(description="Raw output of crew", default="")
|
||||
pydantic: Optional[BaseModel] = Field(
|
||||
pydantic: BaseModel | None = Field(
|
||||
description="Pydantic output of Crew", default=None
|
||||
)
|
||||
json_dict: Optional[Dict[str, Any]] = Field(
|
||||
json_dict: dict[str, Any] | None = Field(
|
||||
description="JSON dict output of Crew", default=None
|
||||
)
|
||||
tasks_output: list[TaskOutput] = Field(
|
||||
@@ -24,7 +24,7 @@ class CrewOutput(BaseModel):
|
||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def json(self) -> str | None:
|
||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
@@ -32,7 +32,7 @@ class CrewOutput(BaseModel):
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
output_dict = {}
|
||||
if self.json_dict:
|
||||
@@ -44,10 +44,9 @@ class CrewOutput(BaseModel):
|
||||
def __getitem__(self, key):
|
||||
if self.pydantic and hasattr(self.pydantic, key):
|
||||
return getattr(self.pydantic, key)
|
||||
elif self.json_dict and key in self.json_dict:
|
||||
if self.json_dict and key in self.json_dict:
|
||||
return self.json_dict[key]
|
||||
else:
|
||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||
raise KeyError(f"Key '{key}' not found in CrewOutput.")
|
||||
|
||||
def __str__(self):
|
||||
if self.pydantic:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
@@ -10,11 +11,11 @@ class BaseEvent(BaseModel):
|
||||
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
type: str
|
||||
source_fingerprint: Optional[str] = None # UUID string of the source entity
|
||||
source_type: Optional[str] = (
|
||||
source_fingerprint: str | None = None # UUID string of the source entity
|
||||
source_type: str | None = (
|
||||
None # "agent", "task", "crew", "memory", "entity_memory", "short_term_memory", "long_term_memory", "external_memory"
|
||||
)
|
||||
fingerprint_metadata: Optional[Dict[str, Any]] = None # Any relevant metadata
|
||||
fingerprint_metadata: dict[str, Any] | None = None # Any relevant metadata
|
||||
|
||||
def to_json(self, exclude: set[str] | None = None):
|
||||
"""
|
||||
@@ -28,13 +29,13 @@ class BaseEvent(BaseModel):
|
||||
"""
|
||||
return to_serializable(self, exclude=exclude)
|
||||
|
||||
def _set_task_params(self, data: Dict[str, Any]):
|
||||
def _set_task_params(self, data: dict[str, Any]):
|
||||
if "from_task" in data and (task := data["from_task"]):
|
||||
self.task_id = task.id
|
||||
self.task_name = task.name or task.description
|
||||
self.from_task = None
|
||||
|
||||
def _set_agent_params(self, data: Dict[str, Any]):
|
||||
def _set_agent_params(self, data: dict[str, Any]):
|
||||
task = data.get("from_task", None)
|
||||
agent = task.agent if task else data.get("from_agent", None)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Type, TypeVar, cast
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
from blinker import Signal
|
||||
|
||||
@@ -25,17 +26,17 @@ class CrewAIEventsBus:
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None: # prevent race condition
|
||||
cls._instance = super(CrewAIEventsBus, cls).__new__(cls)
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""Initialize the event bus internal state"""
|
||||
self._signal = Signal("crewai_event_bus")
|
||||
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
|
||||
self._handlers: dict[type[BaseEvent], list[Callable]] = {}
|
||||
|
||||
def on(
|
||||
self, event_type: Type[EventT]
|
||||
self, event_type: type[EventT]
|
||||
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
|
||||
"""
|
||||
Decorator to register an event handler for a specific event type.
|
||||
@@ -82,7 +83,7 @@ class CrewAIEventsBus:
|
||||
self._signal.send(source, event=event)
|
||||
|
||||
def register_handler(
|
||||
self, event_type: Type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||
self, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||
) -> None:
|
||||
"""Register an event handler for a specific event type"""
|
||||
if event_type not in self._handlers:
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.constants import EMITTER_COLOR
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
@@ -25,34 +40,21 @@ from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.llm_guardrail_events import (
|
||||
LLMGuardrailStartedEvent,
|
||||
LLMGuardrailCompletedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import (
|
||||
AgentLogsStartedEvent,
|
||||
AgentLogsExecutionEvent,
|
||||
AgentLogsStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestResultEvent,
|
||||
CrewTestStartedEvent,
|
||||
CrewTrainCompletedEvent,
|
||||
CrewTrainFailedEvent,
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities import Logger
|
||||
from crewai.utilities.constants import EMITTER_COLOR
|
||||
|
||||
from .listeners.memory_listener import MemoryListener
|
||||
from .types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
@@ -61,26 +63,24 @@ from .types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from .types.task_events import TaskCompletedEvent, TaskFailedEvent, TaskStartedEvent
|
||||
from .types.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
)
|
||||
|
||||
from .listeners.memory_listener import MemoryListener
|
||||
|
||||
|
||||
class EventListener(BaseEventListener):
|
||||
_instance = None
|
||||
_telemetry: Telemetry = PrivateAttr(default_factory=lambda: Telemetry())
|
||||
logger = Logger(verbose=True, default_color=EMITTER_COLOR)
|
||||
execution_spans: Dict[Task, Any] = Field(default_factory=dict)
|
||||
execution_spans: dict[Task, Any] = Field(default_factory=dict)
|
||||
next_chunk = 0
|
||||
text_stream = StringIO()
|
||||
knowledge_retrieval_in_progress = False
|
||||
|
||||
@@ -6,6 +6,7 @@ from crewai.events.types.agent_events import (
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
)
|
||||
|
||||
from .types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
@@ -24,6 +25,14 @@ from .types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from .types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from .types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
@@ -34,6 +43,21 @@ from .types.llm_guardrail_events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from .types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
AgentReasoningStartedEvent,
|
||||
)
|
||||
from .types.task_events import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
@@ -44,30 +68,6 @@ from .types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from .types.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
AgentReasoningFailedEvent,
|
||||
)
|
||||
from .types.knowledge_events import (
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
|
||||
from .types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
|
||||
EventTypes = Union[
|
||||
CrewKickoffStartedEvent,
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
This module contains various event listener implementations
|
||||
for handling memory, tracing, and other event-driven functionality.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -13,7 +13,7 @@ class TraceEvent:
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
type: str = ""
|
||||
event_data: Dict[str, Any] = field(default_factory=dict)
|
||||
event_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
|
||||
This module contains all event types used throughout the CrewAI system
|
||||
for monitoring and extending agent, crew, task, and tool execution.
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic import ConfigDict, model_validator
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class AgentExecutionStartedEvent(BaseEvent):
|
||||
@@ -17,11 +18,11 @@ class AgentExecutionStartedEvent(BaseEvent):
|
||||
|
||||
agent: BaseAgent
|
||||
task: Any
|
||||
tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]]
|
||||
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
||||
task_prompt: str
|
||||
type: str = "agent_execution_started"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
@@ -45,7 +46,7 @@ class AgentExecutionCompletedEvent(BaseEvent):
|
||||
output: str
|
||||
type: str = "agent_execution_completed"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
@@ -69,7 +70,7 @@ class AgentExecutionErrorEvent(BaseEvent):
|
||||
error: str
|
||||
type: str = "agent_execution_error"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_fingerprint_data(self):
|
||||
@@ -89,18 +90,18 @@ class AgentExecutionErrorEvent(BaseEvent):
|
||||
class LiteAgentExecutionStartedEvent(BaseEvent):
|
||||
"""Event emitted when a LiteAgent starts executing"""
|
||||
|
||||
agent_info: Dict[str, Any]
|
||||
tools: Optional[Sequence[Union[BaseTool, CrewStructuredTool]]]
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
agent_info: dict[str, Any]
|
||||
tools: Sequence[BaseTool | CrewStructuredTool] | None
|
||||
messages: str | list[dict[str, str]]
|
||||
type: str = "lite_agent_execution_started"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class LiteAgentExecutionCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a LiteAgent completes execution"""
|
||||
|
||||
agent_info: Dict[str, Any]
|
||||
agent_info: dict[str, Any]
|
||||
output: str
|
||||
type: str = "lite_agent_execution_completed"
|
||||
|
||||
@@ -108,7 +109,7 @@ class LiteAgentExecutionCompletedEvent(BaseEvent):
|
||||
class LiteAgentExecutionErrorEvent(BaseEvent):
|
||||
"""Event emitted when a LiteAgent encounters an error during execution"""
|
||||
|
||||
agent_info: Dict[str, Any]
|
||||
agent_info: dict[str, Any]
|
||||
error: str
|
||||
type: str = "lite_agent_execution_error"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -11,8 +11,8 @@ else:
|
||||
class CrewBaseEvent(BaseEvent):
|
||||
"""Base class for crew events with fingerprint handling"""
|
||||
|
||||
crew_name: Optional[str]
|
||||
crew: Optional[Crew] = None
|
||||
crew_name: str | None
|
||||
crew: Crew | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -38,7 +38,7 @@ class CrewBaseEvent(BaseEvent):
|
||||
class CrewKickoffStartedEvent(CrewBaseEvent):
|
||||
"""Event emitted when a crew starts execution"""
|
||||
|
||||
inputs: Optional[Dict[str, Any]]
|
||||
inputs: dict[str, Any] | None
|
||||
type: str = "crew_kickoff_started"
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ class CrewTrainStartedEvent(CrewBaseEvent):
|
||||
|
||||
n_iterations: int
|
||||
filename: str
|
||||
inputs: Optional[Dict[str, Any]]
|
||||
inputs: dict[str, Any] | None
|
||||
type: str = "crew_train_started"
|
||||
|
||||
|
||||
@@ -85,8 +85,8 @@ class CrewTestStartedEvent(CrewBaseEvent):
|
||||
"""Event emitted when a crew starts testing"""
|
||||
|
||||
n_iterations: int
|
||||
eval_llm: Optional[Union[str, Any]]
|
||||
inputs: Optional[Dict[str, Any]]
|
||||
eval_llm: str | Any | None
|
||||
inputs: dict[str, Any] | None
|
||||
type: str = "crew_test_started"
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@@ -16,7 +16,7 @@ class FlowStartedEvent(FlowEvent):
|
||||
"""Event emitted when a flow starts execution"""
|
||||
|
||||
flow_name: str
|
||||
inputs: Optional[Dict[str, Any]] = None
|
||||
inputs: dict[str, Any] | None = None
|
||||
type: str = "flow_started"
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ class MethodExecutionStartedEvent(FlowEvent):
|
||||
|
||||
flow_name: str
|
||||
method_name: str
|
||||
state: Union[Dict[str, Any], BaseModel]
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
state: dict[str, Any] | BaseModel
|
||||
params: dict[str, Any] | None = None
|
||||
type: str = "method_execution_started"
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class MethodExecutionFinishedEvent(FlowEvent):
|
||||
flow_name: str
|
||||
method_name: str
|
||||
result: Any = None
|
||||
state: Union[Dict[str, Any], BaseModel]
|
||||
state: dict[str, Any] | BaseModel
|
||||
type: str = "method_execution_finished"
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ class FlowFinishedEvent(FlowEvent):
|
||||
"""Event emitted when a flow completes execution"""
|
||||
|
||||
flow_name: str
|
||||
result: Optional[Any] = None
|
||||
result: Any | None = None
|
||||
type: str = "flow_finished"
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class KnowledgeRetrievalStartedEvent(BaseEvent):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -7,14 +7,14 @@ from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class LLMEventBase(BaseEvent):
|
||||
task_name: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
task_name: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
agent_id: Optional[str] = None
|
||||
agent_role: Optional[str] = None
|
||||
agent_id: str | None = None
|
||||
agent_role: str | None = None
|
||||
|
||||
from_task: Optional[Any] = None
|
||||
from_agent: Optional[Any] = None
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -38,11 +38,11 @@ class LLMCallStartedEvent(LLMEventBase):
|
||||
"""
|
||||
|
||||
type: str = "llm_call_started"
|
||||
model: Optional[str] = None
|
||||
messages: Optional[Union[str, List[Dict[str, Any]]]] = None
|
||||
tools: Optional[List[dict[str, Any]]] = None
|
||||
callbacks: Optional[List[Any]] = None
|
||||
available_functions: Optional[Dict[str, Any]] = None
|
||||
model: str | None = None
|
||||
messages: str | list[dict[str, Any]] | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
callbacks: list[Any] | None = None
|
||||
available_functions: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LLMCallCompletedEvent(LLMEventBase):
|
||||
@@ -52,7 +52,7 @@ class LLMCallCompletedEvent(LLMEventBase):
|
||||
messages: str | list[dict[str, Any]] | None = None
|
||||
response: Any
|
||||
call_type: LLMCallType
|
||||
model: Optional[str] = None
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class LLMCallFailedEvent(LLMEventBase):
|
||||
@@ -64,13 +64,13 @@ class LLMCallFailedEvent(LLMEventBase):
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
arguments: str
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: Optional[str] = None
|
||||
id: str | None = None
|
||||
function: FunctionCall
|
||||
type: Optional[str] = None
|
||||
type: str | None = None
|
||||
index: int
|
||||
|
||||
|
||||
@@ -79,4 +79,4 @@ class LLMStreamChunkEvent(LLMEventBase):
|
||||
|
||||
type: str = "llm_stream_chunk"
|
||||
chunk: str
|
||||
tool_call: Optional[ToolCall] = None
|
||||
tool_call: ToolCall | None = None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from inspect import getsource
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -13,12 +14,12 @@ class LLMGuardrailStartedEvent(BaseEvent):
|
||||
"""
|
||||
|
||||
type: str = "llm_guardrail_started"
|
||||
guardrail: Union[str, Callable]
|
||||
guardrail: str | Callable
|
||||
retry_count: int
|
||||
|
||||
def __init__(self, **data):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -41,5 +42,5 @@ class LLMGuardrailCompletedEvent(BaseEvent):
|
||||
type: str = "llm_guardrail_completed"
|
||||
success: bool
|
||||
result: Any
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
retry_count: int
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Agent logging events that don't reference BaseAgent to avoid circular imports."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -9,7 +11,7 @@ class AgentLogsStartedEvent(BaseEvent):
|
||||
"""Event emitted when agent logs should be shown at start"""
|
||||
|
||||
agent_role: str
|
||||
task_description: Optional[str] = None
|
||||
task_description: str | None = None
|
||||
verbose: bool = False
|
||||
type: str = "agent_logs_started"
|
||||
|
||||
@@ -22,4 +24,4 @@ class AgentLogsExecutionEvent(BaseEvent):
|
||||
verbose: bool = False
|
||||
type: str = "agent_logs_execution"
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -7,12 +7,12 @@ class MemoryBaseEvent(BaseEvent):
|
||||
"""Base event for memory operations"""
|
||||
|
||||
type: str
|
||||
task_id: Optional[str] = None
|
||||
task_name: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
from_agent: Optional[Any] = None
|
||||
agent_role: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
task_id: str | None = None
|
||||
task_name: str | None = None
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -26,7 +26,7 @@ class MemoryQueryStartedEvent(MemoryBaseEvent):
|
||||
type: str = "memory_query_started"
|
||||
query: str
|
||||
limit: int
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
|
||||
|
||||
class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
||||
@@ -36,7 +36,7 @@ class MemoryQueryCompletedEvent(MemoryBaseEvent):
|
||||
query: str
|
||||
results: Any
|
||||
limit: int
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
query_time_ms: float
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class MemoryQueryFailedEvent(MemoryBaseEvent):
|
||||
type: str = "memory_query_failed"
|
||||
query: str
|
||||
limit: int
|
||||
score_threshold: Optional[float] = None
|
||||
score_threshold: float | None = None
|
||||
error: str
|
||||
|
||||
|
||||
@@ -54,9 +54,9 @@ class MemorySaveStartedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory save operation is started"""
|
||||
|
||||
type: str = "memory_save_started"
|
||||
value: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
agent_role: Optional[str] = None
|
||||
value: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
agent_role: str | None = None
|
||||
|
||||
|
||||
class MemorySaveCompletedEvent(MemoryBaseEvent):
|
||||
@@ -64,8 +64,8 @@ class MemorySaveCompletedEvent(MemoryBaseEvent):
|
||||
|
||||
type: str = "memory_save_completed"
|
||||
value: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
agent_role: Optional[str] = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
agent_role: str | None = None
|
||||
save_time_ms: float
|
||||
|
||||
|
||||
@@ -73,9 +73,9 @@ class MemorySaveFailedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when a memory save operation fails"""
|
||||
|
||||
type: str = "memory_save_failed"
|
||||
value: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
agent_role: Optional[str] = None
|
||||
value: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
agent_role: str | None = None
|
||||
error: str
|
||||
|
||||
|
||||
@@ -83,13 +83,13 @@ class MemoryRetrievalStartedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when memory retrieval for a task prompt starts"""
|
||||
|
||||
type: str = "memory_retrieval_started"
|
||||
task_id: Optional[str] = None
|
||||
task_id: str | None = None
|
||||
|
||||
|
||||
class MemoryRetrievalCompletedEvent(MemoryBaseEvent):
|
||||
"""Event emitted when memory retrieval for a task prompt completes successfully"""
|
||||
|
||||
type: str = "memory_retrieval_completed"
|
||||
task_id: Optional[str] = None
|
||||
task_id: str | None = None
|
||||
memory_content: str
|
||||
retrieval_time_ms: float
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class ReasoningEvent(BaseEvent):
|
||||
@@ -9,10 +10,10 @@ class ReasoningEvent(BaseEvent):
|
||||
attempt: int = 1
|
||||
agent_role: str
|
||||
task_id: str
|
||||
task_name: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
agent_id: Optional[str] = None
|
||||
from_agent: Optional[Any] = None
|
||||
task_name: str | None = None
|
||||
from_task: Any | None = None
|
||||
agent_id: str | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class TaskStartedEvent(BaseEvent):
|
||||
"""Event emitted when a task starts"""
|
||||
|
||||
type: str = "task_started"
|
||||
context: Optional[str]
|
||||
task: Optional[Any] = None
|
||||
context: str | None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -29,7 +29,7 @@ class TaskCompletedEvent(BaseEvent):
|
||||
|
||||
output: TaskOutput
|
||||
type: str = "task_completed"
|
||||
task: Optional[Any] = None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -49,7 +49,7 @@ class TaskFailedEvent(BaseEvent):
|
||||
|
||||
error: str
|
||||
type: str = "task_failed"
|
||||
task: Optional[Any] = None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -69,7 +69,7 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
|
||||
type: str = "task_evaluation"
|
||||
evaluation_type: str
|
||||
task: Optional[Any] = None
|
||||
task: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
@@ -7,21 +10,21 @@ from crewai.events.base_events import BaseEvent
|
||||
class ToolUsageEvent(BaseEvent):
|
||||
"""Base event for tool usage tracking"""
|
||||
|
||||
agent_key: Optional[str] = None
|
||||
agent_role: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
agent_key: str | None = None
|
||||
agent_role: str | None = None
|
||||
agent_id: str | None = None
|
||||
tool_name: str
|
||||
tool_args: Dict[str, Any] | str
|
||||
tool_class: Optional[str] = None
|
||||
tool_args: dict[str, Any] | str
|
||||
tool_class: str | None = None
|
||||
run_attempts: int | None = None
|
||||
delegations: int | None = None
|
||||
agent: Optional[Any] = None
|
||||
task_name: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
from_task: Optional[Any] = None
|
||||
from_agent: Optional[Any] = None
|
||||
agent: Any | None = None
|
||||
task_name: str | None = None
|
||||
task_id: str | None = None
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -81,9 +84,9 @@ class ToolExecutionErrorEvent(BaseEvent):
|
||||
error: Any
|
||||
type: str = "tool_execution_error"
|
||||
tool_name: str
|
||||
tool_args: Dict[str, Any]
|
||||
tool_args: dict[str, Any]
|
||||
tool_class: Callable
|
||||
agent: Optional[Any] = None
|
||||
agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
from rich.tree import Tree
|
||||
from rich.live import Live
|
||||
from rich.syntax import Syntax
|
||||
|
||||
|
||||
class ConsoleFormatter:
|
||||
current_crew_tree: Optional[Tree] = None
|
||||
current_task_branch: Optional[Tree] = None
|
||||
current_agent_branch: Optional[Tree] = None
|
||||
current_tool_branch: Optional[Tree] = None
|
||||
current_flow_tree: Optional[Tree] = None
|
||||
current_method_branch: Optional[Tree] = None
|
||||
current_lite_agent_branch: Optional[Tree] = None
|
||||
tool_usage_counts: Dict[str, int] = {}
|
||||
current_reasoning_branch: Optional[Tree] = None # Track reasoning status
|
||||
_live_paused: bool = False
|
||||
current_llm_tool_tree: Optional[Tree] = None
|
||||
current_crew_tree: ClassVar[Tree | None] = None
|
||||
current_task_branch: ClassVar[Tree | None] = None
|
||||
current_agent_branch: ClassVar[Tree | None] = None
|
||||
current_tool_branch: ClassVar[Tree | None] = None
|
||||
current_flow_tree: ClassVar[Tree | None] = None
|
||||
current_method_branch: ClassVar[Tree | None] = None
|
||||
current_lite_agent_branch: ClassVar[Tree | None] = None
|
||||
tool_usage_counts: ClassVar[dict[str, int]] = {}
|
||||
current_reasoning_branch: ClassVar[Tree | None] = None # Track reasoning status
|
||||
_live_paused: ClassVar[bool] = False
|
||||
current_llm_tool_tree: ClassVar[Tree | None] = None
|
||||
|
||||
def __init__(self, verbose: bool = False):
|
||||
self.console = Console(width=None)
|
||||
@@ -29,7 +29,7 @@ class ConsoleFormatter:
|
||||
# instance so the previous render is replaced instead of writing a new one.
|
||||
# Once any non-Tree renderable is printed we stop the Live session so the
|
||||
# final Tree persists on the terminal.
|
||||
self._live: Optional[Live] = None
|
||||
self._live: Live | None = None
|
||||
|
||||
def create_panel(self, content: Text, title: str, style: str = "blue") -> Panel:
|
||||
"""Create a standardized panel with consistent styling."""
|
||||
@@ -45,7 +45,7 @@ class ConsoleFormatter:
|
||||
title: str,
|
||||
name: str,
|
||||
status_style: str = "blue",
|
||||
tool_args: Dict[str, Any] | str = "",
|
||||
tool_args: dict[str, Any] | str = "",
|
||||
**fields,
|
||||
) -> Text:
|
||||
"""Create standardized status content with consistent formatting."""
|
||||
@@ -70,7 +70,7 @@ class ConsoleFormatter:
|
||||
prefix: str,
|
||||
name: str,
|
||||
style: str = "blue",
|
||||
status: Optional[str] = None,
|
||||
status: str | None = None,
|
||||
) -> None:
|
||||
"""Update tree label with consistent formatting."""
|
||||
label = Text()
|
||||
@@ -156,7 +156,7 @@ class ConsoleFormatter:
|
||||
|
||||
def update_crew_tree(
|
||||
self,
|
||||
tree: Optional[Tree],
|
||||
tree: Tree | None,
|
||||
crew_name: str,
|
||||
source_id: str,
|
||||
status: str = "completed",
|
||||
@@ -196,7 +196,7 @@ class ConsoleFormatter:
|
||||
|
||||
self.print_panel(content, title, style)
|
||||
|
||||
def create_crew_tree(self, crew_name: str, source_id: str) -> Optional[Tree]:
|
||||
def create_crew_tree(self, crew_name: str, source_id: str) -> Tree | None:
|
||||
"""Create and initialize a new crew tree with initial status."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -220,8 +220,8 @@ class ConsoleFormatter:
|
||||
return tree
|
||||
|
||||
def create_task_branch(
|
||||
self, crew_tree: Optional[Tree], task_id: str, task_name: Optional[str] = None
|
||||
) -> Optional[Tree]:
|
||||
self, crew_tree: Tree | None, task_id: str, task_name: str | None = None
|
||||
) -> Tree | None:
|
||||
"""Create and initialize a task branch."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -255,11 +255,11 @@ class ConsoleFormatter:
|
||||
|
||||
def update_task_status(
|
||||
self,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
task_id: str,
|
||||
agent_role: str,
|
||||
status: str = "completed",
|
||||
task_name: Optional[str] = None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Update task status in the tree."""
|
||||
if not self.verbose or crew_tree is None:
|
||||
@@ -306,8 +306,8 @@ class ConsoleFormatter:
|
||||
self.print_panel(content, panel_title, style)
|
||||
|
||||
def create_agent_branch(
|
||||
self, task_branch: Optional[Tree], agent_role: str, crew_tree: Optional[Tree]
|
||||
) -> Optional[Tree]:
|
||||
self, task_branch: Tree | None, agent_role: str, crew_tree: Tree | None
|
||||
) -> Tree | None:
|
||||
"""Create and initialize an agent branch."""
|
||||
if not self.verbose or not task_branch or not crew_tree:
|
||||
return None
|
||||
@@ -325,9 +325,9 @@ class ConsoleFormatter:
|
||||
|
||||
def update_agent_status(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
agent_role: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
status: str = "completed",
|
||||
) -> None:
|
||||
"""Update agent status in the tree."""
|
||||
@@ -336,7 +336,7 @@ class ConsoleFormatter:
|
||||
# altering the tree. Keeping it a no-op avoids duplicate status lines.
|
||||
return
|
||||
|
||||
def create_flow_tree(self, flow_name: str, flow_id: str) -> Optional[Tree]:
|
||||
def create_flow_tree(self, flow_name: str, flow_id: str) -> Tree | None:
|
||||
"""Create and initialize a flow tree."""
|
||||
content = self.create_status_content(
|
||||
"Starting Flow Execution", flow_name, "blue", ID=flow_id
|
||||
@@ -356,7 +356,7 @@ class ConsoleFormatter:
|
||||
|
||||
return flow_tree
|
||||
|
||||
def start_flow(self, flow_name: str, flow_id: str) -> Optional[Tree]:
|
||||
def start_flow(self, flow_name: str, flow_id: str) -> Tree | None:
|
||||
"""Initialize a flow execution tree."""
|
||||
flow_tree = Tree("")
|
||||
flow_label = Text()
|
||||
@@ -376,7 +376,7 @@ class ConsoleFormatter:
|
||||
|
||||
def update_flow_status(
|
||||
self,
|
||||
flow_tree: Optional[Tree],
|
||||
flow_tree: Tree | None,
|
||||
flow_name: str,
|
||||
flow_id: str,
|
||||
status: str = "completed",
|
||||
@@ -423,11 +423,11 @@ class ConsoleFormatter:
|
||||
|
||||
def update_method_status(
|
||||
self,
|
||||
method_branch: Optional[Tree],
|
||||
flow_tree: Optional[Tree],
|
||||
method_branch: Tree | None,
|
||||
flow_tree: Tree | None,
|
||||
method_name: str,
|
||||
status: str = "running",
|
||||
) -> Optional[Tree]:
|
||||
) -> Tree | None:
|
||||
"""Update method status in the flow tree."""
|
||||
if not flow_tree:
|
||||
return None
|
||||
@@ -480,7 +480,7 @@ class ConsoleFormatter:
|
||||
def handle_llm_tool_usage_started(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: Dict[str, Any] | str,
|
||||
tool_args: dict[str, Any] | str,
|
||||
):
|
||||
# Create status content for the tool usage
|
||||
content = self.create_status_content(
|
||||
@@ -520,11 +520,11 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_tool_usage_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
tool_name: str,
|
||||
crew_tree: Optional[Tree],
|
||||
tool_args: Dict[str, Any] | str = "",
|
||||
) -> Optional[Tree]:
|
||||
crew_tree: Tree | None,
|
||||
tool_args: dict[str, Any] | str = "",
|
||||
) -> Tree | None:
|
||||
"""Handle tool usage started event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -569,9 +569,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_tool_usage_finished(
|
||||
self,
|
||||
tool_branch: Optional[Tree],
|
||||
tool_branch: Tree | None,
|
||||
tool_name: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle tool usage finished event."""
|
||||
if not self.verbose or tool_branch is None:
|
||||
@@ -600,10 +600,10 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_tool_usage_error(
|
||||
self,
|
||||
tool_branch: Optional[Tree],
|
||||
tool_branch: Tree | None,
|
||||
tool_name: str,
|
||||
error: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle tool usage error event."""
|
||||
if not self.verbose:
|
||||
@@ -631,9 +631,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_llm_call_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
) -> Optional[Tree]:
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> Tree | None:
|
||||
"""Handle LLM call started event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -672,9 +672,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_llm_call_completed(
|
||||
self,
|
||||
tool_branch: Optional[Tree],
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
tool_branch: Tree | None,
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle LLM call completed event."""
|
||||
if not self.verbose:
|
||||
@@ -736,7 +736,7 @@ class ConsoleFormatter:
|
||||
self.print()
|
||||
|
||||
def handle_llm_call_failed(
|
||||
self, tool_branch: Optional[Tree], error: str, crew_tree: Optional[Tree]
|
||||
self, tool_branch: Tree | None, error: str, crew_tree: Tree | None
|
||||
) -> None:
|
||||
"""Handle LLM call failed event."""
|
||||
if not self.verbose:
|
||||
@@ -789,7 +789,7 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_crew_test_started(
|
||||
self, crew_name: str, source_id: str, n_iterations: int
|
||||
) -> Optional[Tree]:
|
||||
) -> Tree | None:
|
||||
"""Handle crew test started event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -823,7 +823,7 @@ class ConsoleFormatter:
|
||||
return test_tree
|
||||
|
||||
def handle_crew_test_completed(
|
||||
self, flow_tree: Optional[Tree], crew_name: str
|
||||
self, flow_tree: Tree | None, crew_name: str
|
||||
) -> None:
|
||||
"""Handle crew test completed event."""
|
||||
if not self.verbose:
|
||||
@@ -913,7 +913,7 @@ class ConsoleFormatter:
|
||||
self.print_panel(failure_content, "Test Failure", "red")
|
||||
self.print()
|
||||
|
||||
def create_lite_agent_branch(self, lite_agent_role: str) -> Optional[Tree]:
|
||||
def create_lite_agent_branch(self, lite_agent_role: str) -> Tree | None:
|
||||
"""Create and initialize a lite agent branch."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -935,10 +935,10 @@ class ConsoleFormatter:
|
||||
|
||||
def update_lite_agent_status(
|
||||
self,
|
||||
lite_agent_branch: Optional[Tree],
|
||||
lite_agent_branch: Tree | None,
|
||||
lite_agent_role: str,
|
||||
status: str = "completed",
|
||||
**fields: Dict[str, Any],
|
||||
**fields: dict[str, Any],
|
||||
) -> None:
|
||||
"""Update lite agent status in the tree."""
|
||||
if not self.verbose or lite_agent_branch is None:
|
||||
@@ -981,7 +981,7 @@ class ConsoleFormatter:
|
||||
lite_agent_role: str,
|
||||
status: str = "started",
|
||||
error: Any = None,
|
||||
**fields: Dict[str, Any],
|
||||
**fields: dict[str, Any],
|
||||
) -> None:
|
||||
"""Handle lite agent execution events with consistent formatting."""
|
||||
if not self.verbose:
|
||||
@@ -1006,9 +1006,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_retrieval_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
) -> Optional[Tree]:
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> Tree | None:
|
||||
"""Handle knowledge retrieval started event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -1034,13 +1034,13 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_retrieval_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
retrieved_knowledge: Any,
|
||||
) -> None:
|
||||
"""Handle knowledge retrieval completed event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
@@ -1062,7 +1062,7 @@ class ConsoleFormatter:
|
||||
)
|
||||
self.print(knowledge_panel)
|
||||
self.print()
|
||||
return None
|
||||
return
|
||||
|
||||
knowledge_branch_found = False
|
||||
for child in branch_to_use.children:
|
||||
@@ -1111,18 +1111,18 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_query_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
task_prompt: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle knowledge query generated event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
if branch_to_use is None or tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
query_branch = branch_to_use.add("")
|
||||
self.update_tree_label(
|
||||
@@ -1134,9 +1134,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_query_failed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
error: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle knowledge query failed event."""
|
||||
if not self.verbose:
|
||||
@@ -1159,18 +1159,18 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_query_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle knowledge query completed event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
|
||||
if branch_to_use is None or tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
query_branch = branch_to_use.add("")
|
||||
self.update_tree_label(query_branch, "✅", "Knowledge Query Completed", "green")
|
||||
@@ -1180,9 +1180,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_knowledge_search_query_failed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
error: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle knowledge search query failed event."""
|
||||
if not self.verbose:
|
||||
@@ -1207,10 +1207,10 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_reasoning_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
attempt: int,
|
||||
crew_tree: Optional[Tree],
|
||||
) -> Optional[Tree]:
|
||||
crew_tree: Tree | None,
|
||||
) -> Tree | None:
|
||||
"""Handle agent reasoning started (or refinement) event."""
|
||||
if not self.verbose:
|
||||
return None
|
||||
@@ -1249,7 +1249,7 @@ class ConsoleFormatter:
|
||||
self,
|
||||
plan: str,
|
||||
ready: bool,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle agent reasoning completed event."""
|
||||
if not self.verbose:
|
||||
@@ -1292,7 +1292,7 @@ class ConsoleFormatter:
|
||||
def handle_reasoning_failed(
|
||||
self,
|
||||
error: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
"""Handle agent reasoning failure event."""
|
||||
if not self.verbose:
|
||||
@@ -1329,7 +1329,7 @@ class ConsoleFormatter:
|
||||
def handle_agent_logs_started(
|
||||
self,
|
||||
agent_role: str,
|
||||
task_description: Optional[str] = None,
|
||||
task_description: str | None = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Handle agent logs started event."""
|
||||
@@ -1367,10 +1367,11 @@ class ConsoleFormatter:
|
||||
if not verbose:
|
||||
return
|
||||
|
||||
from crewai.agents.parser import AgentAction, AgentFinish
|
||||
import json
|
||||
import re
|
||||
|
||||
from crewai.agents.parser import AgentAction, AgentFinish
|
||||
|
||||
agent_role = agent_role.partition("\n")[0]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
@@ -1473,9 +1474,9 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_retrieval_started(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
) -> Optional[Tree]:
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
) -> Tree | None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
|
||||
@@ -1497,13 +1498,13 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_retrieval_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
memory_content: str,
|
||||
retrieval_time_ms: float,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
@@ -1528,7 +1529,7 @@ class ConsoleFormatter:
|
||||
|
||||
if branch_to_use is None or tree_to_use is None:
|
||||
add_panel()
|
||||
return None
|
||||
return
|
||||
|
||||
memory_branch_found = False
|
||||
for child in branch_to_use.children:
|
||||
@@ -1565,13 +1566,13 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_query_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
source_type: str,
|
||||
query_time_ms: float,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
@@ -1580,7 +1581,7 @@ class ConsoleFormatter:
|
||||
branch_to_use = tree_to_use
|
||||
|
||||
if branch_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
memory_type = source_type.replace("_", " ").title()
|
||||
|
||||
@@ -1598,13 +1599,13 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_query_failed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
error: str,
|
||||
source_type: str,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = self.current_lite_agent_branch or agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
@@ -1613,7 +1614,7 @@ class ConsoleFormatter:
|
||||
branch_to_use = tree_to_use
|
||||
|
||||
if branch_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
memory_type = source_type.replace("_", " ").title()
|
||||
|
||||
@@ -1630,16 +1631,16 @@ class ConsoleFormatter:
|
||||
break
|
||||
|
||||
def handle_memory_save_started(
|
||||
self, agent_branch: Optional[Tree], crew_tree: Optional[Tree]
|
||||
self, agent_branch: Tree | None, crew_tree: Tree | None
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = agent_branch or self.current_lite_agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
|
||||
if tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
for child in tree_to_use.children:
|
||||
if "Memory Update" in str(child.label):
|
||||
@@ -1655,19 +1656,19 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_save_completed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
crew_tree: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
crew_tree: Tree | None,
|
||||
save_time_ms: float,
|
||||
source_type: str,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = agent_branch or self.current_lite_agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
|
||||
if tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
memory_type = source_type.replace("_", " ").title()
|
||||
content = f"✅ {memory_type} Memory Saved ({save_time_ms:.2f}ms)"
|
||||
@@ -1685,19 +1686,19 @@ class ConsoleFormatter:
|
||||
|
||||
def handle_memory_save_failed(
|
||||
self,
|
||||
agent_branch: Optional[Tree],
|
||||
agent_branch: Tree | None,
|
||||
error: str,
|
||||
source_type: str,
|
||||
crew_tree: Optional[Tree],
|
||||
crew_tree: Tree | None,
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return None
|
||||
return
|
||||
|
||||
branch_to_use = agent_branch or self.current_lite_agent_branch
|
||||
tree_to_use = branch_to_use or crew_tree
|
||||
|
||||
if branch_to_use is None or tree_to_use is None:
|
||||
return None
|
||||
return
|
||||
|
||||
memory_type = source_type.replace("_", " ").title()
|
||||
content = f"❌ {memory_type} Memory Save Failed"
|
||||
@@ -1738,7 +1739,7 @@ class ConsoleFormatter:
|
||||
def handle_guardrail_completed(
|
||||
self,
|
||||
success: bool,
|
||||
error: Optional[str],
|
||||
error: str | None,
|
||||
retry_count: int,
|
||||
) -> None:
|
||||
"""Display guardrail evaluation result.
|
||||
|
||||
@@ -1,40 +1,39 @@
|
||||
from crewai.experimental.evaluation import (
|
||||
AgentEvaluationResult,
|
||||
AgentEvaluator,
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
AgentEvaluationResult,
|
||||
SemanticQualityEvaluator,
|
||||
GoalAlignmentEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
EvaluationTraceCallback,
|
||||
create_evaluation_callbacks,
|
||||
AgentEvaluator,
|
||||
create_default_evaluator,
|
||||
ExperimentRunner,
|
||||
ExperimentResults,
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
ExperimentRunner,
|
||||
GoalAlignmentEvaluator,
|
||||
MetricCategory,
|
||||
ParameterExtractionEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
create_default_evaluator,
|
||||
create_evaluation_callbacks,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationResult",
|
||||
"AgentEvaluator",
|
||||
"BaseEvaluator",
|
||||
"EvaluationScore",
|
||||
"MetricCategory",
|
||||
"AgentEvaluationResult",
|
||||
"SemanticQualityEvaluator",
|
||||
"GoalAlignmentEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"EvaluationTraceCallback",
|
||||
"create_evaluation_callbacks",
|
||||
"AgentEvaluator",
|
||||
"create_default_evaluator",
|
||||
"ExperimentRunner",
|
||||
"ExperimentResult",
|
||||
"ExperimentResults",
|
||||
"ExperimentResult"
|
||||
]
|
||||
"ExperimentRunner",
|
||||
"GoalAlignmentEvaluator",
|
||||
"MetricCategory",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"SemanticQualityEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"create_default_evaluator",
|
||||
"create_evaluation_callbacks",
|
||||
]
|
||||
|
||||
@@ -1,51 +1,47 @@
|
||||
from crewai.experimental.evaluation.agent_evaluator import (
|
||||
AgentEvaluator,
|
||||
create_default_evaluator,
|
||||
)
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentEvaluationResult,
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
AgentEvaluationResult
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics import (
|
||||
SemanticQualityEvaluator,
|
||||
GoalAlignmentEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.evaluation_listener import (
|
||||
EvaluationTraceCallback,
|
||||
create_evaluation_callbacks
|
||||
create_evaluation_callbacks,
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.agent_evaluator import (
|
||||
AgentEvaluator,
|
||||
create_default_evaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.experiment import (
|
||||
ExperimentRunner,
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
ExperimentResult
|
||||
ExperimentRunner,
|
||||
)
|
||||
from crewai.experimental.evaluation.metrics import (
|
||||
GoalAlignmentEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationResult",
|
||||
"AgentEvaluator",
|
||||
"BaseEvaluator",
|
||||
"EvaluationScore",
|
||||
"MetricCategory",
|
||||
"AgentEvaluationResult",
|
||||
"SemanticQualityEvaluator",
|
||||
"GoalAlignmentEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"EvaluationTraceCallback",
|
||||
"create_evaluation_callbacks",
|
||||
"AgentEvaluator",
|
||||
"create_default_evaluator",
|
||||
"ExperimentRunner",
|
||||
"ExperimentResult",
|
||||
"ExperimentResults",
|
||||
"ExperimentResult"
|
||||
"ExperimentRunner",
|
||||
"GoalAlignmentEvaluator",
|
||||
"MetricCategory",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"SemanticQualityEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"create_default_evaluator",
|
||||
"create_evaluation_callbacks",
|
||||
]
|
||||
|
||||
@@ -1,34 +1,32 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
)
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationStartedEvent,
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
AgentEvaluationStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
)
|
||||
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
|
||||
from collections.abc import Sequence
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
from crewai.events.types.agent_events import LiteAgentExecutionCompletedEvent
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ExecutionState:
|
||||
current_agent_id: Optional[str] = None
|
||||
current_task_id: Optional[str] = None
|
||||
current_agent_id: str | None = None
|
||||
current_task_id: str | None = None
|
||||
|
||||
def __init__(self):
|
||||
self.traces = {}
|
||||
@@ -284,7 +282,7 @@ class AgentEvaluator:
|
||||
error=str(e),
|
||||
)
|
||||
self.console_formatter.print(
|
||||
f"Error in {evaluator.metric_category.value} evaluator: {str(e)}"
|
||||
f"Error in {evaluator.metric_category.value} evaluator: {e!s}"
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -340,11 +338,11 @@ class AgentEvaluator:
|
||||
def create_default_evaluator(agents: list[Agent], llm: None = None):
|
||||
from crewai.experimental.evaluation import (
|
||||
GoalAlignmentEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
)
|
||||
|
||||
evaluators = [
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import abc
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
class MetricCategory(enum.Enum):
|
||||
GOAL_ALIGNMENT = "goal_alignment"
|
||||
SEMANTIC_QUALITY = "semantic_quality"
|
||||
@@ -19,7 +20,7 @@ class MetricCategory(enum.Enum):
|
||||
TOOL_INVOCATION = "tool_invocation"
|
||||
|
||||
def title(self):
|
||||
return self.value.replace('_', ' ').title()
|
||||
return self.value.replace("_", " ").title()
|
||||
|
||||
|
||||
class EvaluationScore(BaseModel):
|
||||
@@ -27,15 +28,13 @@ class EvaluationScore(BaseModel):
|
||||
default=5.0,
|
||||
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
|
||||
ge=0.0,
|
||||
le=10.0
|
||||
le=10.0,
|
||||
)
|
||||
feedback: str = Field(
|
||||
default="",
|
||||
description="Detailed feedback explaining the evaluation score"
|
||||
default="", description="Detailed feedback explaining the evaluation score"
|
||||
)
|
||||
raw_response: str | None = Field(
|
||||
default=None,
|
||||
description="Raw response from the evaluator (e.g., LLM)"
|
||||
default=None, description="Raw response from the evaluator (e.g., LLM)"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -57,7 +56,7 @@ class BaseEvaluator(abc.ABC):
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -67,9 +66,8 @@ class BaseEvaluator(abc.ABC):
|
||||
class AgentEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(description="ID of the evaluated agent")
|
||||
task_id: str = Field(description="ID of the task that was executed")
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Evaluation scores for each metric category"
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Evaluation scores for each metric category"
|
||||
)
|
||||
|
||||
|
||||
@@ -81,33 +79,23 @@ class AggregationStrategy(Enum):
|
||||
|
||||
|
||||
class AgentAggregatedEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(
|
||||
default="",
|
||||
description="ID of the agent"
|
||||
)
|
||||
agent_role: str = Field(
|
||||
default="",
|
||||
description="Role of the agent"
|
||||
)
|
||||
agent_id: str = Field(default="", description="ID of the agent")
|
||||
agent_role: str = Field(default="", description="Role of the agent")
|
||||
task_count: int = Field(
|
||||
default=0,
|
||||
description="Number of tasks included in this aggregation"
|
||||
default=0, description="Number of tasks included in this aggregation"
|
||||
)
|
||||
aggregation_strategy: AggregationStrategy = Field(
|
||||
default=AggregationStrategy.SIMPLE_AVERAGE,
|
||||
description="Strategy used for aggregation"
|
||||
description="Strategy used for aggregation",
|
||||
)
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Aggregated metrics across all tasks"
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Aggregated metrics across all tasks"
|
||||
)
|
||||
task_results: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of tasks included in this aggregation"
|
||||
task_results: list[str] = Field(
|
||||
default_factory=list, description="IDs of tasks included in this aggregation"
|
||||
)
|
||||
overall_score: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Overall score for this agent"
|
||||
overall_score: float | None = Field(
|
||||
default=None, description="Overall score for this agent"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -119,7 +107,7 @@ class AgentAggregatedEvaluationResult(BaseModel):
|
||||
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
|
||||
|
||||
if score.feedback:
|
||||
detailed_feedback = "\n ".join(score.feedback.split('\n'))
|
||||
detailed_feedback = "\n ".join(score.feedback.split("\n"))
|
||||
result += f" {detailed_feedback}\n"
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Any, List
|
||||
from rich.table import Table
|
||||
from rich.box import HEAVY_EDGE, ROUNDED
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from rich.box import HEAVY_EDGE, ROUNDED
|
||||
from rich.table import Table
|
||||
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.experimental.evaluation import EvaluationScore
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
AggregationStrategy,
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation import EvaluationScore
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
@@ -19,7 +21,7 @@ class EvaluationDisplayFormatter:
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
|
||||
def display_evaluation_with_feedback(
|
||||
self, iterations_results: Dict[int, Dict[str, List[Any]]]
|
||||
self, iterations_results: dict[int, dict[str, list[Any]]]
|
||||
):
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
@@ -99,7 +101,7 @@ class EvaluationDisplayFormatter:
|
||||
|
||||
def display_summary_results(
|
||||
self,
|
||||
iterations_results: Dict[int, Dict[str, List[AgentAggregatedEvaluationResult]]],
|
||||
iterations_results: dict[int, dict[str, list[AgentAggregatedEvaluationResult]]],
|
||||
):
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
@@ -304,25 +306,25 @@ class EvaluationDisplayFormatter:
|
||||
self,
|
||||
agent_role: str,
|
||||
metric: str,
|
||||
feedbacks: List[str],
|
||||
scores: List[float | None],
|
||||
feedbacks: list[str],
|
||||
scores: list[float | None],
|
||||
strategy: AggregationStrategy,
|
||||
) -> str:
|
||||
if len(feedbacks) <= 2 and all(len(fb) < 200 for fb in feedbacks):
|
||||
return "\n\n".join(
|
||||
[f"Feedback {i+1}: {fb}" for i, fb in enumerate(feedbacks)]
|
||||
[f"Feedback {i + 1}: {fb}" for i, fb in enumerate(feedbacks)]
|
||||
)
|
||||
|
||||
try:
|
||||
llm = create_llm()
|
||||
|
||||
formatted_feedbacks = []
|
||||
for i, (feedback, score) in enumerate(zip(feedbacks, scores)):
|
||||
for i, (feedback, score) in enumerate(zip(feedbacks, scores, strict=False)):
|
||||
if len(feedback) > 500:
|
||||
feedback = feedback[:500] + "..."
|
||||
score_text = f"{score:.1f}" if score is not None else "N/A"
|
||||
formatted_feedbacks.append(
|
||||
f"Feedback #{i+1} (Score: {score_text}):\n{feedback}"
|
||||
f"Feedback #{i + 1} (Score: {score_text}):\n{feedback}"
|
||||
)
|
||||
|
||||
all_feedbacks = "\n\n" + "\n\n---\n\n".join(formatted_feedbacks)
|
||||
@@ -366,9 +368,7 @@ class EvaluationDisplayFormatter:
|
||||
},
|
||||
]
|
||||
assert llm is not None
|
||||
response = llm.call(prompt)
|
||||
|
||||
return response
|
||||
return llm.call(prompt)
|
||||
|
||||
except Exception:
|
||||
return "Synthesized from multiple tasks: " + "\n\n".join(
|
||||
|
||||
@@ -1,26 +1,25 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallStartedEvent
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolExecutionErrorEvent,
|
||||
ToolSelectionErrorEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import LLMCallStartedEvent, LLMCallCompletedEvent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class EvaluationTraceCallback(BaseEventListener):
|
||||
@@ -253,7 +252,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
if hasattr(self, "current_llm_call"):
|
||||
self.current_llm_call = {}
|
||||
|
||||
def get_trace(self, agent_id: str, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
|
||||
trace_key = f"{agent_id}_{task_id}"
|
||||
return self.traces.get(trace_key)
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from crewai.experimental.evaluation.experiment.result import (
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.runner import ExperimentRunner
|
||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
|
||||
|
||||
__all__ = [
|
||||
"ExperimentRunner",
|
||||
"ExperimentResults",
|
||||
"ExperimentResult"
|
||||
]
|
||||
__all__ = ["ExperimentResult", "ExperimentResults", "ExperimentRunner"]
|
||||
|
||||
@@ -2,8 +2,10 @@ import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExperimentResult(BaseModel):
|
||||
identifier: str
|
||||
inputs: dict[str, Any]
|
||||
@@ -12,35 +14,48 @@ class ExperimentResult(BaseModel):
|
||||
passed: bool
|
||||
agent_evaluations: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ExperimentResults:
|
||||
def __init__(self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None):
|
||||
def __init__(
|
||||
self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None
|
||||
):
|
||||
self.results = results
|
||||
self.metadata = metadata or {}
|
||||
self.timestamp = datetime.now(timezone.utc)
|
||||
|
||||
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
|
||||
from crewai.experimental.evaluation.experiment.result_display import (
|
||||
ExperimentResultsDisplay,
|
||||
)
|
||||
|
||||
self.display = ExperimentResultsDisplay()
|
||||
|
||||
def to_json(self, filepath: str | None = None) -> dict[str, Any]:
|
||||
data = {
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"metadata": self.metadata,
|
||||
"results": [r.model_dump(exclude={"agent_evaluations"}) for r in self.results]
|
||||
"results": [
|
||||
r.model_dump(exclude={"agent_evaluations"}) for r in self.results
|
||||
],
|
||||
}
|
||||
|
||||
if filepath:
|
||||
with open(filepath, 'w') as f:
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
self.display.console.print(f"[green]Results saved to {filepath}[/green]")
|
||||
|
||||
return data
|
||||
|
||||
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True, print_summary: bool = False) -> dict[str, Any]:
|
||||
def compare_with_baseline(
|
||||
self,
|
||||
baseline_filepath: str,
|
||||
save_current: bool = True,
|
||||
print_summary: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
baseline_runs = []
|
||||
|
||||
if os.path.exists(baseline_filepath) and os.path.getsize(baseline_filepath) > 0:
|
||||
try:
|
||||
with open(baseline_filepath, 'r') as f:
|
||||
with open(baseline_filepath, "r") as f:
|
||||
baseline_data = json.load(f)
|
||||
|
||||
if isinstance(baseline_data, dict) and "timestamp" in baseline_data:
|
||||
@@ -48,14 +63,18 @@ class ExperimentResults:
|
||||
elif isinstance(baseline_data, list):
|
||||
baseline_runs = baseline_data
|
||||
except (json.JSONDecodeError, FileNotFoundError) as e:
|
||||
self.display.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
|
||||
self.display.console.print(
|
||||
f"[yellow]Warning: Could not load baseline file: {e!s}[/yellow]"
|
||||
)
|
||||
|
||||
if not baseline_runs:
|
||||
if save_current:
|
||||
current_data = self.to_json()
|
||||
with open(baseline_filepath, 'w') as f:
|
||||
with open(baseline_filepath, "w") as f:
|
||||
json.dump([current_data], f, indent=2)
|
||||
self.display.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
|
||||
self.display.console.print(
|
||||
f"[green]Saved current results as new baseline to {baseline_filepath}[/green]"
|
||||
)
|
||||
return {"is_baseline": True, "changes": {}}
|
||||
|
||||
baseline_runs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||||
@@ -69,9 +88,11 @@ class ExperimentResults:
|
||||
if save_current:
|
||||
current_data = self.to_json()
|
||||
baseline_runs.append(current_data)
|
||||
with open(baseline_filepath, 'w') as f:
|
||||
with open(baseline_filepath, "w") as f:
|
||||
json.dump(baseline_runs, f, indent=2)
|
||||
self.display.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
|
||||
self.display.console.print(
|
||||
f"[green]Added current results to baseline file {baseline_filepath}[/green]"
|
||||
)
|
||||
|
||||
return comparison
|
||||
|
||||
@@ -118,5 +139,5 @@ class ExperimentResults:
|
||||
"new_tests": new_tests,
|
||||
"missing_tests": missing_tests,
|
||||
"total_compared": len(improved) + len(regressed) + len(unchanged),
|
||||
"baseline_timestamp": baseline_run.get("timestamp", "unknown")
|
||||
"baseline_timestamp": baseline_run.get("timestamp", "unknown"),
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults
|
||||
|
||||
|
||||
class ExperimentResultsDisplay:
|
||||
def __init__(self):
|
||||
self.console = Console()
|
||||
@@ -19,13 +22,19 @@ class ExperimentResultsDisplay:
|
||||
table.add_row("Total Test Cases", str(total))
|
||||
table.add_row("Passed", str(passed))
|
||||
table.add_row("Failed", str(total - passed))
|
||||
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
|
||||
table.add_row(
|
||||
"Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
def comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
|
||||
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
expand=False))
|
||||
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
|
||||
self.console.print(
|
||||
Panel(
|
||||
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
table = Table(title="Results Comparison")
|
||||
table.add_column("Metric", style="cyan")
|
||||
@@ -34,7 +43,9 @@ class ExperimentResultsDisplay:
|
||||
|
||||
improved = comparison.get("improved", [])
|
||||
if improved:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier in improved[:3]])
|
||||
details = ", ".join(
|
||||
[f"{test_identifier}" for test_identifier in improved[:3]]
|
||||
)
|
||||
if len(improved) > 3:
|
||||
details += f" and {len(improved) - 3} more"
|
||||
table.add_row("✅ Improved", str(len(improved)), details)
|
||||
@@ -43,7 +54,9 @@ class ExperimentResultsDisplay:
|
||||
|
||||
regressed = comparison.get("regressed", [])
|
||||
if regressed:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier in regressed[:3]])
|
||||
details = ", ".join(
|
||||
[f"{test_identifier}" for test_identifier in regressed[:3]]
|
||||
)
|
||||
if len(regressed) > 3:
|
||||
details += f" and {len(regressed) - 3} more"
|
||||
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
|
||||
|
||||
@@ -2,11 +2,19 @@ from collections import defaultdict
|
||||
from hashlib import md5
|
||||
from typing import Any
|
||||
|
||||
from crewai import Crew, Agent
|
||||
from crewai import Agent, Crew
|
||||
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
|
||||
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
|
||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
|
||||
from crewai.experimental.evaluation.evaluation_display import AgentAggregatedEvaluationResult
|
||||
from crewai.experimental.evaluation.evaluation_display import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.result import (
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.result_display import (
|
||||
ExperimentResultsDisplay,
|
||||
)
|
||||
|
||||
|
||||
class ExperimentRunner:
|
||||
def __init__(self, dataset: list[dict[str, Any]]):
|
||||
@@ -14,7 +22,12 @@ class ExperimentRunner:
|
||||
self.evaluator: AgentEvaluator | None = None
|
||||
self.display = ExperimentResultsDisplay()
|
||||
|
||||
def run(self, crew: Crew | None = None, agents: list[Agent] | None = None, print_summary: bool = False) -> ExperimentResults:
|
||||
def run(
|
||||
self,
|
||||
crew: Crew | None = None,
|
||||
agents: list[Agent] | None = None,
|
||||
print_summary: bool = False,
|
||||
) -> ExperimentResults:
|
||||
if crew and not agents:
|
||||
agents = crew.agents
|
||||
|
||||
@@ -35,13 +48,20 @@ class ExperimentRunner:
|
||||
|
||||
return experiment_results
|
||||
|
||||
def _run_test_case(self, test_case: dict[str, Any], agents: list[Agent], crew: Crew | None = None) -> ExperimentResult:
|
||||
def _run_test_case(
|
||||
self, test_case: dict[str, Any], agents: list[Agent], crew: Crew | None = None
|
||||
) -> ExperimentResult:
|
||||
inputs = test_case["inputs"]
|
||||
expected_score = test_case["expected_score"]
|
||||
identifier = test_case.get("identifier") or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
|
||||
identifier = (
|
||||
test_case.get("identifier")
|
||||
or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
|
||||
)
|
||||
|
||||
try:
|
||||
self.display.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
|
||||
self.display.console.print(
|
||||
f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]"
|
||||
)
|
||||
self.display.console.print("\n")
|
||||
if crew:
|
||||
crew.kickoff(inputs=inputs)
|
||||
@@ -61,35 +81,38 @@ class ExperimentRunner:
|
||||
score=actual_score,
|
||||
expected_score=expected_score,
|
||||
passed=passed,
|
||||
agent_evaluations=agent_evaluations
|
||||
agent_evaluations=agent_evaluations,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.display.console.print(f"[red]Error running test case: {str(e)}[/red]")
|
||||
self.display.console.print(f"[red]Error running test case: {e!s}[/red]")
|
||||
return ExperimentResult(
|
||||
identifier=identifier,
|
||||
inputs=inputs,
|
||||
score=0,
|
||||
expected_score=expected_score,
|
||||
passed=False
|
||||
passed=False,
|
||||
)
|
||||
|
||||
def _extract_scores(self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]) -> float | dict[str, float]:
|
||||
def _extract_scores(
|
||||
self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]
|
||||
) -> float | dict[str, float]:
|
||||
all_scores: dict[str, list[float]] = defaultdict(list)
|
||||
for evaluation in agent_evaluations.values():
|
||||
for metric_name, score in evaluation.metrics.items():
|
||||
if score.score is not None:
|
||||
all_scores[metric_name.value].append(score.score)
|
||||
|
||||
avg_scores = {m: sum(s)/len(s) for m, s in all_scores.items()}
|
||||
avg_scores = {m: sum(s) / len(s) for m, s in all_scores.items()}
|
||||
|
||||
if len(avg_scores) == 1:
|
||||
return list(avg_scores.values())[0]
|
||||
return next(iter(avg_scores.values()))
|
||||
|
||||
return avg_scores
|
||||
|
||||
def _assert_scores(self, expected: float | dict[str, float],
|
||||
actual: float | dict[str, float]) -> bool:
|
||||
def _assert_scores(
|
||||
self, expected: float | dict[str, float], actual: float | dict[str, float]
|
||||
) -> bool:
|
||||
"""
|
||||
Compare expected and actual scores, and return whether the test case passed.
|
||||
|
||||
@@ -122,4 +145,4 @@ class ExperimentRunner:
|
||||
# All matching keys must have actual >= expected
|
||||
return all(actual[key] >= expected[key] for key in matching_keys)
|
||||
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
from crewai.experimental.evaluation.metrics.goal_metrics import GoalAlignmentEvaluator
|
||||
from crewai.experimental.evaluation.metrics.reasoning_metrics import (
|
||||
ReasoningEfficiencyEvaluator
|
||||
ReasoningEfficiencyEvaluator,
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics.tools_metrics import (
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics.goal_metrics import (
|
||||
GoalAlignmentEvaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics.semantic_quality_metrics import (
|
||||
SemanticQualityEvaluator
|
||||
SemanticQualityEvaluator,
|
||||
)
|
||||
from crewai.experimental.evaluation.metrics.tools_metrics import (
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"GoalAlignmentEvaluator",
|
||||
"SemanticQualityEvaluator"
|
||||
]
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"SemanticQualityEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
|
||||
class GoalAlignmentEvaluator(BaseEvaluator):
|
||||
@property
|
||||
@@ -14,7 +18,7 @@ class GoalAlignmentEvaluator(BaseEvaluator):
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -23,7 +27,9 @@ class GoalAlignmentEvaluator(BaseEvaluator):
|
||||
task_context = f"Task description: {task.description}\nExpected output: {task.expected_output}\n"
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
|
||||
|
||||
Score the agent's goal alignment on a scale from 0-10 where:
|
||||
- 0: Complete misalignment, agent did not understand or attempt the task goal
|
||||
@@ -37,8 +43,11 @@ Consider:
|
||||
4. Did the agent provide all requested information or deliverables?
|
||||
|
||||
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
Agent goal: {agent.goal}
|
||||
{task_context}
|
||||
@@ -47,7 +56,8 @@ Agent's final output:
|
||||
{final_output}
|
||||
|
||||
Evaluate how well the agent's output aligns with the assigned task goal.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
assert self.llm is not None
|
||||
response = self.llm.call(prompt)
|
||||
@@ -59,11 +69,11 @@ Evaluate how well the agent's output aligns with the assigned task goal.
|
||||
return EvaluationScore(
|
||||
score=evaluation_data.get("score", 0),
|
||||
feedback=evaluation_data.get("feedback", response),
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
except Exception:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@@ -8,18 +8,23 @@ This module provides evaluator implementations for:
|
||||
|
||||
import logging
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class ReasoningPatternType(Enum):
|
||||
EFFICIENT = "efficient" # Good reasoning flow
|
||||
LOOP = "loop" # Agent is stuck in a loop
|
||||
@@ -36,7 +41,7 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: TaskOutput | str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -49,7 +54,7 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
if not llm_calls or len(llm_calls) < 2:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="Insufficient LLM calls to evaluate reasoning efficiency."
|
||||
feedback="Insufficient LLM calls to evaluate reasoning efficiency.",
|
||||
)
|
||||
|
||||
total_calls = len(llm_calls)
|
||||
@@ -58,12 +63,16 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
time_intervals = []
|
||||
has_reliable_timing = True
|
||||
for i in range(1, len(llm_calls)):
|
||||
start_time = llm_calls[i-1].get("end_time")
|
||||
start_time = llm_calls[i - 1].get("end_time")
|
||||
end_time = llm_calls[i].get("start_time")
|
||||
if start_time and end_time and start_time != end_time:
|
||||
try:
|
||||
interval = end_time - start_time
|
||||
time_intervals.append(interval.total_seconds() if hasattr(interval, 'total_seconds') else 0)
|
||||
time_intervals.append(
|
||||
interval.total_seconds()
|
||||
if hasattr(interval, "total_seconds")
|
||||
else 0
|
||||
)
|
||||
except Exception:
|
||||
has_reliable_timing = False
|
||||
else:
|
||||
@@ -83,14 +92,22 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
if has_reliable_timing and time_intervals:
|
||||
efficiency_metrics["avg_time_between_calls"] = np.mean(time_intervals)
|
||||
|
||||
loop_info = f"Detected {len(loop_details)} potential reasoning loops." if loop_detected else "No significant reasoning loops detected."
|
||||
loop_info = (
|
||||
f"Detected {len(loop_details)} potential reasoning loops."
|
||||
if loop_detected
|
||||
else "No significant reasoning loops detected."
|
||||
)
|
||||
|
||||
call_samples = self._get_call_samples(llm_calls)
|
||||
|
||||
final_output = final_output.raw if isinstance(final_output, TaskOutput) else final_output
|
||||
final_output = (
|
||||
final_output.raw if isinstance(final_output, TaskOutput) else final_output
|
||||
)
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
|
||||
|
||||
Evaluate the agent's reasoning efficiency across these five key subcategories:
|
||||
|
||||
@@ -120,8 +137,11 @@ Return your evaluation as JSON with the following structure:
|
||||
"feedback": string (general feedback about overall reasoning efficiency),
|
||||
"optimization_suggestions": string (concrete suggestions for improving reasoning efficiency),
|
||||
"detected_patterns": string (describe any inefficient reasoning patterns you observe)
|
||||
}"""},
|
||||
{"role": "user", "content": f"""
|
||||
}""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -140,7 +160,8 @@ Agent's final output:
|
||||
|
||||
Evaluate the reasoning efficiency of this agent based on these interaction patterns.
|
||||
Identify any inefficient reasoning patterns and provide specific suggestions for optimization.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
@@ -156,34 +177,46 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
conciseness = scores.get("conciseness", 5.0)
|
||||
loop_avoidance = scores.get("loop_avoidance", 5.0)
|
||||
|
||||
overall_score = evaluation_data.get("overall_score", evaluation_data.get("score", 5.0))
|
||||
overall_score = evaluation_data.get(
|
||||
"overall_score", evaluation_data.get("score", 5.0)
|
||||
)
|
||||
feedback = evaluation_data.get("feedback", "No detailed feedback provided.")
|
||||
optimization_suggestions = evaluation_data.get("optimization_suggestions", "No specific suggestions provided.")
|
||||
optimization_suggestions = evaluation_data.get(
|
||||
"optimization_suggestions", "No specific suggestions provided."
|
||||
)
|
||||
|
||||
detailed_feedback = "Reasoning Efficiency Evaluation:\n"
|
||||
detailed_feedback += f"• Focus: {focus}/10 - Staying on topic without tangents\n"
|
||||
detailed_feedback += f"• Progression: {progression}/10 - Building on previous thinking\n"
|
||||
detailed_feedback += (
|
||||
f"• Focus: {focus}/10 - Staying on topic without tangents\n"
|
||||
)
|
||||
detailed_feedback += (
|
||||
f"• Progression: {progression}/10 - Building on previous thinking\n"
|
||||
)
|
||||
detailed_feedback += f"• Decision Quality: {decision_quality}/10 - Making appropriate decisions\n"
|
||||
detailed_feedback += f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
|
||||
detailed_feedback += (
|
||||
f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
|
||||
)
|
||||
detailed_feedback += f"• Loop Avoidance: {loop_avoidance}/10 - Avoiding repetitive patterns\n\n"
|
||||
|
||||
detailed_feedback += f"Feedback:\n{feedback}\n\n"
|
||||
detailed_feedback += f"Optimization Suggestions:\n{optimization_suggestions}"
|
||||
detailed_feedback += (
|
||||
f"Optimization Suggestions:\n{optimization_suggestions}"
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=float(overall_score),
|
||||
feedback=detailed_feedback,
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse reasoning efficiency evaluation: {e}")
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Failed to parse reasoning efficiency evaluation. Raw response: {response[:200]}...",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def _detect_loops(self, llm_calls: List[Dict]) -> Tuple[bool, List[Dict]]:
|
||||
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
|
||||
loop_details = []
|
||||
|
||||
messages = []
|
||||
@@ -205,18 +238,20 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
# A more sophisticated approach would use semantic similarity
|
||||
similarity = self._calculate_text_similarity(messages[i], messages[j])
|
||||
if similarity > 0.7: # Arbitrary threshold
|
||||
loop_details.append({
|
||||
"first_occurrence": i,
|
||||
"second_occurrence": j,
|
||||
"similarity": similarity,
|
||||
"snippet": messages[i][:100] + "..."
|
||||
})
|
||||
loop_details.append(
|
||||
{
|
||||
"first_occurrence": i,
|
||||
"second_occurrence": j,
|
||||
"similarity": similarity,
|
||||
"snippet": messages[i][:100] + "...",
|
||||
}
|
||||
)
|
||||
|
||||
return len(loop_details) > 0, loop_details
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
text1 = re.sub(r'\s+', ' ', text1.lower()).strip()
|
||||
text2 = re.sub(r'\s+', ' ', text2.lower()).strip()
|
||||
text1 = re.sub(r"\s+", " ", text1.lower()).strip()
|
||||
text2 = re.sub(r"\s+", " ", text2.lower()).strip()
|
||||
|
||||
# Simple Jaccard similarity on word sets
|
||||
words1 = set(text1.split())
|
||||
@@ -227,7 +262,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
def _analyze_reasoning_patterns(self, llm_calls: List[Dict]) -> Dict[str, Any]:
|
||||
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
|
||||
call_lengths = []
|
||||
response_times = []
|
||||
|
||||
@@ -267,7 +302,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
details = "Agent is consistently verbose across interactions."
|
||||
elif len(llm_calls) > 10 and length_trend > 0.5:
|
||||
primary_pattern = ReasoningPatternType.INDECISIVE
|
||||
details = "Agent shows signs of indecisiveness with increasing message lengths."
|
||||
details = (
|
||||
"Agent shows signs of indecisiveness with increasing message lengths."
|
||||
)
|
||||
elif std_length / avg_length > 0.8:
|
||||
primary_pattern = ReasoningPatternType.SCATTERED
|
||||
details = "Agent shows inconsistent reasoning flow with highly variable responses."
|
||||
@@ -279,8 +316,8 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
"avg_length": avg_length,
|
||||
"std_length": std_length,
|
||||
"length_trend": length_trend,
|
||||
"loop_score": loop_score
|
||||
}
|
||||
"loop_score": loop_score,
|
||||
},
|
||||
}
|
||||
|
||||
def _calculate_trend(self, values: Sequence[float | int]) -> float:
|
||||
@@ -303,7 +340,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _calculate_loop_likelihood(self, call_lengths: Sequence[float], response_times: Sequence[float]) -> float:
|
||||
def _calculate_loop_likelihood(
|
||||
self, call_lengths: Sequence[float], response_times: Sequence[float]
|
||||
) -> float:
|
||||
if not call_lengths or len(call_lengths) < 3:
|
||||
return 0.0
|
||||
|
||||
@@ -312,7 +351,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
if len(call_lengths) >= 4:
|
||||
repeated_lengths = 0
|
||||
for i in range(len(call_lengths) - 2):
|
||||
ratio = call_lengths[i] / call_lengths[i + 2] if call_lengths[i + 2] > 0 else 0
|
||||
ratio = (
|
||||
call_lengths[i] / call_lengths[i + 2]
|
||||
if call_lengths[i + 2] > 0
|
||||
else 0
|
||||
)
|
||||
if 0.85 <= ratio <= 1.15:
|
||||
repeated_lengths += 1
|
||||
|
||||
@@ -331,14 +374,20 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
|
||||
return np.mean(indicators) if indicators else 0.0
|
||||
|
||||
def _get_call_samples(self, llm_calls: List[Dict]) -> str:
|
||||
def _get_call_samples(self, llm_calls: list[dict]) -> str:
|
||||
samples = []
|
||||
|
||||
if len(llm_calls) <= 6:
|
||||
sample_indices = list(range(len(llm_calls)))
|
||||
else:
|
||||
sample_indices = [0, 1, len(llm_calls) // 2 - 1, len(llm_calls) // 2,
|
||||
len(llm_calls) - 2, len(llm_calls) - 1]
|
||||
sample_indices = [
|
||||
0,
|
||||
1,
|
||||
len(llm_calls) // 2 - 1,
|
||||
len(llm_calls) // 2,
|
||||
len(llm_calls) - 2,
|
||||
len(llm_calls) - 1,
|
||||
]
|
||||
|
||||
for idx in sample_indices:
|
||||
call = llm_calls[idx]
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
|
||||
class SemanticQualityEvaluator(BaseEvaluator):
|
||||
@property
|
||||
@@ -14,7 +18,7 @@ class SemanticQualityEvaluator(BaseEvaluator):
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -22,7 +26,9 @@ class SemanticQualityEvaluator(BaseEvaluator):
|
||||
if task is not None:
|
||||
task_context = f"Task description: {task.description}"
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
|
||||
|
||||
Score the semantic quality on a scale from 0-10 where:
|
||||
- 0: Completely incoherent, confusing, or logically flawed output
|
||||
@@ -37,8 +43,11 @@ Consider:
|
||||
5. Is the output free from contradictions and logical fallacies?
|
||||
|
||||
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -46,7 +55,8 @@ Agent's final output:
|
||||
{final_output}
|
||||
|
||||
Evaluate the semantic quality and reasoning of this output.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
@@ -56,13 +66,15 @@ Evaluate the semantic quality and reasoning of this output.
|
||||
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
return EvaluationScore(
|
||||
score=float(evaluation_data["score"]) if evaluation_data.get("score") is not None else None,
|
||||
score=float(evaluation_data["score"])
|
||||
if evaluation_data.get("score") is not None
|
||||
else None,
|
||||
feedback=evaluation_data.get("feedback", response),
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
except Exception:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
||||
raw_response=response
|
||||
)
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.agent import Agent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ToolSelectionEvaluator(BaseEvaluator):
|
||||
|
||||
@property
|
||||
def metric_category(self) -> MetricCategory:
|
||||
return MetricCategory.TOOL_SELECTION
|
||||
@@ -16,7 +19,7 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -26,19 +29,18 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
||||
|
||||
tool_uses = execution_trace.get("tool_uses", [])
|
||||
tool_count = len(tool_uses)
|
||||
unique_tool_types = set([tool.get("tool", "Unknown tool") for tool in tool_uses])
|
||||
unique_tool_types = set(
|
||||
[tool.get("tool", "Unknown tool") for tool in tool_uses]
|
||||
)
|
||||
|
||||
if tool_count == 0:
|
||||
if not agent.tools:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="Agent had no tools available to use."
|
||||
)
|
||||
else:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="Agent had tools available but didn't use any."
|
||||
score=None, feedback="Agent had no tools available to use."
|
||||
)
|
||||
return EvaluationScore(
|
||||
score=None, feedback="Agent had tools available but didn't use any."
|
||||
)
|
||||
|
||||
available_tools_info = ""
|
||||
if agent.tools:
|
||||
@@ -52,7 +54,9 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
||||
tool_types_summary += f"- {tool_type}\n"
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
|
||||
|
||||
You must evaluate based on these 2 criteria:
|
||||
1. Relevance (0-10): Were the tools chosen directly aligned with the task's goals?
|
||||
@@ -73,8 +77,11 @@ Return your evaluation as JSON with these fields:
|
||||
- overall_score: number (average of all scores, 0-10)
|
||||
- feedback: string (focused ONLY on tool selection decisions from available tools)
|
||||
- improvement_suggestions: string (ONLY suggest better selection from the AVAILABLE tools list, NOT new tools)
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -89,7 +96,8 @@ IMPORTANT:
|
||||
- ONLY evaluate selection from tools listed as available
|
||||
- DO NOT suggest new tools that aren't in the available tools list
|
||||
- DO NOT evaluate tool usage or results
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
assert self.llm is not None
|
||||
response = self.llm.call(prompt)
|
||||
@@ -105,22 +113,24 @@ IMPORTANT:
|
||||
|
||||
feedback = "Tool Selection Evaluation:\n"
|
||||
feedback += f"• Relevance: {relevance}/10 - Selection of appropriate tool types for the task\n"
|
||||
feedback += f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
|
||||
feedback += (
|
||||
f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
|
||||
)
|
||||
if "improvement_suggestions" in evaluation_data:
|
||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||
else:
|
||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
||||
feedback += evaluation_data.get(
|
||||
"feedback", "No detailed feedback available."
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=overall_score,
|
||||
feedback=feedback,
|
||||
raw_response=response
|
||||
score=overall_score, feedback=feedback, raw_response=response
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Error evaluating tool selection: {e}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
|
||||
@@ -132,7 +142,7 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -145,19 +155,26 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
if tool_count == 0:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="No tool usage detected. Cannot evaluate parameter extraction."
|
||||
feedback="No tool usage detected. Cannot evaluate parameter extraction.",
|
||||
)
|
||||
|
||||
validation_errors = []
|
||||
for tool_use in tool_uses:
|
||||
if not tool_use.get("success", True) and tool_use.get("error_type") == "validation_error":
|
||||
validation_errors.append({
|
||||
"tool": tool_use.get("tool", "Unknown tool"),
|
||||
"error": tool_use.get("result"),
|
||||
"args": tool_use.get("args", {})
|
||||
})
|
||||
if (
|
||||
not tool_use.get("success", True)
|
||||
and tool_use.get("error_type") == "validation_error"
|
||||
):
|
||||
validation_errors.append(
|
||||
{
|
||||
"tool": tool_use.get("tool", "Unknown tool"),
|
||||
"error": tool_use.get("result"),
|
||||
"args": tool_use.get("args", {}),
|
||||
}
|
||||
)
|
||||
|
||||
validation_error_rate = len(validation_errors) / tool_count if tool_count > 0 else 0
|
||||
validation_error_rate = (
|
||||
len(validation_errors) / tool_count if tool_count > 0 else 0
|
||||
)
|
||||
|
||||
param_samples = []
|
||||
for i, tool_use in enumerate(tool_uses[:5]):
|
||||
@@ -168,7 +185,7 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
|
||||
is_validation_error = error_type == "validation_error"
|
||||
|
||||
sample = f"Tool use #{i+1} - {tool_name}:\n"
|
||||
sample = f"Tool use #{i + 1} - {tool_name}:\n"
|
||||
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
||||
sample += f"- Success: {'No' if not success else 'Yes'}"
|
||||
|
||||
@@ -187,13 +204,17 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
tool_name = err.get("tool", "Unknown tool")
|
||||
error_msg = err.get("error", "Unknown error")
|
||||
args = err.get("args", {})
|
||||
validation_errors_info += f"\nValidation Error #{i+1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
|
||||
validation_errors_info += f"\nValidation Error #{i + 1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
|
||||
|
||||
if len(validation_errors) > 3:
|
||||
validation_errors_info += f"\n...and {len(validation_errors) - 3} more validation errors."
|
||||
validation_errors_info += (
|
||||
f"\n...and {len(validation_errors) - 3} more validation errors."
|
||||
)
|
||||
param_samples_text = "\n\n".join(param_samples)
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
|
||||
|
||||
Your job is to evaluate ONLY whether the agent used the correct parameter VALUES, not whether the right tools were selected or how the tools were invoked.
|
||||
|
||||
@@ -216,8 +237,11 @@ Return your evaluation as JSON with these fields:
|
||||
- overall_score: number (average of all scores, 0-10)
|
||||
- feedback: string (focused ONLY on parameter value extraction quality)
|
||||
- improvement_suggestions: string (concrete suggestions for better parameter VALUE extraction)
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -226,7 +250,8 @@ Parameter extraction examples:
|
||||
{validation_errors_info}
|
||||
|
||||
Evaluate the quality of the agent's parameter extraction for this task.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
@@ -251,18 +276,18 @@ Evaluate the quality of the agent's parameter extraction for this task.
|
||||
if "improvement_suggestions" in evaluation_data:
|
||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||
else:
|
||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
||||
feedback += evaluation_data.get(
|
||||
"feedback", "No detailed feedback available."
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=overall_score,
|
||||
feedback=feedback,
|
||||
raw_response=response
|
||||
score=overall_score, feedback=feedback, raw_response=response
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Error evaluating parameter extraction: {e}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
|
||||
@@ -274,7 +299,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -288,7 +313,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
if tool_count == 0:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="No tool usage detected. Cannot evaluate tool invocation."
|
||||
feedback="No tool usage detected. Cannot evaluate tool invocation.",
|
||||
)
|
||||
|
||||
for tool_use in tool_uses:
|
||||
@@ -296,7 +321,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
error_info = {
|
||||
"tool": tool_use.get("tool", "Unknown tool"),
|
||||
"error": tool_use.get("result"),
|
||||
"error_type": tool_use.get("error_type", "unknown_error")
|
||||
"error_type": tool_use.get("error_type", "unknown_error"),
|
||||
}
|
||||
tool_errors.append(error_info)
|
||||
|
||||
@@ -315,9 +340,11 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
tool_args = tool_use.get("args", {})
|
||||
success = tool_use.get("success", True) and not tool_use.get("error", False)
|
||||
error_type = tool_use.get("error_type", "") if not success else ""
|
||||
error_msg = tool_use.get("result", "No error") if not success else "No error"
|
||||
error_msg = (
|
||||
tool_use.get("result", "No error") if not success else "No error"
|
||||
)
|
||||
|
||||
sample = f"Tool invocation #{i+1}:\n"
|
||||
sample = f"Tool invocation #{i + 1}:\n"
|
||||
sample += f"- Tool: {tool_name}\n"
|
||||
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
||||
sample += f"- Success: {'No' if not success else 'Yes'}\n"
|
||||
@@ -330,11 +357,13 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
if error_types:
|
||||
error_type_summary = "Error type breakdown:\n"
|
||||
for error_type, count in error_types.items():
|
||||
error_type_summary += f"- {error_type}: {count} occurrences ({(count/tool_count):.1%})\n"
|
||||
error_type_summary += f"- {error_type}: {count} occurrences ({(count / tool_count):.1%})\n"
|
||||
|
||||
invocation_samples_text = "\n\n".join(invocation_samples)
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
|
||||
|
||||
Your job is to evaluate ONLY the structural and syntactical aspects of how the agent called tools, NOT which tools were selected or what parameter values were used.
|
||||
|
||||
@@ -359,8 +388,11 @@ Return your evaluation as JSON with these fields:
|
||||
- overall_score: number (average of all scores, 0-10)
|
||||
- feedback: string (focused ONLY on structural aspects of tool invocation)
|
||||
- improvement_suggestions: string (concrete suggestions for better structuring of tool calls)
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -371,7 +403,8 @@ Tool error rate: {error_rate:.2%} ({len(tool_errors)} errors out of {tool_count}
|
||||
{error_type_summary}
|
||||
|
||||
Evaluate the quality of the agent's tool invocation structure during this task.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
@@ -388,23 +421,25 @@ Evaluate the quality of the agent's tool invocation structure during this task.
|
||||
overall_score = float(evaluation_data.get("overall_score", 5.0))
|
||||
|
||||
feedback = "Tool Invocation Evaluation:\n"
|
||||
feedback += f"• Structure: {structure}/10 - Following proper syntax and format\n"
|
||||
feedback += (
|
||||
f"• Structure: {structure}/10 - Following proper syntax and format\n"
|
||||
)
|
||||
feedback += f"• Error Handling: {error_handling}/10 - Appropriately handling tool errors\n"
|
||||
feedback += f"• Invocation Patterns: {invocation_patterns}/10 - Proper sequencing and management of calls\n\n"
|
||||
|
||||
if "improvement_suggestions" in evaluation_data:
|
||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||
else:
|
||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
||||
feedback += evaluation_data.get(
|
||||
"feedback", "No detailed feedback available."
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=overall_score,
|
||||
feedback=feedback,
|
||||
raw_response=response
|
||||
score=overall_score, feedback=feedback, raw_response=response
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Error evaluating tool invocation: {e}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
from typing_extensions import Any
|
||||
import warnings
|
||||
from crewai.experimental.evaluation.experiment import ExperimentResults, ExperimentRunner
|
||||
from crewai import Crew, Agent
|
||||
|
||||
def assert_experiment_successfully(experiment_results: ExperimentResults, baseline_filepath: str | None = None) -> None:
|
||||
failed_tests = [result for result in experiment_results.results if not result.passed]
|
||||
from crewai import Agent, Crew
|
||||
from crewai.experimental.evaluation.experiment import (
|
||||
ExperimentResults,
|
||||
ExperimentRunner,
|
||||
)
|
||||
|
||||
|
||||
def assert_experiment_successfully(
|
||||
experiment_results: ExperimentResults, baseline_filepath: str | None = None
|
||||
) -> None:
|
||||
failed_tests = [
|
||||
result for result in experiment_results.results if not result.passed
|
||||
]
|
||||
|
||||
if failed_tests:
|
||||
detailed_failures: list[str] = []
|
||||
@@ -14,39 +23,54 @@ def assert_experiment_successfully(experiment_results: ExperimentResults, baseli
|
||||
for result in failed_tests:
|
||||
expected = result.expected_score
|
||||
actual = result.score
|
||||
detailed_failures.append(f"- {result.identifier}: expected {expected}, got {actual}")
|
||||
detailed_failures.append(
|
||||
f"- {result.identifier}: expected {expected}, got {actual}"
|
||||
)
|
||||
|
||||
failure_details = "\n".join(detailed_failures)
|
||||
raise AssertionError(f"The following test cases failed:\n{failure_details}")
|
||||
|
||||
baseline_filepath = baseline_filepath or _get_baseline_filepath_fallback()
|
||||
comparison = experiment_results.compare_with_baseline(baseline_filepath=baseline_filepath)
|
||||
comparison = experiment_results.compare_with_baseline(
|
||||
baseline_filepath=baseline_filepath
|
||||
)
|
||||
assert_experiment_no_regression(comparison)
|
||||
|
||||
|
||||
def assert_experiment_no_regression(comparison_result: dict[str, list[str]]) -> None:
|
||||
regressed = comparison_result.get("regressed", [])
|
||||
if regressed:
|
||||
raise AssertionError(f"Regression detected! The following tests that previously passed now fail: {regressed}")
|
||||
raise AssertionError(
|
||||
f"Regression detected! The following tests that previously passed now fail: {regressed}"
|
||||
)
|
||||
|
||||
missing_tests = comparison_result.get("missing_tests", [])
|
||||
if missing_tests:
|
||||
warnings.warn(
|
||||
f"Warning: {len(missing_tests)} tests from the baseline are missing in the current run: {missing_tests}",
|
||||
UserWarning
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def run_experiment(dataset: list[dict[str, Any]], crew: Crew | None = None, agents: list[Agent] | None = None, verbose: bool = False) -> ExperimentResults:
|
||||
|
||||
def run_experiment(
|
||||
dataset: list[dict[str, Any]],
|
||||
crew: Crew | None = None,
|
||||
agents: list[Agent] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> ExperimentResults:
|
||||
runner = ExperimentRunner(dataset=dataset)
|
||||
|
||||
return runner.run(agents=agents, crew=crew, print_summary=verbose)
|
||||
|
||||
|
||||
def _get_baseline_filepath_fallback() -> str:
|
||||
test_func_name = "experiment_fallback"
|
||||
|
||||
try:
|
||||
current_frame = inspect.currentframe()
|
||||
if current_frame is not None:
|
||||
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
|
||||
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
|
||||
except Exception:
|
||||
...
|
||||
return f"{test_func_name}_results.json"
|
||||
return f"{test_func_name}_results.json"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from crewai.flow.flow import Flow, start, listen, or_, and_, router
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.persistence import persist
|
||||
|
||||
__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"]
|
||||
|
||||
__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]
|
||||
|
||||
@@ -1086,7 +1086,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_type, trigger_methods = self._listeners[
|
||||
_condition_type, trigger_methods = self._listeners[
|
||||
method_name
|
||||
]
|
||||
if current_trigger in trigger_methods:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||
|
||||
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
|
||||
inspecting the call stack.
|
||||
"""
|
||||
|
||||
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
||||
parent_flow: InstanceOf[Flow] | None = Field(
|
||||
default=None,
|
||||
description="The parent flow of the instance, if it was created inside a flow.",
|
||||
)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
# flow_visualizer.py
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pyvis.network import Network
|
||||
|
||||
from crewai.flow.config import COLORS, NODE_STYLES
|
||||
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
||||
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import safe_path_join
|
||||
from crewai.flow.utils import calculate_node_levels
|
||||
from crewai.flow.visualization_utils import (
|
||||
add_edges,
|
||||
@@ -34,13 +33,13 @@ class FlowPlot:
|
||||
ValueError
|
||||
If flow object is invalid or missing required attributes.
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
if not hasattr(flow, "_methods"):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
if not hasattr(flow, "_listeners"):
|
||||
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
||||
if not hasattr(flow, '_start_methods'):
|
||||
if not hasattr(flow, "_start_methods"):
|
||||
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
|
||||
|
||||
|
||||
self.flow = flow
|
||||
self.colors = COLORS
|
||||
self.node_styles = NODE_STYLES
|
||||
@@ -65,7 +64,7 @@ class FlowPlot:
|
||||
"""
|
||||
if not filename or not isinstance(filename, str):
|
||||
raise ValueError("Filename must be a non-empty string")
|
||||
|
||||
|
||||
try:
|
||||
# Initialize network
|
||||
net = Network(
|
||||
@@ -96,32 +95,32 @@ class FlowPlot:
|
||||
try:
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to calculate node levels: {str(e)}")
|
||||
raise ValueError(f"Failed to calculate node levels: {e!s}")
|
||||
|
||||
# Compute positions
|
||||
try:
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to compute node positions: {str(e)}")
|
||||
raise ValueError(f"Failed to compute node positions: {e!s}")
|
||||
|
||||
# Add nodes to the network
|
||||
try:
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
||||
raise RuntimeError(f"Failed to add nodes to network: {e!s}")
|
||||
|
||||
# Add edges to the network
|
||||
try:
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
|
||||
raise RuntimeError(f"Failed to add edges to network: {e!s}")
|
||||
|
||||
# Generate HTML
|
||||
try:
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
|
||||
raise RuntimeError(f"Failed to generate network visualization: {e!s}")
|
||||
|
||||
# Save the final HTML content to the file
|
||||
try:
|
||||
@@ -129,12 +128,14 @@ class FlowPlot:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {filename}.html")
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
|
||||
raise IOError(
|
||||
f"Failed to save flow visualization to {filename}.html: {e!s}"
|
||||
)
|
||||
|
||||
except (ValueError, RuntimeError, IOError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
|
||||
raise RuntimeError(f"Unexpected error during flow visualization: {e!s}")
|
||||
finally:
|
||||
self._cleanup_pyvis_lib()
|
||||
|
||||
@@ -165,7 +166,9 @@ class FlowPlot:
|
||||
try:
|
||||
# Extract just the body content from the generated HTML
|
||||
current_dir = os.path.dirname(__file__)
|
||||
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
|
||||
template_path = safe_path_join(
|
||||
"assets", "crewai_flow_visual_template.html", root=current_dir
|
||||
)
|
||||
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
|
||||
|
||||
if not os.path.exists(template_path):
|
||||
@@ -179,12 +182,9 @@ class FlowPlot:
|
||||
# Generate the legend items HTML
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
network_body, legend_items_html
|
||||
)
|
||||
return final_html_content
|
||||
return html_handler.generate_final_html(network_body, legend_items_html)
|
||||
except Exception as e:
|
||||
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
|
||||
raise IOError(f"Failed to generate visualization HTML: {e!s}")
|
||||
|
||||
def _cleanup_pyvis_lib(self):
|
||||
"""
|
||||
@@ -197,6 +197,7 @@ class FlowPlot:
|
||||
lib_folder = safe_path_join("lib", root=os.getcwd())
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(lib_folder)
|
||||
except ValueError as e:
|
||||
print(f"Error validating lib folder path: {e}")
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import base64
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import validate_path_exists
|
||||
|
||||
|
||||
class HTMLTemplateHandler:
|
||||
@@ -53,23 +52,23 @@ class HTMLTemplateHandler:
|
||||
if "border" in item:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
<div class="legend-color-box" style="background-color: {item["color"]}; border: 2px dashed {item["border"]};"></div>
|
||||
<div>{item["label"]}</div>
|
||||
</div>
|
||||
"""
|
||||
elif item.get("dashed") is not None:
|
||||
style = "dashed" if item["dashed"] else "solid"
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item["color"]};"></div>
|
||||
<div>{item["label"]}</div>
|
||||
</div>
|
||||
"""
|
||||
else:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
<div class="legend-color-box" style="background-color: {item["color"]};"></div>
|
||||
<div>{item["label"]}</div>
|
||||
</div>
|
||||
"""
|
||||
return legend_items_html
|
||||
@@ -86,8 +85,6 @@ class HTMLTemplateHandler:
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
return final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
)
|
||||
|
||||
return final_html_content
|
||||
|
||||
@@ -5,12 +5,10 @@ This module provides utilities for secure path handling to prevent directory
|
||||
traversal attacks and ensure paths remain within allowed boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
|
||||
"""
|
||||
Safely join path components and ensure the result is within allowed boundaries.
|
||||
|
||||
@@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
|
||||
# Establish root directory
|
||||
root_path = Path(root).resolve() if root else Path.cwd()
|
||||
|
||||
|
||||
# Join and resolve the full path
|
||||
full_path = Path(root_path, *clean_parts).resolve()
|
||||
|
||||
|
||||
# Check if the resolved path is within root
|
||||
if not str(full_path).startswith(str(root_path)):
|
||||
raise ValueError(
|
||||
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
||||
)
|
||||
|
||||
|
||||
return str(full_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path components: {str(e)}")
|
||||
raise ValueError(f"Invalid path components: {e!s}")
|
||||
|
||||
|
||||
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
|
||||
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
|
||||
"""
|
||||
Validate that a path exists and is of the expected type.
|
||||
|
||||
@@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path).resolve()
|
||||
|
||||
|
||||
if not path_obj.exists():
|
||||
raise ValueError(f"Path does not exist: {path}")
|
||||
|
||||
|
||||
if file_type == "file" and not path_obj.is_file():
|
||||
raise ValueError(f"Path is not a file: {path}")
|
||||
elif file_type == "directory" and not path_obj.is_dir():
|
||||
if file_type == "directory" and not path_obj.is_dir():
|
||||
raise ValueError(f"Path is not a directory: {path}")
|
||||
|
||||
|
||||
return str(path_obj)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path: {str(e)}")
|
||||
raise ValueError(f"Invalid path: {e!s}")
|
||||
|
||||
|
||||
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
|
||||
"""
|
||||
Safely list files in a directory matching a pattern.
|
||||
|
||||
@@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
dir_path = Path(directory).resolve()
|
||||
if not dir_path.is_dir():
|
||||
raise ValueError(f"Not a directory: {directory}")
|
||||
|
||||
|
||||
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error listing files: {str(e)}")
|
||||
raise ValueError(f"Error listing files: {e!s}")
|
||||
|
||||
@@ -12,7 +12,7 @@ from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.persistence.decorators import persist
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]
|
||||
__all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"]
|
||||
|
||||
StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
|
||||
DictStateType = Dict[str, Any]
|
||||
StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel)
|
||||
DictStateType = dict[str, Any]
|
||||
|
||||
@@ -1,53 +1,47 @@
|
||||
"""Base class for flow state persistence."""
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FlowPersistence(abc.ABC):
|
||||
"""Abstract base class for flow state persistence.
|
||||
|
||||
|
||||
This class defines the interface that all persistence implementations must follow.
|
||||
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
|
||||
"""
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_db(self) -> None:
|
||||
"""Initialize the persistence backend.
|
||||
|
||||
|
||||
This method should handle any necessary setup, such as:
|
||||
- Creating tables
|
||||
- Establishing connections
|
||||
- Setting up indexes
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel]
|
||||
self, flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel
|
||||
) -> None:
|
||||
"""Persist the flow state after method completion.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -24,13 +24,10 @@ Example:
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
|
||||
"save_state": "Saving flow state to memory for ID: {}",
|
||||
"save_error": "Failed to persist state for method {}: {}",
|
||||
"state_missing": "Flow instance has no state",
|
||||
"id_missing": "Flow state must have an 'id' field for persistence"
|
||||
"id_missing": "Flow state must have an 'id' field for persistence",
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +55,13 @@ class PersistenceDecorator:
|
||||
_printer = Printer() # Class-level printer instance
|
||||
|
||||
@classmethod
|
||||
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence, verbose: bool = False) -> None:
|
||||
def persist_state(
|
||||
cls,
|
||||
flow_instance: Any,
|
||||
method_name: str,
|
||||
persistence_instance: FlowPersistence,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Persist flow state with proper error handling and logging.
|
||||
|
||||
This method handles the persistence of flow state data, including proper
|
||||
@@ -76,22 +79,24 @@ class PersistenceDecorator:
|
||||
AttributeError: If flow instance lacks required state attributes
|
||||
"""
|
||||
try:
|
||||
state = getattr(flow_instance, 'state', None)
|
||||
state = getattr(flow_instance, "state", None)
|
||||
if state is None:
|
||||
raise ValueError("Flow instance has no state")
|
||||
|
||||
flow_uuid: Optional[str] = None
|
||||
flow_uuid: str | None = None
|
||||
if isinstance(state, dict):
|
||||
flow_uuid = state.get('id')
|
||||
flow_uuid = state.get("id")
|
||||
elif isinstance(state, BaseModel):
|
||||
flow_uuid = getattr(state, 'id', None)
|
||||
flow_uuid = getattr(state, "id", None)
|
||||
|
||||
if not flow_uuid:
|
||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||
|
||||
# Log state saving only if verbose is True
|
||||
if verbose:
|
||||
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
|
||||
cls._printer.print(
|
||||
LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan"
|
||||
)
|
||||
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
|
||||
|
||||
try:
|
||||
@@ -104,7 +109,7 @@ class PersistenceDecorator:
|
||||
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(f"State persistence failed: {str(e)}") from e
|
||||
raise RuntimeError(f"State persistence failed: {e!s}") from e
|
||||
except AttributeError:
|
||||
error_msg = LOG_MESSAGES["state_missing"]
|
||||
cls._printer.print(error_msg, color="red")
|
||||
@@ -117,7 +122,7 @@ class PersistenceDecorator:
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
|
||||
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||
"""Decorator to persist flow state.
|
||||
|
||||
This decorator can be applied at either the class level or method level.
|
||||
@@ -144,32 +149,33 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
def begin(self):
|
||||
pass
|
||||
"""
|
||||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
||||
|
||||
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||
"""Decorator that handles both class and method decoration."""
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
|
||||
if isinstance(target, type):
|
||||
# Class decoration
|
||||
original_init = getattr(target, "__init__")
|
||||
original_init = target.__init__
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
if 'persistence' not in kwargs:
|
||||
kwargs['persistence'] = actual_persistence
|
||||
if "persistence" not in kwargs:
|
||||
kwargs["persistence"] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
setattr(target, "__init__", new_init)
|
||||
target.__init__ = new_init
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
original_methods = {}
|
||||
|
||||
for name, method in target.__dict__.items():
|
||||
if callable(method) and (
|
||||
hasattr(method, "__is_start_method__") or
|
||||
hasattr(method, "__trigger_methods__") or
|
||||
hasattr(method, "__condition_type__") or
|
||||
hasattr(method, "__is_flow_method__") or
|
||||
hasattr(method, "__is_router__")
|
||||
hasattr(method, "__is_start_method__")
|
||||
or hasattr(method, "__trigger_methods__")
|
||||
or hasattr(method, "__condition_type__")
|
||||
or hasattr(method, "__is_flow_method__")
|
||||
or hasattr(method, "__is_router__")
|
||||
):
|
||||
original_methods[name] = method
|
||||
|
||||
@@ -177,78 +183,116 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
for name, method in original_methods.items():
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Create a closure to capture the current name and method
|
||||
def create_async_wrapper(method_name: str, original_method: Callable):
|
||||
def create_async_wrapper(
|
||||
method_name: str, original_method: Callable
|
||||
):
|
||||
@functools.wraps(original_method)
|
||||
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
async def method_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
result = await original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_async_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
else:
|
||||
# Create a closure to capture the current name and method
|
||||
def create_sync_wrapper(method_name: str, original_method: Callable):
|
||||
def create_sync_wrapper(
|
||||
method_name: str, original_method: Callable
|
||||
):
|
||||
@functools.wraps(original_method)
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_sync_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
return target
|
||||
else:
|
||||
# Method decoration
|
||||
method = target
|
||||
setattr(method, "__is_flow_method__", True)
|
||||
# Method decoration
|
||||
method = target
|
||||
method.__is_flow_method__ = True
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_async_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
else:
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(
|
||||
flow_instance: Any, *args: Any, **kwargs: Any
|
||||
) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_sync_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
method_async_wrapper.__is_flow_method__ = True
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
method_sync_wrapper.__is_flow_method__ = True
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -23,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
db_path: str
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
def __init__(self, db_path: str | None = None):
|
||||
"""Initialize SQLite persistence.
|
||||
|
||||
Args:
|
||||
@@ -70,7 +70,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel],
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
"""Save the current flow state to SQLite.
|
||||
|
||||
@@ -107,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
),
|
||||
)
|
||||
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,6 +5,7 @@ the Flow system.
|
||||
"""
|
||||
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired, Required
|
||||
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@ import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Deque, Dict, List, Optional, Set, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
@@ -94,7 +94,7 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
return list(return_values) if return_values else None
|
||||
|
||||
|
||||
def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
def calculate_node_levels(flow: Any) -> dict[str, int]:
|
||||
"""
|
||||
Calculate the hierarchical level of each node in the flow.
|
||||
|
||||
@@ -118,10 +118,10 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
- Handles both OR and AND conditions for listeners
|
||||
- Processes router paths separately
|
||||
"""
|
||||
levels: Dict[str, int] = {}
|
||||
queue: Deque[str] = deque()
|
||||
visited: Set[str] = set()
|
||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
levels: dict[str, int] = {}
|
||||
queue: deque[str] = deque()
|
||||
visited: set[str] = set()
|
||||
pending_and_listeners: dict[str, set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
@@ -172,7 +172,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
def count_outgoing_edges(flow: Any) -> dict[str, int]:
|
||||
"""
|
||||
Count the number of outgoing edges for each method in the flow.
|
||||
|
||||
@@ -197,7 +197,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
|
||||
"""
|
||||
Build a dictionary mapping each node to its ancestor nodes.
|
||||
|
||||
@@ -211,8 +211,8 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
Dict[str, Set[str]]
|
||||
Dictionary mapping each node to a set of its ancestor nodes.
|
||||
"""
|
||||
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
|
||||
visited: Set[str] = set()
|
||||
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
|
||||
visited: set[str] = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
@@ -220,7 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
|
||||
|
||||
def dfs_ancestors(
|
||||
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
|
||||
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any
|
||||
) -> None:
|
||||
"""
|
||||
Perform depth-first search to build ancestor relationships.
|
||||
@@ -265,7 +265,7 @@ def dfs_ancestors(
|
||||
|
||||
|
||||
def is_ancestor(
|
||||
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
|
||||
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if one node is an ancestor of another.
|
||||
@@ -287,7 +287,7 @@ def is_ancestor(
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
|
||||
|
||||
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
|
||||
"""
|
||||
Build a dictionary mapping parent nodes to their children.
|
||||
|
||||
@@ -307,7 +307,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
- Maps router methods to their paths and listeners
|
||||
- Children lists are sorted for consistent ordering
|
||||
"""
|
||||
parent_children: Dict[str, List[str]] = {}
|
||||
parent_children: dict[str, list[str]] = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
@@ -332,7 +332,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
|
||||
|
||||
def get_child_index(
|
||||
parent: str, child: str, parent_children: Dict[str, List[str]]
|
||||
parent: str, child: str, parent_children: dict[str, list[str]]
|
||||
) -> int:
|
||||
"""
|
||||
Get the index of a child node in its parent's sorted children list.
|
||||
@@ -364,7 +364,7 @@ def process_router_paths(flow, current, current_level, levels, queue):
|
||||
paths = flow._router_paths.get(current, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
_condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
|
||||
@@ -17,7 +17,7 @@ Example
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
build_ancestor_dict,
|
||||
@@ -56,6 +56,7 @@ def method_calls_crew(method: Any) -> bool:
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to detect .crew() method calls."""
|
||||
|
||||
def __init__(self):
|
||||
self.found = False
|
||||
|
||||
@@ -73,8 +74,8 @@ def method_calls_crew(method: Any) -> bool:
|
||||
def add_nodes_to_network(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
node_styles: Dict[str, Dict[str, Any]]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
node_styles: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Add nodes to the network visualization with appropriate styling.
|
||||
@@ -98,6 +99,7 @@ def add_nodes_to_network(
|
||||
- Crew methods
|
||||
- Regular methods
|
||||
"""
|
||||
|
||||
def human_friendly_label(method_name):
|
||||
return method_name.replace("_", " ").title()
|
||||
|
||||
@@ -138,10 +140,10 @@ def add_nodes_to_network(
|
||||
|
||||
def compute_positions(
|
||||
flow: Any,
|
||||
node_levels: Dict[str, int],
|
||||
node_levels: dict[str, int],
|
||||
y_spacing: float = 150,
|
||||
x_spacing: float = 300
|
||||
) -> Dict[str, Tuple[float, float]]:
|
||||
x_spacing: float = 300,
|
||||
) -> dict[str, tuple[float, float]]:
|
||||
"""
|
||||
Compute the (x, y) positions for each node in the flow graph.
|
||||
|
||||
@@ -161,8 +163,8 @@ def compute_positions(
|
||||
Dict[str, Tuple[float, float]]
|
||||
Dictionary mapping node names to their (x, y) coordinates.
|
||||
"""
|
||||
level_nodes: Dict[int, List[str]] = {}
|
||||
node_positions: Dict[str, Tuple[float, float]] = {}
|
||||
level_nodes: dict[int, list[str]] = {}
|
||||
node_positions: dict[str, tuple[float, float]] = {}
|
||||
|
||||
for method_name, level in node_levels.items():
|
||||
level_nodes.setdefault(level, []).append(method_name)
|
||||
@@ -180,10 +182,10 @@ def compute_positions(
|
||||
def add_edges(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
colors: Dict[str, str]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
colors: dict[str, str],
|
||||
) -> None:
|
||||
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
|
||||
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
|
||||
"""
|
||||
Add edges to the network visualization with appropriate styling.
|
||||
|
||||
@@ -269,7 +271,7 @@ def add_edges(
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
_condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -14,19 +13,19 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
"""Base class for knowledge sources that load content from files."""
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
)
|
||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||
default_factory=list, description="The path to the file"
|
||||
)
|
||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info):
|
||||
def validate_file_path(self, v, info):
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -46,9 +45,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
self.content = self.load_content()
|
||||
|
||||
@abstractmethod
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||
pass
|
||||
|
||||
def validate_content(self):
|
||||
"""Validate the paths."""
|
||||
@@ -74,11 +72,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
def _process_file_paths(self) -> List[Path]:
|
||||
def _process_file_paths(self) -> list[Path]:
|
||||
"""Convert file_path to a list of Path objects."""
|
||||
|
||||
if hasattr(self, "file_path") and self.file_path is not None:
|
||||
@@ -93,7 +91,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
|
||||
# Convert single path to list
|
||||
path_list: List[Union[Path, str]] = (
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
else list(self.file_paths)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -12,29 +12,27 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
|
||||
chunk_size: int = 4000
|
||||
chunk_overlap: int = 200
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
@abstractmethod
|
||||
def validate_content(self) -> Any:
|
||||
"""Load and preprocess content from the source."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them."""
|
||||
pass
|
||||
|
||||
def get_embeddings(self) -> List[np.ndarray]:
|
||||
def get_embeddings(self) -> list[np.ndarray]:
|
||||
"""Return the list of embeddings for the chunks."""
|
||||
return self.chunk_embeddings
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
@@ -35,11 +35,11 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
|
||||
file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||
content: List["DoclingDocument"] = Field(default_factory=list)
|
||||
file_path: list[Path | str] | None = Field(default=None)
|
||||
file_paths: list[Path | str] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
safe_file_paths: list[Path | str] = Field(default_factory=list)
|
||||
content: list["DoclingDocument"] = Field(default_factory=list)
|
||||
document_converter: "DocumentConverter" = Field(
|
||||
default_factory=lambda: DocumentConverter(
|
||||
allowed_formats=[
|
||||
@@ -66,7 +66,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.safe_file_paths = self.validate_content()
|
||||
self.content = self._load_content()
|
||||
|
||||
def _load_content(self) -> List["DoclingDocument"]:
|
||||
def _load_content(self) -> list["DoclingDocument"]:
|
||||
try:
|
||||
return self._convert_source_to_docling_documents()
|
||||
except ConversionError as e:
|
||||
@@ -88,7 +88,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
self._save_documents()
|
||||
|
||||
def _convert_source_to_docling_documents(self) -> List["DoclingDocument"]:
|
||||
def _convert_source_to_docling_documents(self) -> list["DoclingDocument"]:
|
||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
@@ -97,8 +97,8 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
for chunk in chunker.chunk(doc):
|
||||
yield chunk.text
|
||||
|
||||
def validate_content(self) -> List[Union[Path, str]]:
|
||||
processed_paths: List[Union[Path, str]] = []
|
||||
def validate_content(self) -> list[Path | str]:
|
||||
processed_paths: list[Path | str] = []
|
||||
for path in self.file_paths:
|
||||
if isinstance(path, str):
|
||||
if path.startswith(("http://", "https://")):
|
||||
@@ -108,7 +108,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
else:
|
||||
raise ValueError(f"Invalid URL format: {path}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
|
||||
raise ValueError(f"Invalid URL: {path}. Error: {e!s}") from e
|
||||
else:
|
||||
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
|
||||
if local_path.exists():
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,7 +7,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries CSV file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess CSV file content."""
|
||||
content_dict = {}
|
||||
for file_path in self.safe_file_paths:
|
||||
@@ -32,7 +31,7 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -16,19 +14,19 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
)
|
||||
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
file_paths: Path | list[Path] | str | list[str] | None = Field(
|
||||
default_factory=list, description="The path to the file"
|
||||
)
|
||||
chunks: List[str] = Field(default_factory=list)
|
||||
content: Dict[Path, Dict[str, str]] = Field(default_factory=dict)
|
||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
content: dict[Path, dict[str, str]] = Field(default_factory=dict)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info):
|
||||
def validate_file_path(self, v, info):
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -41,7 +39,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
return v
|
||||
|
||||
def _process_file_paths(self) -> List[Path]:
|
||||
def _process_file_paths(self) -> list[Path]:
|
||||
"""Convert file_path to a list of Path objects."""
|
||||
|
||||
if hasattr(self, "file_path") and self.file_path is not None:
|
||||
@@ -56,7 +54,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
|
||||
# Convert single path to list
|
||||
path_list: List[Union[Path, str]] = (
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
else list(self.file_paths)
|
||||
@@ -100,7 +98,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.validate_content()
|
||||
self.content = self._load_content()
|
||||
|
||||
def _load_content(self) -> Dict[Path, Dict[str, str]]:
|
||||
def _load_content(self) -> dict[Path, dict[str, str]]:
|
||||
"""Load and preprocess Excel file content from multiple sheets.
|
||||
|
||||
Each sheet's content is converted to CSV format and stored.
|
||||
@@ -126,7 +124,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
content_dict[file_path] = sheet_dict
|
||||
return content_dict
|
||||
|
||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
@@ -161,7 +159,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,9 +8,9 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries JSON file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess JSON file content."""
|
||||
content: Dict[Path, str] = {}
|
||||
content: dict[Path, str] = {}
|
||||
for path in self.safe_file_paths:
|
||||
path = self.convert_to_path(path)
|
||||
with open(path, "r", encoding="utf-8") as json_file:
|
||||
@@ -29,7 +29,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
for item in data:
|
||||
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
|
||||
else:
|
||||
text += f"{str(data)}"
|
||||
text += f"{data!s}"
|
||||
return text
|
||||
|
||||
def add(self) -> None:
|
||||
@@ -44,7 +44,7 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries PDF file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess PDF file content."""
|
||||
pdfplumber = self._import_pdfplumber()
|
||||
|
||||
@@ -40,12 +39,12 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
Add PDF file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
for _, text in self.content.items():
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
@@ -9,7 +7,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
"""A knowledge source that stores and queries plain text content using embeddings."""
|
||||
|
||||
content: str = Field(...)
|
||||
collection_name: Optional[str] = Field(default=None)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
def model_post_init(self, _):
|
||||
"""Post-initialization method to validate content."""
|
||||
@@ -26,7 +24,7 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,7 +6,7 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries text file content using embeddings."""
|
||||
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess text file content."""
|
||||
content = {}
|
||||
for path in self.safe_file_paths:
|
||||
@@ -21,12 +20,12 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
Add text file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
for _, text in self.content.items():
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> List[str]:
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
text[i : i + self.chunk_size]
|
||||
|
||||
@@ -1,21 +1,14 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
@@ -24,11 +17,12 @@ except ImportError:
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -39,12 +33,18 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
@@ -62,14 +62,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
@@ -79,18 +72,18 @@ from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
class LiteAgentOutput(BaseModel):
|
||||
"""Class that represents the result of a LiteAgent execution."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
raw: str = Field(description="Raw output of the agent", default="")
|
||||
pydantic: Optional[BaseModel] = Field(
|
||||
pydantic: BaseModel | None = Field(
|
||||
description="Pydantic output of the agent", default=None
|
||||
)
|
||||
agent_role: str = Field(description="Role of the agent that produced this output")
|
||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
||||
usage_metrics: dict[str, Any] | None = Field(
|
||||
description="Token usage metrics for this execution", default=None
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert pydantic_output to a dictionary."""
|
||||
if self.pydantic:
|
||||
return self.pydantic.model_dump()
|
||||
@@ -123,17 +116,17 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
response_format: Optional Pydantic model for structured output.
|
||||
"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# Core Agent Properties
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
)
|
||||
tools: List[BaseTool] = Field(
|
||||
tools: list[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agent's disposal"
|
||||
)
|
||||
|
||||
@@ -141,7 +134,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
max_iterations: int = Field(
|
||||
default=15, description="Maximum number of iterations for tool usage"
|
||||
)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
max_execution_time: int | None = Field(
|
||||
default=None, description=". Maximum execution time in seconds"
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
@@ -152,52 +145,50 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=True,
|
||||
description="Whether to use stop words to prevent the LLM from using tools",
|
||||
)
|
||||
request_within_rpm_limit: Optional[Callable[[], bool]] = Field(
|
||||
request_within_rpm_limit: Callable[[], bool] | None = Field(
|
||||
default=None,
|
||||
description="Callback to check if the request is within the RPM limit",
|
||||
)
|
||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||
|
||||
# Output and Formatting Properties
|
||||
response_format: Optional[Type[BaseModel]] = Field(
|
||||
response_format: type[BaseModel] | None = Field(
|
||||
default=None, description="Pydantic model for structured output"
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default=False, description="Whether to print execution details"
|
||||
)
|
||||
callbacks: List[Callable] = Field(
|
||||
callbacks: list[Callable] = Field(
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
|
||||
# Guardrail Properties
|
||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
|
||||
Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
guardrail: Callable[[LiteAgentOutput], tuple[bool, Any]] | str | None = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
|
||||
# State and Results
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
tools_results: list[dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
)
|
||||
|
||||
# Reference of Agent
|
||||
original_agent: Optional[BaseAgent] = Field(
|
||||
original_agent: BaseAgent | None = Field(
|
||||
default=None, description="Reference to the agent that created this LiteAgent"
|
||||
)
|
||||
# Private Attributes
|
||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_guardrail: Callable | None = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -241,8 +232,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
cls, v: Optional[Union[Callable, str]]
|
||||
) -> Optional[Union[Callable, str]]:
|
||||
cls, v: Callable | str | None
|
||||
) -> Callable | str | None:
|
||||
"""Validate that the guardrail function has the correct signature.
|
||||
|
||||
If v is a callable, validate that it has the correct signature.
|
||||
@@ -267,7 +258,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Check return annotation if present
|
||||
if sig.return_annotation is not sig.empty:
|
||||
if sig.return_annotation == Tuple[bool, Any]:
|
||||
if sig.return_annotation == tuple[bool, Any]:
|
||||
return v
|
||||
|
||||
origin = get_origin(sig.return_annotation)
|
||||
@@ -290,7 +281,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
|
||||
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages.
|
||||
|
||||
@@ -338,7 +329,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise e
|
||||
|
||||
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
|
||||
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -351,7 +342,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
agent_finish = self._invoke_loop()
|
||||
formatted_result: Optional[BaseModel] = None
|
||||
formatted_result: BaseModel | None = None
|
||||
if self.response_format:
|
||||
try:
|
||||
# Cast to BaseModel to ensure type safety
|
||||
@@ -360,7 +351,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
formatted_result = result
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
content=f"Failed to parse output into response format: {str(e)}",
|
||||
content=f"Failed to parse output into response format: {e!s}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
@@ -428,7 +419,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return output
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages.
|
||||
@@ -475,8 +466,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return base_prompt
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages for the LLM."""
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
@@ -571,9 +562,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
i18n=self.i18n,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
|
||||
finally:
|
||||
self._iterations += 1
|
||||
@@ -582,7 +572,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
def _show_logs(self, formatted_answer: AgentAction | AgentFinish):
|
||||
"""Show logs for the agent's execution."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -6,19 +6,14 @@ import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -31,9 +26,9 @@ from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@@ -51,8 +46,8 @@ with warnings.catch_warnings():
|
||||
import io
|
||||
from typing import TextIO
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
@@ -268,14 +263,14 @@ def suppress_warnings():
|
||||
|
||||
|
||||
class Delta(TypedDict):
|
||||
content: Optional[str]
|
||||
role: Optional[str]
|
||||
content: str | None
|
||||
role: str | None
|
||||
|
||||
|
||||
class StreamingChoices(TypedDict):
|
||||
delta: Delta
|
||||
index: int
|
||||
finish_reason: Optional[str]
|
||||
finish_reason: str | None
|
||||
|
||||
|
||||
class FunctionArgs(BaseModel):
|
||||
@@ -288,31 +283,31 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: Optional[float] = None
|
||||
completion_cost: float | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] | None = None,
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -345,7 +340,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -368,9 +363,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the completion call.
|
||||
|
||||
Args:
|
||||
@@ -419,11 +414,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
@@ -447,7 +442,7 @@ class LLM(BaseLLM):
|
||||
usage_info = None
|
||||
tool_calls = None
|
||||
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
@@ -472,16 +467,16 @@ class LLM(BaseLLM):
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "choices"), type):
|
||||
choices = getattr(chunk, "choices")
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
# Try to extract usage information if available
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "usage"), type):
|
||||
usage_info = getattr(chunk, "usage")
|
||||
if not isinstance(chunk.usage, type):
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -491,7 +486,7 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
delta = getattr(choice, "delta")
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content from delta
|
||||
if delta:
|
||||
@@ -501,7 +496,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = getattr(delta, "content")
|
||||
chunk_content = delta.content
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
@@ -572,8 +567,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -583,14 +578,14 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
content = getattr(message, "content")
|
||||
content = message.content
|
||||
|
||||
if content:
|
||||
full_response = content
|
||||
@@ -617,8 +612,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -627,13 +622,13 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
tool_calls = message.tool_calls
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
@@ -675,9 +670,9 @@ class LLM(BaseLLM):
|
||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
logging.error(f"Error in streaming response: {e!s}")
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
logging.warning(f"Returning partial response despite error: {e!s}")
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -695,15 +690,15 @@ class LLM(BaseLLM):
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
raise Exception(f"Failed to get streaming response: {e!s}")
|
||||
|
||||
def _handle_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
tool_calls: list[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> None | str:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -744,9 +739,9 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_callbacks(
|
||||
self,
|
||||
callbacks: Optional[List[Any]],
|
||||
usage_info: Optional[Dict[str, Any]],
|
||||
last_chunk: Optional[Any],
|
||||
callbacks: list[Any] | None,
|
||||
usage_info: dict[str, Any] | None,
|
||||
last_chunk: Any | None,
|
||||
) -> None:
|
||||
"""Handle callbacks with usage info for streaming responses.
|
||||
|
||||
@@ -769,10 +764,8 @@ class LLM(BaseLLM):
|
||||
):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(
|
||||
getattr(last_chunk, "usage"), type
|
||||
):
|
||||
usage_info = getattr(last_chunk, "usage")
|
||||
if not isinstance(last_chunk.usage, type):
|
||||
usage_info = last_chunk.usage
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
@@ -786,11 +779,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_non_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
@@ -847,7 +840,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
||||
elif tool_calls and not available_functions and not text_response:
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
# --- 7) Handle tool calls if present
|
||||
@@ -868,11 +861,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Optional[str]:
|
||||
tool_calls: list[Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | None:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
Args:
|
||||
@@ -942,14 +935,14 @@ class LLM(BaseLLM):
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
error=f"Tool execution error: {e!s}",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
@@ -958,13 +951,13 @@ class LLM(BaseLLM):
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""High-level LLM call method.
|
||||
|
||||
Args:
|
||||
@@ -1028,10 +1021,9 @@ class LLM(BaseLLM):
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
@@ -1078,8 +1070,8 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
response: Any,
|
||||
call_type: LLMCallType,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
messages: str | list[dict[str, Any]] | None = None,
|
||||
):
|
||||
"""Handle the events for the LLM call.
|
||||
@@ -1105,8 +1097,8 @@ class LLM(BaseLLM):
|
||||
)
|
||||
|
||||
def _format_messages_for_provider(
|
||||
self, messages: List[Dict[str, str]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages according to provider requirements.
|
||||
|
||||
Args:
|
||||
@@ -1147,7 +1139,7 @@ class LLM(BaseLLM):
|
||||
if "mistral" in self.model.lower():
|
||||
# Check if the last message has a role of 'assistant'
|
||||
if messages and messages[-1]["role"] == "assistant":
|
||||
return messages + [{"role": "user", "content": "Please continue."}]
|
||||
return [*messages, {"role": "user", "content": "Please continue."}]
|
||||
return messages
|
||||
|
||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||
@@ -1157,7 +1149,7 @@ class LLM(BaseLLM):
|
||||
and messages
|
||||
and messages[-1]["role"] == "assistant"
|
||||
):
|
||||
return messages + [{"role": "user", "content": ""}]
|
||||
return [*messages, {"role": "user", "content": ""}]
|
||||
|
||||
# Handle Anthropic models
|
||||
if not self.is_anthropic:
|
||||
@@ -1170,7 +1162,7 @@ class LLM(BaseLLM):
|
||||
|
||||
return messages
|
||||
|
||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
||||
def _get_custom_llm_provider(self) -> str | None:
|
||||
"""
|
||||
Derives the custom_llm_provider from the model string.
|
||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||
@@ -1207,7 +1199,7 @@ class LLM(BaseLLM):
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
||||
logging.error(f"Failed to check function calling support: {e!s}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -1215,7 +1207,7 @@ class LLM(BaseLLM):
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
logging.error(f"Failed to get supported params: {e!s}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
@@ -1247,7 +1239,7 @@ class LLM(BaseLLM):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
def set_callbacks(self, callbacks: list[Any]):
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from .entity.entity_memory import EntityMemory
|
||||
from .external.external_memory import ExternalMemory
|
||||
from .long_term.long_term_memory import LongTermMemory
|
||||
from .short_term.short_term_memory import ShortTermMemory
|
||||
from .external.external_memory import ExternalMemory
|
||||
|
||||
__all__ = [
|
||||
"EntityMemory",
|
||||
"ExternalMemory",
|
||||
"LongTermMemory",
|
||||
"ShortTermMemory",
|
||||
"ExternalMemory",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ExternalMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
agent: str | None = None,
|
||||
):
|
||||
self.value = value
|
||||
self.metadata = metadata
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from typing import Any, Dict, List
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveFailedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class LongTermMemory(Memory):
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LongTermMemoryItem:
|
||||
@@ -8,8 +8,8 @@ class LongTermMemoryItem:
|
||||
task: str,
|
||||
expected_output: str,
|
||||
datetime: str,
|
||||
quality: Optional[Union[int, float]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
quality: int | float | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ShortTermMemoryItem:
|
||||
def __init__(
|
||||
self,
|
||||
data: Any,
|
||||
agent: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
self.data = data
|
||||
self.agent = agent
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Abstract base class defining the storage interface"""
|
||||
|
||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def search(
|
||||
self, query: str, limit: int, score_threshold: float
|
||||
) -> Dict[str, Any] | List[Any]:
|
||||
) -> dict[str, Any] | list[Any]:
|
||||
return {}
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from crewai.task import Task
|
||||
from crewai.utilities import Printer
|
||||
@@ -18,7 +18,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
An updated SQLite storage class for kickoff task outputs storage.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None) -> None:
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
@@ -62,10 +62,10 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
def add(
|
||||
self,
|
||||
task: Task,
|
||||
output: Dict[str, Any],
|
||||
output: dict[str, Any],
|
||||
task_index: int,
|
||||
was_replayed: bool = False,
|
||||
inputs: Dict[str, Any] | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Add a new task output record to the database.
|
||||
|
||||
@@ -153,7 +153,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
logger.error(error_msg)
|
||||
raise DatabaseOperationError(error_msg, e)
|
||||
|
||||
def load(self) -> List[Dict[str, Any]]:
|
||||
def load(self) -> list[dict[str, Any]]:
|
||||
"""Load all task output records from the database.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
@@ -12,9 +12,7 @@ class LTMSQLiteStorage:
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, db_path: Optional[str] = None
|
||||
) -> None:
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||
@@ -53,9 +51,9 @@ class LTMSQLiteStorage:
|
||||
def save(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: Dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: Union[int, float],
|
||||
score: int | float,
|
||||
) -> None:
|
||||
"""Saves data to the LTM table with error handling."""
|
||||
try:
|
||||
@@ -75,9 +73,7 @@ class LTMSQLiteStorage:
|
||||
color="red",
|
||||
)
|
||||
|
||||
def load(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
def load(self, task_description: str, latest_n: int) -> list[dict[str, Any]] | None:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -125,4 +121,4 @@ class LTMSQLiteStorage:
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
return
|
||||
|
||||
@@ -14,16 +14,16 @@ from .annotations import (
|
||||
from .crew_base import CrewBase
|
||||
|
||||
__all__ = [
|
||||
"CrewBase",
|
||||
"after_kickoff",
|
||||
"agent",
|
||||
"before_kickoff",
|
||||
"cache_handler",
|
||||
"callback",
|
||||
"crew",
|
||||
"task",
|
||||
"llm",
|
||||
"output_json",
|
||||
"output_pydantic",
|
||||
"task",
|
||||
"tool",
|
||||
"callback",
|
||||
"CrewBase",
|
||||
"llm",
|
||||
"cache_handler",
|
||||
"before_kickoff",
|
||||
"after_kickoff",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.project.utils import memoize
|
||||
@@ -36,15 +36,13 @@ def task(func):
|
||||
def agent(func):
|
||||
"""Marks a method as a crew agent."""
|
||||
func.is_agent = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def llm(func):
|
||||
"""Marks a method as an LLM provider."""
|
||||
func.is_llm = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
return memoize(func)
|
||||
|
||||
|
||||
def output_json(cls):
|
||||
@@ -91,7 +89,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agents = self._original_agents.items()
|
||||
|
||||
# Instantiate tasks in order
|
||||
for task_name, task_method in tasks:
|
||||
for _task_name, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
@@ -100,7 +98,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agent_roles.add(agent_instance.role)
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for agent_name, agent_method in agents:
|
||||
for _agent_name, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
@@ -117,9 +115,9 @@ def crew(func) -> Callable[..., Crew]:
|
||||
|
||||
return wrapper
|
||||
|
||||
for _, callback in self._before_kickoff.items():
|
||||
for callback in self._before_kickoff.values():
|
||||
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
for _, callback in self._after_kickoff.items():
|
||||
for callback in self._after_kickoff.values():
|
||||
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
|
||||
return crew
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
|
||||
import sys
|
||||
import importlib
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import set_rag_config
|
||||
|
||||
|
||||
_module_path = __path__
|
||||
_module_file = __file__
|
||||
|
||||
|
||||
class _RagModule(ModuleType):
|
||||
"""Module wrapper to intercept attribute setting for config."""
|
||||
|
||||
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(f"{self.__name__}.{name}")
|
||||
except ImportError:
|
||||
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
||||
except ImportError as e:
|
||||
raise AttributeError(
|
||||
f"module '{self.__name__}' has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
|
||||
sys.modules[__name__] = _RagModule(__name__)
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Base classes for missing provider configurations."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user